feat : add PositionEmbeddingType=0 to xqa support (#4934)

Signed-off-by: Jiying Dong <87510204+dongjiyingdjy@users.noreply.github.com>
This commit is contained in:
dongjiyingdjy 2025-06-05 21:50:42 +08:00 committed by GitHub
parent bfa877a22e
commit 51652b9b2b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 5 deletions

View File

@ -68,15 +68,14 @@ CubinObj CompileEngine::compile() const
case PositionEmbeddingType::kROPE_GPTJ: ropeStyle = tllmXqaJitRopeStyle::TLLM_XQA_JIT_ROPE_GPTJ; break;
case PositionEmbeddingType::kROPE_GPT_NEOX:
case PositionEmbeddingType::kLONG_ROPE: ropeStyle = tllmXqaJitRopeStyle::TLLM_XQA_JIT_ROPE_NEOX; break;
// For kROPE_M, set ropeStyle to TLLM_XQA_JIT_ROPE_NONE to let XQA kernel not apply RoPE.
// At runtime, a separate kernel (see invokeQKVPreprocessing) will be launched to apply RoPE.
case PositionEmbeddingType::kROPE_M: ropeStyle = tllmXqaJitRopeStyle::TLLM_XQA_JIT_ROPE_NONE; break;
default: TLLM_THROW("TllmXqaJit: Bad RoPE type");
}
}
else
{
// Make it explicit that Ampere-style kernel doesn't apply RoPE in the kernel.
// For kROPE_M, set ropeStyle to TLLM_XQA_JIT_ROPE_NONE to let XQA kernel not apply RoPE.
// At runtime, a separate kernel (see invokeQKVPreprocessing) will be launched to apply RoPE.
ropeStyle = tllmXqaJitRopeStyle::TLLM_XQA_JIT_ROPE_NONE;
}
if (applyRoPEInXqaKernel)

View File

@ -62,8 +62,9 @@ bool supportConfigCommon(XQAParams const& xqaParams, bool forConfigurePlugin)
// TODO: remove this when the kernel bug for num_kv_heads <= 128 gets fixed.
return false;
}
if (!contains({PositionEmbeddingType::kROPE_GPTJ, PositionEmbeddingType::kROPE_GPT_NEOX,
PositionEmbeddingType::kROPE_M, PositionEmbeddingType::kLONG_ROPE},
if (!contains(
{PositionEmbeddingType::kROPE_GPTJ, PositionEmbeddingType::kROPE_GPT_NEOX, PositionEmbeddingType::kROPE_M,
PositionEmbeddingType::kLONG_ROPE, PositionEmbeddingType::kLEARNED_ABSOLUTE},
xqaParams.position_embedding_type))
{
return false;