[None][fix] Enable FP8 ContextMLA on GB300 (#8080)

Signed-off-by: Jonas Li <6110159+longlee0622@users.noreply.github.com>
This commit is contained in:
Jonas Li 2025-10-10 10:20:46 +08:00 committed by GitHub
parent 7da4b05289
commit 76a47c7bef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -647,8 +647,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq), static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
static_cast<int>(layer_num)}; static_cast<int>(layer_num)};
op->mFP8ContextMLA = (tensorrt_llm::common::getSMVersion() == 90 || tensorrt_llm::common::getSMVersion() == 100 op->mFP8ContextMLA
|| tensorrt_llm::common::getSMVersion() == 120) = (tensorrt_llm::common::getSMVersion() == 90 || tensorrt_llm::common::getSMVersion() == 100
|| tensorrt_llm::common::getSMVersion() == 103 || tensorrt_llm::common::getSMVersion() == 120)
&& op->mKVCacheQuantMode.hasFp8KvCache(); && op->mKVCacheQuantMode.hasFp8KvCache();
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim; op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache(); op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();