mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-5589] feat: Integrate TRT-LLM Gen FP8 Batched GEMM with Pytorch workflow kernel autotuner (#4872)
Signed-off-by: Dom Brown <3886319+DomBrown@users.noreply.github.com>
This commit is contained in:
parent
1d4f748773
commit
9c012d5bf8
@ -55,11 +55,12 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TLLM_CHECK_WITH_INFO(mPassingConfigIndices.size() != 0, "No kernel found for the given output type");
|
TLLM_CHECK_WITH_INFO(!mPassingConfigIndices.empty(), "No kernel found for the given options");
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
|
size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
|
||||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
|
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
|
||||||
|
std::optional<int32_t> configIndex)
|
||||||
{
|
{
|
||||||
BatchedGemmData gemmData;
|
BatchedGemmData gemmData;
|
||||||
gemmData.mProblemDimensions.mNumBatches = numBatches;
|
gemmData.mProblemDimensions.mNumBatches = numBatches;
|
||||||
@ -74,13 +75,18 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n,
|
|||||||
gemmData.mProblemDimensions.mWorldSize = 1;
|
gemmData.mProblemDimensions.mWorldSize = 1;
|
||||||
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
|
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
|
||||||
|
|
||||||
selectGemmConfig(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
|
||||||
|
|
||||||
auto bmm = BatchedGemmInterface();
|
auto bmm = BatchedGemmInterface();
|
||||||
|
|
||||||
auto const configs = bmm.getBatchedGemmConfigs();
|
auto const configs = bmm.getBatchedGemmConfigs();
|
||||||
TLLM_CHECK_WITH_INFO(
|
|
||||||
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
|
if (!configIndex.has_value())
|
||||||
auto const& config = configs[mSelectedConfigIndex.value()];
|
{
|
||||||
|
mSelectedConfigIndex
|
||||||
|
= getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
||||||
|
configIndex = mSelectedConfigIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const& config = configs[configIndex.value()];
|
||||||
return bmm.getWorkspaceSizeInBytes(config, gemmData);
|
return bmm.getWorkspaceSizeInBytes(config, gemmData);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -89,16 +95,22 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
|
|||||||
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC,
|
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC, float const* scaleGateC,
|
||||||
void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
|
void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
|
||||||
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
|
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
|
||||||
void* workspace, CUstream stream, int device)
|
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex)
|
||||||
{
|
{
|
||||||
auto bmm = BatchedGemmInterface();
|
auto bmm = BatchedGemmInterface();
|
||||||
|
|
||||||
BatchedGemmData gemmData;
|
BatchedGemmData gemmData;
|
||||||
|
|
||||||
auto const configs = bmm.getBatchedGemmConfigs();
|
auto const configs = bmm.getBatchedGemmConfigs();
|
||||||
TLLM_CHECK_WITH_INFO(
|
|
||||||
mSelectedConfigIndex.has_value(), "No valid kernel found for given param config and problem size");
|
if (!configIndex.has_value())
|
||||||
auto const& config = configs[mSelectedConfigIndex.value()];
|
{
|
||||||
|
TLLM_CHECK_WITH_INFO(mSelectedConfigIndex.has_value(), "Tried to use default config index but none was set");
|
||||||
|
|
||||||
|
configIndex = mSelectedConfigIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const& config = configs[configIndex.value()];
|
||||||
|
|
||||||
TLLM_CHECK_WITH_INFO(numBatches > 0, "Batched GEMM requires numBatches > 0");
|
TLLM_CHECK_WITH_INFO(numBatches > 0, "Batched GEMM requires numBatches > 0");
|
||||||
if (!mOptions.staticBatch)
|
if (!mOptions.staticBatch)
|
||||||
@ -170,7 +182,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
|
|||||||
|
|
||||||
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
|
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
|
||||||
void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace,
|
void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace,
|
||||||
CUstream stream, int device)
|
CUstream stream, int device, std::optional<int32_t> configIndex)
|
||||||
{
|
{
|
||||||
// Dispatch with block scaling factors and with static batching.
|
// Dispatch with block scaling factors and with static batching.
|
||||||
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
|
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB,
|
||||||
@ -178,12 +190,12 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
|
|||||||
/* scaleC */ nullptr, /* scaleGateC */ nullptr, c, outSfC,
|
/* scaleC */ nullptr, /* scaleGateC */ nullptr, c, outSfC,
|
||||||
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
|
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
|
||||||
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
|
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
|
||||||
/* numNonExitingCtas */ nullptr, workspace, stream, device);
|
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
|
void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens,
|
||||||
void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace,
|
void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace,
|
||||||
CUstream stream, int device)
|
CUstream stream, int device, std::optional<int32_t> configIndex)
|
||||||
{
|
{
|
||||||
// Dispatch with block scaling factors and with static batching.
|
// Dispatch with block scaling factors and with static batching.
|
||||||
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
|
run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a,
|
||||||
@ -191,11 +203,12 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
|
|||||||
scaleGateC, c, /* outSfC */ nullptr,
|
scaleGateC, c, /* outSfC */ nullptr,
|
||||||
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
|
/* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr,
|
||||||
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
|
/* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr,
|
||||||
/* numNonExitingCtas */ nullptr, workspace, stream, device);
|
/* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k,
|
std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m, int32_t n, int32_t k,
|
||||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
|
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
|
||||||
|
int32_t maxNumCtasInBatchDim) const
|
||||||
{
|
{
|
||||||
auto const bmm = BatchedGemmInterface();
|
auto const bmm = BatchedGemmInterface();
|
||||||
auto const configs = bmm.getBatchedGemmConfigs();
|
auto const configs = bmm.getBatchedGemmConfigs();
|
||||||
@ -242,16 +255,30 @@ void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t
|
|||||||
return optionsA.mTileM > optionsB.mTileM;
|
return optionsA.mTileM > optionsB.mTileM;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
std::vector<int64_t> validConfigIndices;
|
||||||
for (auto const& configIndex : sortedIndices)
|
for (auto const& configIndex : sortedIndices)
|
||||||
{
|
{
|
||||||
auto const& config = configs[configIndex];
|
auto const& config = configs[configIndex];
|
||||||
auto isValidConfig = bmm.isValidConfig(config, gemmData);
|
auto isValidConfig = bmm.isValidConfig(config, gemmData);
|
||||||
if (isValidConfig)
|
if (isValidConfig)
|
||||||
{
|
{
|
||||||
mSelectedConfigIndex = configIndex;
|
validConfigIndices.push_back(configIndex);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TLLM_CHECK_WITH_INFO(!validConfigIndices.empty(), "No valid config found for the given problem shape");
|
||||||
|
|
||||||
|
return validConfigIndices;
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
|
||||||
|
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
|
||||||
|
int32_t maxNumCtasInBatchDim) const
|
||||||
|
{
|
||||||
|
auto const validConfigIndices
|
||||||
|
= getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
||||||
|
|
||||||
|
return validConfigIndices[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace kernels
|
} // namespace kernels
|
||||||
|
|||||||
@ -16,8 +16,10 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
|
#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
|
||||||
|
|
||||||
@ -45,20 +47,40 @@ public:
|
|||||||
explicit TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunnerOptions const& options);
|
explicit TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunnerOptions const& options);
|
||||||
|
|
||||||
[[nodiscard]] size_t getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
|
[[nodiscard]] size_t getWorkspaceSizeInBytes(int32_t m, int32_t n, int32_t k,
|
||||||
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim);
|
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
|
||||||
|
std::optional<int32_t> configIndex = std::nullopt);
|
||||||
|
|
||||||
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
|
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
|
||||||
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
|
int32_t numBatches, int32_t maxNumCtasInBatchDim, void const* a, void const* sfA, void const* b,
|
||||||
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
|
void const* sfB, void const* perTokensSfA, void const* perTokensSfB, float const* scaleC,
|
||||||
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
|
float const* scaleGateC, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens,
|
||||||
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
|
int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas,
|
||||||
void* workspace, CUstream stream, int device);
|
void* workspace, CUstream stream, int device, std::optional<int32_t> configIndex = std::nullopt);
|
||||||
|
|
||||||
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
|
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* sfA,
|
||||||
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device);
|
void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device,
|
||||||
|
std::optional<int32_t> configIndex = std::nullopt);
|
||||||
|
|
||||||
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
|
void run(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, void const* a, void const* b,
|
||||||
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device);
|
float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device,
|
||||||
|
std::optional<int32_t> configIndex = std::nullopt);
|
||||||
|
|
||||||
|
// Get the list of configs that passed the validation based on the constructor options
|
||||||
|
[[nodiscard]] std::vector<int32_t> getPassingConfigIndices() const
|
||||||
|
{
|
||||||
|
return mPassingConfigIndices;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the list of config indices that are valid for the given problem shape
|
||||||
|
[[nodiscard]] std::vector<int64_t> getValidConfigIndices(int32_t m, int32_t n, int32_t k,
|
||||||
|
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
|
||||||
|
int32_t maxNumCtasInBatchDim) const;
|
||||||
|
|
||||||
|
// Get a default config index that is valid for the given problem shape
|
||||||
|
// This will be used as the fallback config if using auto-tuning
|
||||||
|
[[nodiscard]] int64_t getDefaultValidConfigIndex(int32_t m, int32_t n, int32_t k,
|
||||||
|
std::vector<int32_t> const& batchedTokens, int32_t numTokens, int32_t numBatches,
|
||||||
|
int32_t maxNumCtasInBatchDim) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
|
void selectGemmConfig(int32_t m, int32_t n, int32_t k, std::vector<int32_t> const& batchedTokens, int32_t numTokens,
|
||||||
@ -66,8 +88,8 @@ private:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
TrtllmGenBatchedGemmRunnerOptions mOptions;
|
TrtllmGenBatchedGemmRunnerOptions mOptions;
|
||||||
std::optional<int> mSelectedConfigIndex;
|
|
||||||
std::vector<int32_t> mPassingConfigIndices;
|
std::vector<int32_t> mPassingConfigIndices;
|
||||||
|
std::optional<int32_t> mSelectedConfigIndex;
|
||||||
};
|
};
|
||||||
} // namespace kernels
|
} // namespace kernels
|
||||||
} // namespace tensorrt_llm
|
} // namespace tensorrt_llm
|
||||||
|
|||||||
@ -26,6 +26,8 @@
|
|||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
@ -35,27 +37,16 @@ template <tg::Dtype outDtype>
|
|||||||
void runBatchedGemm(at::Tensor& out, at::Tensor& outSfC, at::Tensor const& mat1, at::Tensor const& mat2,
|
void runBatchedGemm(at::Tensor& out, at::Tensor& outSfC, at::Tensor const& mat1, at::Tensor const& mat2,
|
||||||
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
||||||
std::optional<at::Tensor> const& scaleC, int64_t m, int64_t n, int64_t k, int32_t tileSize, int32_t epilogueTileM,
|
std::optional<at::Tensor> const& scaleC, int64_t m, int64_t n, int64_t k, int32_t tileSize, int32_t epilogueTileM,
|
||||||
std::vector<int32_t> const& batchedTokens, bool useDeepSeekFp8, bool lowLatencyKernel)
|
std::vector<int32_t> const& batchedTokens, bool useDeepSeekFp8, bool lowLatencyKernel,
|
||||||
|
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner& runner, int32_t const configIndex)
|
||||||
{
|
{
|
||||||
auto eltType = tg::Dtype::E4m3;
|
|
||||||
|
|
||||||
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = {.eltType = eltType,
|
|
||||||
.outputType = outDtype,
|
|
||||||
.deepSeekFp8 = useDeepSeekFp8,
|
|
||||||
.fusedAct = false,
|
|
||||||
.routeAct = false,
|
|
||||||
.staticBatch = true,
|
|
||||||
.transposeMmaOutput = lowLatencyKernel,
|
|
||||||
.tileSize = tileSize,
|
|
||||||
.epilogueTileM = epilogueTileM};
|
|
||||||
|
|
||||||
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner runner(options);
|
|
||||||
|
|
||||||
// numTokens and maxNumCtasInBatchDim are not used for static batching
|
// numTokens and maxNumCtasInBatchDim are not used for static batching
|
||||||
int32_t numTokens = 0;
|
int32_t const numTokens = 0;
|
||||||
int32_t maxNumCtasInBatchDim = 0;
|
int32_t const maxNumCtasInBatchDim = 0;
|
||||||
int64_t const numBytesWorkspace
|
|
||||||
= runner.getWorkspaceSizeInBytes(m, n, k, batchedTokens, numTokens, batchedTokens.size(), maxNumCtasInBatchDim);
|
int64_t const numBytesWorkspace = runner.getWorkspaceSizeInBytes(
|
||||||
|
m, n, k, batchedTokens, numTokens, batchedTokens.size(), maxNumCtasInBatchDim, configIndex);
|
||||||
at::Tensor workspace
|
at::Tensor workspace
|
||||||
= at::detail::empty_cuda({numBytesWorkspace}, at::ScalarType::Char, mat1.device(), std::nullopt);
|
= at::detail::empty_cuda({numBytesWorkspace}, at::ScalarType::Char, mat1.device(), std::nullopt);
|
||||||
|
|
||||||
@ -66,20 +57,21 @@ void runBatchedGemm(at::Tensor& out, at::Tensor& outSfC, at::Tensor const& mat1,
|
|||||||
float* outSfCPtr = outDtype == tg::Dtype::E4m3 ? outSfC.data_ptr<float>() : nullptr;
|
float* outSfCPtr = outDtype == tg::Dtype::E4m3 ? outSfC.data_ptr<float>() : nullptr;
|
||||||
runner.run(m, n, k, batchedTokens, mat1.const_data_ptr(), dDqSfsA.value().const_data_ptr(),
|
runner.run(m, n, k, batchedTokens, mat1.const_data_ptr(), dDqSfsA.value().const_data_ptr(),
|
||||||
mat2.const_data_ptr(), dDqSfsB.value().const_data_ptr(), out.data_ptr(), outSfCPtr, workspace.data_ptr(),
|
mat2.const_data_ptr(), dDqSfsB.value().const_data_ptr(), out.data_ptr(), outSfCPtr, workspace.data_ptr(),
|
||||||
stream.stream(), mat1.get_device());
|
stream.stream(), mat1.get_device(), configIndex);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
runner.run(m, n, k, batchedTokens, mat1.const_data_ptr(), mat2.const_data_ptr(),
|
runner.run(m, n, k, batchedTokens, mat1.const_data_ptr(), mat2.const_data_ptr(),
|
||||||
reinterpret_cast<float const*>(scaleC.value().const_data_ptr()), nullptr, out.data_ptr(),
|
reinterpret_cast<float const*>(scaleC.value().const_data_ptr()), nullptr, out.data_ptr(),
|
||||||
workspace.data_ptr(), stream.stream(), mat1.get_device());
|
workspace.data_ptr(), stream.stream(), mat1.get_device(), configIndex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1, at::Tensor const& mat2,
|
std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1, at::Tensor const& mat2,
|
||||||
int32_t tileSize, bool useDeepSeekFp8, bool lowLatencyKernel, int64_t epilogueTileM,
|
int32_t tileSize, bool useDeepSeekFp8, bool lowLatencyKernel, int64_t epilogueTileM,
|
||||||
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
||||||
std::optional<at::Tensor> const& scaleC, std::optional<c10::ScalarType> outDtype)
|
std::optional<at::Tensor> const& scaleC, std::optional<c10::ScalarType> outDtype,
|
||||||
|
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner& runner, int32_t const configIndex)
|
||||||
{
|
{
|
||||||
TORCH_CHECK(mat1.dim() == 3, "Matrix A must be of size [B, M, K]");
|
TORCH_CHECK(mat1.dim() == 3, "Matrix A must be of size [B, M, K]");
|
||||||
TORCH_CHECK(mat2.dim() == 3, "Matrix B must be of size [B, N, K]");
|
TORCH_CHECK(mat2.dim() == 3, "Matrix B must be of size [B, N, K]");
|
||||||
@ -144,16 +136,19 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1
|
|||||||
TORCH_CHECK(scaleC.value().sizes()[0] == b, "outScalingFactor must be a 1D matrix of size B");
|
TORCH_CHECK(scaleC.value().sizes()[0] == b, "outScalingFactor must be a 1D matrix of size B");
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t outputN = n;
|
int64_t const outputN = n;
|
||||||
|
|
||||||
// Create output tensor.
|
// Create output tensor.
|
||||||
at::Tensor out = at::detail::empty_cuda({b, m, outputN}, outDtype.value(), mat1.device(), std::nullopt);
|
at::Tensor out = at::detail::empty_cuda({b, m, outputN}, outDtype.value(), mat1.device(), std::nullopt);
|
||||||
at::Tensor outSfC;
|
|
||||||
if (useDeepSeekFp8 && outDtype.value() == at::ScalarType::Float8_e4m3fn)
|
bool const needOutSfC = useDeepSeekFp8 && outDtype.value() == at::ScalarType::Float8_e4m3fn;
|
||||||
{
|
|
||||||
outSfC = at::detail::empty_cuda(
|
// Torch class did not support returning a default tensor so using empty instead.
|
||||||
{outputN / dsFp8QuantBlockSize, m * b}, at::ScalarType::Float, mat1.device(), std::nullopt);
|
int64_t const outSfCSize0 = needOutSfC ? (outputN / dsFp8QuantBlockSize) : 0;
|
||||||
}
|
int64_t const outSfCSize1 = needOutSfC ? (m * b) : 0;
|
||||||
|
|
||||||
|
at::Tensor outSfC
|
||||||
|
= at::detail::empty_cuda({outSfCSize0, outSfCSize1}, at::ScalarType::Float, mat1.device(), std::nullopt);
|
||||||
|
|
||||||
std::vector<int32_t> batchedTokens(b, m);
|
std::vector<int32_t> batchedTokens(b, m);
|
||||||
|
|
||||||
@ -161,15 +156,15 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1
|
|||||||
{
|
{
|
||||||
case at::ScalarType::Half:
|
case at::ScalarType::Half:
|
||||||
runBatchedGemm<tg::Dtype::Fp16>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
|
runBatchedGemm<tg::Dtype::Fp16>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
|
||||||
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel);
|
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel, runner, configIndex);
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::BFloat16:
|
case at::ScalarType::BFloat16:
|
||||||
runBatchedGemm<tg::Dtype::Bfloat16>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
|
runBatchedGemm<tg::Dtype::Bfloat16>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
|
||||||
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel);
|
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel, runner, configIndex);
|
||||||
break;
|
break;
|
||||||
case at::ScalarType::Float8_e4m3fn:
|
case at::ScalarType::Float8_e4m3fn:
|
||||||
runBatchedGemm<tg::Dtype::E4m3>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
|
runBatchedGemm<tg::Dtype::E4m3>(out, outSfC, mat1, mat2, dDqSfsA, dDqSfsB, scaleC, m, n, k, tileSize,
|
||||||
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel);
|
epilogueTileM, batchedTokens, useDeepSeekFp8, lowLatencyKernel, runner, configIndex);
|
||||||
break;
|
break;
|
||||||
default: C10_THROW_ERROR(NotImplementedError, "outDtype must be one of fp16/bf16/e4m3.");
|
default: C10_THROW_ERROR(NotImplementedError, "outDtype must be one of fp16/bf16/e4m3.");
|
||||||
}
|
}
|
||||||
@ -181,35 +176,101 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_sm100(at::Tensor const& mat1
|
|||||||
namespace torch_ext
|
namespace torch_ext
|
||||||
{
|
{
|
||||||
|
|
||||||
extern std::tuple<at::Tensor, at::Tensor> fp8_batched_gemm_trtllmgen(at::Tensor const& mat1, at::Tensor const& mat2,
|
// Wrapped the TRTLLM-Gen kernel runner in a Torch custom class to allow
|
||||||
int64_t tileSize, bool useDeepSeekFp8, bool lowLatency, int64_t epilogueTileM,
|
// use with the torch workflow autotuner class.
|
||||||
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
class FP8BatchedGemmRunner : public torch::CustomClassHolder
|
||||||
std::optional<at::Tensor> const& scaleC, std::optional<c10::ScalarType> outDtype)
|
|
||||||
{
|
{
|
||||||
auto const smVersion = tensorrt_llm::common::getSMVersion();
|
|
||||||
switch (smVersion)
|
public:
|
||||||
|
explicit FP8BatchedGemmRunner(c10::ScalarType outDtypeArg, bool useDeepSeekFp8, bool lowLatencyKernel,
|
||||||
|
int64_t tileSize, int64_t epilogueTileM)
|
||||||
|
: mOutDtypeArg(outDtypeArg)
|
||||||
|
, mUseDeepSeekFp8(useDeepSeekFp8)
|
||||||
|
, mLowLatencyKernel(lowLatencyKernel)
|
||||||
|
, mTileSize(tileSize)
|
||||||
|
, mEpilogueTileM(epilogueTileM)
|
||||||
{
|
{
|
||||||
case tensorrt_llm::kernels::kSM_100:
|
|
||||||
|
auto const smVersion = tensorrt_llm::common::getSMVersion();
|
||||||
|
if (smVersion != tensorrt_llm::kernels::kSM_100)
|
||||||
|
{
|
||||||
|
TLLM_THROW("Unsupported or unimplemented compute capability for fp8 batched gemm: %i", smVersion);
|
||||||
|
}
|
||||||
|
|
||||||
|
tg::Dtype outDtype = tg::Dtype::E4m3; // Default to E4m3, will be updated based on outDtypeArg
|
||||||
|
|
||||||
|
switch (outDtypeArg)
|
||||||
|
{
|
||||||
|
case at::ScalarType::Half: outDtype = tg::Dtype::Fp16; break;
|
||||||
|
case at::ScalarType::BFloat16: outDtype = tg::Dtype::Bfloat16; break;
|
||||||
|
case at::ScalarType::Float8_e4m3fn: outDtype = tg::Dtype::E4m3; break;
|
||||||
|
default: C10_THROW_ERROR(NotImplementedError, "outDtype must be one of fp16/bf16/e4m3.");
|
||||||
|
}
|
||||||
|
|
||||||
|
RunnerOptionsType const options = {.eltType = mEltType,
|
||||||
|
.outputType = outDtype,
|
||||||
|
.deepSeekFp8 = mUseDeepSeekFp8,
|
||||||
|
.fusedAct = false,
|
||||||
|
.routeAct = false,
|
||||||
|
.staticBatch = true,
|
||||||
|
.transposeMmaOutput = mLowLatencyKernel,
|
||||||
|
.tileSize = static_cast<int32_t>(mTileSize),
|
||||||
|
.epilogueTileM = static_cast<int32_t>(mEpilogueTileM)};
|
||||||
|
|
||||||
|
mRunner = std::make_unique<RunnerType>(options);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::tuple<at::Tensor, at::Tensor> runBatchedGemm(at::Tensor const& mat1, at::Tensor const& mat2,
|
||||||
|
std::optional<at::Tensor> const& dDqSfsA, std::optional<at::Tensor> const& dDqSfsB,
|
||||||
|
std::optional<at::Tensor> const& scaleC, int64_t configIndex)
|
||||||
{
|
{
|
||||||
return fp8_batched_gemm_sm100(
|
|
||||||
mat1, mat2, tileSize, useDeepSeekFp8, lowLatency, epilogueTileM, dDqSfsA, dDqSfsB, scaleC, outDtype);
|
return fp8_batched_gemm_sm100(mat1, mat2, mTileSize, mUseDeepSeekFp8, mLowLatencyKernel, mEpilogueTileM,
|
||||||
|
dDqSfsA, dDqSfsB, scaleC, mOutDtypeArg, *mRunner, configIndex);
|
||||||
}
|
}
|
||||||
default: TLLM_THROW("Unsupported or unimplemented compute capability for fp8 batched gemm: %i", smVersion);
|
|
||||||
|
std::vector<int64_t> getValidConfigs(int64_t numBatches, int64_t m, int64_t n, int64_t k) const
|
||||||
|
{
|
||||||
|
// numTokens and maxNumCtasInBatchDim are not used for static batching
|
||||||
|
int32_t const numTokens = 0;
|
||||||
|
int32_t const maxNumCtasInBatchDim = 0;
|
||||||
|
|
||||||
|
std::vector<int32_t> const batchedTokens(numBatches, m);
|
||||||
|
|
||||||
|
return mRunner->getValidConfigIndices(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
int64_t getDefaultValidConfigIndex(int64_t numBatches, int64_t m, int64_t n, int64_t k) const
|
||||||
|
{
|
||||||
|
// numTokens and maxNumCtasInBatchDim are not used for static batching
|
||||||
|
int32_t const numTokens = 0;
|
||||||
|
int32_t const maxNumCtasInBatchDim = 0;
|
||||||
|
|
||||||
|
std::vector<int32_t> const batchedTokens(numBatches, m);
|
||||||
|
|
||||||
|
return mRunner->getDefaultValidConfigIndex(m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
using RunnerType = tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner;
|
||||||
|
using RunnerOptionsType = tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions;
|
||||||
|
|
||||||
|
std::unique_ptr<RunnerType> mRunner;
|
||||||
|
tg::Dtype mEltType{tg::Dtype::E4m3};
|
||||||
|
c10::ScalarType mOutDtypeArg;
|
||||||
|
bool mUseDeepSeekFp8;
|
||||||
|
bool mLowLatencyKernel;
|
||||||
|
int64_t mTileSize;
|
||||||
|
int64_t mEpilogueTileM;
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace torch_ext
|
} // namespace torch_ext
|
||||||
|
|
||||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||||
{
|
{
|
||||||
m.def(
|
m.class_<torch_ext::FP8BatchedGemmRunner>("FP8BatchedGemmRunner")
|
||||||
"fp8_batched_gemm_trtllmgen(Tensor a, Tensor b, int tile_size,"
|
.def(torch::init<at::ScalarType, bool, bool, int64_t, int64_t>())
|
||||||
"bool use_deep_seek_fp8=False, bool low_latency=False, "
|
.def("get_valid_configs", &torch_ext::FP8BatchedGemmRunner::getValidConfigs)
|
||||||
"int epilogue_tile_m=0, Tensor? dq_sfs_a=None, Tensor? dq_sfs_b=None, "
|
.def("get_default_valid_config", &torch_ext::FP8BatchedGemmRunner::getDefaultValidConfigIndex)
|
||||||
"Tensor? scale_c=None, "
|
.def("run_batched_gemm", &torch_ext::FP8BatchedGemmRunner::runBatchedGemm);
|
||||||
"ScalarType? out_dtype=None) -> (Tensor, Tensor)");
|
|
||||||
}
|
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
|
||||||
{
|
|
||||||
m.impl("fp8_batched_gemm_trtllmgen", &torch_ext::fp8_batched_gemm_trtllmgen);
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -340,7 +340,12 @@ class AutoTuner:
|
|||||||
Although runners[0] with tactic=-1 is always treated as the fallback runner.
|
Although runners[0] with tactic=-1 is always treated as the fallback runner.
|
||||||
Runner authors are suggested to provide a fallback implementation for each runner to avoid potential issues.
|
Runner authors are suggested to provide a fallback implementation for each runner to avoid potential issues.
|
||||||
"""
|
"""
|
||||||
input_shapes = tuple(t.shape for t in inputs)
|
|
||||||
|
# Treat None tensors as size zero
|
||||||
|
# This allows the tuner to handle TRT-LLM-Gen torch ops that have optional tensor
|
||||||
|
# arguments, such as block scaling factors.
|
||||||
|
input_shapes = tuple(
|
||||||
|
(t.shape if t is not None else torch.Size((0, ))) for t in inputs)
|
||||||
|
|
||||||
# Early return if it's not tuning, use cache found one or fallback one
|
# Early return if it's not tuning, use cache found one or fallback one
|
||||||
if not self.is_tuning_mode:
|
if not self.is_tuning_mode:
|
||||||
@ -393,8 +398,14 @@ class AutoTuner:
|
|||||||
time_measured = self._profile_single_kernel(
|
time_measured = self._profile_single_kernel(
|
||||||
r, tensors, tac, **kwargs)
|
r, tensors, tac, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# Handle None tensors for optional inputs
|
||||||
|
shapes = [
|
||||||
|
t.size() if t is not None else torch.Size((0, ))
|
||||||
|
for t in tensors
|
||||||
|
]
|
||||||
|
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={[t.size() for t in tensors]}. Error occurred: {e}"
|
f"[Autotuner]: Failed when profiling {r} {tac}, shapes={shapes}. Error occurred: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Record the failed profiling combinations
|
# Record the failed profiling combinations
|
||||||
@ -471,8 +482,13 @@ class AutoTuner:
|
|||||||
stream.synchronize()
|
stream.synchronize()
|
||||||
|
|
||||||
avg_time = start.elapsed_time(end) / self.repeat
|
avg_time = start.elapsed_time(end) / self.repeat
|
||||||
|
|
||||||
|
# Handle None tensors for optional inputs
|
||||||
|
shapes = [
|
||||||
|
t.size() if t is not None else torch.Size((0, )) for t in inputs
|
||||||
|
]
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"[Autotuner]: profiling {runner} {tactic}, shapes={[t.size() for t in inputs]}, avg_time {avg_time}"
|
f"[Autotuner]: profiling {runner} {tactic}, shapes={shapes}, avg_time {avg_time}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return avg_time
|
return avg_time
|
||||||
@ -494,9 +510,13 @@ class AutoTuner:
|
|||||||
combinations specified in dynamic_tensor_specs.
|
combinations specified in dynamic_tensor_specs.
|
||||||
"""
|
"""
|
||||||
# every dimension created from the concrete input tensor shape
|
# every dimension created from the concrete input tensor shape
|
||||||
# generate some dynamic dimension description based on the dynamic_tensor_specs
|
# generate some dynamic dimension description based on the dynamic_tensors
|
||||||
base_profile = OptimizationProfile([[StaticDim(x) for x in t.size()]
|
|
||||||
for t in inputs])
|
# Zero handles the case where a TRTLLM op has optional inputs.
|
||||||
|
base_profile = OptimizationProfile(
|
||||||
|
[[StaticDim(x)
|
||||||
|
for x in t.size()] if t is not None else [StaticDim(0)]
|
||||||
|
for t in inputs])
|
||||||
|
|
||||||
generated_profiles: List[OptimizationProfile] = []
|
generated_profiles: List[OptimizationProfile] = []
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -340,6 +340,218 @@ def _(
|
|||||||
dtype=output_dtype)
|
dtype=output_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class FP8BatchedGemmRunner(TunableRunner):
|
||||||
|
|
||||||
|
_runner_dict = dict()
|
||||||
|
|
||||||
|
def __init__(self, output_dtype: torch.dtype, use_deep_seek_fp8: bool,
|
||||||
|
low_latency_kernel: bool, tile_size: int,
|
||||||
|
epilogue_tile_m: int):
|
||||||
|
|
||||||
|
self.output_dtype = output_dtype
|
||||||
|
self.use_deep_seek_fp8 = use_deep_seek_fp8
|
||||||
|
self.low_latency_kernel = low_latency_kernel
|
||||||
|
self.tile_size = tile_size
|
||||||
|
self.epilogue_tile_m = epilogue_tile_m
|
||||||
|
self.tuning_config = self.get_tuning_config()
|
||||||
|
|
||||||
|
instance_key = (output_dtype, use_deep_seek_fp8, low_latency_kernel,
|
||||||
|
tile_size, epilogue_tile_m)
|
||||||
|
|
||||||
|
if instance_key not in FP8BatchedGemmRunner._runner_dict:
|
||||||
|
FP8BatchedGemmRunner._runner_dict[
|
||||||
|
instance_key] = torch.classes.trtllm.FP8BatchedGemmRunner(
|
||||||
|
output_dtype, use_deep_seek_fp8, low_latency_kernel,
|
||||||
|
tile_size, epilogue_tile_m)
|
||||||
|
|
||||||
|
self._kernel_runner = FP8BatchedGemmRunner._runner_dict[instance_key]
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
inputs: List[torch.Tensor],
|
||||||
|
tactic: int = -1,
|
||||||
|
do_preparation: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Run the batched GEMM operation with the given inputs and tactic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
mat1, mat2, dq_sfs_a, dq_sfs_b, scale_c = inputs
|
||||||
|
|
||||||
|
chosen_tactic = self.get_default_valid_tactic(
|
||||||
|
inputs) if tactic == -1 else tactic
|
||||||
|
|
||||||
|
out_tensors = self._kernel_runner.run_batched_gemm(
|
||||||
|
mat1,
|
||||||
|
mat2,
|
||||||
|
dq_sfs_a,
|
||||||
|
dq_sfs_b,
|
||||||
|
scale_c,
|
||||||
|
chosen_tactic,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out_tensors
|
||||||
|
|
||||||
|
def get_valid_tactics(
|
||||||
|
self,
|
||||||
|
inputs: List[torch.Tensor],
|
||||||
|
profile: OptimizationProfile,
|
||||||
|
) -> List[int]:
|
||||||
|
|
||||||
|
mat1, mat2, _, _, _ = inputs
|
||||||
|
|
||||||
|
b = mat1.shape[0]
|
||||||
|
m = mat1.shape[1]
|
||||||
|
n = mat2.shape[1]
|
||||||
|
k = mat1.shape[2]
|
||||||
|
|
||||||
|
tactics = self._kernel_runner.get_valid_configs(b, m, n, k)
|
||||||
|
|
||||||
|
return tactics
|
||||||
|
|
||||||
|
def get_default_valid_tactic(
|
||||||
|
self,
|
||||||
|
inputs: List[torch.Tensor],
|
||||||
|
) -> int:
|
||||||
|
|
||||||
|
mat1, mat2, _, _, _ = inputs
|
||||||
|
|
||||||
|
b = mat1.shape[0]
|
||||||
|
m = mat1.shape[1]
|
||||||
|
n = mat2.shape[1]
|
||||||
|
k = mat1.shape[2]
|
||||||
|
|
||||||
|
default_tactic = self._kernel_runner.get_default_valid_config(
|
||||||
|
b, m, n, k)
|
||||||
|
|
||||||
|
return default_tactic
|
||||||
|
|
||||||
|
def get_dynamic_tensor_specs(self) -> Tuple[DynamicTensorSpec, ...]:
|
||||||
|
"""Get the dynamic tensor specs for use with the AutoTuner."""
|
||||||
|
|
||||||
|
# These indices correspond to the 0th input tensor and it's first dimension
|
||||||
|
# i.e. we are tuning M where the first input tensor is of shape [B, M, K]
|
||||||
|
|
||||||
|
MAT1_IDX = 0
|
||||||
|
TUNED_DIM = 1
|
||||||
|
|
||||||
|
# Starting at 8 as M % tile size == 0 is required
|
||||||
|
m_values = (8, 16, 32, 64, 128, 256, 512, 1024, 2048)
|
||||||
|
round_rule = lambda x: last_positive_power_of_2(x)
|
||||||
|
|
||||||
|
specs = (DynamicTensorSpec(MAT1_IDX, TUNED_DIM, m_values, round_rule), )
|
||||||
|
|
||||||
|
return specs
|
||||||
|
|
||||||
|
def get_constraint_specs(self) -> Tuple[ConstraintSpec, ...]:
|
||||||
|
"""Get the constraint specs for the dynamic tensors for use with the AutoTuner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# When using deepseek fp8, the dq_sfs_a and dq_sfs_b tensors are expected to
|
||||||
|
# have specific dimensions. As we are only tuning M, we need only constrain
|
||||||
|
# dimension 1 of dq_sfs_a
|
||||||
|
if not self.use_deep_seek_fp8:
|
||||||
|
constraint_dq_sfs_a = ()
|
||||||
|
else:
|
||||||
|
|
||||||
|
def _constrain_dq_sfs_a_dim1(shapes: Tuple[torch.Size]) -> int:
|
||||||
|
b = shapes[0][0]
|
||||||
|
m = shapes[0][1]
|
||||||
|
|
||||||
|
m_padded = (m + self.tile_size - 1) // self.tile_size
|
||||||
|
result = m_padded * self.tile_size * b
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
SFS_A_IDX = 2
|
||||||
|
CONSTRAINED_DIM = 1
|
||||||
|
|
||||||
|
constraint_dq_sfs_a = (ConstraintSpec(SFS_A_IDX, CONSTRAINED_DIM,
|
||||||
|
_constrain_dq_sfs_a_dim1), )
|
||||||
|
|
||||||
|
return constraint_dq_sfs_a
|
||||||
|
|
||||||
|
def get_tuning_config(self) -> TuningConfig:
|
||||||
|
"""Get the tuning configuration for the AutoTuner."""
|
||||||
|
|
||||||
|
dynamic_tensor_specs = self.get_dynamic_tensor_specs()
|
||||||
|
constraint_specs = self.get_constraint_specs()
|
||||||
|
|
||||||
|
tuning_config = TuningConfig(dynamic_tensor_specs=dynamic_tensor_specs,
|
||||||
|
constraint_specs=constraint_specs)
|
||||||
|
|
||||||
|
return tuning_config
|
||||||
|
|
||||||
|
|
||||||
|
@torch.library.custom_op("trtllm::fp8_batched_gemm_trtllmgen", mutates_args=())
|
||||||
|
def fp8_batched_gemm_trtllmgen(
|
||||||
|
mat1: torch.Tensor,
|
||||||
|
mat2: torch.Tensor,
|
||||||
|
tile_size: int,
|
||||||
|
use_deep_seek_fp8: Optional[bool] = False,
|
||||||
|
low_latency: Optional[bool] = False,
|
||||||
|
epilogue_tile_m: Optional[int] = 0,
|
||||||
|
dq_sfs_a: Optional[torch.Tensor] = None,
|
||||||
|
dq_sfs_b: Optional[torch.Tensor] = None,
|
||||||
|
scale_c: Optional[torch.Tensor] = None,
|
||||||
|
out_dtype: Optional[torch.dtype] = torch.half
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
kernel_runner = FP8BatchedGemmRunner(output_dtype=out_dtype,
|
||||||
|
use_deep_seek_fp8=use_deep_seek_fp8,
|
||||||
|
low_latency_kernel=low_latency,
|
||||||
|
tile_size=tile_size,
|
||||||
|
epilogue_tile_m=epilogue_tile_m)
|
||||||
|
|
||||||
|
tuner = AutoTuner.get()
|
||||||
|
|
||||||
|
inputs = [mat1, mat2, dq_sfs_a, dq_sfs_b, scale_c]
|
||||||
|
|
||||||
|
_, best_tactic = tuner.choose_one(
|
||||||
|
"trtllm::fp8_batched_gemm_trtllmgen::batched_gemm",
|
||||||
|
[kernel_runner],
|
||||||
|
kernel_runner.tuning_config,
|
||||||
|
inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return kernel_runner(
|
||||||
|
inputs=inputs,
|
||||||
|
tactic=best_tactic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Allows the tunable TRTLLM-Gen FP8 batched GEMM to be
|
||||||
|
# used with torch.compile
|
||||||
|
@fp8_batched_gemm_trtllmgen.register_fake
|
||||||
|
def _(
|
||||||
|
mat1: torch.Tensor,
|
||||||
|
mat2: torch.Tensor,
|
||||||
|
tile_size: int,
|
||||||
|
use_deep_seek_fp8: Optional[bool] = False,
|
||||||
|
low_latency: Optional[bool] = False,
|
||||||
|
epilogue_tile_m: Optional[int] = 0,
|
||||||
|
dq_sfs_a: Optional[torch.Tensor] = None,
|
||||||
|
dq_sfs_b: Optional[torch.Tensor] = None,
|
||||||
|
scale_c: Optional[torch.Tensor] = None,
|
||||||
|
out_dtype: Optional[torch.dtype] = None
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|
||||||
|
b = mat1.size(0)
|
||||||
|
m = mat1.size(1)
|
||||||
|
n = mat2.size(1)
|
||||||
|
|
||||||
|
fake_out = mat1.new_empty((b, m, n), dtype=out_dtype)
|
||||||
|
|
||||||
|
if use_deep_seek_fp8:
|
||||||
|
ds_fp8_quant_block_size = 128
|
||||||
|
dim0_size = n // ds_fp8_quant_block_size
|
||||||
|
dim1_size = b * m
|
||||||
|
fake_dq_sfs_c = torch.empty((dim0_size, dim1_size), dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
fake_dq_sfs_c = torch.empty((0, 0), dtype=torch.float32)
|
||||||
|
|
||||||
|
return (fake_out, fake_dq_sfs_c)
|
||||||
|
|
||||||
|
|
||||||
@torch.library.custom_op("trtllm::attention", mutates_args=())
|
@torch.library.custom_op("trtllm::attention", mutates_args=())
|
||||||
def attention(
|
def attention(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
|
|||||||
@ -20,6 +20,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from utils.util import getSMVersion
|
from utils.util import getSMVersion
|
||||||
|
|
||||||
|
from tensorrt_llm._torch.autotuner import autotune
|
||||||
from tensorrt_llm.quantization.utils.fp4_utils import shuffle_matrix_a
|
from tensorrt_llm.quantization.utils.fp4_utils import shuffle_matrix_a
|
||||||
|
|
||||||
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
@ -246,98 +247,199 @@ def quant_fp8(x_fp32: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_fp8_batched_gemm_trtllmgen(test_case: BatchedGemmTestCase) -> None:
|
class TestFP8BatchedGemmTRTLLMGen:
|
||||||
torch.random.manual_seed(42)
|
|
||||||
|
|
||||||
b = test_case.b
|
def test_thop(self, test_case: BatchedGemmTestCase) -> None:
|
||||||
m = test_case.m
|
torch.random.manual_seed(42)
|
||||||
n = test_case.n
|
|
||||||
k = test_case.k
|
|
||||||
use_deep_seek_fp8 = test_case.use_deep_seek_fp8
|
|
||||||
dtype_c = test_case.dtype_c
|
|
||||||
low_latency = test_case.low_latency
|
|
||||||
tile_size = test_case.tile_size
|
|
||||||
|
|
||||||
a_fp32 = torch.randn((b, m, k), device="cuda", dtype=torch.float32)
|
b = test_case.b
|
||||||
b_fp32 = torch.randn((b, n, k), device="cuda", dtype=torch.float32)
|
m = test_case.m
|
||||||
# Pad to the tile size. It is needed for the TRT-LLM Gen BMM input requirements.
|
n = test_case.n
|
||||||
if m % tile_size:
|
k = test_case.k
|
||||||
tiled_shape = ((m + tile_size - 1) // tile_size) * tile_size
|
use_deep_seek_fp8 = test_case.use_deep_seek_fp8
|
||||||
a_fp32 = torch.nn.functional.pad(a_fp32, (0, 0, 0, tiled_shape - m),
|
dtype_c = test_case.dtype_c
|
||||||
"constant", 0)
|
low_latency = test_case.low_latency
|
||||||
m_padded = ((m + tile_size - 1) // tile_size) * tile_size
|
tile_size = test_case.tile_size
|
||||||
|
|
||||||
dq_sf_a = None
|
a_fp32 = torch.randn((b, m, k), device="cuda", dtype=torch.float32)
|
||||||
dq_sf_b = None
|
b_fp32 = torch.randn((b, n, k), device="cuda", dtype=torch.float32)
|
||||||
global_dq_a = None
|
# Pad to the tile size. It is needed for the TRT-LLM Gen BMM input requirements.
|
||||||
global_dq_b = None
|
if m % tile_size:
|
||||||
out_global_scaling_factor = None
|
tiled_shape = ((m + tile_size - 1) // tile_size) * tile_size
|
||||||
|
a_fp32 = torch.nn.functional.pad(a_fp32, (0, 0, 0, tiled_shape - m),
|
||||||
|
"constant", 0)
|
||||||
|
m_padded = ((m + tile_size - 1) // tile_size) * tile_size
|
||||||
|
|
||||||
if use_deep_seek_fp8:
|
dq_sf_a = None
|
||||||
a_fp8, dq_sf_a = quant_ds_fp8(a_fp32, activations=True)
|
dq_sf_b = None
|
||||||
dq_sf_a = dq_sf_a.reshape(k // 128, -1).contiguous()
|
global_dq_a = None
|
||||||
b_fp8, dq_sf_b = quant_ds_fp8(b_fp32, activations=False)
|
global_dq_b = None
|
||||||
dq_sf_b = dq_sf_b.contiguous()
|
out_global_scaling_factor = None
|
||||||
else:
|
|
||||||
a_fp8, global_dq_a = quant_fp8(a_fp32)
|
|
||||||
b_fp8, global_dq_b = quant_fp8(b_fp32)
|
|
||||||
out_global_scaling_factor = global_dq_a * global_dq_b
|
|
||||||
|
|
||||||
# Compute reference batched matrix multiplication
|
|
||||||
output = fp8_bmm_reference(a_fp8, b_fp8, dq_sf_a, dq_sf_b, global_dq_a,
|
|
||||||
global_dq_b, dtype_c, use_deep_seek_fp8)
|
|
||||||
|
|
||||||
c_dq_sf_ref = None
|
|
||||||
if dtype_c == torch.float8_e4m3fn:
|
|
||||||
if use_deep_seek_fp8:
|
if use_deep_seek_fp8:
|
||||||
c_ref, c_dq_sf_ref = output
|
a_fp8, dq_sf_a = quant_ds_fp8(a_fp32, activations=True)
|
||||||
c_dq_sf_ref = c_dq_sf_ref.reshape(n // 128, -1)
|
dq_sf_a = dq_sf_a.reshape(k // 128, -1).contiguous()
|
||||||
|
b_fp8, dq_sf_b = quant_ds_fp8(b_fp32, activations=False)
|
||||||
|
dq_sf_b = dq_sf_b.contiguous()
|
||||||
else:
|
else:
|
||||||
c_ref, c_scale = output
|
a_fp8, global_dq_a = quant_fp8(a_fp32)
|
||||||
out_global_scaling_factor /= c_scale.cuda()
|
b_fp8, global_dq_b = quant_fp8(b_fp32)
|
||||||
else:
|
out_global_scaling_factor = global_dq_a * global_dq_b
|
||||||
c_ref = output
|
|
||||||
|
|
||||||
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
|
# Compute reference batched matrix multiplication
|
||||||
if low_latency and not use_deep_seek_fp8:
|
output = fp8_bmm_reference(a_fp8, b_fp8, dq_sf_a, dq_sf_b, global_dq_a,
|
||||||
b_fp8_shuffled = []
|
global_dq_b, dtype_c, use_deep_seek_fp8)
|
||||||
for bi in range(b):
|
|
||||||
b_fp8_shuffled.append(
|
|
||||||
shuffle_matrix_a(b_fp8[bi].view(torch.uint8).clone(),
|
|
||||||
epilogue_tile_m))
|
|
||||||
|
|
||||||
# Stack weights for all experts
|
c_dq_sf_ref = None
|
||||||
b_fp8 = torch.stack(b_fp8_shuffled).view(torch.float8_e4m3fn)
|
if dtype_c == torch.float8_e4m3fn:
|
||||||
|
if use_deep_seek_fp8:
|
||||||
|
c_ref, c_dq_sf_ref = output
|
||||||
|
c_dq_sf_ref = c_dq_sf_ref.reshape(n // 128, -1)
|
||||||
|
else:
|
||||||
|
c_ref, c_scale = output
|
||||||
|
out_global_scaling_factor /= c_scale.cuda()
|
||||||
|
else:
|
||||||
|
c_ref = output
|
||||||
|
|
||||||
if not use_deep_seek_fp8:
|
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
|
||||||
out_global_scaling_factor = out_global_scaling_factor.contiguous().to(
|
if low_latency and not use_deep_seek_fp8:
|
||||||
torch.float32)
|
b_fp8_shuffled = []
|
||||||
|
for bi in range(b):
|
||||||
|
b_fp8_shuffled.append(
|
||||||
|
shuffle_matrix_a(b_fp8[bi].view(torch.uint8).clone(),
|
||||||
|
epilogue_tile_m))
|
||||||
|
|
||||||
c_actual, c_dq_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
|
# Stack weights for all experts
|
||||||
a_fp8.contiguous(),
|
b_fp8 = torch.stack(b_fp8_shuffled).view(torch.float8_e4m3fn)
|
||||||
b_fp8.contiguous(),
|
|
||||||
tile_size=tile_size,
|
|
||||||
epilogue_tile_m=epilogue_tile_m,
|
|
||||||
use_deep_seek_fp8=use_deep_seek_fp8,
|
|
||||||
low_latency=low_latency,
|
|
||||||
out_dtype=dtype_c,
|
|
||||||
dq_sfs_a=dq_sf_a,
|
|
||||||
dq_sfs_b=dq_sf_b,
|
|
||||||
scale_c=out_global_scaling_factor)
|
|
||||||
|
|
||||||
c_actual = c_actual.detach().cpu()
|
if not use_deep_seek_fp8:
|
||||||
c_ref = c_ref.detach().cpu()
|
out_global_scaling_factor = out_global_scaling_factor.contiguous(
|
||||||
|
).to(torch.float32)
|
||||||
|
|
||||||
torch.testing.assert_close(c_actual.to(torch.float32)[:, :m],
|
c_actual, c_dq_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
|
||||||
c_ref.to(torch.float32)[:, :m],
|
a_fp8.contiguous(),
|
||||||
atol=1e-2,
|
b_fp8.contiguous(),
|
||||||
rtol=1e-2)
|
tile_size=tile_size,
|
||||||
if use_deep_seek_fp8 and dtype_c == torch.float8_e4m3fn:
|
epilogue_tile_m=epilogue_tile_m,
|
||||||
c_dq_sf = c_dq_sf.detach().cpu()
|
use_deep_seek_fp8=use_deep_seek_fp8,
|
||||||
for bi in range(b):
|
low_latency=low_latency,
|
||||||
torch.testing.assert_close(
|
out_dtype=dtype_c,
|
||||||
c_dq_sf[:, bi * m_padded:bi * m_padded + m].to(torch.float32),
|
dq_sfs_a=dq_sf_a,
|
||||||
c_dq_sf_ref[:,
|
dq_sfs_b=dq_sf_b,
|
||||||
|
scale_c=out_global_scaling_factor)
|
||||||
|
|
||||||
|
c_actual = c_actual.detach().cpu()
|
||||||
|
c_ref = c_ref.detach().cpu()
|
||||||
|
|
||||||
|
torch.testing.assert_close(c_actual.to(torch.float32)[:, :m],
|
||||||
|
c_ref.to(torch.float32)[:, :m],
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2)
|
||||||
|
if use_deep_seek_fp8 and dtype_c == torch.float8_e4m3fn:
|
||||||
|
c_dq_sf = c_dq_sf.detach().cpu()
|
||||||
|
for bi in range(b):
|
||||||
|
torch.testing.assert_close(
|
||||||
|
c_dq_sf[:,
|
||||||
bi * m_padded:bi * m_padded + m].to(torch.float32),
|
bi * m_padded:bi * m_padded + m].to(torch.float32),
|
||||||
atol=1e-2,
|
c_dq_sf_ref[:, bi * m_padded:bi * m_padded + m].to(
|
||||||
rtol=1e-2)
|
torch.float32),
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2)
|
||||||
|
|
||||||
|
def test_autotuned_thop(self, test_case: BatchedGemmTestCase) -> None:
|
||||||
|
torch.random.manual_seed(42)
|
||||||
|
|
||||||
|
b = test_case.b
|
||||||
|
m = test_case.m
|
||||||
|
n = test_case.n
|
||||||
|
k = test_case.k
|
||||||
|
use_deep_seek_fp8 = test_case.use_deep_seek_fp8
|
||||||
|
dtype_c = test_case.dtype_c
|
||||||
|
low_latency = test_case.low_latency
|
||||||
|
tile_size = test_case.tile_size
|
||||||
|
|
||||||
|
a_fp32 = torch.randn((b, m, k), device="cuda", dtype=torch.float32)
|
||||||
|
b_fp32 = torch.randn((b, n, k), device="cuda", dtype=torch.float32)
|
||||||
|
# Pad to the tile size. It is needed for the TRT-LLM Gen BMM input requirements.
|
||||||
|
if m % tile_size:
|
||||||
|
tiled_shape = ((m + tile_size - 1) // tile_size) * tile_size
|
||||||
|
a_fp32 = torch.nn.functional.pad(a_fp32, (0, 0, 0, tiled_shape - m),
|
||||||
|
"constant", 0)
|
||||||
|
m_padded = ((m + tile_size - 1) // tile_size) * tile_size
|
||||||
|
|
||||||
|
dq_sf_a = None
|
||||||
|
dq_sf_b = None
|
||||||
|
global_dq_a = None
|
||||||
|
global_dq_b = None
|
||||||
|
out_global_scaling_factor = None
|
||||||
|
|
||||||
|
if use_deep_seek_fp8:
|
||||||
|
a_fp8, dq_sf_a = quant_ds_fp8(a_fp32, activations=True)
|
||||||
|
dq_sf_a = dq_sf_a.reshape(k // 128, -1).contiguous()
|
||||||
|
b_fp8, dq_sf_b = quant_ds_fp8(b_fp32, activations=False)
|
||||||
|
dq_sf_b = dq_sf_b.contiguous()
|
||||||
|
else:
|
||||||
|
a_fp8, global_dq_a = quant_fp8(a_fp32)
|
||||||
|
b_fp8, global_dq_b = quant_fp8(b_fp32)
|
||||||
|
out_global_scaling_factor = global_dq_a * global_dq_b
|
||||||
|
|
||||||
|
# Compute reference batched matrix multiplication
|
||||||
|
output = fp8_bmm_reference(a_fp8, b_fp8, dq_sf_a, dq_sf_b, global_dq_a,
|
||||||
|
global_dq_b, dtype_c, use_deep_seek_fp8)
|
||||||
|
|
||||||
|
c_dq_sf_ref = None
|
||||||
|
if dtype_c == torch.float8_e4m3fn:
|
||||||
|
if use_deep_seek_fp8:
|
||||||
|
c_ref, c_dq_sf_ref = output
|
||||||
|
c_dq_sf_ref = c_dq_sf_ref.reshape(n // 128, -1)
|
||||||
|
else:
|
||||||
|
c_ref, c_scale = output
|
||||||
|
out_global_scaling_factor /= c_scale.cuda()
|
||||||
|
else:
|
||||||
|
c_ref = output
|
||||||
|
|
||||||
|
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
|
||||||
|
if low_latency and not use_deep_seek_fp8:
|
||||||
|
b_fp8_shuffled = []
|
||||||
|
for bi in range(b):
|
||||||
|
b_fp8_shuffled.append(
|
||||||
|
shuffle_matrix_a(b_fp8[bi].view(torch.uint8).clone(),
|
||||||
|
epilogue_tile_m))
|
||||||
|
|
||||||
|
# Stack weights for all experts
|
||||||
|
b_fp8 = torch.stack(b_fp8_shuffled).view(torch.float8_e4m3fn)
|
||||||
|
|
||||||
|
if not use_deep_seek_fp8:
|
||||||
|
out_global_scaling_factor = out_global_scaling_factor.contiguous(
|
||||||
|
).to(torch.float32)
|
||||||
|
|
||||||
|
with autotune():
|
||||||
|
c_actual, c_dq_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
|
||||||
|
a_fp8.contiguous(),
|
||||||
|
b_fp8.contiguous(),
|
||||||
|
tile_size=tile_size,
|
||||||
|
epilogue_tile_m=epilogue_tile_m,
|
||||||
|
use_deep_seek_fp8=use_deep_seek_fp8,
|
||||||
|
low_latency=low_latency,
|
||||||
|
out_dtype=dtype_c,
|
||||||
|
dq_sfs_a=dq_sf_a,
|
||||||
|
dq_sfs_b=dq_sf_b,
|
||||||
|
scale_c=out_global_scaling_factor)
|
||||||
|
|
||||||
|
c_actual = c_actual.detach().cpu()
|
||||||
|
c_ref = c_ref.detach().cpu()
|
||||||
|
|
||||||
|
torch.testing.assert_close(c_actual.to(torch.float32)[:, :m],
|
||||||
|
c_ref.to(torch.float32)[:, :m],
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2)
|
||||||
|
if use_deep_seek_fp8 and dtype_c == torch.float8_e4m3fn:
|
||||||
|
c_dq_sf = c_dq_sf.detach().cpu()
|
||||||
|
for bi in range(b):
|
||||||
|
torch.testing.assert_close(
|
||||||
|
c_dq_sf[:,
|
||||||
|
bi * m_padded:bi * m_padded + m].to(torch.float32),
|
||||||
|
c_dq_sf_ref[:, bi * m_padded:bi * m_padded + m].to(
|
||||||
|
torch.float32),
|
||||||
|
atol=1e-2,
|
||||||
|
rtol=1e-2)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user