mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-5589] feat: Integrate TRT-LLM Gen FP8 Batched GEMM with Pytorch workflow kernel autotuner (#4872)
Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
This commit is contained in:
parent
1d4f748773
commit
9c012d5bf8
@ -55,11 +55,12 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mPassingConfigIndices.size() != 0, "No kernel found for the given output type");
|
||||
TLLM_CHECK_WITH_INFO(!mPassingConfigIndices.empty(), "No kernel found for the given options");
|
||||
}
|
||||
|
||||
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
|
||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
|
||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
|
||||
std::optional<int32_t> configIndex)
|
||||
{
|
||||
BatchedGemmData gemmData;
|
||||
gemmData.mProblemDimensions.mNumBatches = numBatches;
|
||||
@ -74,13 +75,18 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n,
|
||||
gemmData.mProblemDimensions.mWorldSize = 1;
|
||||
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
|
||||
|
||||
selectGemmConfig(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
||||
|
||||
auto bmm = BatchedGemmInterface();
|
||||
|
||||
auto const configs = bmm.getBatchedGemmConfigs();
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
|
||||
auto const& config = configs[mSelectedConfigIndex.value()];
|
||||
|
||||
if (!configIndex.has_value())
|
||||
{
|
||||
mSelectedConfigIndex
|
||||
= getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
||||
configIndex = mSelectedConfigIndex;
|
||||
}
|
||||
|
||||
auto const& config = configs[configIndex.value()];
|
||||
return bmm.getWorkspaceSizeInBytes(config, gemmData);
|
||||
}
|
||||
|
||||
@ -89,16 +95,22 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
|
||||
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC,
|
||||
void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
|
||||
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
|
||||
void* workspace, CUstream stream, int device)
|
||||
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex)
|
||||
{
|
||||
auto bmm = BatchedGemmInterface();
|
||||
|
||||
BatchedGemmData gemmData;
|
||||
|
||||
auto const configs = bmm.getBatchedGemmConfigs();
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
|
||||
auto const& config = configs[mSelectedConfigIndex.value()];
|
||||
|
||||
if (!configIndex.has_value())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mSelectedConfigIndex.has_value(), "Tried to use default config index but none was set");
|
||||
|
||||
configIndex = mSelectedConfigIndex;
|
||||
}
|
||||
|
||||
auto const& config = configs[configIndex.value()];
|
||||
|
||||
TLLM_CHECK_WITH_INFO(numBatches > 0, "Batched GEMM requires numBatches > 0");
|
||||
if (!mOptions.staticBatch)
|
||||
@ -170,7 +182,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
|
||||
|
||||
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
|
||||
void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace,
|
||||
CUstream stream, int device)
|
||||
CUstream stream, int device, std::optional<int32_t> configIndex)
|
||||
{
|
||||
// Dispatch with block scaling factors and with static batching.
|
||||
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
|
||||
@ -178,12 +190,12 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
|
||||
/* scaleC */ nullptr, /* scaleGateC */ nullptr, c, outSfC,
|
||||
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
|
||||
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
|
||||
/* numNonExitingCtas */ nullptr, workspace, stream, device);
|
||||
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
|
||||
}
|
||||
|
||||
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
|
||||
void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace,
|
||||
CUstream stream, int device)
|
||||
CUstream stream, int device, std::optional<int32_t> configIndex)
|
||||
{
|
||||
// Dispatch with block scaling factors and with static batching.
|
||||
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
|
||||
@ -191,11 +203,12 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
|
||||
scaleGateC, c, /* outSfC */ nullptr,
|
||||
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
|
||||
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
|
||||
/* numNonExitingCtas */ nullptr, workspace, stream, device);
|
||||
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
|
||||
}
|
||||
|
||||
void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k,
|
||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
|
||||
std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m, int32_t n, int32_t k,
|
||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
|
||||
int32_t maxNumCtasInBatchDim) const
|
||||
{
|
||||
auto const bmm = BatchedGemmInterface();
|
||||
auto const configs = bmm.getBatchedGemmConfigs();
|
||||
@ -242,16 +255,30 @@ void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t
|
||||
return optionsA.mTileM > optionsB.mTileM;
|
||||
});
|
||||
|
||||
std::vector<int64_t> validConfigIndices;
|
||||
for (auto const& configIndex : sortedIndices)
|
||||
{
|
||||
auto const& config = configs[configIndex];
|
||||
auto isValidConfig = bmm.isValidConfig(config, gemmData);
|
||||
if (isValidConfig)
|
||||
{
|
||||
mSelectedConfigIndex = configIndex;
|
||||
return;
|
||||
validConfigIndices.push_back(configIndex);
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(!validConfigIndices.empty(), "No valid config found for the given problem shape");
|
||||
|
||||
return validConfigIndices;
|
||||
}
|
||||
|
||||
int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
|
||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
|
||||
int32_t maxNumCtasInBatchDim) const
|
||||
{
|
||||
auto const validConfigIndices
|
||||
= getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
||||
|
||||
return validConfigIndices[0];
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -16,8 +16,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <cuda.h>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
|
||||
|
||||
@ -45,20 +47,40 @@ public:
|
||||
explicit TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunnerOptions const& options);
|
||||
|
||||
[[nodiscard]] size_t getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
|
||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim);
|
||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
|
||||
std::optional<int32_t> configIndex = std::nullopt);
|
||||
|
||||
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
|
||||
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
|
||||
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
|
||||
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
|
||||
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
|
||||
void* workspace, CUstream stream, int device);
|
||||
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex = std::nullopt);
|
||||
|
||||
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
|
||||
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device);
|
||||
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
|
||||
std::optional<int32_t> configIndex = std::nullopt);
|
||||
|
||||
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
|
||||
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device);
|
||||
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
|
||||
std::optional<int32_t> configIndex = std::nullopt);
|
||||
|
||||
// Get the list of configs that passed the validation based on the constructor options
|
||||
[[nodiscard]] std::vector<int32_t> getPassingConfigIndices() const
|
||||
{
|
||||
return mPassingConfigIndices;
|
||||
}
|
||||
|
||||
// Get the list of config indices that are valid for the given problem shape
|
||||
[[nodiscard]] std::vector<int64_t> getValidConfigIndices(int32_t m, int32_t n, int32_t k,
|
||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
|
||||
int32_t maxNumCtasInBatchDim) const;
|
||||
|
||||
// Get a default config index that is valid for the given problem shape
|
||||
// This will be used as the fallback config if using auto-tuning
|
||||
[[nodiscard]] int64_t getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
|
||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
|
||||
int32_t maxNumCtasInBatchDim) const;
|
||||
|
||||
private:
|
||||
void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
|
||||
@ -66,8 +88,8 @@ private:
|
||||
|
||||
private:
|
||||
TrtllmGenBatchedGemmRunnerOptions mOptions;
|
||||
std::optional<int> mSelectedConfigIndex;
|
||||
std::vector<int32_t> mPassingConfigIndices;
|
||||
std::optional<int32_t> mSelectedConfigIndex;
|
||||
};
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -26,6 +26,8 @@
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
namespace
|
||||
{
|
||||
@ -35,27 +37,16 @@ template <tg::Dtype outDtype>
|
||||
void runBatchedGemm(at::Tensor& out, at::Tensor& outSfC, at::Tensor const& mat1, at::Tensor const& mat2,
|
||||
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
||||
std::optional<at::Tensor> const& scaleC, int64_t m, int64_t n, int64_t k, int32_t tileSize, int32_t epilogueTileM,
|
||||
std::vector<int32_t> const& batchedTokens, bool useDeepSeekFp8, bool lowLatencyKernel)
|
||||
std::vector<int32_t> const& batchedTokens, bool useDeepSeekFp8, bool lowLatencyKernel,
|
||||
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner& runner, int32_t const configIndex)
|
||||
{
|
||||
auto eltType = tg::Dtype::E4m3;
|
||||
|
||||
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = {.eltType = eltType,
|
||||
.outputType = outDtype,
|
||||
.deepSeekFp8 = useDeepSeekFp8,
|
||||
.fusedAct = false,
|
||||
.routeAct = false,
|
||||
.staticBatch = true,
|
||||
.transposeMmaOutput = lowLatencyKernel,
|
||||
.tileSize = tileSize,
|
||||
.epilogueTileM = epilogueTileM};
|
||||
|
||||
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner runner(options);
|
||||
|
||||
// numTokens and maxNumCtasInBatchDim are not used for static batching
|
||||
int32_t numTokens = 0;
|
||||
int32_t maxNumCtasInBatchDim = 0;
|
||||
int64_t const numBytesWorkspace
|
||||
= runner.getWorkspaceSizeInBytes(m, n, k, batchedTokens, numTokens, batchedTokens.size(), maxNumCtasInBatchDim);
|
||||
int32_t const numTokens = 0;
|
||||
int32_t const maxNumCtasInBatchDim = 0;
|
||||
|
||||
int64_t const numBytesWorkspace = runner.getWorkspaceSizeInBytes(
|
||||
m, n, k, batchedTokens, numTokens, batchedTokens.size(), maxNumCtasInBatchDim, configIndex);
|
||||
at::Tensor workspace
|
||||
= at::detail::empty_cuda({numBytesWorkspace}, at::ScalarType::Char, mat1.device(), std::nullopt);
|
||||
|
||||
@ -66,20 +57,21 @@ void runBatchedGemm(at::Tensor& out, at::Tensor& outSfC, at::Tensor const& mat1,
|
||||
float* outSfCPtr = outDtype == tg::Dtype::E4m3 ? outSfC.data_ptr<float>() : nullptr;
|
||||
runner.run(m, n, k, batchedTokens, mat1.const_data_ptr(), dDqSfsA.value().const_data_ptr(),
|
||||
mat2.const_data_ptr(), dDqSfsB.value().const_data_ptr(), out.data_ptr(), outSfCPtr, workspace.data_ptr(),
|
||||
stream.stream(), mat1.get_device());
|
||||
stream.stream(), mat1.get_device(), configIndex);
|
||||
}
|
||||
else
|
||||
{
|
||||
runner.run(m, n, k, batchedTokens, mat1.const_data_ptr(), mat2.const_data_ptr(),
|
||||
reinterpret_cast<float const*>(scaleC.value().const_data_ptr()), nullptr, out.data_ptr(),
|
||||
workspace.data_ptr(), stream.stream(), mat1.get_device());
|
||||
workspace.data_ptr(), stream.stream(), mat1.get_device(), configIndex);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1, at::Tensor const& mat2,
|
||||
int32_t tileSize, bool useDeepSeekFp8, bool lowLatencyKernel, int64_t epilogueTileM,
|
||||
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
||||
std::optional<at::Tensor> const& scaleC, std::optional<c10::ScalarType> outDtype)
|
||||
std::optional<at::Tensor> const& scaleC, std::optional<c10::ScalarType> outDtype,
|
||||
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner& runner, int32_t const configIndex)
|
||||
{
|
||||
TORCH_CHECK(mat1.dim() == 3, "Matrix A must be of size [B, M, K]");
|
||||
TORCH_CHECK(mat2.dim() == 3, "Matrix B must be of size [B, N, K]");
|
||||
@ -144,16 +136,19 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1
|
||||
TORCH_CHECK(scaleC.value().sizes()[0] == b, "outScalingFactor must be a 1D matrix of size B");
|
||||
}
|
||||
|
||||
int64_t outputN = n;
|
||||
int64_t const outputN = n;
|
||||
|
||||
// Create output tensor.
|
||||
at::Tensor out = at::detail::empty_cuda({b, m, outputN}, outDtype.value(), mat1.device(), std::nullopt);
|
||||
at::Tensor outSfC;
|
||||
if (useDeepSeekFp8 && outDtype.value() == at::ScalarType::Float8_e4m3fn)
|
||||
{
|
||||
outSfC = at::detail::empty_cuda(
|
||||
{outputN / dsFp8QuantBlockSize, m * b}, at::ScalarType::Float, mat1.device(), std::nullopt);
|
||||
}
|
||||
|
||||
bool const needOutSfC = useDeepSeekFp8 && outDtype.value() == at::ScalarType::Float8_e4m3fn;
|
||||
|
||||
// Torch class did not support returning a default tensor so using empty instead.
|
||||
int64_t const outSfCSize0 = needOutSfC ? (outputN / dsFp8QuantBlockSize) : 0;
|
||||
int64_t const outSfCSize1 = needOutSfC ? (m * b) : 0;
|
||||
|
||||
at::Tensor outSfC
|
||||
= at::detail::empty_cuda({outSfCSize0, outSfCSize1}, at::ScalarType::Float, mat1.device(), std::nullopt);
|
||||
|
||||
std::vector<int32_t> batchedTokens(b, m);
|
||||
|
||||
@ -161,15 +156,15 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1
|
||||
{
|
||||
case at::ScalarType::Half:
|
||||
runBatchedGemm<tg::Dtype::Fp16>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
|
||||
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel);
|
||||
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel, runner, configIndex);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
runBatchedGemm<tg::Dtype::Bfloat16>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
|
||||
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel);
|
||||
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel, runner, configIndex);
|
||||
break;
|
||||
case at::ScalarType::Float8_e4m3fn:
|
||||
runBatchedGemm<tg::Dtype::E4m3>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
|
||||
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel);
|
||||
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel, runner, configIndex);
|
||||
break;
|
||||
default: C10_THROW_ERROR(NotImplementedError, "outDtype must be one of fp16/bf16/e4m3.");
|
||||
}
|
||||
@ -181,35 +176,101 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
extern std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_trtllmgen(at::Tensor const& mat1, at::Tensor const& mat2,
|
||||
int64_t tileSize, bool useDeepSeekFp8, bool lowLatency, int64_t epilogueTileM,
|
||||
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
||||
std::optional<at::Tensor> const& scaleC, std::optional<c10::ScalarType> outDtype)
|
||||
// Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow
|
||||
// use with the torch workflow autotuner class.
|
||||
class FP8BatchedGemmRunner : public torch::CustomClassHolder
|
||||
{
|
||||
auto const smVersion = tensorrt_llm::common::getSMVersion();
|
||||
switch (smVersion)
|
||||
|
||||
public:
|
||||
explicit FP8BatchedGemmRunner(c10::ScalarType outDtypeArg, bool useDeepSeekFp8, bool lowLatencyKernel,
|
||||
int64_t tileSize, int64_t epilogueTileM)
|
||||
: mOutDtypeArg(outDtypeArg)
|
||||
, mUseDeepSeekFp8(useDeepSeekFp8)
|
||||
, mLowLatencyKernel(lowLatencyKernel)
|
||||
, mTileSize(tileSize)
|
||||
, mEpilogueTileM(epilogueTileM)
|
||||
{
|
||||
case tensorrt_llm::kernels::kSM_100:
|
||||
|
||||
auto const smVersion = tensorrt_llm::common::getSMVersion();
|
||||
if (smVersion != tensorrt_llm::kernels::kSM_100)
|
||||
{
|
||||
TLLM_THROW("Unsupported or unimplemented compute capability for fp8 batched gemm: %i", smVersion);
|
||||
}
|
||||
|
||||
tg::Dtype outDtype = tg::Dtype::E4m3; // Default to E4m3, will be updated based on outDtypeArg
|
||||
|
||||
switch (outDtypeArg)
|
||||
{
|
||||
case at::ScalarType::Half: outDtype = tg::Dtype::Fp16; break;
|
||||
case at::ScalarType::BFloat16: outDtype = tg::Dtype::Bfloat16; break;
|
||||
case at::ScalarType::Float8_e4m3fn: outDtype = tg::Dtype::E4m3; break;
|
||||
default: C10_THROW_ERROR(NotImplementedError, "outDtype must be one of fp16/bf16/e4m3.");
|
||||
}
|
||||
|
||||
RunnerOptionsType const options = {.eltType = mEltType,
|
||||
.outputType = outDtype,
|
||||
.deepSeekFp8 = mUseDeepSeekFp8,
|
||||
.fusedAct = false,
|
||||
.routeAct = false,
|
||||
.staticBatch = true,
|
||||
.transposeMmaOutput = mLowLatencyKernel,
|
||||
.tileSize = static_cast<int32_t>(mTileSize),
|
||||
.epilogueTileM = static_cast<int32_t>(mEpilogueTileM)};
|
||||
|
||||
mRunner = std::make_unique<RunnerType>(options);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> runBatchedGemm(at::Tensor const& mat1, at::Tensor const& mat2,
|
||||
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
||||
std::optional<at::Tensor> const& scaleC, int64_t configIndex)
|
||||
{
|
||||
return fp8_batched_gemm_sm100(
|
||||
mat1, mat2, tileSize, useDeepSeekFp8, lowLatency, epilogueTileM, dDqSfsA, dDqSfsB, scaleC, outDtype);
|
||||
|
||||
return fp8_batched_gemm_sm100(mat1, mat2, mTileSize, mUseDeepSeekFp8, mLowLatencyKernel, mEpilogueTileM,
|
||||
dDqSfsA, dDqSfsB, scaleC, mOutDtypeArg, *mRunner, configIndex);
|
||||
}
|
||||
default: TLLM_THROW("Unsupported or unimplemented compute capability for fp8 batched gemm: %i", smVersion);
|
||||
|
||||
std::vector<int64_t> getValidConfigs(int64_t numBatches, int64_t m, int64_t n, int64_t k) const
|
||||
{
|
||||
// numTokens and maxNumCtasInBatchDim are not used for static batching
|
||||
int32_t const numTokens = 0;
|
||||
int32_t const maxNumCtasInBatchDim = 0;
|
||||
|
||||
std::vector<int32_t> const batchedTokens(numBatches, m);
|
||||
|
||||
return mRunner->getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
||||
}
|
||||
}
|
||||
|
||||
int64_t getDefaultValidConfigIndex(int64_t numBatches, int64_t m, int64_t n, int64_t k) const
|
||||
{
|
||||
// numTokens and maxNumCtasInBatchDim are not used for static batching
|
||||
int32_t const numTokens = 0;
|
||||
int32_t const maxNumCtasInBatchDim = 0;
|
||||
|
||||
std::vector<int32_t> const batchedTokens(numBatches, m);
|
||||
|
||||
return mRunner->getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
||||
}
|
||||
|
||||
private:
|
||||
using RunnerType = tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner;
|
||||
using RunnerOptionsType = tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions;
|
||||
|
||||
std::unique_ptr<RunnerType> mRunner;
|
||||
tg::Dtype mEltType{tg::Dtype::E4m3};
|
||||
c10::ScalarType mOutDtypeArg;
|
||||
bool mUseDeepSeekFp8;
|
||||
bool mLowLatencyKernel;
|
||||
int64_t mTileSize;
|
||||
int64_t mEpilogueTileM;
|
||||
};
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"fp8_batched_gemm_trtllmgen(Tensor a, Tensor b, int tile_size,"
|
||||
"bool use_deep_seek_fp8=False, bool low_latency=False, "
|
||||
"int epilogue_tile_m=0, Tensor? dq_sfs_a=None, Tensor? dq_sfs_b=None, "
|
||||
"Tensor? scale_c=None, "
|
||||
"ScalarType? out_dtype=None) -> (Tensor, Tensor)");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("fp8_batched_gemm_trtllmgen", &torch_ext::fp8_batched_gemm_trtllmgen);
|
||||
m.class_<torch_ext::FP8BatchedGemmRunner>("FP8BatchedGemmRunner")
|
||||
.def(torch::init<at::ScalarType, bool, bool, int64_t, int64_t>())
|
||||
.def("get_valid_configs", &torch_ext::FP8BatchedGemmRunner::getValidConfigs)
|
||||
.def("get_default_valid_config", &torch_ext::FP8BatchedGemmRunner::getDefaultValidConfigIndex)
|
||||
.def("run_batched_gemm", &torch_ext::FP8BatchedGemmRunner::runBatchedGemm);
|
||||
}
|
||||
|
||||
@ -340,7 +340,12 @@ class AutoTuner:
|
||||
Although runners[0] with tactic=-1 is always treated as the fallback runner.
|
||||
Runner authors are suggested to provide a fallback implementation for each runner to avoid potential issues.
|
||||
"""
|
||||
input_shapes = tuple(t.shape for t in inputs)
|
||||
|
||||
# Treat None tensors as size zero
|
||||
# This allows the tuner to handle TRT-LLM-Gen torch ops that have optional tensor
|
||||
# arguments, such as block scaling factors.
|
||||
input_shapes = tuple(
|
||||
(t.shape if t is not None else torch.Size((0, ))) for t in inputs)
|
||||
|
||||
# Early return if it's not tuning, use cache found one or fallback one
|
||||
if not self.is_tuning_mode:
|
||||
@ -393,8 +398,14 @@ class AutoTuner:
|
||||
time_measured = self._profile_single_kernel(
|
||||
r, tensors, tac, **kwargs)
|
||||
except Exception as e:
|
||||
# Handle None tensors for optional inputs
|
||||
shapes = [
|
||||
t.size() if t is not None else torch.Size((0, ))
|
||||
for t in tensors
|
||||
]
|
||||
|
||||
logger.error(
|
||||
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={[t.size() for t in tensors]}. Error occurred: {e}"
|
||||
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
|
||||
)
|
||||
|
||||
# Record the failed profiling combinations
|
||||
@ -471,8 +482,13 @@ class AutoTuner:
|
||||
stream.synchronize()
|
||||
|
||||
avg_time = start.elapsed_time(end) / self.repeat
|
||||
|
||||
# Handle None tensors for optional inputs
|
||||
shapes = [
|
||||
t.size() if t is not None else torch.Size((0, )) for t in inputs
|
||||
]
|
||||
logger.debug(
|
||||
f"[Autotuner]: profiling {runner} {tactic}, shapes={[t.size() for t in inputs]}, avg_time {avg_time}"
|
||||
f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes}, avg_time {avg_time}"
|
||||
)
|
||||
|
||||
return avg_time
|
||||
@ -494,9 +510,13 @@ class AutoTuner:
|
||||
combinations specified in dynamic_tensor_specs.
|
||||
"""
|
||||
# every dimension created from the concrete input tensor shape
|
||||
# generate some dynamic dimension description based on the dynamic_tensor_specs
|
||||
base_profile = OptimizationProfile([[StaticDim(x) for x in t.size()]
|
||||
for t in inputs])
|
||||
# generate some dynamic dimension description based on the dynamic_tensors
|
||||
|
||||
# Zero handles the case where a TRTLLM op has optional inputs.
|
||||
base_profile = OptimizationProfile(
|
||||
[[StaticDim(x)
|
||||
for x in t.size()] if t is not None else [StaticDim(0)]
|
||||
for t in inputs])
|
||||
|
||||
generated_profiles: List[OptimizationProfile] = []
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@ -340,6 +340,218 @@ def _(
|
||||
dtype=output_dtype)
|
||||
|
||||
|
||||
class FP8BatchedGemmRunner(TunableRunner):
|
||||
|
||||
_runner_dict = dict()
|
||||
|
||||
def __init__(self, output_dtype: torch.dtype, use_deep_seek_fp8: bool,
|
||||
low_latency_kernel: bool, tile_size: int,
|
||||
epilogue_tile_m: int):
|
||||
|
||||
self.output_dtype = output_dtype
|
||||
self.use_deep_seek_fp8 = use_deep_seek_fp8
|
||||
self.low_latency_kernel = low_latency_kernel
|
||||
self.tile_size = tile_size
|
||||
self.epilogue_tile_m = epilogue_tile_m
|
||||
self.tuning_config = self.get_tuning_config()
|
||||
|
||||
instance_key = (output_dtype, use_deep_seek_fp8, low_latency_kernel,
|
||||
tile_size, epilogue_tile_m)
|
||||
|
||||
if instance_key not in FP8BatchedGemmRunner._runner_dict:
|
||||
FP8BatchedGemmRunner._runner_dict[
|
||||
instance_key] = torch.classes.trtllm.FP8BatchedGemmRunner(
|
||||
output_dtype, use_deep_seek_fp8, low_latency_kernel,
|
||||
tile_size, epilogue_tile_m)
|
||||
|
||||
self._kernel_runner = FP8BatchedGemmRunner._runner_dict[instance_key]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic: int = -1,
|
||||
do_preparation: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Run the batched GEMM operation with the given inputs and tactic.
|
||||
"""
|
||||
|
||||
mat1, mat2, dq_sfs_a, dq_sfs_b, scale_c = inputs
|
||||
|
||||
chosen_tactic = self.get_default_valid_tactic(
|
||||
inputs) if tactic == -1 else tactic
|
||||
|
||||
out_tensors = self._kernel_runner.run_batched_gemm(
|
||||
mat1,
|
||||
mat2,
|
||||
dq_sfs_a,
|
||||
dq_sfs_b,
|
||||
scale_c,
|
||||
chosen_tactic,
|
||||
)
|
||||
|
||||
return out_tensors
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
) -> List[int]:
|
||||
|
||||
mat1, mat2, _, _, _ = inputs
|
||||
|
||||
b = mat1.shape[0]
|
||||
m = mat1.shape[1]
|
||||
n = mat2.shape[1]
|
||||
k = mat1.shape[2]
|
||||
|
||||
tactics = self._kernel_runner.get_valid_configs(b, m, n, k)
|
||||
|
||||
return tactics
|
||||
|
||||
def get_default_valid_tactic(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
) -> int:
|
||||
|
||||
mat1, mat2, _, _, _ = inputs
|
||||
|
||||
b = mat1.shape[0]
|
||||
m = mat1.shape[1]
|
||||
n = mat2.shape[1]
|
||||
k = mat1.shape[2]
|
||||
|
||||
default_tactic = self._kernel_runner.get_default_valid_config(
|
||||
b, m, n, k)
|
||||
|
||||
return default_tactic
|
||||
|
||||
def get_dynamic_tensor_specs(self) -> Tuple[DynamicTensorSpec, ...]:
|
||||
"""Get the dynamic tensor specs for use with the AutoTuner."""
|
||||
|
||||
# These indices correspond to the 0th input tensor and it's first dimension
|
||||
# i.e. we are tuning M where the first input tensor is of shape [B, M, K]
|
||||
|
||||
MAT1_IDX = 0
|
||||
TUNED_DIM = 1
|
||||
|
||||
# Starting at 8 as M % tile size == 0 is required
|
||||
m_values = (8, 16, 32, 64, 128, 256, 512, 1024, 2048)
|
||||
round_rule = lambda x: last_positive_power_of_2(x)
|
||||
|
||||
specs = (DynamicTensorSpec(MAT1_IDX, TUNED_DIM, m_values, round_rule), )
|
||||
|
||||
return specs
|
||||
|
||||
def get_constraint_specs(self) -> Tuple[ConstraintSpec, ...]:
|
||||
"""Get the constraint specs for the dynamic tensors for use with the AutoTuner.
|
||||
"""
|
||||
|
||||
# When using deepseek fp8, the dq_sfs_a and dq_sfs_b tensors are expected to
|
||||
# have specific dimensions. As we are only tuning M, we need only constrain
|
||||
# dimension 1 of dq_sfs_a
|
||||
if not self.use_deep_seek_fp8:
|
||||
constraint_dq_sfs_a = ()
|
||||
else:
|
||||
|
||||
def _constrain_dq_sfs_a_dim1(shapes: Tuple[torch.Size]) -> int:
|
||||
b = shapes[0][0]
|
||||
m = shapes[0][1]
|
||||
|
||||
m_padded = (m + self.tile_size - 1) // self.tile_size
|
||||
result = m_padded * self.tile_size * b
|
||||
|
||||
return result
|
||||
|
||||
SFS_A_IDX = 2
|
||||
CONSTRAINED_DIM = 1
|
||||
|
||||
constraint_dq_sfs_a = (ConstraintSpec(SFS_A_IDX, CONSTRAINED_DIM,
|
||||
_constrain_dq_sfs_a_dim1), )
|
||||
|
||||
return constraint_dq_sfs_a
|
||||
|
||||
def get_tuning_config(self) -> TuningConfig:
|
||||
"""Get the tuning configuration for the AutoTuner."""
|
||||
|
||||
dynamic_tensor_specs = self.get_dynamic_tensor_specs()
|
||||
constraint_specs = self.get_constraint_specs()
|
||||
|
||||
tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs,
|
||||
constraint_specs=constraint_specs)
|
||||
|
||||
return tuning_config
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::fp8_batched_gemm_trtllmgen", mutates_args=())
|
||||
def fp8_batched_gemm_trtllmgen(
|
||||
mat1: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
tile_size: int,
|
||||
use_deep_seek_fp8: Optional[bool] = False,
|
||||
low_latency: Optional[bool] = False,
|
||||
epilogue_tile_m: Optional[int] = 0,
|
||||
dq_sfs_a: Optional[torch.Tensor] = None,
|
||||
dq_sfs_b: Optional[torch.Tensor] = None,
|
||||
scale_c: Optional[torch.Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = torch.half
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
kernel_runner = FP8BatchedGemmRunner(output_dtype=out_dtype,
|
||||
use_deep_seek_fp8=use_deep_seek_fp8,
|
||||
low_latency_kernel=low_latency,
|
||||
tile_size=tile_size,
|
||||
epilogue_tile_m=epilogue_tile_m)
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
inputs = [mat1, mat2, dq_sfs_a, dq_sfs_b, scale_c]
|
||||
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::fp8_batched_gemm_trtllmgen::batched_gemm",
|
||||
[kernel_runner],
|
||||
kernel_runner.tuning_config,
|
||||
inputs,
|
||||
)
|
||||
|
||||
return kernel_runner(
|
||||
inputs=inputs,
|
||||
tactic=best_tactic,
|
||||
)
|
||||
|
||||
|
||||
# Allows the tunable TRTLLM-Gen FP8 batched GEMM to be
|
||||
# used with torch.compile
|
||||
@fp8_batched_gemm_trtllmgen.register_fake
|
||||
def _(
|
||||
mat1: torch.Tensor,
|
||||
mat2: torch.Tensor,
|
||||
tile_size: int,
|
||||
use_deep_seek_fp8: Optional[bool] = False,
|
||||
low_latency: Optional[bool] = False,
|
||||
epilogue_tile_m: Optional[int] = 0,
|
||||
dq_sfs_a: Optional[torch.Tensor] = None,
|
||||
dq_sfs_b: Optional[torch.Tensor] = None,
|
||||
scale_c: Optional[torch.Tensor] = None,
|
||||
out_dtype: Optional[torch.dtype] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
b = mat1.size(0)
|
||||
m = mat1.size(1)
|
||||
n = mat2.size(1)
|
||||
|
||||
fake_out = mat1.new_empty((b, m, n), dtype=out_dtype)
|
||||
|
||||
if use_deep_seek_fp8:
|
||||
ds_fp8_quant_block_size = 128
|
||||
dim0_size = n // ds_fp8_quant_block_size
|
||||
dim1_size = b * m
|
||||
fake_dq_sfs_c = torch.empty((dim0_size, dim1_size), dtype=torch.float32)
|
||||
else:
|
||||
fake_dq_sfs_c = torch.empty((0, 0), dtype=torch.float32)
|
||||
|
||||
return (fake_out, fake_dq_sfs_c)
|
||||
|
||||
|
||||
@torch.library.custom_op("trtllm::attention", mutates_args=())
|
||||
def attention(
|
||||
q: torch.Tensor,
|
||||
|
||||
@ -20,6 +20,7 @@ import pytest
|
||||
import torch
|
||||
from utils.util import getSMVersion
|
||||
|
||||
from tensorrt_llm._torch.autotuner import autotune
|
||||
from tensorrt_llm.quantization.utils.fp4_utils import shuffle_matrix_a
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@ -246,98 +247,199 @@ def quant_fp8(x_fp32: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_fp8_batched_gemm_trtllmgen(test_case: BatchedGemmTestCase) -> None:
|
||||
torch.random.manual_seed(42)
|
||||
class TestFP8BatchedGemmTRTLLMGen:
|
||||
|
||||
b = test_case.b
|
||||
m = test_case.m
|
||||
n = test_case.n
|
||||
k = test_case.k
|
||||
use_deep_seek_fp8 = test_case.use_deep_seek_fp8
|
||||
dtype_c = test_case.dtype_c
|
||||
low_latency = test_case.low_latency
|
||||
tile_size = test_case.tile_size
|
||||
def test_thop(self, test_case: BatchedGemmTestCase) -> None:
|
||||
torch.random.manual_seed(42)
|
||||
|
||||
a_fp32 = torch.randn((b, m, k), device="cuda", dtype=torch.float32)
|
||||
b_fp32 = torch.randn((b, n, k), device="cuda", dtype=torch.float32)
|
||||
# Pad to the tile size. It is needed for the TRT-LLM Gen BMM input requirements.
|
||||
if m % tile_size:
|
||||
tiled_shape = ((m + tile_size - 1) // tile_size) * tile_size
|
||||
a_fp32 = torch.nn.functional.pad(a_fp32, (0, 0, 0, tiled_shape - m),
|
||||
"constant", 0)
|
||||
m_padded = ((m + tile_size - 1) // tile_size) * tile_size
|
||||
b = test_case.b
|
||||
m = test_case.m
|
||||
n = test_case.n
|
||||
k = test_case.k
|
||||
use_deep_seek_fp8 = test_case.use_deep_seek_fp8
|
||||
dtype_c = test_case.dtype_c
|
||||
low_latency = test_case.low_latency
|
||||
tile_size = test_case.tile_size
|
||||
|
||||
dq_sf_a = None
|
||||
dq_sf_b = None
|
||||
global_dq_a = None
|
||||
global_dq_b = None
|
||||
out_global_scaling_factor = None
|
||||
a_fp32 = torch.randn((b, m, k), device="cuda", dtype=torch.float32)
|
||||
b_fp32 = torch.randn((b, n, k), device="cuda", dtype=torch.float32)
|
||||
# Pad to the tile size. It is needed for the TRT-LLM Gen BMM input requirements.
|
||||
if m % tile_size:
|
||||
tiled_shape = ((m + tile_size - 1) // tile_size) * tile_size
|
||||
a_fp32 = torch.nn.functional.pad(a_fp32, (0, 0, 0, tiled_shape - m),
|
||||
"constant", 0)
|
||||
m_padded = ((m + tile_size - 1) // tile_size) * tile_size
|
||||
|
||||
if use_deep_seek_fp8:
|
||||
a_fp8, dq_sf_a = quant_ds_fp8(a_fp32, activations=True)
|
||||
dq_sf_a = dq_sf_a.reshape(k // 128, -1).contiguous()
|
||||
b_fp8, dq_sf_b = quant_ds_fp8(b_fp32, activations=False)
|
||||
dq_sf_b = dq_sf_b.contiguous()
|
||||
else:
|
||||
a_fp8, global_dq_a = quant_fp8(a_fp32)
|
||||
b_fp8, global_dq_b = quant_fp8(b_fp32)
|
||||
out_global_scaling_factor = global_dq_a * global_dq_b
|
||||
dq_sf_a = None
|
||||
dq_sf_b = None
|
||||
global_dq_a = None
|
||||
global_dq_b = None
|
||||
out_global_scaling_factor = None
|
||||
|
||||
# Compute reference batched matrix multiplication
|
||||
output = fp8_bmm_reference(a_fp8, b_fp8, dq_sf_a, dq_sf_b, global_dq_a,
|
||||
global_dq_b, dtype_c, use_deep_seek_fp8)
|
||||
|
||||
c_dq_sf_ref = None
|
||||
if dtype_c == torch.float8_e4m3fn:
|
||||
if use_deep_seek_fp8:
|
||||
c_ref, c_dq_sf_ref = output
|
||||
c_dq_sf_ref = c_dq_sf_ref.reshape(n // 128, -1)
|
||||
a_fp8, dq_sf_a = quant_ds_fp8(a_fp32, activations=True)
|
||||
dq_sf_a = dq_sf_a.reshape(k // 128, -1).contiguous()
|
||||
b_fp8, dq_sf_b = quant_ds_fp8(b_fp32, activations=False)
|
||||
dq_sf_b = dq_sf_b.contiguous()
|
||||
else:
|
||||
c_ref, c_scale = output
|
||||
out_global_scaling_factor /= c_scale.cuda()
|
||||
else:
|
||||
c_ref = output
|
||||
a_fp8, global_dq_a = quant_fp8(a_fp32)
|
||||
b_fp8, global_dq_b = quant_fp8(b_fp32)
|
||||
out_global_scaling_factor = global_dq_a * global_dq_b
|
||||
|
||||
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
|
||||
if low_latency and not use_deep_seek_fp8:
|
||||
b_fp8_shuffled = []
|
||||
for bi in range(b):
|
||||
b_fp8_shuffled.append(
|
||||
shuffle_matrix_a(b_fp8[bi].view(torch.uint8).clone(),
|
||||
epilogue_tile_m))
|
||||
# Compute reference batched matrix multiplication
|
||||
output = fp8_bmm_reference(a_fp8, b_fp8, dq_sf_a, dq_sf_b, global_dq_a,
|
||||
global_dq_b, dtype_c, use_deep_seek_fp8)
|
||||
|
||||
# Stack weights for all experts
|
||||
b_fp8 = torch.stack(b_fp8_shuffled).view(torch.float8_e4m3fn)
|
||||
c_dq_sf_ref = None
|
||||
if dtype_c == torch.float8_e4m3fn:
|
||||
if use_deep_seek_fp8:
|
||||
c_ref, c_dq_sf_ref = output
|
||||
c_dq_sf_ref = c_dq_sf_ref.reshape(n // 128, -1)
|
||||
else:
|
||||
c_ref, c_scale = output
|
||||
out_global_scaling_factor /= c_scale.cuda()
|
||||
else:
|
||||
c_ref = output
|
||||
|
||||
if not use_deep_seek_fp8:
|
||||
out_global_scaling_factor = out_global_scaling_factor.contiguous().to(
|
||||
torch.float32)
|
||||
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
|
||||
if low_latency and not use_deep_seek_fp8:
|
||||
b_fp8_shuffled = []
|
||||
for bi in range(b):
|
||||
b_fp8_shuffled.append(
|
||||
shuffle_matrix_a(b_fp8[bi].view(torch.uint8).clone(),
|
||||
epilogue_tile_m))
|
||||
|
||||
c_actual, c_dq_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
|
||||
a_fp8.contiguous(),
|
||||
b_fp8.contiguous(),
|
||||
tile_size=tile_size,
|
||||
epilogue_tile_m=epilogue_tile_m,
|
||||
use_deep_seek_fp8=use_deep_seek_fp8,
|
||||
low_latency=low_latency,
|
||||
out_dtype=dtype_c,
|
||||
dq_sfs_a=dq_sf_a,
|
||||
dq_sfs_b=dq_sf_b,
|
||||
scale_c=out_global_scaling_factor)
|
||||
# Stack weights for all experts
|
||||
b_fp8 = torch.stack(b_fp8_shuffled).view(torch.float8_e4m3fn)
|
||||
|
||||
c_actual = c_actual.detach().cpu()
|
||||
c_ref = c_ref.detach().cpu()
|
||||
if not use_deep_seek_fp8:
|
||||
out_global_scaling_factor = out_global_scaling_factor.contiguous(
|
||||
).to(torch.float32)
|
||||
|
||||
torch.testing.assert_close(c_actual.to(torch.float32)[:, :m],
|
||||
c_ref.to(torch.float32)[:, :m],
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
if use_deep_seek_fp8 and dtype_c == torch.float8_e4m3fn:
|
||||
c_dq_sf = c_dq_sf.detach().cpu()
|
||||
for bi in range(b):
|
||||
torch.testing.assert_close(
|
||||
c_dq_sf[:, bi * m_padded:bi * m_padded + m].to(torch.float32),
|
||||
c_dq_sf_ref[:,
|
||||
c_actual, c_dq_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
|
||||
a_fp8.contiguous(),
|
||||
b_fp8.contiguous(),
|
||||
tile_size=tile_size,
|
||||
epilogue_tile_m=epilogue_tile_m,
|
||||
use_deep_seek_fp8=use_deep_seek_fp8,
|
||||
low_latency=low_latency,
|
||||
out_dtype=dtype_c,
|
||||
dq_sfs_a=dq_sf_a,
|
||||
dq_sfs_b=dq_sf_b,
|
||||
scale_c=out_global_scaling_factor)
|
||||
|
||||
c_actual = c_actual.detach().cpu()
|
||||
c_ref = c_ref.detach().cpu()
|
||||
|
||||
torch.testing.assert_close(c_actual.to(torch.float32)[:, :m],
|
||||
c_ref.to(torch.float32)[:, :m],
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
if use_deep_seek_fp8 and dtype_c == torch.float8_e4m3fn:
|
||||
c_dq_sf = c_dq_sf.detach().cpu()
|
||||
for bi in range(b):
|
||||
torch.testing.assert_close(
|
||||
c_dq_sf[:,
|
||||
bi * m_padded:bi * m_padded + m].to(torch.float32),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
c_dq_sf_ref[:, bi * m_padded:bi * m_padded + m].to(
|
||||
torch.float32),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
def test_autotuned_thop(self, test_case: BatchedGemmTestCase) -> None:
|
||||
torch.random.manual_seed(42)
|
||||
|
||||
b = test_case.b
|
||||
m = test_case.m
|
||||
n = test_case.n
|
||||
k = test_case.k
|
||||
use_deep_seek_fp8 = test_case.use_deep_seek_fp8
|
||||
dtype_c = test_case.dtype_c
|
||||
low_latency = test_case.low_latency
|
||||
tile_size = test_case.tile_size
|
||||
|
||||
a_fp32 = torch.randn((b, m, k), device="cuda", dtype=torch.float32)
|
||||
b_fp32 = torch.randn((b, n, k), device="cuda", dtype=torch.float32)
|
||||
# Pad to the tile size. It is needed for the TRT-LLM Gen BMM input requirements.
|
||||
if m % tile_size:
|
||||
tiled_shape = ((m + tile_size - 1) // tile_size) * tile_size
|
||||
a_fp32 = torch.nn.functional.pad(a_fp32, (0, 0, 0, tiled_shape - m),
|
||||
"constant", 0)
|
||||
m_padded = ((m + tile_size - 1) // tile_size) * tile_size
|
||||
|
||||
dq_sf_a = None
|
||||
dq_sf_b = None
|
||||
global_dq_a = None
|
||||
global_dq_b = None
|
||||
out_global_scaling_factor = None
|
||||
|
||||
if use_deep_seek_fp8:
|
||||
a_fp8, dq_sf_a = quant_ds_fp8(a_fp32, activations=True)
|
||||
dq_sf_a = dq_sf_a.reshape(k // 128, -1).contiguous()
|
||||
b_fp8, dq_sf_b = quant_ds_fp8(b_fp32, activations=False)
|
||||
dq_sf_b = dq_sf_b.contiguous()
|
||||
else:
|
||||
a_fp8, global_dq_a = quant_fp8(a_fp32)
|
||||
b_fp8, global_dq_b = quant_fp8(b_fp32)
|
||||
out_global_scaling_factor = global_dq_a * global_dq_b
|
||||
|
||||
# Compute reference batched matrix multiplication
|
||||
output = fp8_bmm_reference(a_fp8, b_fp8, dq_sf_a, dq_sf_b, global_dq_a,
|
||||
global_dq_b, dtype_c, use_deep_seek_fp8)
|
||||
|
||||
c_dq_sf_ref = None
|
||||
if dtype_c == torch.float8_e4m3fn:
|
||||
if use_deep_seek_fp8:
|
||||
c_ref, c_dq_sf_ref = output
|
||||
c_dq_sf_ref = c_dq_sf_ref.reshape(n // 128, -1)
|
||||
else:
|
||||
c_ref, c_scale = output
|
||||
out_global_scaling_factor /= c_scale.cuda()
|
||||
else:
|
||||
c_ref = output
|
||||
|
||||
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
|
||||
if low_latency and not use_deep_seek_fp8:
|
||||
b_fp8_shuffled = []
|
||||
for bi in range(b):
|
||||
b_fp8_shuffled.append(
|
||||
shuffle_matrix_a(b_fp8[bi].view(torch.uint8).clone(),
|
||||
epilogue_tile_m))
|
||||
|
||||
# Stack weights for all experts
|
||||
b_fp8 = torch.stack(b_fp8_shuffled).view(torch.float8_e4m3fn)
|
||||
|
||||
if not use_deep_seek_fp8:
|
||||
out_global_scaling_factor = out_global_scaling_factor.contiguous(
|
||||
).to(torch.float32)
|
||||
|
||||
with autotune():
|
||||
c_actual, c_dq_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
|
||||
a_fp8.contiguous(),
|
||||
b_fp8.contiguous(),
|
||||
tile_size=tile_size,
|
||||
epilogue_tile_m=epilogue_tile_m,
|
||||
use_deep_seek_fp8=use_deep_seek_fp8,
|
||||
low_latency=low_latency,
|
||||
out_dtype=dtype_c,
|
||||
dq_sfs_a=dq_sf_a,
|
||||
dq_sfs_b=dq_sf_b,
|
||||
scale_c=out_global_scaling_factor)
|
||||
|
||||
c_actual = c_actual.detach().cpu()
|
||||
c_ref = c_ref.detach().cpu()
|
||||
|
||||
torch.testing.assert_close(c_actual.to(torch.float32)[:, :m],
|
||||
c_ref.to(torch.float32)[:, :m],
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
if use_deep_seek_fp8 and dtype_c == torch.float8_e4m3fn:
|
||||
c_dq_sf = c_dq_sf.detach().cpu()
|
||||
for bi in range(b):
|
||||
torch.testing.assert_close(
|
||||
c_dq_sf[:,
|
||||
bi * m_padded:bi * m_padded + m].to(torch.float32),
|
||||
c_dq_sf_ref[:, bi * m_padded:bi * m_padded + m].to(
|
||||
torch.float32),
|
||||
atol=1e-2,
|
||||
rtol=1e-2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user