[TRTLLM-5589] feat: Integrate TRT-LLM Gen FP8 Batched GEMM with Pytorch workflow kernel autotuner (#4872)

Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
This commit is contained in:
Dom Brown 2025-06-09 11:02:48 +01:00 committed by GitHub
parent 1d4f748773
commit 9c012d5bf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 609 additions and 165 deletions

View File

@ -55,11 +55,12 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
}
}
TLLM_CHECK_WITH_INFO(mPassingConfigIndices.size() != 0, "No kernel found for the given output type");
TLLM_CHECK_WITH_INFO(!mPassingConfigIndices.empty(), "No kernel found for the given options");
}
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
std::optional<int32_t> configIndex)
{
BatchedGemmData gemmData;
gemmData.mProblemDimensions.mNumBatches = numBatches;
@ -74,13 +75,18 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n,
gemmData.mProblemDimensions.mWorldSize = 1;
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
selectGemmConfig(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
auto bmm = BatchedGemmInterface();
auto const configs = bmm.getBatchedGemmConfigs();
TLLM_CHECK_WITH_INFO(
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
auto const& config = configs[mSelectedConfigIndex.value()];
if (!configIndex.has_value())
{
mSelectedConfigIndex
= getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
configIndex = mSelectedConfigIndex;
}
auto const& config = configs[configIndex.value()];
return bmm.getWorkspaceSizeInBytes(config, gemmData);
}
@ -89,16 +95,22 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC,
void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device)
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex)
{
auto bmm = BatchedGemmInterface();
BatchedGemmData gemmData;
auto const configs = bmm.getBatchedGemmConfigs();
TLLM_CHECK_WITH_INFO(
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
auto const& config = configs[mSelectedConfigIndex.value()];
if (!configIndex.has_value())
{
TLLM_CHECK_WITH_INFO(mSelectedConfigIndex.has_value(), "Tried to use default config index but none was set");
configIndex = mSelectedConfigIndex;
}
auto const& config = configs[configIndex.value()];
TLLM_CHECK_WITH_INFO(numBatches > 0, "Batched GEMM requires numBatches > 0");
if (!mOptions.staticBatch)
@ -170,7 +182,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace,
CUstream stream, int device)
CUstream stream, int device, std::optional<int32_t> configIndex)
{
// Dispatch with block scaling factors and with static batching.
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
@ -178,12 +190,12 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
/* scaleC */ nullptr, /* scaleGateC */ nullptr, c, outSfC,
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
/* numNonExitingCtas */ nullptr, workspace, stream, device);
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
}
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace,
CUstream stream, int device)
CUstream stream, int device, std::optional<int32_t> configIndex)
{
// Dispatch with block scaling factors and with static batching.
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
@ -191,11 +203,12 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
scaleGateC, c, /* outSfC */ nullptr,
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
/* numNonExitingCtas */ nullptr, workspace, stream, device);
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
}
void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const
{
auto const bmm = BatchedGemmInterface();
auto const configs = bmm.getBatchedGemmConfigs();
@ -242,16 +255,30 @@ void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t
return optionsA.mTileM > optionsB.mTileM;
});
std::vector<int64_t> validConfigIndices;
for (auto const& configIndex : sortedIndices)
{
auto const& config = configs[configIndex];
auto isValidConfig = bmm.isValidConfig(config, gemmData);
if (isValidConfig)
{
mSelectedConfigIndex = configIndex;
return;
validConfigIndices.push_back(configIndex);
}
}
TLLM_CHECK_WITH_INFO(!validConfigIndices.empty(), "No valid config found for the given problem shape");
return validConfigIndices;
}
int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const
{
auto const validConfigIndices
= getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
return validConfigIndices[0];
}
} // namespace kernels

View File

@ -16,8 +16,10 @@
#pragma once
#include <cstdint>
#include <cuda.h>
#include <optional>
#include <vector>
#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
@ -45,20 +47,40 @@ public:
explicit TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunnerOptions const& options);
[[nodiscard]] size_t getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim);
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
std::optional<int32_t> configIndex = std::nullopt);
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
void* workspace, CUstream stream, int device);
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex = std::nullopt);
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device);
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
std::optional<int32_t> configIndex = std::nullopt);
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device);
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
std::optional<int32_t> configIndex = std::nullopt);
// Get the list of configs that passed the validation based on the constructor options
[[nodiscard]] std::vector<int32_t> getPassingConfigIndices() const
{
return mPassingConfigIndices;
}
// Get the list of config indices that are valid for the given problem shape
[[nodiscard]] std::vector<int64_t> getValidConfigIndices(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const;
// Get a default config index that is valid for the given problem shape
// This will be used as the fallback config if using auto-tuning
[[nodiscard]] int64_t getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
int32_t maxNumCtasInBatchDim) const;
private:
void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
@ -66,8 +88,8 @@ private:
private:
TrtllmGenBatchedGemmRunnerOptions mOptions;
std::optional<int> mSelectedConfigIndex;
std::vector<int32_t> mPassingConfigIndices;
std::optional<int32_t> mSelectedConfigIndex;
};
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -26,6 +26,8 @@
#include <cuda_fp16.h>
#include <cstdint>
#include <memory>
#include <tuple>
namespace
{
@ -35,27 +37,16 @@ template <tg::Dtype outDtype>
void runBatchedGemm(at::Tensor& out, at::Tensor& outSfC, at::Tensor const& mat1, at::Tensor const& mat2,
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
std::optional<at::Tensor> const& scaleC, int64_t m, int64_t n, int64_t k, int32_t tileSize, int32_t epilogueTileM,
std::vector<int32_t> const& batchedTokens, bool useDeepSeekFp8, bool lowLatencyKernel)
std::vector<int32_t> const& batchedTokens, bool useDeepSeekFp8, bool lowLatencyKernel,
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner& runner, int32_t const configIndex)
{
auto eltType = tg::Dtype::E4m3;
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = {.eltType = eltType,
.outputType = outDtype,
.deepSeekFp8 = useDeepSeekFp8,
.fusedAct = false,
.routeAct = false,
.staticBatch = true,
.transposeMmaOutput = lowLatencyKernel,
.tileSize = tileSize,
.epilogueTileM = epilogueTileM};
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner runner(options);
// numTokens and maxNumCtasInBatchDim are not used for static batching
int32_t numTokens = 0;
int32_t maxNumCtasInBatchDim = 0;
int64_t const numBytesWorkspace
= runner.getWorkspaceSizeInBytes(m, n, k, batchedTokens, numTokens, batchedTokens.size(), maxNumCtasInBatchDim);
int32_t const numTokens = 0;
int32_t const maxNumCtasInBatchDim = 0;
int64_t const numBytesWorkspace = runner.getWorkspaceSizeInBytes(
m, n, k, batchedTokens, numTokens, batchedTokens.size(), maxNumCtasInBatchDim, configIndex);
at::Tensor workspace
= at::detail::empty_cuda({numBytesWorkspace}, at::ScalarType::Char, mat1.device(), std::nullopt);
@ -66,20 +57,21 @@ void runBatchedGemm(at::Tensor& out, at::Tensor& outSfC, at::Tensor const& mat1,
float* outSfCPtr = outDtype == tg::Dtype::E4m3 ? outSfC.data_ptr<float>() : nullptr;
runner.run(m, n, k, batchedTokens, mat1.const_data_ptr(), dDqSfsA.value().const_data_ptr(),
mat2.const_data_ptr(), dDqSfsB.value().const_data_ptr(), out.data_ptr(), outSfCPtr, workspace.data_ptr(),
stream.stream(), mat1.get_device());
stream.stream(), mat1.get_device(), configIndex);
}
else
{
runner.run(m, n, k, batchedTokens, mat1.const_data_ptr(), mat2.const_data_ptr(),
reinterpret_cast<float const*>(scaleC.value().const_data_ptr()), nullptr, out.data_ptr(),
workspace.data_ptr(), stream.stream(), mat1.get_device());
workspace.data_ptr(), stream.stream(), mat1.get_device(), configIndex);
}
}
std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1, at::Tensor const& mat2,
int32_t tileSize, bool useDeepSeekFp8, bool lowLatencyKernel, int64_t epilogueTileM,
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
std::optional<at::Tensor> const& scaleC, std::optional<c10::ScalarType> outDtype)
std::optional<at::Tensor> const& scaleC, std::optional<c10::ScalarType> outDtype,
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner& runner, int32_t const configIndex)
{
TORCH_CHECK(mat1.dim() == 3, "Matrix A must be of size [B, M, K]");
TORCH_CHECK(mat2.dim() == 3, "Matrix B must be of size [B, N, K]");
@ -144,16 +136,19 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1
TORCH_CHECK(scaleC.value().sizes()[0] == b, "outScalingFactor must be a 1D matrix of size B");
}
int64_t outputN = n;
int64_t const outputN = n;
// Create output tensor.
at::Tensor out = at::detail::empty_cuda({b, m, outputN}, outDtype.value(), mat1.device(), std::nullopt);
at::Tensor outSfC;
if (useDeepSeekFp8 && outDtype.value() == at::ScalarType::Float8_e4m3fn)
{
outSfC = at::detail::empty_cuda(
{outputN / dsFp8QuantBlockSize, m * b}, at::ScalarType::Float, mat1.device(), std::nullopt);
}
bool const needOutSfC = useDeepSeekFp8 && outDtype.value() == at::ScalarType::Float8_e4m3fn;
// Torch class did not support returning a default tensor so using empty instead.
int64_t const outSfCSize0 = needOutSfC ? (outputN / dsFp8QuantBlockSize) : 0;
int64_t const outSfCSize1 = needOutSfC ? (m * b) : 0;
at::Tensor outSfC
= at::detail::empty_cuda({outSfCSize0, outSfCSize1}, at::ScalarType::Float, mat1.device(), std::nullopt);
std::vector<int32_t> batchedTokens(b, m);
@ -161,15 +156,15 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1
{
case at::ScalarType::Half:
runBatchedGemm<tg::Dtype::Fp16>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel);
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel, runner, configIndex);
break;
case at::ScalarType::BFloat16:
runBatchedGemm<tg::Dtype::Bfloat16>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel);
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel, runner, configIndex);
break;
case at::ScalarType::Float8_e4m3fn:
runBatchedGemm<tg::Dtype::E4m3>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel);
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel, runner, configIndex);
break;
default: C10_THROW_ERROR(NotImplementedError, "outDtype must be one of fp16/bf16/e4m3.");
}
@ -181,35 +176,101 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1
namespace torch_ext
{
extern std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_trtllmgen(at::Tensor const& mat1, at::Tensor const& mat2,
int64_t tileSize, bool useDeepSeekFp8, bool lowLatency, int64_t epilogueTileM,
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
std::optional<at::Tensor> const& scaleC, std::optional<c10::ScalarType> outDtype)
// Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow
// use with the torch workflow autotuner class.
class FP8BatchedGemmRunner : public torch::CustomClassHolder
{
auto const smVersion = tensorrt_llm::common::getSMVersion();
switch (smVersion)
public:
explicit FP8BatchedGemmRunner(c10::ScalarType outDtypeArg, bool useDeepSeekFp8, bool lowLatencyKernel,
int64_t tileSize, int64_t epilogueTileM)
: mOutDtypeArg(outDtypeArg)
, mUseDeepSeekFp8(useDeepSeekFp8)
, mLowLatencyKernel(lowLatencyKernel)
, mTileSize(tileSize)
, mEpilogueTileM(epilogueTileM)
{
case tensorrt_llm::kernels::kSM_100:
auto const smVersion = tensorrt_llm::common::getSMVersion();
if (smVersion != tensorrt_llm::kernels::kSM_100)
{
TLLM_THROW("Unsupported or unimplemented compute capability for fp8 batched gemm: %i", smVersion);
}
tg::Dtype outDtype = tg::Dtype::E4m3; // Default to E4m3, will be updated based on outDtypeArg
switch (outDtypeArg)
{
case at::ScalarType::Half: outDtype = tg::Dtype::Fp16; break;
case at::ScalarType::BFloat16: outDtype = tg::Dtype::Bfloat16; break;
case at::ScalarType::Float8_e4m3fn: outDtype = tg::Dtype::E4m3; break;
default: C10_THROW_ERROR(NotImplementedError, "outDtype must be one of fp16/bf16/e4m3.");
}
RunnerOptionsType const options = {.eltType = mEltType,
.outputType = outDtype,
.deepSeekFp8 = mUseDeepSeekFp8,
.fusedAct = false,
.routeAct = false,
.staticBatch = true,
.transposeMmaOutput = mLowLatencyKernel,
.tileSize = static_cast<int32_t>(mTileSize),
.epilogueTileM = static_cast<int32_t>(mEpilogueTileM)};
mRunner = std::make_unique<RunnerType>(options);
}
std::tuple<at::Tensor, at::Tensor> runBatchedGemm(at::Tensor const& mat1, at::Tensor const& mat2,
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
std::optional<at::Tensor> const& scaleC, int64_t configIndex)
{
return fp8_batched_gemm_sm100(
mat1, mat2, tileSize, useDeepSeekFp8, lowLatency, epilogueTileM, dDqSfsA, dDqSfsB, scaleC, outDtype);
return fp8_batched_gemm_sm100(mat1, mat2, mTileSize, mUseDeepSeekFp8, mLowLatencyKernel, mEpilogueTileM,
dDqSfsA, dDqSfsB, scaleC, mOutDtypeArg, *mRunner, configIndex);
}
default: TLLM_THROW("Unsupported or unimplemented compute capability for fp8 batched gemm: %i", smVersion);
std::vector<int64_t> getValidConfigs(int64_t numBatches, int64_t m, int64_t n, int64_t k) const
{
// numTokens and maxNumCtasInBatchDim are not used for static batching
int32_t const numTokens = 0;
int32_t const maxNumCtasInBatchDim = 0;
std::vector<int32_t> const batchedTokens(numBatches, m);
return mRunner->getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
}
}
int64_t getDefaultValidConfigIndex(int64_t numBatches, int64_t m, int64_t n, int64_t k) const
{
// numTokens and maxNumCtasInBatchDim are not used for static batching
int32_t const numTokens = 0;
int32_t const maxNumCtasInBatchDim = 0;
std::vector<int32_t> const batchedTokens(numBatches, m);
return mRunner->getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
}
private:
using RunnerType = tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner;
using RunnerOptionsType = tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions;
std::unique_ptr<RunnerType> mRunner;
tg::Dtype mEltType{tg::Dtype::E4m3};
c10::ScalarType mOutDtypeArg;
bool mUseDeepSeekFp8;
bool mLowLatencyKernel;
int64_t mTileSize;
int64_t mEpilogueTileM;
};
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"fp8_batched_gemm_trtllmgen(Tensor a, Tensor b, int tile_size,"
"bool use_deep_seek_fp8=False, bool low_latency=False, "
"int epilogue_tile_m=0, Tensor? dq_sfs_a=None, Tensor? dq_sfs_b=None, "
"Tensor? scale_c=None, "
"ScalarType? out_dtype=None) -> (Tensor, Tensor)");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("fp8_batched_gemm_trtllmgen", &torch_ext::fp8_batched_gemm_trtllmgen);
m.class_<torch_ext::FP8BatchedGemmRunner>("FP8BatchedGemmRunner")
.def(torch::init<at::ScalarType, bool, bool, int64_t, int64_t>())
.def("get_valid_configs", &torch_ext::FP8BatchedGemmRunner::getValidConfigs)
.def("get_default_valid_config", &torch_ext::FP8BatchedGemmRunner::getDefaultValidConfigIndex)
.def("run_batched_gemm", &torch_ext::FP8BatchedGemmRunner::runBatchedGemm);
}

View File

@ -340,7 +340,12 @@ class AutoTuner:
Although runners[0] with tactic=-1 is always treated as the fallback runner.
Runner authors are suggested to provide a fallback implementation for each runner to avoid potential issues.
"""
input_shapes = tuple(t.shape for t in inputs)
# Treat None tensors as size zero
# This allows the tuner to handle TRT-LLM-Gen torch ops that have optional tensor
# arguments, such as block scaling factors.
input_shapes = tuple(
(t.shape if t is not None else torch.Size((0, ))) for t in inputs)
# Early return if it's not tuning, use cache found one or fallback one
if not self.is_tuning_mode:
@ -393,8 +398,14 @@ class AutoTuner:
time_measured = self._profile_single_kernel(
r, tensors, tac, **kwargs)
except Exception as e:
# Handle None tensors for optional inputs
shapes = [
t.size() if t is not None else torch.Size((0, ))
for t in tensors
]
logger.error(
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={[t.size() for t in tensors]}. Error occurred: {e}"
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
)
# Record the failed profiling combinations
@ -471,8 +482,13 @@ class AutoTuner:
stream.synchronize()
avg_time = start.elapsed_time(end) / self.repeat
# Handle None tensors for optional inputs
shapes = [
t.size() if t is not None else torch.Size((0, )) for t in inputs
]
logger.debug(
f"[Autotuner]: profiling {runner} {tactic}, shapes={[t.size() for t in inputs]}, avg_time {avg_time}"
f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes}, avg_time {avg_time}"
)
return avg_time
@ -494,9 +510,13 @@ class AutoTuner:
combinations specified in dynamic_tensor_specs.
"""
# every dimension created from the concrete input tensor shape
# generate some dynamic dimension description based on the dynamic_tensor_specs
base_profile = OptimizationProfile([[StaticDim(x) for x in t.size()]
for t in inputs])
# generate some dynamic dimension description based on the dynamic_tensors
# Zero handles the case where a TRTLLM op has optional inputs.
base_profile = OptimizationProfile(
[[StaticDim(x)
for x in t.size()] if t is not None else [StaticDim(0)]
for t in inputs])
generated_profiles: List[OptimizationProfile] = []

View File

@ -1,5 +1,5 @@
from functools import lru_cache
from typing import List, Optional
from typing import List, Optional, Tuple
import torch
@ -340,6 +340,218 @@ def _(
dtype=output_dtype)
class FP8BatchedGemmRunner(TunableRunner):
_runner_dict = dict()
def __init__(self, output_dtype: torch.dtype, use_deep_seek_fp8: bool,
low_latency_kernel: bool, tile_size: int,
epilogue_tile_m: int):
self.output_dtype = output_dtype
self.use_deep_seek_fp8 = use_deep_seek_fp8
self.low_latency_kernel = low_latency_kernel
self.tile_size = tile_size
self.epilogue_tile_m = epilogue_tile_m
self.tuning_config = self.get_tuning_config()
instance_key = (output_dtype, use_deep_seek_fp8, low_latency_kernel,
tile_size, epilogue_tile_m)
if instance_key not in FP8BatchedGemmRunner._runner_dict:
FP8BatchedGemmRunner._runner_dict[
instance_key] = torch.classes.trtllm.FP8BatchedGemmRunner(
output_dtype, use_deep_seek_fp8, low_latency_kernel,
tile_size, epilogue_tile_m)
self._kernel_runner = FP8BatchedGemmRunner._runner_dict[instance_key]
def forward(
self,
inputs: List[torch.Tensor],
tactic: int = -1,
do_preparation: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Run the batched GEMM operation with the given inputs and tactic.
"""
mat1, mat2, dq_sfs_a, dq_sfs_b, scale_c = inputs
chosen_tactic = self.get_default_valid_tactic(
inputs) if tactic == -1 else tactic
out_tensors = self._kernel_runner.run_batched_gemm(
mat1,
mat2,
dq_sfs_a,
dq_sfs_b,
scale_c,
chosen_tactic,
)
return out_tensors
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
) -> List[int]:
mat1, mat2, _, _, _ = inputs
b = mat1.shape[0]
m = mat1.shape[1]
n = mat2.shape[1]
k = mat1.shape[2]
tactics = self._kernel_runner.get_valid_configs(b, m, n, k)
return tactics
def get_default_valid_tactic(
self,
inputs: List[torch.Tensor],
) -> int:
mat1, mat2, _, _, _ = inputs
b = mat1.shape[0]
m = mat1.shape[1]
n = mat2.shape[1]
k = mat1.shape[2]
default_tactic = self._kernel_runner.get_default_valid_config(
b, m, n, k)
return default_tactic
def get_dynamic_tensor_specs(self) -> Tuple[DynamicTensorSpec, ...]:
"""Get the dynamic tensor specs for use with the AutoTuner."""
# These indices correspond to the 0th input tensor and it's first dimension
# i.e. we are tuning M where the first input tensor is of shape [B, M, K]
MAT1_IDX = 0
TUNED_DIM = 1
# Starting at 8 as M % tile size == 0 is required
m_values = (8, 16, 32, 64, 128, 256, 512, 1024, 2048)
round_rule = lambda x: last_positive_power_of_2(x)
specs = (DynamicTensorSpec(MAT1_IDX, TUNED_DIM, m_values, round_rule), )
return specs
def get_constraint_specs(self) -> Tuple[ConstraintSpec, ...]:
"""Get the constraint specs for the dynamic tensors for use with the AutoTuner.
"""
# When using deepseek fp8, the dq_sfs_a and dq_sfs_b tensors are expected to
# have specific dimensions. As we are only tuning M, we need only constrain
# dimension 1 of dq_sfs_a
if not self.use_deep_seek_fp8:
constraint_dq_sfs_a = ()
else:
def _constrain_dq_sfs_a_dim1(shapes: Tuple[torch.Size]) -> int:
b = shapes[0][0]
m = shapes[0][1]
m_padded = (m + self.tile_size - 1) // self.tile_size
result = m_padded * self.tile_size * b
return result
SFS_A_IDX = 2
CONSTRAINED_DIM = 1
constraint_dq_sfs_a = (ConstraintSpec(SFS_A_IDX, CONSTRAINED_DIM,
_constrain_dq_sfs_a_dim1), )
return constraint_dq_sfs_a
def get_tuning_config(self) -> TuningConfig:
"""Get the tuning configuration for the AutoTuner."""
dynamic_tensor_specs = self.get_dynamic_tensor_specs()
constraint_specs = self.get_constraint_specs()
tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs,
constraint_specs=constraint_specs)
return tuning_config
@torch.library.custom_op("trtllm::fp8_batched_gemm_trtllmgen", mutates_args=())
def fp8_batched_gemm_trtllmgen(
mat1: torch.Tensor,
mat2: torch.Tensor,
tile_size: int,
use_deep_seek_fp8: Optional[bool] = False,
low_latency: Optional[bool] = False,
epilogue_tile_m: Optional[int] = 0,
dq_sfs_a: Optional[torch.Tensor] = None,
dq_sfs_b: Optional[torch.Tensor] = None,
scale_c: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = torch.half
) -> Tuple[torch.Tensor, torch.Tensor]:
kernel_runner = FP8BatchedGemmRunner(output_dtype=out_dtype,
use_deep_seek_fp8=use_deep_seek_fp8,
low_latency_kernel=low_latency,
tile_size=tile_size,
epilogue_tile_m=epilogue_tile_m)
tuner = AutoTuner.get()
inputs = [mat1, mat2, dq_sfs_a, dq_sfs_b, scale_c]
_, best_tactic = tuner.choose_one(
"trtllm::fp8_batched_gemm_trtllmgen::batched_gemm",
[kernel_runner],
kernel_runner.tuning_config,
inputs,
)
return kernel_runner(
inputs=inputs,
tactic=best_tactic,
)
# Allows the tunable TRTLLM-Gen FP8 batched GEMM to be
# used with torch.compile
@fp8_batched_gemm_trtllmgen.register_fake
def _(
mat1: torch.Tensor,
mat2: torch.Tensor,
tile_size: int,
use_deep_seek_fp8: Optional[bool] = False,
low_latency: Optional[bool] = False,
epilogue_tile_m: Optional[int] = 0,
dq_sfs_a: Optional[torch.Tensor] = None,
dq_sfs_b: Optional[torch.Tensor] = None,
scale_c: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
b = mat1.size(0)
m = mat1.size(1)
n = mat2.size(1)
fake_out = mat1.new_empty((b, m, n), dtype=out_dtype)
if use_deep_seek_fp8:
ds_fp8_quant_block_size = 128
dim0_size = n // ds_fp8_quant_block_size
dim1_size = b * m
fake_dq_sfs_c = torch.empty((dim0_size, dim1_size), dtype=torch.float32)
else:
fake_dq_sfs_c = torch.empty((0, 0), dtype=torch.float32)
return (fake_out, fake_dq_sfs_c)
@torch.library.custom_op("trtllm::attention", mutates_args=())
def attention(
q: torch.Tensor,

View File

@ -20,6 +20,7 @@ import pytest
import torch
from utils.util import getSMVersion
from tensorrt_llm._torch.autotuner import autotune
from tensorrt_llm.quantization.utils.fp4_utils import shuffle_matrix_a
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
@ -246,98 +247,199 @@ def quant_fp8(x_fp32: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
),
],
)
def test_fp8_batched_gemm_trtllmgen(test_case: BatchedGemmTestCase) -> None:
torch.random.manual_seed(42)
class TestFP8BatchedGemmTRTLLMGen:
b = test_case.b
m = test_case.m
n = test_case.n
k = test_case.k
use_deep_seek_fp8 = test_case.use_deep_seek_fp8
dtype_c = test_case.dtype_c
low_latency = test_case.low_latency
tile_size = test_case.tile_size
def test_thop(self, test_case: BatchedGemmTestCase) -> None:
torch.random.manual_seed(42)
a_fp32 = torch.randn((b, m, k), device="cuda", dtype=torch.float32)
b_fp32 = torch.randn((b, n, k), device="cuda", dtype=torch.float32)
# Pad to the tile size. It is needed for the TRT-LLM Gen BMM input requirements.
if m % tile_size:
tiled_shape = ((m + tile_size - 1) // tile_size) * tile_size
a_fp32 = torch.nn.functional.pad(a_fp32, (0, 0, 0, tiled_shape - m),
"constant", 0)
m_padded = ((m + tile_size - 1) // tile_size) * tile_size
b = test_case.b
m = test_case.m
n = test_case.n
k = test_case.k
use_deep_seek_fp8 = test_case.use_deep_seek_fp8
dtype_c = test_case.dtype_c
low_latency = test_case.low_latency
tile_size = test_case.tile_size
dq_sf_a = None
dq_sf_b = None
global_dq_a = None
global_dq_b = None
out_global_scaling_factor = None
a_fp32 = torch.randn((b, m, k), device="cuda", dtype=torch.float32)
b_fp32 = torch.randn((b, n, k), device="cuda", dtype=torch.float32)
# Pad to the tile size. It is needed for the TRT-LLM Gen BMM input requirements.
if m % tile_size:
tiled_shape = ((m + tile_size - 1) // tile_size) * tile_size
a_fp32 = torch.nn.functional.pad(a_fp32, (0, 0, 0, tiled_shape - m),
"constant", 0)
m_padded = ((m + tile_size - 1) // tile_size) * tile_size
if use_deep_seek_fp8:
a_fp8, dq_sf_a = quant_ds_fp8(a_fp32, activations=True)
dq_sf_a = dq_sf_a.reshape(k // 128, -1).contiguous()
b_fp8, dq_sf_b = quant_ds_fp8(b_fp32, activations=False)
dq_sf_b = dq_sf_b.contiguous()
else:
a_fp8, global_dq_a = quant_fp8(a_fp32)
b_fp8, global_dq_b = quant_fp8(b_fp32)
out_global_scaling_factor = global_dq_a * global_dq_b
dq_sf_a = None
dq_sf_b = None
global_dq_a = None
global_dq_b = None
out_global_scaling_factor = None
# Compute reference batched matrix multiplication
output = fp8_bmm_reference(a_fp8, b_fp8, dq_sf_a, dq_sf_b, global_dq_a,
global_dq_b, dtype_c, use_deep_seek_fp8)
c_dq_sf_ref = None
if dtype_c == torch.float8_e4m3fn:
if use_deep_seek_fp8:
c_ref, c_dq_sf_ref = output
c_dq_sf_ref = c_dq_sf_ref.reshape(n // 128, -1)
a_fp8, dq_sf_a = quant_ds_fp8(a_fp32, activations=True)
dq_sf_a = dq_sf_a.reshape(k // 128, -1).contiguous()
b_fp8, dq_sf_b = quant_ds_fp8(b_fp32, activations=False)
dq_sf_b = dq_sf_b.contiguous()
else:
c_ref, c_scale = output
out_global_scaling_factor /= c_scale.cuda()
else:
c_ref = output
a_fp8, global_dq_a = quant_fp8(a_fp32)
b_fp8, global_dq_b = quant_fp8(b_fp32)
out_global_scaling_factor = global_dq_a * global_dq_b
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
if low_latency and not use_deep_seek_fp8:
b_fp8_shuffled = []
for bi in range(b):
b_fp8_shuffled.append(
shuffle_matrix_a(b_fp8[bi].view(torch.uint8).clone(),
epilogue_tile_m))
# Compute reference batched matrix multiplication
output = fp8_bmm_reference(a_fp8, b_fp8, dq_sf_a, dq_sf_b, global_dq_a,
global_dq_b, dtype_c, use_deep_seek_fp8)
# Stack weights for all experts
b_fp8 = torch.stack(b_fp8_shuffled).view(torch.float8_e4m3fn)
c_dq_sf_ref = None
if dtype_c == torch.float8_e4m3fn:
if use_deep_seek_fp8:
c_ref, c_dq_sf_ref = output
c_dq_sf_ref = c_dq_sf_ref.reshape(n // 128, -1)
else:
c_ref, c_scale = output
out_global_scaling_factor /= c_scale.cuda()
else:
c_ref = output
if not use_deep_seek_fp8:
out_global_scaling_factor = out_global_scaling_factor.contiguous().to(
torch.float32)
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
if low_latency and not use_deep_seek_fp8:
b_fp8_shuffled = []
for bi in range(b):
b_fp8_shuffled.append(
shuffle_matrix_a(b_fp8[bi].view(torch.uint8).clone(),
epilogue_tile_m))
c_actual, c_dq_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
a_fp8.contiguous(),
b_fp8.contiguous(),
tile_size=tile_size,
epilogue_tile_m=epilogue_tile_m,
use_deep_seek_fp8=use_deep_seek_fp8,
low_latency=low_latency,
out_dtype=dtype_c,
dq_sfs_a=dq_sf_a,
dq_sfs_b=dq_sf_b,
scale_c=out_global_scaling_factor)
# Stack weights for all experts
b_fp8 = torch.stack(b_fp8_shuffled).view(torch.float8_e4m3fn)
c_actual = c_actual.detach().cpu()
c_ref = c_ref.detach().cpu()
if not use_deep_seek_fp8:
out_global_scaling_factor = out_global_scaling_factor.contiguous(
).to(torch.float32)
torch.testing.assert_close(c_actual.to(torch.float32)[:, :m],
c_ref.to(torch.float32)[:, :m],
atol=1e-2,
rtol=1e-2)
if use_deep_seek_fp8 and dtype_c == torch.float8_e4m3fn:
c_dq_sf = c_dq_sf.detach().cpu()
for bi in range(b):
torch.testing.assert_close(
c_dq_sf[:, bi * m_padded:bi * m_padded + m].to(torch.float32),
c_dq_sf_ref[:,
c_actual, c_dq_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
a_fp8.contiguous(),
b_fp8.contiguous(),
tile_size=tile_size,
epilogue_tile_m=epilogue_tile_m,
use_deep_seek_fp8=use_deep_seek_fp8,
low_latency=low_latency,
out_dtype=dtype_c,
dq_sfs_a=dq_sf_a,
dq_sfs_b=dq_sf_b,
scale_c=out_global_scaling_factor)
c_actual = c_actual.detach().cpu()
c_ref = c_ref.detach().cpu()
torch.testing.assert_close(c_actual.to(torch.float32)[:, :m],
c_ref.to(torch.float32)[:, :m],
atol=1e-2,
rtol=1e-2)
if use_deep_seek_fp8 and dtype_c == torch.float8_e4m3fn:
c_dq_sf = c_dq_sf.detach().cpu()
for bi in range(b):
torch.testing.assert_close(
c_dq_sf[:,
bi * m_padded:bi * m_padded + m].to(torch.float32),
atol=1e-2,
rtol=1e-2)
c_dq_sf_ref[:, bi * m_padded:bi * m_padded + m].to(
torch.float32),
atol=1e-2,
rtol=1e-2)
def test_autotuned_thop(self, test_case: BatchedGemmTestCase) -> None:
torch.random.manual_seed(42)
b = test_case.b
m = test_case.m
n = test_case.n
k = test_case.k
use_deep_seek_fp8 = test_case.use_deep_seek_fp8
dtype_c = test_case.dtype_c
low_latency = test_case.low_latency
tile_size = test_case.tile_size
a_fp32 = torch.randn((b, m, k), device="cuda", dtype=torch.float32)
b_fp32 = torch.randn((b, n, k), device="cuda", dtype=torch.float32)
# Pad to the tile size. It is needed for the TRT-LLM Gen BMM input requirements.
if m % tile_size:
tiled_shape = ((m + tile_size - 1) // tile_size) * tile_size
a_fp32 = torch.nn.functional.pad(a_fp32, (0, 0, 0, tiled_shape - m),
"constant", 0)
m_padded = ((m + tile_size - 1) // tile_size) * tile_size
dq_sf_a = None
dq_sf_b = None
global_dq_a = None
global_dq_b = None
out_global_scaling_factor = None
if use_deep_seek_fp8:
a_fp8, dq_sf_a = quant_ds_fp8(a_fp32, activations=True)
dq_sf_a = dq_sf_a.reshape(k // 128, -1).contiguous()
b_fp8, dq_sf_b = quant_ds_fp8(b_fp32, activations=False)
dq_sf_b = dq_sf_b.contiguous()
else:
a_fp8, global_dq_a = quant_fp8(a_fp32)
b_fp8, global_dq_b = quant_fp8(b_fp32)
out_global_scaling_factor = global_dq_a * global_dq_b
# Compute reference batched matrix multiplication
output = fp8_bmm_reference(a_fp8, b_fp8, dq_sf_a, dq_sf_b, global_dq_a,
global_dq_b, dtype_c, use_deep_seek_fp8)
c_dq_sf_ref = None
if dtype_c == torch.float8_e4m3fn:
if use_deep_seek_fp8:
c_ref, c_dq_sf_ref = output
c_dq_sf_ref = c_dq_sf_ref.reshape(n // 128, -1)
else:
c_ref, c_scale = output
out_global_scaling_factor /= c_scale.cuda()
else:
c_ref = output
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
if low_latency and not use_deep_seek_fp8:
b_fp8_shuffled = []
for bi in range(b):
b_fp8_shuffled.append(
shuffle_matrix_a(b_fp8[bi].view(torch.uint8).clone(),
epilogue_tile_m))
# Stack weights for all experts
b_fp8 = torch.stack(b_fp8_shuffled).view(torch.float8_e4m3fn)
if not use_deep_seek_fp8:
out_global_scaling_factor = out_global_scaling_factor.contiguous(
).to(torch.float32)
with autotune():
c_actual, c_dq_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
a_fp8.contiguous(),
b_fp8.contiguous(),
tile_size=tile_size,
epilogue_tile_m=epilogue_tile_m,
use_deep_seek_fp8=use_deep_seek_fp8,
low_latency=low_latency,
out_dtype=dtype_c,
dq_sfs_a=dq_sf_a,
dq_sfs_b=dq_sf_b,
scale_c=out_global_scaling_factor)
c_actual = c_actual.detach().cpu()
c_ref = c_ref.detach().cpu()
torch.testing.assert_close(c_actual.to(torch.float32)[:, :m],
c_ref.to(torch.float32)[:, :m],
atol=1e-2,
rtol=1e-2)
if use_deep_seek_fp8 and dtype_c == torch.float8_e4m3fn:
c_dq_sf = c_dq_sf.detach().cpu()
for bi in range(b):
torch.testing.assert_close(
c_dq_sf[:,
bi * m_padded:bi * m_padded + m].to(torch.float32),
c_dq_sf_ref[:, bi * m_padded:bi * m_padded + m].to(
torch.float32),
atol=1e-2,
rtol=1e-2)