mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
chore: Refine attention backend interface. (#3271)
Refine attention backend interface. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
7199588796
commit
7225bd8b91
@ -412,9 +412,10 @@ class FlashInferAttention(AttentionBackend[FlashInferAttentionMetadata]):
|
||||
head_dim: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
|
||||
quant_config)
|
||||
quant_config, **kwargs)
|
||||
|
||||
self.has_fp8_kv_cache = False
|
||||
if quant_config and quant_config.layer_quant_mode.has_any_quant():
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import copy
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, IntEnum
|
||||
from functools import lru_cache
|
||||
from typing import (Generic, List, Optional, Protocol, Tuple, Type, TypeVar,
|
||||
Union)
|
||||
@ -24,6 +24,14 @@ class AttentionRuntimeFeatures:
|
||||
has_speculative_draft_tokens: bool = False
|
||||
|
||||
|
||||
# The type of requests in qkv passed to attention
|
||||
# Please keep sync with AttentionInputType in cpp/tensorrt_llm/thop/attentionOp.cpp
|
||||
class AttentionInputType(IntEnum):
|
||||
mixed = 0 # contains both context and generation
|
||||
context_only = 1
|
||||
generation_only = 2
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class AttentionMetadata:
|
||||
"""
|
||||
@ -421,7 +429,7 @@ class PositionalEmbeddingParams:
|
||||
TMetadata = TypeVar("TMetadata", bound=AttentionMetadata)
|
||||
|
||||
|
||||
class PredefinedAttentionMask(str, enum.Enum):
|
||||
class PredefinedAttentionMask(str, Enum):
|
||||
"""
|
||||
Predefined attention mask types
|
||||
|
||||
@ -450,6 +458,7 @@ class AttentionBackend(Generic[TMetadata]):
|
||||
head_dim: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the backend.
|
||||
|
||||
@ -304,9 +304,10 @@ class StarAttention(AttentionBackend[StarAttentionMetadata]):
|
||||
head_dim: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
|
||||
quant_config)
|
||||
quant_config, **kwargs)
|
||||
|
||||
def forward(self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@ -1,6 +1,4 @@
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -10,20 +8,13 @@ from tensorrt_llm.functional import (AttentionMaskType, RopeEmbeddingUtils,
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
|
||||
KVCacheParams, MLAParams, PositionalEmbeddingParams,
|
||||
PredefinedAttentionMask, RopeParams)
|
||||
from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
|
||||
AttentionMetadata, KVCacheParams, MLAParams,
|
||||
PositionalEmbeddingParams, PredefinedAttentionMask,
|
||||
RopeParams)
|
||||
from .vanilla import VanillaAttention
|
||||
|
||||
|
||||
# The type of requests in qkv passed to attention
|
||||
# Please keep sync with AttentionInputType in cpp/tensorrt_llm/thop/attentionOp.cpp
|
||||
class AttentionInputType(IntEnum):
|
||||
mixed = 0 # contains both context and generation
|
||||
context_only = 1
|
||||
generation_only = 2
|
||||
|
||||
|
||||
@dataclass(kw_only=True, init=False)
|
||||
class TrtllmAttentionWrapper:
|
||||
sequence_length: torch.Tensor
|
||||
@ -65,11 +56,11 @@ class TrtllmAttentionWrapper:
|
||||
rotary_embedding_original_max_positions: int
|
||||
use_paged_context_fmha: bool
|
||||
is_mla_enable: bool
|
||||
q_lora_rank: int
|
||||
kv_lora_rank: int
|
||||
qk_rope_head_dim: int
|
||||
qk_nope_head_dim: int
|
||||
v_head_dim: int
|
||||
q_lora_rank: Optional[int]
|
||||
kv_lora_rank: Optional[int]
|
||||
qk_rope_head_dim: Optional[int]
|
||||
qk_nope_head_dim: Optional[int]
|
||||
v_head_dim: Optional[int]
|
||||
kwargs: dict
|
||||
|
||||
def __init__(
|
||||
@ -80,6 +71,7 @@ class TrtllmAttentionWrapper:
|
||||
num_kv_heads: Optional[int] = None,
|
||||
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
q_scaling: Optional[float] = None,
|
||||
mla_params: Optional[MLAParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -101,7 +93,7 @@ class TrtllmAttentionWrapper:
|
||||
rope_params = RopeParams()
|
||||
|
||||
self.is_mla_enable = mla_params is not None
|
||||
self.q_scaling = 1.0
|
||||
self.q_scaling = q_scaling or 1.0
|
||||
self.mla_rope_params = None
|
||||
self.predicted_tokens_per_seq = 1
|
||||
|
||||
@ -114,18 +106,6 @@ class TrtllmAttentionWrapper:
|
||||
self.predicted_tokens_per_seq = mla_params.predicted_tokens_per_seq
|
||||
|
||||
self.rotary_embedding_dim = 0
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
mscale_all_dim = rope_params.mscale_all_dim
|
||||
scaling_factor = rope_params.scale
|
||||
if mscale_all_dim:
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
self.q_scaling = 1.0 / (mscale * mscale)
|
||||
|
||||
self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_deepseek_rope_const_params(
|
||||
self.qk_rope_head_dim)
|
||||
self.rotary_embedding_scale_type = RotaryScalingType.none
|
||||
@ -165,6 +145,7 @@ class TrtllmAttentionWrapper:
|
||||
*,
|
||||
tokens_per_block: Optional[int] = None,
|
||||
max_num_requests: int,
|
||||
max_sequence_length: int,
|
||||
max_context_length: int,
|
||||
attention_window_size: Optional[int] = None,
|
||||
sink_token_length: int = 0,
|
||||
@ -196,6 +177,7 @@ class TrtllmAttentionWrapper:
|
||||
Args:
|
||||
tokens_per_block (int): Token number per KV cache block.
|
||||
max_num_requests (int): Max request number per batch.
|
||||
max_sequence_length (int): Max sequence length.
|
||||
max_context_length (int): Max context length per context-phase sequence.
|
||||
attention_window_size (int): Max token number attended in windowed attention.
|
||||
sink_token_length (int): Sink token number in StreamingLLM.
|
||||
@ -220,7 +202,7 @@ class TrtllmAttentionWrapper:
|
||||
self.tokens_per_block = tokens_per_block
|
||||
self.max_num_requests = max_num_requests
|
||||
self.max_context_length = max_context_length
|
||||
self.attention_window_size = attention_window_size or max_context_length
|
||||
self.attention_window_size = attention_window_size or max_sequence_length
|
||||
self.sink_token_length = sink_token_length
|
||||
self.beam_width = beam_width
|
||||
self.sequence_length = sequence_length
|
||||
@ -608,9 +590,11 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
q_scaling: Optional[float] = None,
|
||||
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
|
||||
mla_params: Optional[MLAParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the backend.
|
||||
@ -619,13 +603,22 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
num_heads (int): The number of query heads.
|
||||
head_dim (int): The size of each attention head (hidden_size // num_heads).
|
||||
num_kv_heads (int): The number of kv heads. Defaults to num_heads if None.
|
||||
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
|
||||
q_scaling (float): Scaling factor for QK. Defaults to 1.0 if None.
|
||||
pos_embd_params (PositionalEmbeddingParams): Optional parameters defining how positional embedding should be applied.
|
||||
If None, positional embedding should be applied by the model before calling the backend.
|
||||
Otherwise, the backend is in-charge of applying positional embedding and may cache K without embedding it first.
|
||||
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
|
||||
mla_params (MLAParams): Optional parameters for MLA. If None, MLA is not enabled.
|
||||
"""
|
||||
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
|
||||
quant_config)
|
||||
super().__init__(layer_idx,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
quant_config,
|
||||
q_scaling=q_scaling,
|
||||
pos_embd_params=pos_embd_params,
|
||||
mla_params=mla_params,
|
||||
**kwargs)
|
||||
self.wrapper = TrtllmAttentionWrapper(
|
||||
layer_idx,
|
||||
num_heads,
|
||||
@ -633,6 +626,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
num_kv_heads,
|
||||
pos_embd_params=pos_embd_params,
|
||||
quant_config=quant_config,
|
||||
q_scaling=q_scaling,
|
||||
mla_params=mla_params,
|
||||
)
|
||||
|
||||
@ -722,7 +716,9 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
self.wrapper.plan(
|
||||
tokens_per_block=metadata.tokens_per_block,
|
||||
max_num_requests=metadata.max_num_requests,
|
||||
max_context_length=metadata.max_seq_len,
|
||||
max_sequence_length=metadata.max_seq_len,
|
||||
max_context_length=min(metadata.max_seq_len - 1,
|
||||
metadata.max_num_tokens),
|
||||
attention_window_size=None,
|
||||
sink_token_length=0,
|
||||
beam_width=1,
|
||||
|
||||
@ -32,42 +32,40 @@ def create_attention(
|
||||
num_kv_heads: Optional[int] = None,
|
||||
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
is_mla_enable: Optional[bool] = False,
|
||||
q_lora_rank: Optional[int] = 0,
|
||||
kv_lora_rank: Optional[int] = 0,
|
||||
qk_rope_head_dim: Optional[int] = 0,
|
||||
qk_nope_head_dim: Optional[int] = 0,
|
||||
v_head_dim: Optional[int] = 0,
|
||||
q_scaling: Optional[float] = None,
|
||||
is_mla_enable: bool = False,
|
||||
q_lora_rank: Optional[int] = None,
|
||||
kv_lora_rank: Optional[int] = None,
|
||||
qk_rope_head_dim: Optional[int] = None,
|
||||
qk_nope_head_dim: Optional[int] = None,
|
||||
v_head_dim: Optional[int] = None,
|
||||
predicted_tokens_per_seq: Optional[int] = 1,
|
||||
):
|
||||
|
||||
attn_cls = get_attention_backend(backend_name)
|
||||
|
||||
if is_mla_enable:
|
||||
assert attn_cls == TrtllmAttention
|
||||
assert (q_lora_rank > 0 and kv_lora_rank > 0 and qk_rope_head_dim > 0
|
||||
and qk_nope_head_dim > 0 and v_head_dim > 0)
|
||||
|
||||
if attn_cls == TrtllmAttention:
|
||||
if is_mla_enable:
|
||||
mla_params = MLAParams(
|
||||
q_lora_rank=q_lora_rank,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
predicted_tokens_per_seq=predicted_tokens_per_seq,
|
||||
)
|
||||
else:
|
||||
mla_params = None
|
||||
|
||||
return TrtllmAttention(
|
||||
layer_idx,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
pos_embd_params,
|
||||
quant_config,
|
||||
mla_params=mla_params,
|
||||
mla_params = MLAParams(
|
||||
q_lora_rank=q_lora_rank,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
predicted_tokens_per_seq=predicted_tokens_per_seq,
|
||||
)
|
||||
else:
|
||||
mla_params = None
|
||||
|
||||
return attn_cls(layer_idx, num_heads, head_dim, num_kv_heads, quant_config)
|
||||
return attn_cls(
|
||||
layer_idx,
|
||||
num_heads,
|
||||
head_dim,
|
||||
num_kv_heads,
|
||||
quant_config=quant_config,
|
||||
q_scaling=q_scaling,
|
||||
pos_embd_params=pos_embd_params,
|
||||
mla_params=mla_params,
|
||||
)
|
||||
|
||||
@ -67,9 +67,10 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
|
||||
head_dim: int,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
quant_config: Optional[QuantConfig] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
|
||||
quant_config)
|
||||
quant_config, **kwargs)
|
||||
self.num_key_value_groups = self.num_heads // self.num_kv_heads
|
||||
|
||||
def _single_request_update_kv_cache(self, k, v, kv_cache_tensor, seq_len,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
@ -210,6 +211,7 @@ class MLA(nn.Module):
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
@ -384,14 +386,25 @@ class MLA(nn.Module):
|
||||
skip_create_weights=config.skip_create_weights,
|
||||
)
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
mscale_all_dim = pos_embd_params.rope.mscale_all_dim
|
||||
scaling_factor = pos_embd_params.rope.scale
|
||||
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
|
||||
q_scaling = 1.0 / (mscale * mscale)
|
||||
|
||||
self.mha = create_attention(
|
||||
config.attn_backend,
|
||||
self.layer_idx,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
||||
self.num_key_value_heads,
|
||||
head_dim=self.qk_head_dim,
|
||||
num_kv_heads=self.num_key_value_heads,
|
||||
pos_embd_params=pos_embd_params,
|
||||
quant_config=quant_config,
|
||||
q_scaling=q_scaling,
|
||||
is_mla_enable=True,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
@ -405,10 +418,11 @@ class MLA(nn.Module):
|
||||
config.attn_backend,
|
||||
self.layer_idx,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
1, # num_kv_heads
|
||||
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
num_kv_heads=1,
|
||||
pos_embd_params=pos_embd_params,
|
||||
quant_config=quant_config,
|
||||
q_scaling=q_scaling,
|
||||
is_mla_enable=True,
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
[pytest]
|
||||
threadleak = True
|
||||
threadleak_exclude = asyncio_\d+
|
||||
junit_family=legacy
|
||||
addopts = --ignore-glob="*perf/test_perf.py" --ignore-glob="*test_list_validation.py" --ignore-glob="*llm-test-workspace*" --durations=0 -W ignore::DeprecationWarning
|
||||
markers =
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
[pytest]
|
||||
threadleak = True
|
||||
threadleak_exclude = asyncio_\d+
|
||||
addopts = --durations=0 -W ignore::DeprecationWarning
|
||||
pythonpath =
|
||||
_torch/auto_deploy/_utils_test
|
||||
|
||||
Loading…
Reference in New Issue
Block a user