mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5394685][fix] using static scheduler 2CTA MLA as WAR for an accuracy issue (#6896)
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
This commit is contained in:
parent
5346eb7bc5
commit
11d89a3732
@ -102,9 +102,9 @@ public:
|
||||
int headDimPerCtaV, int headDimQk, int headDimV, int tileSizeKv, int numTokensPerPage,
|
||||
int maxNumHeadsQPerKvInCta, bool reuseSmemKForV, bool uses2CtaMma) const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) && (headDimPerCtaV <= 2048)
|
||||
&& (headDimQk <= 2048) && (headDimV <= 2048) && (numTokensPerPage <= 128),
|
||||
"Expect (32 <= headDim <= 2048) && (numTokensPerPage <= 128), got headDimPerCtaV=%d, headDimQk=%d, "
|
||||
TLLM_CHECK_WITH_INFO((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) && (headDimPerCtaV <= 1024)
|
||||
&& (headDimQk <= 1024) && (headDimV <= 1024) && (numTokensPerPage <= 128),
|
||||
"Expect (32 <= headDim <= 1024) && (numTokensPerPage <= 128), got headDimPerCtaV=%d, headDimQk=%d, "
|
||||
"headDimV=%d, numTokensPerPage=%d",
|
||||
headDimPerCtaV, headDimQk, headDimV, numTokensPerPage);
|
||||
TLLM_CHECK_WITH_INFO(maxNumHeadsQPerKvInCta <= 128, "The maxNumHeadsQPerKvInCta <= 128 is required.");
|
||||
@ -115,19 +115,19 @@ public:
|
||||
// Bit 8 - 11: kernelType.
|
||||
// Bit 12 - 15: tileScheduler.
|
||||
// Bit 16 - 17: multiCtasKvMode.
|
||||
// Bit 18 - 24: (headDimPerCtaV >> 5).
|
||||
// Bit 25 - 31: (headDimQk >> 5).
|
||||
// Bit 32 - 38: (headDimV >> 5).
|
||||
// Bit 39 - 40: (tileSizeKv >> 6).
|
||||
// Bit 41 - 48: numTokensPerPage.
|
||||
// Bit 18 - 25: (headDimPerCtaV >> 3).
|
||||
// Bit 26 - 33: (headDimQk >> 3).
|
||||
// Bit 34 - 41: (headDimV >> 3).
|
||||
// Bit 42 - 43: (tileSizeKv >> 6).
|
||||
// Bit 44 - 48: (numTokensPerPage >> 3).
|
||||
// Bit 49 - 56: maxNumHeadsQPerKvInCta.
|
||||
// Bit 57 - 57: reuseSmemKForV.
|
||||
// Bit 58 - 58: uses2CtaMma.
|
||||
return (static_cast<uint64_t>(qkvLayout) << 0) | (static_cast<uint64_t>(maskType) << 4)
|
||||
| (static_cast<uint64_t>(kernelType) << 8) | (static_cast<uint64_t>(scheduler) << 12)
|
||||
| (static_cast<uint64_t>(multiCtasKvMode) << 16) | (static_cast<uint64_t>(headDimPerCtaV >> 5) << 18)
|
||||
| (static_cast<uint64_t>(headDimQk >> 5) << 25) | (static_cast<uint64_t>(headDimV >> 5) << 32)
|
||||
| (static_cast<uint64_t>(tileSizeKv >> 6) << 39) | (static_cast<uint64_t>(numTokensPerPage) << 41)
|
||||
| (static_cast<uint64_t>(multiCtasKvMode) << 16) | (static_cast<uint64_t>(headDimPerCtaV >> 3) << 18)
|
||||
| (static_cast<uint64_t>(headDimQk >> 3) << 26) | (static_cast<uint64_t>(headDimV >> 3) << 34)
|
||||
| (static_cast<uint64_t>(tileSizeKv >> 6) << 42) | (static_cast<uint64_t>(numTokensPerPage >> 3) << 44)
|
||||
| (static_cast<uint64_t>(maxNumHeadsQPerKvInCta) << 49) | (static_cast<uint64_t>(reuseSmemKForV) << 57)
|
||||
| (static_cast<uint64_t>(uses2CtaMma) << 58);
|
||||
}
|
||||
@ -142,6 +142,17 @@ public:
|
||||
|
||||
std::pair<bool, std::string> checkIfKernelExist(RunnerParams const& params) const
|
||||
{
|
||||
// Some conditions to check if the kernel is supported.
|
||||
// This is meant to avoid occupying unnecessary hashId bits.
|
||||
if (params.mHeadDimQk % 8 != 0 || params.mHeadDimV % 8 != 0)
|
||||
{
|
||||
return std::make_pair(false, "HeadDimQk and HeadDimV must be divisible by 8");
|
||||
}
|
||||
if (params.mNumTokensPerPage % 8 != 0)
|
||||
{
|
||||
return std::make_pair(false, "NumTokensPerPage must be divisible by 8");
|
||||
}
|
||||
|
||||
// The selectKernelParams that might be updated.
|
||||
SelectKernelParams selectKernelParams{params};
|
||||
auto [hashId, info] = hashFromRunnerParams(params, selectKernelParams);
|
||||
@ -347,6 +358,11 @@ private:
|
||||
selectKernelParams.mTileScheduler = TileScheduler::Persistent;
|
||||
// Need to select a different kernel.
|
||||
selectKernelParams.mSelectNewKernel = true;
|
||||
// FIXME(perkz): use static scheduler instead as WAR for https://nvbugspro.nvidia.com/bug/5394685.
|
||||
if (selectKernelParams.mUses2CtaMma)
|
||||
{
|
||||
selectKernelParams.mTileScheduler = TileScheduler::Static;
|
||||
}
|
||||
}
|
||||
else if (totalNumCtas < params.mMultiProcessorCount && isMlaGenKernel(params)
|
||||
&& selectKernelParams.mTileSizeKv == 128 && tensorrt_llm::common::getEnvUseTileSizeKv64ForTrtllmGen())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user