diff --git a/tensorrt_llm/_torch/attention_backend/flashinfer.py b/tensorrt_llm/_torch/attention_backend/flashinfer.py index 701d0f43b6..264c62d618 100644 --- a/tensorrt_llm/_torch/attention_backend/flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/flashinfer.py @@ -412,9 +412,10 @@ class FlashInferAttention(AttentionBackend[FlashInferAttentionMetadata]): head_dim: int, num_kv_heads: Optional[int] = None, quant_config: Optional[QuantConfig] = None, + **kwargs, ): super().__init__(layer_idx, num_heads, head_dim, num_kv_heads, - quant_config) + quant_config, **kwargs) self.has_fp8_kv_cache = False if quant_config and quant_config.layer_quant_mode.has_any_quant(): diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index 5594cb0357..82de89dad9 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -1,6 +1,6 @@ import copy -import enum from dataclasses import dataclass, field +from enum import Enum, IntEnum from functools import lru_cache from typing import (Generic, List, Optional, Protocol, Tuple, Type, TypeVar, Union) @@ -24,6 +24,14 @@ class AttentionRuntimeFeatures: has_speculative_draft_tokens: bool = False +# The type of requests in qkv passed to attention +# Please keep sync with AttentionInputType in cpp/tensorrt_llm/thop/attentionOp.cpp +class AttentionInputType(IntEnum): + mixed = 0 # contains both context and generation + context_only = 1 + generation_only = 2 + + @dataclass(kw_only=True) class AttentionMetadata: """ @@ -421,7 +429,7 @@ class PositionalEmbeddingParams: TMetadata = TypeVar("TMetadata", bound=AttentionMetadata) -class PredefinedAttentionMask(str, enum.Enum): +class PredefinedAttentionMask(str, Enum): """ Predefined attention mask types @@ -450,6 +458,7 @@ class AttentionBackend(Generic[TMetadata]): head_dim: int, num_kv_heads: Optional[int] = None, quant_config: Optional[QuantConfig] = None, + **kwargs, ): """ Initialize the backend. diff --git a/tensorrt_llm/_torch/attention_backend/star_flashinfer.py b/tensorrt_llm/_torch/attention_backend/star_flashinfer.py index 80d31815b9..29ca290258 100644 --- a/tensorrt_llm/_torch/attention_backend/star_flashinfer.py +++ b/tensorrt_llm/_torch/attention_backend/star_flashinfer.py @@ -304,9 +304,10 @@ class StarAttention(AttentionBackend[StarAttentionMetadata]): head_dim: int, num_kv_heads: Optional[int] = None, quant_config: Optional[QuantConfig] = None, + **kwargs, ): super().__init__(layer_idx, num_heads, head_dim, num_kv_heads, - quant_config) + quant_config, **kwargs) def forward(self, q: torch.Tensor, diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index ea5a9ec113..ee0a74a95e 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -1,6 +1,4 @@ -import math from dataclasses import dataclass, field -from enum import IntEnum from typing import Optional import torch @@ -10,20 +8,13 @@ from tensorrt_llm.functional import (AttentionMaskType, RopeEmbeddingUtils, from tensorrt_llm.logger import logger from tensorrt_llm.models.modeling_utils import QuantConfig -from .interface import (AttentionBackend, AttentionMask, AttentionMetadata, - KVCacheParams, MLAParams, PositionalEmbeddingParams, - PredefinedAttentionMask, RopeParams) +from .interface import (AttentionBackend, AttentionInputType, AttentionMask, + AttentionMetadata, KVCacheParams, MLAParams, + PositionalEmbeddingParams, PredefinedAttentionMask, + RopeParams) from .vanilla import VanillaAttention -# The type of requests in qkv passed to attention -# Please keep sync with AttentionInputType in cpp/tensorrt_llm/thop/attentionOp.cpp -class AttentionInputType(IntEnum): - mixed = 0 # contains both context and generation - context_only = 1 - generation_only = 2 - - @dataclass(kw_only=True, init=False) class TrtllmAttentionWrapper: sequence_length: torch.Tensor @@ -65,11 +56,11 @@ class TrtllmAttentionWrapper: rotary_embedding_original_max_positions: int use_paged_context_fmha: bool is_mla_enable: bool - q_lora_rank: int - kv_lora_rank: int - qk_rope_head_dim: int - qk_nope_head_dim: int - v_head_dim: int + q_lora_rank: Optional[int] + kv_lora_rank: Optional[int] + qk_rope_head_dim: Optional[int] + qk_nope_head_dim: Optional[int] + v_head_dim: Optional[int] kwargs: dict def __init__( @@ -80,6 +71,7 @@ class TrtllmAttentionWrapper: num_kv_heads: Optional[int] = None, pos_embd_params: Optional[PositionalEmbeddingParams] = None, quant_config: Optional[QuantConfig] = None, + q_scaling: Optional[float] = None, mla_params: Optional[MLAParams] = None, **kwargs, ): @@ -101,7 +93,7 @@ class TrtllmAttentionWrapper: rope_params = RopeParams() self.is_mla_enable = mla_params is not None - self.q_scaling = 1.0 + self.q_scaling = q_scaling or 1.0 self.mla_rope_params = None self.predicted_tokens_per_seq = 1 @@ -114,18 +106,6 @@ class TrtllmAttentionWrapper: self.predicted_tokens_per_seq = mla_params.predicted_tokens_per_seq self.rotary_embedding_dim = 0 - - def yarn_get_mscale(scale=1, mscale=1): - if scale <= 1: - return 1.0 - return 0.1 * mscale * math.log(scale) + 1.0 - - mscale_all_dim = rope_params.mscale_all_dim - scaling_factor = rope_params.scale - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.q_scaling = 1.0 / (mscale * mscale) - self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_deepseek_rope_const_params( self.qk_rope_head_dim) self.rotary_embedding_scale_type = RotaryScalingType.none @@ -165,6 +145,7 @@ class TrtllmAttentionWrapper: *, tokens_per_block: Optional[int] = None, max_num_requests: int, + max_sequence_length: int, max_context_length: int, attention_window_size: Optional[int] = None, sink_token_length: int = 0, @@ -196,6 +177,7 @@ class TrtllmAttentionWrapper: Args: tokens_per_block (int): Token number per KV cache block. max_num_requests (int): Max request number per batch. + max_sequence_length (int): Max sequence length. max_context_length (int): Max context length per context-phase sequence. attention_window_size (int): Max token number attended in windowed attention. sink_token_length (int): Sink token number in StreamingLLM. @@ -220,7 +202,7 @@ class TrtllmAttentionWrapper: self.tokens_per_block = tokens_per_block self.max_num_requests = max_num_requests self.max_context_length = max_context_length - self.attention_window_size = attention_window_size or max_context_length + self.attention_window_size = attention_window_size or max_sequence_length self.sink_token_length = sink_token_length self.beam_width = beam_width self.sequence_length = sequence_length @@ -608,9 +590,11 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): num_heads: int, head_dim: int, num_kv_heads: Optional[int] = None, - pos_embd_params: Optional[PositionalEmbeddingParams] = None, quant_config: Optional[QuantConfig] = None, + q_scaling: Optional[float] = None, + pos_embd_params: Optional[PositionalEmbeddingParams] = None, mla_params: Optional[MLAParams] = None, + **kwargs, ): """ Initialize the backend. @@ -619,13 +603,22 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): num_heads (int): The number of query heads. head_dim (int): The size of each attention head (hidden_size // num_heads). num_kv_heads (int): The number of kv heads. Defaults to num_heads if None. + quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied. + q_scaling (float): Scaling factor for QK. Defaults to 1.0 if None. pos_embd_params (PositionalEmbeddingParams): Optional parameters defining how positional embedding should be applied. If None, positional embedding should be applied by the model before calling the backend. Otherwise, the backend is in-charge of applying positional embedding and may cache K without embedding it first. - quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied. + mla_params (MLAParams): Optional parameters for MLA. If None, MLA is not enabled. """ - super().__init__(layer_idx, num_heads, head_dim, num_kv_heads, - quant_config) + super().__init__(layer_idx, + num_heads, + head_dim, + num_kv_heads, + quant_config, + q_scaling=q_scaling, + pos_embd_params=pos_embd_params, + mla_params=mla_params, + **kwargs) self.wrapper = TrtllmAttentionWrapper( layer_idx, num_heads, @@ -633,6 +626,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): num_kv_heads, pos_embd_params=pos_embd_params, quant_config=quant_config, + q_scaling=q_scaling, mla_params=mla_params, ) @@ -722,7 +716,9 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]): self.wrapper.plan( tokens_per_block=metadata.tokens_per_block, max_num_requests=metadata.max_num_requests, - max_context_length=metadata.max_seq_len, + max_sequence_length=metadata.max_seq_len, + max_context_length=min(metadata.max_seq_len - 1, + metadata.max_num_tokens), attention_window_size=None, sink_token_length=0, beam_width=1, diff --git a/tensorrt_llm/_torch/attention_backend/utils.py b/tensorrt_llm/_torch/attention_backend/utils.py index 0f152d4d85..a248cd12e8 100644 --- a/tensorrt_llm/_torch/attention_backend/utils.py +++ b/tensorrt_llm/_torch/attention_backend/utils.py @@ -32,42 +32,40 @@ def create_attention( num_kv_heads: Optional[int] = None, pos_embd_params: Optional[PositionalEmbeddingParams] = None, quant_config: Optional[QuantConfig] = None, - is_mla_enable: Optional[bool] = False, - q_lora_rank: Optional[int] = 0, - kv_lora_rank: Optional[int] = 0, - qk_rope_head_dim: Optional[int] = 0, - qk_nope_head_dim: Optional[int] = 0, - v_head_dim: Optional[int] = 0, + q_scaling: Optional[float] = None, + is_mla_enable: bool = False, + q_lora_rank: Optional[int] = None, + kv_lora_rank: Optional[int] = None, + qk_rope_head_dim: Optional[int] = None, + qk_nope_head_dim: Optional[int] = None, + v_head_dim: Optional[int] = None, predicted_tokens_per_seq: Optional[int] = 1, ): attn_cls = get_attention_backend(backend_name) + if is_mla_enable: assert attn_cls == TrtllmAttention assert (q_lora_rank > 0 and kv_lora_rank > 0 and qk_rope_head_dim > 0 and qk_nope_head_dim > 0 and v_head_dim > 0) - - if attn_cls == TrtllmAttention: - if is_mla_enable: - mla_params = MLAParams( - q_lora_rank=q_lora_rank, - kv_lora_rank=kv_lora_rank, - qk_rope_head_dim=qk_rope_head_dim, - qk_nope_head_dim=qk_nope_head_dim, - v_head_dim=v_head_dim, - predicted_tokens_per_seq=predicted_tokens_per_seq, - ) - else: - mla_params = None - - return TrtllmAttention( - layer_idx, - num_heads, - head_dim, - num_kv_heads, - pos_embd_params, - quant_config, - mla_params=mla_params, + mla_params = MLAParams( + q_lora_rank=q_lora_rank, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + qk_nope_head_dim=qk_nope_head_dim, + v_head_dim=v_head_dim, + predicted_tokens_per_seq=predicted_tokens_per_seq, ) + else: + mla_params = None - return attn_cls(layer_idx, num_heads, head_dim, num_kv_heads, quant_config) + return attn_cls( + layer_idx, + num_heads, + head_dim, + num_kv_heads, + quant_config=quant_config, + q_scaling=q_scaling, + pos_embd_params=pos_embd_params, + mla_params=mla_params, + ) diff --git a/tensorrt_llm/_torch/attention_backend/vanilla.py b/tensorrt_llm/_torch/attention_backend/vanilla.py index 8f672c5595..410fa31ba3 100644 --- a/tensorrt_llm/_torch/attention_backend/vanilla.py +++ b/tensorrt_llm/_torch/attention_backend/vanilla.py @@ -67,9 +67,10 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]): head_dim: int, num_kv_heads: Optional[int] = None, quant_config: Optional[QuantConfig] = None, + **kwargs, ): super().__init__(layer_idx, num_heads, head_dim, num_kv_heads, - quant_config) + quant_config, **kwargs) self.num_key_value_groups = self.num_heads // self.num_kv_heads def _single_request_update_kv_cache(self, k, v, kv_cache_tensor, seq_len, diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 99b80da4a5..55c7a71109 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1,3 +1,4 @@ +import math from typing import Optional import torch @@ -210,6 +211,7 @@ class MLA(nn.Module): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank @@ -384,14 +386,25 @@ class MLA(nn.Module): skip_create_weights=config.skip_create_weights, ) + def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + mscale_all_dim = pos_embd_params.rope.mscale_all_dim + scaling_factor = pos_embd_params.rope.scale + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + q_scaling = 1.0 / (mscale * mscale) + self.mha = create_attention( config.attn_backend, self.layer_idx, self.num_heads, - self.qk_nope_head_dim + self.qk_rope_head_dim, - self.num_key_value_heads, + head_dim=self.qk_head_dim, + num_kv_heads=self.num_key_value_heads, pos_embd_params=pos_embd_params, quant_config=quant_config, + q_scaling=q_scaling, is_mla_enable=True, q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, @@ -405,10 +418,11 @@ class MLA(nn.Module): config.attn_backend, self.layer_idx, self.num_heads, - self.kv_lora_rank + self.qk_rope_head_dim, - 1, # num_kv_heads + head_dim=self.kv_lora_rank + self.qk_rope_head_dim, + num_kv_heads=1, pos_embd_params=pos_embd_params, quant_config=quant_config, + q_scaling=q_scaling, is_mla_enable=True, q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, diff --git a/tests/integration/defs/pytest.ini b/tests/integration/defs/pytest.ini index 75b5de6ad2..fb57d29c42 100644 --- a/tests/integration/defs/pytest.ini +++ b/tests/integration/defs/pytest.ini @@ -1,5 +1,6 @@ [pytest] threadleak = True +threadleak_exclude = asyncio_\d+ junit_family=legacy addopts = --ignore-glob="*perf/test_perf.py" --ignore-glob="*test_list_validation.py" --ignore-glob="*llm-test-workspace*" --durations=0 -W ignore::DeprecationWarning markers = diff --git a/tests/unittest/pytest.ini b/tests/unittest/pytest.ini index 3a97ecf0c6..974a02c292 100644 --- a/tests/unittest/pytest.ini +++ b/tests/unittest/pytest.ini @@ -1,5 +1,6 @@ [pytest] threadleak = True +threadleak_exclude = asyncio_\d+ addopts = --durations=0 -W ignore::DeprecationWarning pythonpath = _torch/auto_deploy/_utils_test