[TRTLLM-10171][fix] Correct attention handling in ModelConfig and KVCacheManager (#10330)

Signed-off-by: Jaedeok Kim <jaedeokk@nvidia.com>
This commit is contained in:
Jaedeok Kim 2026-01-04 20:07:30 +09:00 committed by GitHub
parent 6ba04eba06
commit a4dcc6a711
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 103 additions and 7 deletions

View File

@ -495,10 +495,15 @@ class ModelConfig(Generic[TConfig]):
# TODO smor- currently assuming no rnn layers, no MOE
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
num_heads = self.pretrained_config.num_attention_heads // (
self.mapping.tp_size * self.mapping.cp_size)
# Attention DP should not shard attention heads; use attn_tp_size=1 in that case
# so downstream KV calculations see the full (non-partitioned) head count.
attn_tp_size = self.mapping.attn_tp_size if not self.mapping.enable_attention_dp else 1
attn_cp_size = self.mapping.attn_cp_size
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
num_heads = self.pretrained_config.num_attention_heads // (
attn_tp_size * attn_cp_size)
hidden_size = self.pretrained_config.hidden_size // attn_tp_size
model_config_cpp = ModelConfigCpp(
vocab_size=self.pretrained_config.vocab_size,
@ -523,13 +528,12 @@ class ModelConfig(Generic[TConfig]):
if isinstance(num_key_value_heads, (list, tuple)):
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
num_kv_heads_per_layer = [
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
kv_heads // (attn_tp_size * attn_cp_size)
for kv_heads in num_key_value_heads
]
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
else:
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
self.mapping.cp_size)
num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size)
model_config_cpp.set_num_kv_heads(num_kv_heads)
mlp_hidden_size = None

View File

@ -1138,7 +1138,8 @@ class KVCacheManager(BaseResourceManager):
tensor_parallelism=self.mapping.tp_size,
pipeline_parallelism=self.mapping.pp_size,
rank=self.mapping.rank,
gpus_per_node=self.mapping.gpus_per_node)
gpus_per_node=self.mapping.gpus_per_node,
enable_attention_dp=self.mapping.enable_attention_dp)
window_size_to_layers = self._get_window_size_to_layers()
logger.debug(f"window_size_to_layers: {window_size_to_layers}")

View File

@ -17,6 +17,7 @@ l0_a10:
- unittest/_torch/sampler/test_torch_sampler.py
- unittest/_torch/sampler/test_torch_multi_arange.py
- unittest/utils/test_util.py
- unittest/_torch/test_model_config.py
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
- unittest/_torch/sampler/test_trtllm_sampler.py

View File

@ -0,0 +1,90 @@
import types
import pytest
import torch
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm.mapping import Mapping
def make_pretrained_config(
*,
num_attention_heads: int = 16,
num_key_value_heads=8,
head_dim: int | None = None,
num_hidden_layers: int = 1,
vocab_size: int = 3000,
):
# A minimal config object that provides the attributes used by
# ModelConfig.get_bindings_model_config().
hidden_size = head_dim * num_attention_heads
intermediate_size = hidden_size * 4
return types.SimpleNamespace(
architectures=["DummyArchitecture"],
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
num_hidden_layers=num_hidden_layers,
vocab_size=vocab_size,
torch_dtype=torch.float16,
)
@pytest.mark.parametrize(
"num_key_value_heads",
[
pytest.param(8, id="kv_heads_scalar"),
pytest.param([8, 20], id="kv_heads_per_layer_varied"),
],
)
@pytest.mark.parametrize("enable_attention_dp", [False, True])
@pytest.mark.parametrize(
"mapping_kwargs",
[
# Same tp/cp sizes, but different ways of setting attention TP:
# - No explicit attn_tp_size: Mapping infers it.
# - Explicit attn_tp_size: Mapping uses the provided value.
dict(world_size=8, tp_size=4, cp_size=2),
dict(world_size=4, tp_size=2, cp_size=2, attn_tp_size=4),
],
)
def test_get_bindings_model_config_attention_dp_attn_tp_override(
enable_attention_dp, mapping_kwargs, num_key_value_heads
):
mapping = Mapping(enable_attention_dp=enable_attention_dp, **mapping_kwargs)
cfg = make_pretrained_config(
# Keep values consistent:
# hidden_size = num_attention_heads * head_dim.
num_attention_heads=16,
head_dim=4,
num_key_value_heads=num_key_value_heads,
num_hidden_layers=2,
)
model_config = ModelConfig(pretrained_config=cfg, mapping=mapping)
tokens_per_block = 32
bindings_cfg = model_config.get_bindings_model_config(tokens_per_block=tokens_per_block)
# bindings hidden_size is sharded by attn_tp_size and attn_cp_size.
attn_tp_size = mapping.attn_tp_size if not mapping.enable_attention_dp else 1
attn_cp_size = mapping.attn_cp_size
assert bindings_cfg.num_heads == cfg.num_attention_heads // (attn_tp_size * attn_cp_size)
# bindings hidden_size is sharded by attn_tp_size.
assert bindings_cfg.hidden_size == cfg.hidden_size // attn_tp_size
if isinstance(cfg.num_key_value_heads, (list, tuple)):
expected_num_kv_heads_per_layer = [
kv // (attn_tp_size * attn_cp_size) for kv in cfg.num_key_value_heads
]
assert list(bindings_cfg.num_kv_heads_per_layer) == expected_num_kv_heads_per_layer
assert bindings_cfg.num_kv_heads(0) == expected_num_kv_heads_per_layer[0]
else:
assert bindings_cfg.num_kv_heads(0) == cfg.num_key_value_heads // (
attn_tp_size * attn_cp_size
)
# tp_size-dependent value (uses mapping.tp_size, not attn_tp_size).
assert bindings_cfg.mlp_hidden_size == (cfg.intermediate_size // mapping.tp_size)
assert bindings_cfg.tokens_per_block == tokens_per_block