mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[TRTLLM-10171][fix] Correct attention handling in ModelConfig and KVCacheManager (#10330)
Signed-off-by: Jaedeok Kim <jaedeokk@nvidia.com>
This commit is contained in:
parent
6ba04eba06
commit
a4dcc6a711
@ -495,10 +495,15 @@ class ModelConfig(Generic[TConfig]):
|
||||
# TODO smor- currently assuming no rnn layers, no MOE
|
||||
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
|
||||
|
||||
num_heads = self.pretrained_config.num_attention_heads // (
|
||||
self.mapping.tp_size * self.mapping.cp_size)
|
||||
# Attention DP should not shard attention heads; use attn_tp_size=1 in that case
|
||||
# so downstream KV calculations see the full (non-partitioned) head count.
|
||||
attn_tp_size = self.mapping.attn_tp_size if not self.mapping.enable_attention_dp else 1
|
||||
attn_cp_size = self.mapping.attn_cp_size
|
||||
|
||||
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
|
||||
num_heads = self.pretrained_config.num_attention_heads // (
|
||||
attn_tp_size * attn_cp_size)
|
||||
|
||||
hidden_size = self.pretrained_config.hidden_size // attn_tp_size
|
||||
|
||||
model_config_cpp = ModelConfigCpp(
|
||||
vocab_size=self.pretrained_config.vocab_size,
|
||||
@ -523,13 +528,12 @@ class ModelConfig(Generic[TConfig]):
|
||||
if isinstance(num_key_value_heads, (list, tuple)):
|
||||
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
|
||||
num_kv_heads_per_layer = [
|
||||
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
|
||||
kv_heads // (attn_tp_size * attn_cp_size)
|
||||
for kv_heads in num_key_value_heads
|
||||
]
|
||||
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
|
||||
else:
|
||||
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
|
||||
self.mapping.cp_size)
|
||||
num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size)
|
||||
model_config_cpp.set_num_kv_heads(num_kv_heads)
|
||||
|
||||
mlp_hidden_size = None
|
||||
|
||||
@ -1138,7 +1138,8 @@ class KVCacheManager(BaseResourceManager):
|
||||
tensor_parallelism=self.mapping.tp_size,
|
||||
pipeline_parallelism=self.mapping.pp_size,
|
||||
rank=self.mapping.rank,
|
||||
gpus_per_node=self.mapping.gpus_per_node)
|
||||
gpus_per_node=self.mapping.gpus_per_node,
|
||||
enable_attention_dp=self.mapping.enable_attention_dp)
|
||||
|
||||
window_size_to_layers = self._get_window_size_to_layers()
|
||||
logger.debug(f"window_size_to_layers: {window_size_to_layers}")
|
||||
|
||||
@ -17,6 +17,7 @@ l0_a10:
|
||||
- unittest/_torch/sampler/test_torch_sampler.py
|
||||
- unittest/_torch/sampler/test_torch_multi_arange.py
|
||||
- unittest/utils/test_util.py
|
||||
- unittest/_torch/test_model_config.py
|
||||
- unittest/_torch/modeling/test_modeling_mistral.py
|
||||
- unittest/_torch/modeling/test_modeling_pixtral.py
|
||||
- unittest/_torch/sampler/test_trtllm_sampler.py
|
||||
|
||||
90
tests/unittest/_torch/test_model_config.py
Normal file
90
tests/unittest/_torch/test_model_config.py
Normal file
@ -0,0 +1,90 @@
|
||||
import types
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
|
||||
def make_pretrained_config(
|
||||
*,
|
||||
num_attention_heads: int = 16,
|
||||
num_key_value_heads=8,
|
||||
head_dim: int | None = None,
|
||||
num_hidden_layers: int = 1,
|
||||
vocab_size: int = 3000,
|
||||
):
|
||||
# A minimal config object that provides the attributes used by
|
||||
# ModelConfig.get_bindings_model_config().
|
||||
hidden_size = head_dim * num_attention_heads
|
||||
intermediate_size = hidden_size * 4
|
||||
|
||||
return types.SimpleNamespace(
|
||||
architectures=["DummyArchitecture"],
|
||||
num_attention_heads=num_attention_heads,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
vocab_size=vocab_size,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_key_value_heads",
|
||||
[
|
||||
pytest.param(8, id="kv_heads_scalar"),
|
||||
pytest.param([8, 20], id="kv_heads_per_layer_varied"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("enable_attention_dp", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"mapping_kwargs",
|
||||
[
|
||||
# Same tp/cp sizes, but different ways of setting attention TP:
|
||||
# - No explicit attn_tp_size: Mapping infers it.
|
||||
# - Explicit attn_tp_size: Mapping uses the provided value.
|
||||
dict(world_size=8, tp_size=4, cp_size=2),
|
||||
dict(world_size=4, tp_size=2, cp_size=2, attn_tp_size=4),
|
||||
],
|
||||
)
|
||||
def test_get_bindings_model_config_attention_dp_attn_tp_override(
|
||||
enable_attention_dp, mapping_kwargs, num_key_value_heads
|
||||
):
|
||||
mapping = Mapping(enable_attention_dp=enable_attention_dp, **mapping_kwargs)
|
||||
cfg = make_pretrained_config(
|
||||
# Keep values consistent:
|
||||
# hidden_size = num_attention_heads * head_dim.
|
||||
num_attention_heads=16,
|
||||
head_dim=4,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
num_hidden_layers=2,
|
||||
)
|
||||
model_config = ModelConfig(pretrained_config=cfg, mapping=mapping)
|
||||
|
||||
tokens_per_block = 32
|
||||
bindings_cfg = model_config.get_bindings_model_config(tokens_per_block=tokens_per_block)
|
||||
|
||||
# bindings hidden_size is sharded by attn_tp_size and attn_cp_size.
|
||||
attn_tp_size = mapping.attn_tp_size if not mapping.enable_attention_dp else 1
|
||||
attn_cp_size = mapping.attn_cp_size
|
||||
assert bindings_cfg.num_heads == cfg.num_attention_heads // (attn_tp_size * attn_cp_size)
|
||||
# bindings hidden_size is sharded by attn_tp_size.
|
||||
assert bindings_cfg.hidden_size == cfg.hidden_size // attn_tp_size
|
||||
if isinstance(cfg.num_key_value_heads, (list, tuple)):
|
||||
expected_num_kv_heads_per_layer = [
|
||||
kv // (attn_tp_size * attn_cp_size) for kv in cfg.num_key_value_heads
|
||||
]
|
||||
assert list(bindings_cfg.num_kv_heads_per_layer) == expected_num_kv_heads_per_layer
|
||||
assert bindings_cfg.num_kv_heads(0) == expected_num_kv_heads_per_layer[0]
|
||||
else:
|
||||
assert bindings_cfg.num_kv_heads(0) == cfg.num_key_value_heads // (
|
||||
attn_tp_size * attn_cp_size
|
||||
)
|
||||
|
||||
# tp_size-dependent value (uses mapping.tp_size, not attn_tp_size).
|
||||
assert bindings_cfg.mlp_hidden_size == (cfg.intermediate_size // mapping.tp_size)
|
||||
assert bindings_cfg.tokens_per_block == tokens_per_block
|
||||
Loading…
Reference in New Issue
Block a user