mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][fix] Enable FP8 ContextMLA on GB300 (#8080)
Signed-off-by: Jonas Li <6110159+longlee0622@users.noreply.github.com>
This commit is contained in:
parent
7da4b05289
commit
76a47c7bef
@ -647,8 +647,9 @@ void attention(torch::Tensor q, std::optional<torch::Tensor> k, std::optional<to
|
|||||||
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
|
static_cast<int>(v_head_dim.value()), static_cast<int>(predicted_tokens_per_seq),
|
||||||
static_cast<int>(layer_num)};
|
static_cast<int>(layer_num)};
|
||||||
|
|
||||||
op->mFP8ContextMLA = (tensorrt_llm::common::getSMVersion() == 90 || tensorrt_llm::common::getSMVersion() == 100
|
op->mFP8ContextMLA
|
||||||
|| tensorrt_llm::common::getSMVersion() == 120)
|
= (tensorrt_llm::common::getSMVersion() == 90 || tensorrt_llm::common::getSMVersion() == 100
|
||||||
|
|| tensorrt_llm::common::getSMVersion() == 103 || tensorrt_llm::common::getSMVersion() == 120)
|
||||||
&& op->mKVCacheQuantMode.hasFp8KvCache();
|
&& op->mKVCacheQuantMode.hasFp8KvCache();
|
||||||
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
|
op->mIsGenerationMLA = head_size == op->mMLAParams.kv_lora_rank + op->mMLAParams.qk_rope_head_dim;
|
||||||
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();
|
op->mFP8GenerationMLA = op->mKVCacheQuantMode.hasFp8KvCache();
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user