TensorRT-LLMs/tests/unittest/_torch/test_model_config.py
2026-01-04 06:07:30 -05:00

91 lines
3.3 KiB
Python

import types
import pytest
import torch
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm.mapping import Mapping
def make_pretrained_config(
*,
num_attention_heads: int = 16,
num_key_value_heads=8,
head_dim: int | None = None,
num_hidden_layers: int = 1,
vocab_size: int = 3000,
):
# A minimal config object that provides the attributes used by
# ModelConfig.get_bindings_model_config().
hidden_size = head_dim * num_attention_heads
intermediate_size = hidden_size * 4
return types.SimpleNamespace(
architectures=["DummyArchitecture"],
num_attention_heads=num_attention_heads,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
num_hidden_layers=num_hidden_layers,
vocab_size=vocab_size,
torch_dtype=torch.float16,
)
@pytest.mark.parametrize(
"num_key_value_heads",
[
pytest.param(8, id="kv_heads_scalar"),
pytest.param([8, 20], id="kv_heads_per_layer_varied"),
],
)
@pytest.mark.parametrize("enable_attention_dp", [False, True])
@pytest.mark.parametrize(
"mapping_kwargs",
[
# Same tp/cp sizes, but different ways of setting attention TP:
# - No explicit attn_tp_size: Mapping infers it.
# - Explicit attn_tp_size: Mapping uses the provided value.
dict(world_size=8, tp_size=4, cp_size=2),
dict(world_size=4, tp_size=2, cp_size=2, attn_tp_size=4),
],
)
def test_get_bindings_model_config_attention_dp_attn_tp_override(
enable_attention_dp, mapping_kwargs, num_key_value_heads
):
mapping = Mapping(enable_attention_dp=enable_attention_dp, **mapping_kwargs)
cfg = make_pretrained_config(
# Keep values consistent:
# hidden_size = num_attention_heads * head_dim.
num_attention_heads=16,
head_dim=4,
num_key_value_heads=num_key_value_heads,
num_hidden_layers=2,
)
model_config = ModelConfig(pretrained_config=cfg, mapping=mapping)
tokens_per_block = 32
bindings_cfg = model_config.get_bindings_model_config(tokens_per_block=tokens_per_block)
# bindings hidden_size is sharded by attn_tp_size and attn_cp_size.
attn_tp_size = mapping.attn_tp_size if not mapping.enable_attention_dp else 1
attn_cp_size = mapping.attn_cp_size
assert bindings_cfg.num_heads == cfg.num_attention_heads // (attn_tp_size * attn_cp_size)
# bindings hidden_size is sharded by attn_tp_size.
assert bindings_cfg.hidden_size == cfg.hidden_size // attn_tp_size
if isinstance(cfg.num_key_value_heads, (list, tuple)):
expected_num_kv_heads_per_layer = [
kv // (attn_tp_size * attn_cp_size) for kv in cfg.num_key_value_heads
]
assert list(bindings_cfg.num_kv_heads_per_layer) == expected_num_kv_heads_per_layer
assert bindings_cfg.num_kv_heads(0) == expected_num_kv_heads_per_layer[0]
else:
assert bindings_cfg.num_kv_heads(0) == cfg.num_key_value_heads // (
attn_tp_size * attn_cp_size
)
# tp_size-dependent value (uses mapping.tp_size, not attn_tp_size).
assert bindings_cfg.mlp_hidden_size == (cfg.intermediate_size // mapping.tp_size)
assert bindings_cfg.tokens_per_block == tokens_per_block