mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-21 18:25:20 +08:00
Bugfix/fix nemotron nas lora support (#6380)
Signed-off-by: Shahar Mor <17088876+shaharmor98@users.noreply.github.com>
This commit is contained in:
parent
baece56758
commit
0c42f54a39
@ -299,48 +299,6 @@ class ModelConfig(Generic[TConfig]):
|
||||
num_heads = self.pretrained_config.num_attention_heads // (
|
||||
self.mapping.tp_size * self.mapping.cp_size)
|
||||
|
||||
# Handle both uniform and per-layer KV heads
|
||||
num_kv_heads_per_layer = getattr(self.pretrained_config,
|
||||
'num_kv_heads_per_layer', None)
|
||||
if num_kv_heads_per_layer is not None:
|
||||
# For models with per-layer KV heads, like nemotron-nas
|
||||
kv_heads_per_layer_raw = num_kv_heads_per_layer
|
||||
use_per_layer_kv_heads = True
|
||||
else:
|
||||
# Check if num_key_value_heads is a list (per-layer) or scalar (uniform)
|
||||
num_kv_heads_raw = getattr(self.pretrained_config,
|
||||
'num_key_value_heads', None)
|
||||
|
||||
if num_kv_heads_raw is not None and isinstance(
|
||||
num_kv_heads_raw, list):
|
||||
# num_key_value_heads is a list - treat as per-layer KV heads
|
||||
kv_heads_per_layer_raw = num_kv_heads_raw
|
||||
use_per_layer_kv_heads = True
|
||||
else:
|
||||
# num_key_value_heads is scalar or None - treat as uniform KV heads
|
||||
if num_kv_heads_raw is None:
|
||||
# For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads
|
||||
num_kv_heads_raw = getattr(
|
||||
self.pretrained_config, 'num_query_groups',
|
||||
self.pretrained_config.num_attention_heads)
|
||||
|
||||
num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size *
|
||||
self.mapping.cp_size)
|
||||
use_per_layer_kv_heads = False
|
||||
|
||||
if use_per_layer_kv_heads:
|
||||
# TRT-LLM LoRA requires uniform KV heads across layers
|
||||
if self.lora_config is not None and len(
|
||||
set(kv_heads_per_layer_raw)) > 1:
|
||||
raise ValueError(
|
||||
f"TRT-LLM LoRA requires uniform KV heads across layers, "
|
||||
f"got: {kv_heads_per_layer_raw}")
|
||||
# Apply TP/CP scaling to each layer
|
||||
num_kv_heads_per_layer = [
|
||||
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
|
||||
for kv_heads in kv_heads_per_layer_raw
|
||||
]
|
||||
|
||||
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
|
||||
|
||||
model_config_cpp = ModelConfigCpp(
|
||||
@ -361,9 +319,18 @@ class ModelConfig(Generic[TConfig]):
|
||||
else:
|
||||
model_config_cpp.tokens_per_block = tokens_per_block
|
||||
|
||||
if use_per_layer_kv_heads:
|
||||
num_key_value_heads = getattr(self.pretrained_config,
|
||||
"num_key_value_heads", num_heads)
|
||||
if isinstance(num_key_value_heads, (list, tuple)):
|
||||
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
|
||||
num_kv_heads_per_layer = [
|
||||
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
|
||||
for kv_heads in num_key_value_heads
|
||||
]
|
||||
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
|
||||
else:
|
||||
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
|
||||
self.mapping.cp_size)
|
||||
model_config_cpp.set_num_kv_heads(num_kv_heads)
|
||||
|
||||
mlp_hidden_size = None
|
||||
|
||||
@ -451,18 +451,16 @@ def create_py_executor_instance(
|
||||
|
||||
num_experts = _try_infer_num_experts(model_engine.model.model_config)
|
||||
|
||||
num_attn_layers = model_binding_config.num_attention_layers()
|
||||
per_layer_kv_heads = [
|
||||
model_binding_config.num_kv_heads(i) for i in range(num_attn_layers)
|
||||
]
|
||||
num_kv_attention_heads = max(per_layer_kv_heads)
|
||||
if len(set(per_layer_kv_heads)) > 1:
|
||||
# NOTE: This code-path is currently untested and not validated. Can fail!
|
||||
# This support is tracked in TRTLLM-6561
|
||||
num_kv_attention_heads_per_layer = model_binding_config.num_kv_heads_per_layer
|
||||
if max(num_kv_attention_heads_per_layer) != min(
|
||||
num_kv_attention_heads_per_layer):
|
||||
logger.warning(
|
||||
f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. "
|
||||
"This code-path is currently untested and not validated. May fail!"
|
||||
"Defining LORA with per-layer KV heads is not supported for LORA, using the max number of KV heads per layer"
|
||||
)
|
||||
num_kv_attention_heads = max(num_kv_attention_heads_per_layer)
|
||||
else:
|
||||
# all layers have the same number of KV heads
|
||||
num_kv_attention_heads = num_kv_attention_heads_per_layer[0]
|
||||
|
||||
lora_modules = LoraModule.create_lora_modules(
|
||||
lora_module_names=lora_config.lora_target_modules,
|
||||
|
||||
@ -350,7 +350,6 @@ def test_llama_7b_lora_config_overrides_peft_cache_config():
|
||||
|
||||
# TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high
|
||||
# https://jirasw.nvidia.com/browse/TRTLLM-5045
|
||||
@pytest.mark.skip(reason="https://nvbugs/5401210")
|
||||
@skip_gpu_memory_less_than_138gb
|
||||
def test_nemotron_nas_lora() -> None:
|
||||
lora_config = LoraConfig(lora_dir=[
|
||||
|
||||
Loading…
Reference in New Issue
Block a user