[https://nvbugs/5394685][fix] using static scheduler 2CTA MLA as WAR for an accuracy issue (#6896)

Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
This commit is contained in:
Perkz Zheng 2025-08-15 08:51:04 +08:00 committed by GitHub
parent 5346eb7bc5
commit 11d89a3732
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -102,9 +102,9 @@ public:
int headDimPerCtaV, int headDimQk, int headDimV, int tileSizeKv, int numTokensPerPage,
int maxNumHeadsQPerKvInCta, bool reuseSmemKForV, bool uses2CtaMma) const
{
TLLM_CHECK_WITH_INFO((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) && (headDimPerCtaV <= 2048)
&& (headDimQk <= 2048) && (headDimV <= 2048) && (numTokensPerPage <= 128),
"Expect (32 <= headDim <= 2048) && (numTokensPerPage <= 128), got headDimPerCtaV=%d, headDimQk=%d, "
TLLM_CHECK_WITH_INFO((headDimPerCtaV >= 32) && (headDimQk >= 32) && (headDimV >= 32) && (headDimPerCtaV <= 1024)
&& (headDimQk <= 1024) && (headDimV <= 1024) && (numTokensPerPage <= 128),
"Expect (32 <= headDim <= 1024) && (numTokensPerPage <= 128), got headDimPerCtaV=%d, headDimQk=%d, "
"headDimV=%d, numTokensPerPage=%d",
headDimPerCtaV, headDimQk, headDimV, numTokensPerPage);
TLLM_CHECK_WITH_INFO(maxNumHeadsQPerKvInCta <= 128, "The maxNumHeadsQPerKvInCta <= 128 is required.");
@ -115,19 +115,19 @@ public:
// Bit 8 - 11: kernelType.
// Bit 12 - 15: tileScheduler.
// Bit 16 - 17: multiCtasKvMode.
// Bit 18 - 24: (headDimPerCtaV >> 5).
// Bit 25 - 31: (headDimQk >> 5).
// Bit 32 - 38: (headDimV >> 5).
// Bit 39 - 40: (tileSizeKv >> 6).
// Bit 41 - 48: numTokensPerPage.
// Bit 18 - 25: (headDimPerCtaV >> 3).
// Bit 26 - 33: (headDimQk >> 3).
// Bit 34 - 41: (headDimV >> 3).
// Bit 42 - 43: (tileSizeKv >> 6).
// Bit 44 - 48: (numTokensPerPage >> 3).
// Bit 49 - 56: maxNumHeadsQPerKvInCta.
// Bit 57 - 57: reuseSmemKForV.
// Bit 58 - 58: uses2CtaMma.
return (static_cast<uint64_t>(qkvLayout) << 0) | (static_cast<uint64_t>(maskType) << 4)
| (static_cast<uint64_t>(kernelType) << 8) | (static_cast<uint64_t>(scheduler) << 12)
| (static_cast<uint64_t>(multiCtasKvMode) << 16) | (static_cast<uint64_t>(headDimPerCtaV >> 5) << 18)
| (static_cast<uint64_t>(headDimQk >> 5) << 25) | (static_cast<uint64_t>(headDimV >> 5) << 32)
| (static_cast<uint64_t>(tileSizeKv >> 6) << 39) | (static_cast<uint64_t>(numTokensPerPage) << 41)
| (static_cast<uint64_t>(multiCtasKvMode) << 16) | (static_cast<uint64_t>(headDimPerCtaV >> 3) << 18)
| (static_cast<uint64_t>(headDimQk >> 3) << 26) | (static_cast<uint64_t>(headDimV >> 3) << 34)
| (static_cast<uint64_t>(tileSizeKv >> 6) << 42) | (static_cast<uint64_t>(numTokensPerPage >> 3) << 44)
| (static_cast<uint64_t>(maxNumHeadsQPerKvInCta) << 49) | (static_cast<uint64_t>(reuseSmemKForV) << 57)
| (static_cast<uint64_t>(uses2CtaMma) << 58);
}
@ -142,6 +142,17 @@ public:
std::pair<bool, std::string> checkIfKernelExist(RunnerParams const& params) const
{
// Some conditions to check if the kernel is supported.
// This is meant to avoid occupying unnecessary hashId bits.
if (params.mHeadDimQk % 8 != 0 || params.mHeadDimV % 8 != 0)
{
return std::make_pair(false, "HeadDimQk and HeadDimV must be divisible by 8");
}
if (params.mNumTokensPerPage % 8 != 0)
{
return std::make_pair(false, "NumTokensPerPage must be divisible by 8");
}
// The selectKernelParams that might be updated.
SelectKernelParams selectKernelParams{params};
auto [hashId, info] = hashFromRunnerParams(params, selectKernelParams);
@ -347,6 +358,11 @@ private:
selectKernelParams.mTileScheduler = TileScheduler::Persistent;
// Need to select a different kernel.
selectKernelParams.mSelectNewKernel = true;
// FIXME(perkz): use static scheduler instead as WAR for https://nvbugspro.nvidia.com/bug/5394685.
if (selectKernelParams.mUses2CtaMma)
{
selectKernelParams.mTileScheduler = TileScheduler::Static;
}
}
else if (totalNumCtas < params.mMultiProcessorCount && isMlaGenKernel(params)
&& selectKernelParams.mTileSizeKv == 128 && tensorrt_llm::common::getEnvUseTileSizeKv64ForTrtllmGen())