mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][feat] Initial PR for trtllm-gen attention backend (#10784)
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
This commit is contained in:
parent
18c992efb1
commit
e8b860965b
@ -11,6 +11,8 @@ if TYPE_CHECKING:
|
||||
from ..speculative.interface import SpecMetadata
|
||||
from ..speculative.spec_tree_manager import SpecTreeManager
|
||||
|
||||
from tensorrt_llm._torch.attention_backend import trtllm_gen
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.bindings.internal import thop
|
||||
from tensorrt_llm.functional import AttentionMaskType
|
||||
@ -24,6 +26,11 @@ from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
|
||||
AttentionMetadata, KVCacheParams, MLAParams,
|
||||
PositionalEmbeddingParams, PredefinedAttentionMask,
|
||||
RopeParams)
|
||||
from .trtllm_gen import trtllm_gen_attention
|
||||
|
||||
# Enable TRTLLM-Gen attention backend via environment variable.
|
||||
_TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION = os.environ.get(
|
||||
"TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION", "0") == "1"
|
||||
|
||||
|
||||
@dataclass(kw_only=True, init=False)
|
||||
@ -86,6 +93,8 @@ class TrtllmAttentionWrapper:
|
||||
helix_position_offsets: Optional[torch.Tensor]
|
||||
helix_is_inactive_rank: Optional[torch.Tensor]
|
||||
attention_input_type: Optional[torch.Tensor]
|
||||
quant_config: Optional[QuantConfig]
|
||||
kv_cache_manager: Optional[KVCacheManager]
|
||||
kwargs: dict
|
||||
|
||||
def __init__(
|
||||
@ -219,6 +228,8 @@ class TrtllmAttentionWrapper:
|
||||
skip_softmax_threshold_scale_factor_decode: Optional[float] = None,
|
||||
helix_position_offsets: Optional[torch.Tensor] = None,
|
||||
helix_is_inactive_rank: Optional[torch.Tensor] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
kv_cache_manager: Optional[KVCacheManager] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -266,6 +277,8 @@ class TrtllmAttentionWrapper:
|
||||
skip_softmax_threshold_scale_factor_decode (float): The scale factor for the skip softmax threshold in decode phase.
|
||||
helix_position_offsets (torch.Tensor): The tensor to store the helix position offsets, with shape (num_tokens) on GPU.
|
||||
helix_is_inactive_rank (torch.Tensor): For Helix: whether the current rank is inactive, with shape (batch_size) on GPU.
|
||||
quant_config (Optional[QuantConfig]): The quantization configuration.
|
||||
kv_cache_manager (Optional[KVCacheManager]): The KV cache manager.
|
||||
"""
|
||||
self.layer_idx = layer_idx
|
||||
self.tokens_per_block = tokens_per_block
|
||||
@ -326,6 +339,8 @@ class TrtllmAttentionWrapper:
|
||||
self.chunked_prefill_buffer_batch_size = chunked_prefill_buffer_batch_size
|
||||
self.skip_softmax_threshold_scale_factor_prefill = skip_softmax_threshold_scale_factor_prefill
|
||||
self.skip_softmax_threshold_scale_factor_decode = skip_softmax_threshold_scale_factor_decode
|
||||
self.quant_config = quant_config
|
||||
self.kv_cache_manager = kv_cache_manager
|
||||
self.kwargs.update(kwargs)
|
||||
|
||||
def create_output(
|
||||
@ -493,86 +508,195 @@ class TrtllmAttentionWrapper:
|
||||
if self.print_skip_softmax_stat:
|
||||
self.skip_softmax_stat.zero_()
|
||||
|
||||
thop.attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
output_sf,
|
||||
self.workspace,
|
||||
self.sequence_length,
|
||||
self.host_past_key_value_lengths,
|
||||
self.host_total_kv_lens,
|
||||
self.context_lengths,
|
||||
self.host_context_lengths,
|
||||
self.host_request_types,
|
||||
self.kv_cache_block_offsets,
|
||||
self.host_kv_cache_pool_pointers,
|
||||
self.host_kv_cache_pool_mapping,
|
||||
self.cache_indirection,
|
||||
self.kv_scale_orig_quant,
|
||||
self.kv_scale_quant_orig,
|
||||
self.out_scale_sf if self.use_nvfp4_output else self.out_scale,
|
||||
self.rotary_inv_freq,
|
||||
self.rotary_cos_sin,
|
||||
self.latent_cache,
|
||||
self.q_pe,
|
||||
self.block_ids_per_seq,
|
||||
self.attention_sinks,
|
||||
is_fused_qkv,
|
||||
update_kv_cache,
|
||||
self.predicted_tokens_per_seq,
|
||||
self.layer_idx,
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.tokens_per_block,
|
||||
self.max_num_requests,
|
||||
self.max_context_length,
|
||||
self.attention_window_size,
|
||||
self.sink_token_length,
|
||||
self.beam_width,
|
||||
int(mask_type),
|
||||
self.quant_mode,
|
||||
self.q_scaling,
|
||||
self.position_embedding_type,
|
||||
self.rotary_embedding_dim,
|
||||
self.rotary_embedding_base,
|
||||
self.rotary_embedding_scale_type,
|
||||
rotary_embedding_scales,
|
||||
rotary_embedding_max_position_info,
|
||||
self.use_paged_context_fmha,
|
||||
self.attention_input_type,
|
||||
self.is_mla_enable,
|
||||
self.chunked_prefill_buffer_batch_size,
|
||||
self.q_lora_rank,
|
||||
self.kv_lora_rank,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim,
|
||||
self.v_head_dim,
|
||||
self.mrope_rotary_cos_sin,
|
||||
self.mrope_position_deltas,
|
||||
mla_tensor_params,
|
||||
self.attention_chunk_size,
|
||||
self.softmax_stats_tensor,
|
||||
spec_decoding_bool_params,
|
||||
spec_decoding_tensor_params,
|
||||
self.sparse_kv_indices,
|
||||
self.sparse_kv_offsets,
|
||||
self.sparse_attn_indices,
|
||||
self.sparse_attn_offsets,
|
||||
self.sparse_attn_indices_block_size,
|
||||
self.sparse_mla_topk,
|
||||
self.skip_softmax_threshold_scale_factor_prefill,
|
||||
self.skip_softmax_threshold_scale_factor_decode,
|
||||
self.skip_softmax_stat,
|
||||
cu_q_seqlens,
|
||||
cu_kv_seqlens,
|
||||
fmha_scheduler_counter,
|
||||
mla_bmm1_scale,
|
||||
mla_bmm2_scale,
|
||||
quant_q_buffer,
|
||||
)
|
||||
out_scale = self.out_scale_sf if self.use_nvfp4_output else self.out_scale
|
||||
|
||||
if _TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION and trtllm_gen.is_supported(
|
||||
q=q,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
out_dtype=output.dtype,
|
||||
mask_type=int(mask_type),
|
||||
has_alibi=(self.position_embedding_type == 4
|
||||
or self.position_embedding_type == 5),
|
||||
is_padded=False,
|
||||
use_paged_kv_cache=(self.kv_cache_block_offsets is not None),
|
||||
tokens_per_block=self.tokens_per_block,
|
||||
beam_width=self.beam_width,
|
||||
position_shift_enabled=False,
|
||||
sink_token_length=self.sink_token_length,
|
||||
cross_attention=False,
|
||||
is_spec_decoding=self.is_spec_decoding_enabled,
|
||||
is_mla_enable=self.is_mla_enable,
|
||||
is_fused_qkv=is_fused_qkv,
|
||||
update_kv_cache=update_kv_cache,
|
||||
has_cross_kv=False,
|
||||
quant_config=self.quant_config,
|
||||
kv_cache_manager=self.kv_cache_manager,
|
||||
)[0]:
|
||||
trtllm_gen_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
output_sf,
|
||||
self.workspace,
|
||||
self.sequence_length,
|
||||
self.host_past_key_value_lengths,
|
||||
self.host_total_kv_lens,
|
||||
self.context_lengths,
|
||||
self.host_context_lengths,
|
||||
self.host_request_types,
|
||||
self.kv_cache_block_offsets,
|
||||
self.host_kv_cache_pool_pointers,
|
||||
self.host_kv_cache_pool_mapping,
|
||||
self.cache_indirection,
|
||||
self.kv_scale_orig_quant,
|
||||
self.kv_scale_quant_orig,
|
||||
out_scale,
|
||||
self.rotary_inv_freq,
|
||||
self.rotary_cos_sin,
|
||||
self.latent_cache,
|
||||
self.q_pe,
|
||||
self.block_ids_per_seq,
|
||||
self.attention_sinks,
|
||||
is_fused_qkv,
|
||||
update_kv_cache,
|
||||
self.predicted_tokens_per_seq,
|
||||
self.layer_idx,
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.tokens_per_block,
|
||||
self.max_num_requests,
|
||||
self.max_context_length,
|
||||
self.attention_window_size,
|
||||
self.sink_token_length,
|
||||
self.beam_width,
|
||||
int(mask_type),
|
||||
self.quant_mode,
|
||||
self.q_scaling,
|
||||
self.position_embedding_type,
|
||||
self.rotary_embedding_dim,
|
||||
self.rotary_embedding_base,
|
||||
self.rotary_embedding_scale_type,
|
||||
rotary_embedding_scales,
|
||||
rotary_embedding_max_position_info,
|
||||
self.use_paged_context_fmha,
|
||||
self.attention_input_type,
|
||||
self.is_mla_enable,
|
||||
self.chunked_prefill_buffer_batch_size,
|
||||
self.q_lora_rank,
|
||||
self.kv_lora_rank,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim,
|
||||
self.v_head_dim,
|
||||
self.mrope_rotary_cos_sin,
|
||||
self.mrope_position_deltas,
|
||||
mla_tensor_params,
|
||||
self.attention_chunk_size,
|
||||
self.softmax_stats_tensor,
|
||||
spec_decoding_bool_params,
|
||||
spec_decoding_tensor_params,
|
||||
self.sparse_kv_indices,
|
||||
self.sparse_kv_offsets,
|
||||
self.sparse_attn_indices,
|
||||
self.sparse_attn_offsets,
|
||||
self.sparse_attn_indices_block_size,
|
||||
self.sparse_mla_topk,
|
||||
self.skip_softmax_threshold_scale_factor_prefill,
|
||||
self.skip_softmax_threshold_scale_factor_decode,
|
||||
self.skip_softmax_stat,
|
||||
cu_q_seqlens,
|
||||
cu_kv_seqlens,
|
||||
fmha_scheduler_counter,
|
||||
mla_bmm1_scale,
|
||||
mla_bmm2_scale,
|
||||
quant_q_buffer,
|
||||
self.quant_config,
|
||||
self.kv_cache_manager,
|
||||
)
|
||||
else:
|
||||
thop.attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output,
|
||||
output_sf,
|
||||
self.workspace,
|
||||
self.sequence_length,
|
||||
self.host_past_key_value_lengths,
|
||||
self.host_total_kv_lens,
|
||||
self.context_lengths,
|
||||
self.host_context_lengths,
|
||||
self.host_request_types,
|
||||
self.kv_cache_block_offsets,
|
||||
self.host_kv_cache_pool_pointers,
|
||||
self.host_kv_cache_pool_mapping,
|
||||
self.cache_indirection,
|
||||
self.kv_scale_orig_quant,
|
||||
self.kv_scale_quant_orig,
|
||||
out_scale,
|
||||
self.rotary_inv_freq,
|
||||
self.rotary_cos_sin,
|
||||
self.latent_cache,
|
||||
self.q_pe,
|
||||
self.block_ids_per_seq,
|
||||
self.attention_sinks,
|
||||
is_fused_qkv,
|
||||
update_kv_cache,
|
||||
self.predicted_tokens_per_seq,
|
||||
self.layer_idx,
|
||||
self.num_heads,
|
||||
self.num_kv_heads,
|
||||
self.head_size,
|
||||
self.tokens_per_block,
|
||||
self.max_num_requests,
|
||||
self.max_context_length,
|
||||
self.attention_window_size,
|
||||
self.sink_token_length,
|
||||
self.beam_width,
|
||||
int(mask_type),
|
||||
self.quant_mode,
|
||||
self.q_scaling,
|
||||
self.position_embedding_type,
|
||||
self.rotary_embedding_dim,
|
||||
self.rotary_embedding_base,
|
||||
self.rotary_embedding_scale_type,
|
||||
rotary_embedding_scales,
|
||||
rotary_embedding_max_position_info,
|
||||
self.use_paged_context_fmha,
|
||||
self.attention_input_type,
|
||||
self.is_mla_enable,
|
||||
self.chunked_prefill_buffer_batch_size,
|
||||
self.q_lora_rank,
|
||||
self.kv_lora_rank,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_rope_head_dim,
|
||||
self.v_head_dim,
|
||||
self.mrope_rotary_cos_sin,
|
||||
self.mrope_position_deltas,
|
||||
mla_tensor_params,
|
||||
self.attention_chunk_size,
|
||||
self.softmax_stats_tensor,
|
||||
spec_decoding_bool_params,
|
||||
spec_decoding_tensor_params,
|
||||
self.sparse_kv_indices,
|
||||
self.sparse_kv_offsets,
|
||||
self.sparse_attn_indices,
|
||||
self.sparse_attn_offsets,
|
||||
self.sparse_attn_indices_block_size,
|
||||
self.sparse_mla_topk,
|
||||
self.skip_softmax_threshold_scale_factor_prefill,
|
||||
self.skip_softmax_threshold_scale_factor_decode,
|
||||
self.skip_softmax_stat,
|
||||
cu_q_seqlens,
|
||||
cu_kv_seqlens,
|
||||
fmha_scheduler_counter,
|
||||
mla_bmm1_scale,
|
||||
mla_bmm2_scale,
|
||||
quant_q_buffer,
|
||||
)
|
||||
|
||||
if self.print_skip_softmax_stat:
|
||||
(total_blocks, skipped_blocks) = self.skip_softmax_stat
|
||||
@ -1809,6 +1933,8 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
skip_softmax_threshold_scale_factor_decode,
|
||||
helix_position_offsets=metadata.helix_position_offsets,
|
||||
helix_is_inactive_rank=metadata.helix_is_inactive_rank,
|
||||
quant_config=self.quant_config,
|
||||
kv_cache_manager=metadata.kv_cache_manager,
|
||||
)
|
||||
|
||||
self.wrapper.run(q,
|
||||
|
||||
1265
tensorrt_llm/_torch/attention_backend/trtllm_gen.py
Normal file
1265
tensorrt_llm/_torch/attention_backend/trtllm_gen.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user