mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[https://nvbugs/5624818][fix] Work around accuracy issue by enforcing paged_context_fmha on Hopper for fmha_v2 (#11192)
Signed-off-by: eopXD <yuehtingc@nvidia.com>
This commit is contained in:
parent
3d8c1a51bd
commit
f6fff18142
@ -1549,6 +1549,11 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
or metadata.runtime_features.has_speculative_draft_tokens
|
||||
) if metadata.runtime_features else False
|
||||
|
||||
# This is a workaround for https://nvbugs/5624818
|
||||
# Paged context FMHA is forced on SM90 for correctness
|
||||
if get_sm_version() == 90:
|
||||
use_paged_context_fmha = True
|
||||
|
||||
return self.wrapper.is_nvfp4_output_kernel_available(
|
||||
tokens_per_block=metadata.tokens_per_block,
|
||||
attention_mask=attention_mask,
|
||||
@ -1648,6 +1653,11 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
or metadata.runtime_features.has_speculative_draft_tokens
|
||||
) if metadata.runtime_features else False
|
||||
|
||||
# This is a workaround for https://nvbugs/5624818
|
||||
# Paged context FMHA is forced on SM90 for correctness
|
||||
if get_sm_version() == 90:
|
||||
use_paged_context_fmha = True
|
||||
|
||||
if self.is_mla_enable:
|
||||
# Context MLA uses separate qkv instead of paged_context_fmha
|
||||
use_paged_context_fmha = False
|
||||
|
||||
Loading…
Reference in New Issue
Block a user