[None][feat] Initial PR for trtllm-gen attention backend (#10784)

Signed-off-by: Yihan Wang <yihwang@nvidia.com>
This commit is contained in:
Yihan Wang 2026-02-11 17:16:52 +08:00 committed by GitHub
parent 18c992efb1
commit e8b860965b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 1471 additions and 80 deletions

View File

@ -11,6 +11,8 @@ if TYPE_CHECKING:
from ..speculative.interface import SpecMetadata
from ..speculative.spec_tree_manager import SpecTreeManager
from tensorrt_llm._torch.attention_backend import trtllm_gen
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.internal import thop
from tensorrt_llm.functional import AttentionMaskType
@ -24,6 +26,11 @@ from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
AttentionMetadata, KVCacheParams, MLAParams,
PositionalEmbeddingParams, PredefinedAttentionMask,
RopeParams)
from .trtllm_gen import trtllm_gen_attention
# Enable TRTLLM-Gen attention backend via environment variable.
_TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION = os.environ.get(
"TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", "0") == "1"
@dataclass(kw_only=True, init=False)
@ -86,6 +93,8 @@ class TrtllmAttentionWrapper:
helix_position_offsets: Optional[torch.Tensor]
helix_is_inactive_rank: Optional[torch.Tensor]
attention_input_type: Optional[torch.Tensor]
quant_config: Optional[QuantConfig]
kv_cache_manager: Optional[KVCacheManager]
kwargs: dict
def __init__(
@ -219,6 +228,8 @@ class TrtllmAttentionWrapper:
skip_softmax_threshold_scale_factor_decode: Optional[float] = None,
helix_position_offsets: Optional[torch.Tensor] = None,
helix_is_inactive_rank: Optional[torch.Tensor] = None,
quant_config: Optional[QuantConfig] = None,
kv_cache_manager: Optional[KVCacheManager] = None,
**kwargs,
):
"""
@ -266,6 +277,8 @@ class TrtllmAttentionWrapper:
skip_softmax_threshold_scale_factor_decode (float): The scale factor for the skip softmax threshold in decode phase.
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
helix_is_inactive_rank (torch.Tensor): For Helix: whether the current rank is inactive, with shape (batch_size) on GPU.
quant_config (Optional[QuantConfig]): The quantization configuration.
kv_cache_manager (Optional[KVCacheManager]): The KV cache manager.
"""
self.layer_idx = layer_idx
self.tokens_per_block = tokens_per_block
@ -326,6 +339,8 @@ class TrtllmAttentionWrapper:
self.chunked_prefill_buffer_batch_size = chunked_prefill_buffer_batch_size
self.skip_softmax_threshold_scale_factor_prefill = skip_softmax_threshold_scale_factor_prefill
self.skip_softmax_threshold_scale_factor_decode = skip_softmax_threshold_scale_factor_decode
self.quant_config = quant_config
self.kv_cache_manager = kv_cache_manager
self.kwargs.update(kwargs)
def create_output(
@ -493,86 +508,195 @@ class TrtllmAttentionWrapper:
if self.print_skip_softmax_stat:
self.skip_softmax_stat.zero_()
thop.attention(
q,
k,
v,
output,
output_sf,
self.workspace,
self.sequence_length,
self.host_past_key_value_lengths,
self.host_total_kv_lens,
self.context_lengths,
self.host_context_lengths,
self.host_request_types,
self.kv_cache_block_offsets,
self.host_kv_cache_pool_pointers,
self.host_kv_cache_pool_mapping,
self.cache_indirection,
self.kv_scale_orig_quant,
self.kv_scale_quant_orig,
self.out_scale_sf if self.use_nvfp4_output else self.out_scale,
self.rotary_inv_freq,
self.rotary_cos_sin,
self.latent_cache,
self.q_pe,
self.block_ids_per_seq,
self.attention_sinks,
is_fused_qkv,
update_kv_cache,
self.predicted_tokens_per_seq,
self.layer_idx,
self.num_heads,
self.num_kv_heads,
self.head_size,
self.tokens_per_block,
self.max_num_requests,
self.max_context_length,
self.attention_window_size,
self.sink_token_length,
self.beam_width,
int(mask_type),
self.quant_mode,
self.q_scaling,
self.position_embedding_type,
self.rotary_embedding_dim,
self.rotary_embedding_base,
self.rotary_embedding_scale_type,
rotary_embedding_scales,
rotary_embedding_max_position_info,
self.use_paged_context_fmha,
self.attention_input_type,
self.is_mla_enable,
self.chunked_prefill_buffer_batch_size,
self.q_lora_rank,
self.kv_lora_rank,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.v_head_dim,
self.mrope_rotary_cos_sin,
self.mrope_position_deltas,
mla_tensor_params,
self.attention_chunk_size,
self.softmax_stats_tensor,
spec_decoding_bool_params,
spec_decoding_tensor_params,
self.sparse_kv_indices,
self.sparse_kv_offsets,
self.sparse_attn_indices,
self.sparse_attn_offsets,
self.sparse_attn_indices_block_size,
self.sparse_mla_topk,
self.skip_softmax_threshold_scale_factor_prefill,
self.skip_softmax_threshold_scale_factor_decode,
self.skip_softmax_stat,
cu_q_seqlens,
cu_kv_seqlens,
fmha_scheduler_counter,
mla_bmm1_scale,
mla_bmm2_scale,
quant_q_buffer,
)
out_scale = self.out_scale_sf if self.use_nvfp4_output else self.out_scale
if _TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION and trtllm_gen.is_supported(
q=q,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
out_dtype=output.dtype,
mask_type=int(mask_type),
has_alibi=(self.position_embedding_type == 4
or self.position_embedding_type == 5),
is_padded=False,
use_paged_kv_cache=(self.kv_cache_block_offsets is not None),
tokens_per_block=self.tokens_per_block,
beam_width=self.beam_width,
position_shift_enabled=False,
sink_token_length=self.sink_token_length,
cross_attention=False,
is_spec_decoding=self.is_spec_decoding_enabled,
is_mla_enable=self.is_mla_enable,
is_fused_qkv=is_fused_qkv,
update_kv_cache=update_kv_cache,
has_cross_kv=False,
quant_config=self.quant_config,
kv_cache_manager=self.kv_cache_manager,
)[0]:
trtllm_gen_attention(
q,
k,
v,
output,
output_sf,
self.workspace,
self.sequence_length,
self.host_past_key_value_lengths,
self.host_total_kv_lens,
self.context_lengths,
self.host_context_lengths,
self.host_request_types,
self.kv_cache_block_offsets,
self.host_kv_cache_pool_pointers,
self.host_kv_cache_pool_mapping,
self.cache_indirection,
self.kv_scale_orig_quant,
self.kv_scale_quant_orig,
out_scale,
self.rotary_inv_freq,
self.rotary_cos_sin,
self.latent_cache,
self.q_pe,
self.block_ids_per_seq,
self.attention_sinks,
is_fused_qkv,
update_kv_cache,
self.predicted_tokens_per_seq,
self.layer_idx,
self.num_heads,
self.num_kv_heads,
self.head_size,
self.tokens_per_block,
self.max_num_requests,
self.max_context_length,
self.attention_window_size,
self.sink_token_length,
self.beam_width,
int(mask_type),
self.quant_mode,
self.q_scaling,
self.position_embedding_type,
self.rotary_embedding_dim,
self.rotary_embedding_base,
self.rotary_embedding_scale_type,
rotary_embedding_scales,
rotary_embedding_max_position_info,
self.use_paged_context_fmha,
self.attention_input_type,
self.is_mla_enable,
self.chunked_prefill_buffer_batch_size,
self.q_lora_rank,
self.kv_lora_rank,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.v_head_dim,
self.mrope_rotary_cos_sin,
self.mrope_position_deltas,
mla_tensor_params,
self.attention_chunk_size,
self.softmax_stats_tensor,
spec_decoding_bool_params,
spec_decoding_tensor_params,
self.sparse_kv_indices,
self.sparse_kv_offsets,
self.sparse_attn_indices,
self.sparse_attn_offsets,
self.sparse_attn_indices_block_size,
self.sparse_mla_topk,
self.skip_softmax_threshold_scale_factor_prefill,
self.skip_softmax_threshold_scale_factor_decode,
self.skip_softmax_stat,
cu_q_seqlens,
cu_kv_seqlens,
fmha_scheduler_counter,
mla_bmm1_scale,
mla_bmm2_scale,
quant_q_buffer,
self.quant_config,
self.kv_cache_manager,
)
else:
thop.attention(
q,
k,
v,
output,
output_sf,
self.workspace,
self.sequence_length,
self.host_past_key_value_lengths,
self.host_total_kv_lens,
self.context_lengths,
self.host_context_lengths,
self.host_request_types,
self.kv_cache_block_offsets,
self.host_kv_cache_pool_pointers,
self.host_kv_cache_pool_mapping,
self.cache_indirection,
self.kv_scale_orig_quant,
self.kv_scale_quant_orig,
out_scale,
self.rotary_inv_freq,
self.rotary_cos_sin,
self.latent_cache,
self.q_pe,
self.block_ids_per_seq,
self.attention_sinks,
is_fused_qkv,
update_kv_cache,
self.predicted_tokens_per_seq,
self.layer_idx,
self.num_heads,
self.num_kv_heads,
self.head_size,
self.tokens_per_block,
self.max_num_requests,
self.max_context_length,
self.attention_window_size,
self.sink_token_length,
self.beam_width,
int(mask_type),
self.quant_mode,
self.q_scaling,
self.position_embedding_type,
self.rotary_embedding_dim,
self.rotary_embedding_base,
self.rotary_embedding_scale_type,
rotary_embedding_scales,
rotary_embedding_max_position_info,
self.use_paged_context_fmha,
self.attention_input_type,
self.is_mla_enable,
self.chunked_prefill_buffer_batch_size,
self.q_lora_rank,
self.kv_lora_rank,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
self.v_head_dim,
self.mrope_rotary_cos_sin,
self.mrope_position_deltas,
mla_tensor_params,
self.attention_chunk_size,
self.softmax_stats_tensor,
spec_decoding_bool_params,
spec_decoding_tensor_params,
self.sparse_kv_indices,
self.sparse_kv_offsets,
self.sparse_attn_indices,
self.sparse_attn_offsets,
self.sparse_attn_indices_block_size,
self.sparse_mla_topk,
self.skip_softmax_threshold_scale_factor_prefill,
self.skip_softmax_threshold_scale_factor_decode,
self.skip_softmax_stat,
cu_q_seqlens,
cu_kv_seqlens,
fmha_scheduler_counter,
mla_bmm1_scale,
mla_bmm2_scale,
quant_q_buffer,
)
if self.print_skip_softmax_stat:
(total_blocks, skipped_blocks) = self.skip_softmax_stat
@ -1809,6 +1933,8 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
skip_softmax_threshold_scale_factor_decode,
helix_position_offsets=metadata.helix_position_offsets,
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
quant_config=self.quant_config,
kv_cache_manager=metadata.kv_cache_manager,
)
self.wrapper.run(q,

File diff suppressed because it is too large Load Diff