chore: Refine attention backend interface. (#3271)

Refine attention backend interface.

Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
yuxianq 2025-04-09 02:34:53 +08:00 committed by GitHub
parent 7199588796
commit 7225bd8b91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 97 additions and 75 deletions

View File

@ -412,9 +412,10 @@ class FlashInferAttention(AttentionBackend[FlashInferAttentionMetadata]):
head_dim: int,
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs,
):
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
quant_config)
quant_config, **kwargs)
self.has_fp8_kv_cache = False
if quant_config and quant_config.layer_quant_mode.has_any_quant():

View File

@ -1,6 +1,6 @@
import copy
import enum
from dataclasses import dataclass, field
from enum import Enum, IntEnum
from functools import lru_cache
from typing import (Generic, List, Optional, Protocol, Tuple, Type, TypeVar,
Union)
@ -24,6 +24,14 @@ class AttentionRuntimeFeatures:
has_speculative_draft_tokens: bool = False
# The type of requests in qkv passed to attention
# Please keep sync with AttentionInputType in cpp/tensorrt_llm/thop/attentionOp.cpp
class AttentionInputType(IntEnum):
mixed = 0 # contains both context and generation
context_only = 1
generation_only = 2
@dataclass(kw_only=True)
class AttentionMetadata:
"""
@ -421,7 +429,7 @@ class PositionalEmbeddingParams:
TMetadata = TypeVar("TMetadata", bound=AttentionMetadata)
class PredefinedAttentionMask(str, enum.Enum):
class PredefinedAttentionMask(str, Enum):
"""
Predefined attention mask types
@ -450,6 +458,7 @@ class AttentionBackend(Generic[TMetadata]):
head_dim: int,
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs,
):
"""
Initialize the backend.

View File

@ -304,9 +304,10 @@ class StarAttention(AttentionBackend[StarAttentionMetadata]):
head_dim: int,
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs,
):
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
quant_config)
quant_config, **kwargs)
def forward(self,
q: torch.Tensor,

View File

@ -1,6 +1,4 @@
import math
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Optional
import torch
@ -10,20 +8,13 @@ from tensorrt_llm.functional import (AttentionMaskType, RopeEmbeddingUtils,
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
KVCacheParams, MLAParams, PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
from .interface import (AttentionBackend, AttentionInputType, AttentionMask,
AttentionMetadata, KVCacheParams, MLAParams,
PositionalEmbeddingParams, PredefinedAttentionMask,
RopeParams)
from .vanilla import VanillaAttention
# The type of requests in qkv passed to attention
# Please keep sync with AttentionInputType in cpp/tensorrt_llm/thop/attentionOp.cpp
class AttentionInputType(IntEnum):
mixed = 0 # contains both context and generation
context_only = 1
generation_only = 2
@dataclass(kw_only=True, init=False)
class TrtllmAttentionWrapper:
sequence_length: torch.Tensor
@ -65,11 +56,11 @@ class TrtllmAttentionWrapper:
rotary_embedding_original_max_positions: int
use_paged_context_fmha: bool
is_mla_enable: bool
q_lora_rank: int
kv_lora_rank: int
qk_rope_head_dim: int
qk_nope_head_dim: int
v_head_dim: int
q_lora_rank: Optional[int]
kv_lora_rank: Optional[int]
qk_rope_head_dim: Optional[int]
qk_nope_head_dim: Optional[int]
v_head_dim: Optional[int]
kwargs: dict
def __init__(
@ -80,6 +71,7 @@ class TrtllmAttentionWrapper:
num_kv_heads: Optional[int] = None,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
quant_config: Optional[QuantConfig] = None,
q_scaling: Optional[float] = None,
mla_params: Optional[MLAParams] = None,
**kwargs,
):
@ -101,7 +93,7 @@ class TrtllmAttentionWrapper:
rope_params = RopeParams()
self.is_mla_enable = mla_params is not None
self.q_scaling = 1.0
self.q_scaling = q_scaling or 1.0
self.mla_rope_params = None
self.predicted_tokens_per_seq = 1
@ -114,18 +106,6 @@ class TrtllmAttentionWrapper:
self.predicted_tokens_per_seq = mla_params.predicted_tokens_per_seq
self.rotary_embedding_dim = 0
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
mscale_all_dim = rope_params.mscale_all_dim
scaling_factor = rope_params.scale
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.q_scaling = 1.0 / (mscale * mscale)
self.rotary_inv_freq, self.rotary_cos_sin = rope_params.create_deepseek_rope_const_params(
self.qk_rope_head_dim)
self.rotary_embedding_scale_type = RotaryScalingType.none
@ -165,6 +145,7 @@ class TrtllmAttentionWrapper:
*,
tokens_per_block: Optional[int] = None,
max_num_requests: int,
max_sequence_length: int,
max_context_length: int,
attention_window_size: Optional[int] = None,
sink_token_length: int = 0,
@ -196,6 +177,7 @@ class TrtllmAttentionWrapper:
Args:
tokens_per_block (int): Token number per KV cache block.
max_num_requests (int): Max request number per batch.
max_sequence_length (int): Max sequence length.
max_context_length (int): Max context length per context-phase sequence.
attention_window_size (int): Max token number attended in windowed attention.
sink_token_length (int): Sink token number in StreamingLLM.
@ -220,7 +202,7 @@ class TrtllmAttentionWrapper:
self.tokens_per_block = tokens_per_block
self.max_num_requests = max_num_requests
self.max_context_length = max_context_length
self.attention_window_size = attention_window_size or max_context_length
self.attention_window_size = attention_window_size or max_sequence_length
self.sink_token_length = sink_token_length
self.beam_width = beam_width
self.sequence_length = sequence_length
@ -608,9 +590,11 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
num_heads: int,
head_dim: int,
num_kv_heads: Optional[int] = None,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
quant_config: Optional[QuantConfig] = None,
q_scaling: Optional[float] = None,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
mla_params: Optional[MLAParams] = None,
**kwargs,
):
"""
Initialize the backend.
@ -619,13 +603,22 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
num_heads (int): The number of query heads.
head_dim (int): The size of each attention head (hidden_size // num_heads).
num_kv_heads (int): The number of kv heads. Defaults to num_heads if None.
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
q_scaling (float): Scaling factor for QK. Defaults to 1.0 if None.
pos_embd_params (PositionalEmbeddingParams): Optional parameters defining how positional embedding should be applied.
If None, positional embedding should be applied by the model before calling the backend.
Otherwise, the backend is in-charge of applying positional embedding and may cache K without embedding it first.
quant_config (QuantConfig): Optional quantization configuration. If None, no quantization is applied.
mla_params (MLAParams): Optional parameters for MLA. If None, MLA is not enabled.
"""
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
quant_config)
super().__init__(layer_idx,
num_heads,
head_dim,
num_kv_heads,
quant_config,
q_scaling=q_scaling,
pos_embd_params=pos_embd_params,
mla_params=mla_params,
**kwargs)
self.wrapper = TrtllmAttentionWrapper(
layer_idx,
num_heads,
@ -633,6 +626,7 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
num_kv_heads,
pos_embd_params=pos_embd_params,
quant_config=quant_config,
q_scaling=q_scaling,
mla_params=mla_params,
)
@ -722,7 +716,9 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
self.wrapper.plan(
tokens_per_block=metadata.tokens_per_block,
max_num_requests=metadata.max_num_requests,
max_context_length=metadata.max_seq_len,
max_sequence_length=metadata.max_seq_len,
max_context_length=min(metadata.max_seq_len - 1,
metadata.max_num_tokens),
attention_window_size=None,
sink_token_length=0,
beam_width=1,

View File

@ -32,42 +32,40 @@ def create_attention(
num_kv_heads: Optional[int] = None,
pos_embd_params: Optional[PositionalEmbeddingParams] = None,
quant_config: Optional[QuantConfig] = None,
is_mla_enable: Optional[bool] = False,
q_lora_rank: Optional[int] = 0,
kv_lora_rank: Optional[int] = 0,
qk_rope_head_dim: Optional[int] = 0,
qk_nope_head_dim: Optional[int] = 0,
v_head_dim: Optional[int] = 0,
q_scaling: Optional[float] = None,
is_mla_enable: bool = False,
q_lora_rank: Optional[int] = None,
kv_lora_rank: Optional[int] = None,
qk_rope_head_dim: Optional[int] = None,
qk_nope_head_dim: Optional[int] = None,
v_head_dim: Optional[int] = None,
predicted_tokens_per_seq: Optional[int] = 1,
):
attn_cls = get_attention_backend(backend_name)
if is_mla_enable:
assert attn_cls == TrtllmAttention
assert (q_lora_rank > 0 and kv_lora_rank > 0 and qk_rope_head_dim > 0
and qk_nope_head_dim > 0 and v_head_dim > 0)
if attn_cls == TrtllmAttention:
if is_mla_enable:
mla_params = MLAParams(
q_lora_rank=q_lora_rank,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
qk_nope_head_dim=qk_nope_head_dim,
v_head_dim=v_head_dim,
predicted_tokens_per_seq=predicted_tokens_per_seq,
)
else:
mla_params = None
return TrtllmAttention(
layer_idx,
num_heads,
head_dim,
num_kv_heads,
pos_embd_params,
quant_config,
mla_params=mla_params,
mla_params = MLAParams(
q_lora_rank=q_lora_rank,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
qk_nope_head_dim=qk_nope_head_dim,
v_head_dim=v_head_dim,
predicted_tokens_per_seq=predicted_tokens_per_seq,
)
else:
mla_params = None
return attn_cls(layer_idx, num_heads, head_dim, num_kv_heads, quant_config)
return attn_cls(
layer_idx,
num_heads,
head_dim,
num_kv_heads,
quant_config=quant_config,
q_scaling=q_scaling,
pos_embd_params=pos_embd_params,
mla_params=mla_params,
)

View File

@ -67,9 +67,10 @@ class VanillaAttention(AttentionBackend[VanillaAttentionMetadata]):
head_dim: int,
num_kv_heads: Optional[int] = None,
quant_config: Optional[QuantConfig] = None,
**kwargs,
):
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
quant_config)
quant_config, **kwargs)
self.num_key_value_groups = self.num_heads // self.num_kv_heads
def _single_request_update_kv_cache(self, k, v, kv_cache_tensor, seq_len,

View File

@ -1,3 +1,4 @@
import math
from typing import Optional
import torch
@ -210,6 +211,7 @@ class MLA(nn.Module):
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
@ -384,14 +386,25 @@ class MLA(nn.Module):
skip_create_weights=config.skip_create_weights,
)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0
mscale_all_dim = pos_embd_params.rope.mscale_all_dim
scaling_factor = pos_embd_params.rope.scale
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
q_scaling = 1.0 / (mscale * mscale)
self.mha = create_attention(
config.attn_backend,
self.layer_idx,
self.num_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
self.num_key_value_heads,
head_dim=self.qk_head_dim,
num_kv_heads=self.num_key_value_heads,
pos_embd_params=pos_embd_params,
quant_config=quant_config,
q_scaling=q_scaling,
is_mla_enable=True,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
@ -405,10 +418,11 @@ class MLA(nn.Module):
config.attn_backend,
self.layer_idx,
self.num_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
1, # num_kv_heads
head_dim=self.kv_lora_rank + self.qk_rope_head_dim,
num_kv_heads=1,
pos_embd_params=pos_embd_params,
quant_config=quant_config,
q_scaling=q_scaling,
is_mla_enable=True,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,

View File

@ -1,5 +1,6 @@
[pytest]
threadleak = True
threadleak_exclude = asyncio_\d+
junit_family=legacy
addopts = --ignore-glob="*perf/test_perf.py" --ignore-glob="*test_list_validation.py" --ignore-glob="*llm-test-workspace*" --durations=0 -W ignore::DeprecationWarning
markers =

View File

@ -1,5 +1,6 @@
[pytest]
threadleak = True
threadleak_exclude = asyncio_\d+
addopts = --durations=0 -W ignore::DeprecationWarning
pythonpath =
_torch/auto_deploy/_utils_test