mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[TRTLLM-10171][fix] Correct attention handling in ModelConfig and KVCacheManager (#10330)
Signed-off-by: Jaedeok Kim <jaedeokk@nvidia.com>
This commit is contained in:
parent
6ba04eba06
commit
a4dcc6a711
@ -495,10 +495,15 @@ class ModelConfig(Generic[TConfig]):
|
|||||||
# TODO smor- currently assuming no rnn layers, no MOE
|
# TODO smor- currently assuming no rnn layers, no MOE
|
||||||
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
|
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
|
||||||
|
|
||||||
num_heads = self.pretrained_config.num_attention_heads // (
|
# Attention DP should not shard attention heads; use attn_tp_size=1 in that case
|
||||||
self.mapping.tp_size * self.mapping.cp_size)
|
# so downstream KV calculations see the full (non-partitioned) head count.
|
||||||
|
attn_tp_size = self.mapping.attn_tp_size if not self.mapping.enable_attention_dp else 1
|
||||||
|
attn_cp_size = self.mapping.attn_cp_size
|
||||||
|
|
||||||
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
|
num_heads = self.pretrained_config.num_attention_heads // (
|
||||||
|
attn_tp_size * attn_cp_size)
|
||||||
|
|
||||||
|
hidden_size = self.pretrained_config.hidden_size // attn_tp_size
|
||||||
|
|
||||||
model_config_cpp = ModelConfigCpp(
|
model_config_cpp = ModelConfigCpp(
|
||||||
vocab_size=self.pretrained_config.vocab_size,
|
vocab_size=self.pretrained_config.vocab_size,
|
||||||
@ -523,13 +528,12 @@ class ModelConfig(Generic[TConfig]):
|
|||||||
if isinstance(num_key_value_heads, (list, tuple)):
|
if isinstance(num_key_value_heads, (list, tuple)):
|
||||||
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
|
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
|
||||||
num_kv_heads_per_layer = [
|
num_kv_heads_per_layer = [
|
||||||
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
|
kv_heads // (attn_tp_size * attn_cp_size)
|
||||||
for kv_heads in num_key_value_heads
|
for kv_heads in num_key_value_heads
|
||||||
]
|
]
|
||||||
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
|
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
|
||||||
else:
|
else:
|
||||||
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
|
num_kv_heads = num_key_value_heads // (attn_tp_size * attn_cp_size)
|
||||||
self.mapping.cp_size)
|
|
||||||
model_config_cpp.set_num_kv_heads(num_kv_heads)
|
model_config_cpp.set_num_kv_heads(num_kv_heads)
|
||||||
|
|
||||||
mlp_hidden_size = None
|
mlp_hidden_size = None
|
||||||
|
|||||||
@ -1138,7 +1138,8 @@ class KVCacheManager(BaseResourceManager):
|
|||||||
tensor_parallelism=self.mapping.tp_size,
|
tensor_parallelism=self.mapping.tp_size,
|
||||||
pipeline_parallelism=self.mapping.pp_size,
|
pipeline_parallelism=self.mapping.pp_size,
|
||||||
rank=self.mapping.rank,
|
rank=self.mapping.rank,
|
||||||
gpus_per_node=self.mapping.gpus_per_node)
|
gpus_per_node=self.mapping.gpus_per_node,
|
||||||
|
enable_attention_dp=self.mapping.enable_attention_dp)
|
||||||
|
|
||||||
window_size_to_layers = self._get_window_size_to_layers()
|
window_size_to_layers = self._get_window_size_to_layers()
|
||||||
logger.debug(f"window_size_to_layers: {window_size_to_layers}")
|
logger.debug(f"window_size_to_layers: {window_size_to_layers}")
|
||||||
|
|||||||
@ -17,6 +17,7 @@ l0_a10:
|
|||||||
- unittest/_torch/sampler/test_torch_sampler.py
|
- unittest/_torch/sampler/test_torch_sampler.py
|
||||||
- unittest/_torch/sampler/test_torch_multi_arange.py
|
- unittest/_torch/sampler/test_torch_multi_arange.py
|
||||||
- unittest/utils/test_util.py
|
- unittest/utils/test_util.py
|
||||||
|
- unittest/_torch/test_model_config.py
|
||||||
- unittest/_torch/modeling/test_modeling_mistral.py
|
- unittest/_torch/modeling/test_modeling_mistral.py
|
||||||
- unittest/_torch/modeling/test_modeling_pixtral.py
|
- unittest/_torch/modeling/test_modeling_pixtral.py
|
||||||
- unittest/_torch/sampler/test_trtllm_sampler.py
|
- unittest/_torch/sampler/test_trtllm_sampler.py
|
||||||
|
|||||||
90
tests/unittest/_torch/test_model_config.py
Normal file
90
tests/unittest/_torch/test_model_config.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import types
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from tensorrt_llm._torch.model_config import ModelConfig
|
||||||
|
from tensorrt_llm.mapping import Mapping
|
||||||
|
|
||||||
|
|
||||||
|
def make_pretrained_config(
|
||||||
|
*,
|
||||||
|
num_attention_heads: int = 16,
|
||||||
|
num_key_value_heads=8,
|
||||||
|
head_dim: int | None = None,
|
||||||
|
num_hidden_layers: int = 1,
|
||||||
|
vocab_size: int = 3000,
|
||||||
|
):
|
||||||
|
# A minimal config object that provides the attributes used by
|
||||||
|
# ModelConfig.get_bindings_model_config().
|
||||||
|
hidden_size = head_dim * num_attention_heads
|
||||||
|
intermediate_size = hidden_size * 4
|
||||||
|
|
||||||
|
return types.SimpleNamespace(
|
||||||
|
architectures=["DummyArchitecture"],
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_hidden_layers=num_hidden_layers,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_key_value_heads",
|
||||||
|
[
|
||||||
|
pytest.param(8, id="kv_heads_scalar"),
|
||||||
|
pytest.param([8, 20], id="kv_heads_per_layer_varied"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize("enable_attention_dp", [False, True])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mapping_kwargs",
|
||||||
|
[
|
||||||
|
# Same tp/cp sizes, but different ways of setting attention TP:
|
||||||
|
# - No explicit attn_tp_size: Mapping infers it.
|
||||||
|
# - Explicit attn_tp_size: Mapping uses the provided value.
|
||||||
|
dict(world_size=8, tp_size=4, cp_size=2),
|
||||||
|
dict(world_size=4, tp_size=2, cp_size=2, attn_tp_size=4),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_bindings_model_config_attention_dp_attn_tp_override(
|
||||||
|
enable_attention_dp, mapping_kwargs, num_key_value_heads
|
||||||
|
):
|
||||||
|
mapping = Mapping(enable_attention_dp=enable_attention_dp, **mapping_kwargs)
|
||||||
|
cfg = make_pretrained_config(
|
||||||
|
# Keep values consistent:
|
||||||
|
# hidden_size = num_attention_heads * head_dim.
|
||||||
|
num_attention_heads=16,
|
||||||
|
head_dim=4,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
num_hidden_layers=2,
|
||||||
|
)
|
||||||
|
model_config = ModelConfig(pretrained_config=cfg, mapping=mapping)
|
||||||
|
|
||||||
|
tokens_per_block = 32
|
||||||
|
bindings_cfg = model_config.get_bindings_model_config(tokens_per_block=tokens_per_block)
|
||||||
|
|
||||||
|
# bindings hidden_size is sharded by attn_tp_size and attn_cp_size.
|
||||||
|
attn_tp_size = mapping.attn_tp_size if not mapping.enable_attention_dp else 1
|
||||||
|
attn_cp_size = mapping.attn_cp_size
|
||||||
|
assert bindings_cfg.num_heads == cfg.num_attention_heads // (attn_tp_size * attn_cp_size)
|
||||||
|
# bindings hidden_size is sharded by attn_tp_size.
|
||||||
|
assert bindings_cfg.hidden_size == cfg.hidden_size // attn_tp_size
|
||||||
|
if isinstance(cfg.num_key_value_heads, (list, tuple)):
|
||||||
|
expected_num_kv_heads_per_layer = [
|
||||||
|
kv // (attn_tp_size * attn_cp_size) for kv in cfg.num_key_value_heads
|
||||||
|
]
|
||||||
|
assert list(bindings_cfg.num_kv_heads_per_layer) == expected_num_kv_heads_per_layer
|
||||||
|
assert bindings_cfg.num_kv_heads(0) == expected_num_kv_heads_per_layer[0]
|
||||||
|
else:
|
||||||
|
assert bindings_cfg.num_kv_heads(0) == cfg.num_key_value_heads // (
|
||||||
|
attn_tp_size * attn_cp_size
|
||||||
|
)
|
||||||
|
|
||||||
|
# tp_size-dependent value (uses mapping.tp_size, not attn_tp_size).
|
||||||
|
assert bindings_cfg.mlp_hidden_size == (cfg.intermediate_size // mapping.tp_size)
|
||||||
|
assert bindings_cfg.tokens_per_block == tokens_per_block
|
||||||
Loading…
Reference in New Issue
Block a user