Bugfix/fix nemotron nas lora support (#6380)

Signed-off-by: Shahar Mor <17088876+shaharmor98@users.noreply.github.com>
This commit is contained in:
shaharmor98 2025-07-31 20:39:35 +03:00 committed by GitHub
parent baece56758
commit 0c42f54a39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 54 deletions

View File

@ -299,48 +299,6 @@ class ModelConfig(Generic[TConfig]):
num_heads = self.pretrained_config.num_attention_heads // (
self.mapping.tp_size * self.mapping.cp_size)
# Handle both uniform and per-layer KV heads
num_kv_heads_per_layer = getattr(self.pretrained_config,
'num_kv_heads_per_layer', None)
if num_kv_heads_per_layer is not None:
# For models with per-layer KV heads, like nemotron-nas
kv_heads_per_layer_raw = num_kv_heads_per_layer
use_per_layer_kv_heads = True
else:
# Check if num_key_value_heads is a list (per-layer) or scalar (uniform)
num_kv_heads_raw = getattr(self.pretrained_config,
'num_key_value_heads', None)
if num_kv_heads_raw is not None and isinstance(
num_kv_heads_raw, list):
# num_key_value_heads is a list - treat as per-layer KV heads
kv_heads_per_layer_raw = num_kv_heads_raw
use_per_layer_kv_heads = True
else:
# num_key_value_heads is scalar or None - treat as uniform KV heads
if num_kv_heads_raw is None:
# For uniform models, check: num_key_value_heads (standard) -> num_query_groups (NeMo) -> num_attention_heads
num_kv_heads_raw = getattr(
self.pretrained_config, 'num_query_groups',
self.pretrained_config.num_attention_heads)
num_kv_heads = num_kv_heads_raw // (self.mapping.tp_size *
self.mapping.cp_size)
use_per_layer_kv_heads = False
if use_per_layer_kv_heads:
# TRT-LLM LoRA requires uniform KV heads across layers
if self.lora_config is not None and len(
set(kv_heads_per_layer_raw)) > 1:
raise ValueError(
f"TRT-LLM LoRA requires uniform KV heads across layers, "
f"got: {kv_heads_per_layer_raw}")
# Apply TP/CP scaling to each layer
num_kv_heads_per_layer = [
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
for kv_heads in kv_heads_per_layer_raw
]
hidden_size = self.pretrained_config.hidden_size // self.mapping.tp_size
model_config_cpp = ModelConfigCpp(
@ -361,9 +319,18 @@ class ModelConfig(Generic[TConfig]):
else:
model_config_cpp.tokens_per_block = tokens_per_block
if use_per_layer_kv_heads:
num_key_value_heads = getattr(self.pretrained_config,
"num_key_value_heads", num_heads)
if isinstance(num_key_value_heads, (list, tuple)):
# Per-layer KV heads (e.g., Nemotron-NAS, variable GQA models)
num_kv_heads_per_layer = [
kv_heads // (self.mapping.tp_size * self.mapping.cp_size)
for kv_heads in num_key_value_heads
]
model_config_cpp.num_kv_heads_per_layer = num_kv_heads_per_layer
else:
num_kv_heads = num_key_value_heads // (self.mapping.tp_size *
self.mapping.cp_size)
model_config_cpp.set_num_kv_heads(num_kv_heads)
mlp_hidden_size = None

View File

@ -451,18 +451,16 @@ def create_py_executor_instance(
num_experts = _try_infer_num_experts(model_engine.model.model_config)
num_attn_layers = model_binding_config.num_attention_layers()
per_layer_kv_heads = [
model_binding_config.num_kv_heads(i) for i in range(num_attn_layers)
]
num_kv_attention_heads = max(per_layer_kv_heads)
if len(set(per_layer_kv_heads)) > 1:
# NOTE: This code-path is currently untested and not validated. Can fail!
# This support is tracked in TRTLLM-6561
num_kv_attention_heads_per_layer = model_binding_config.num_kv_heads_per_layer
if max(num_kv_attention_heads_per_layer) != min(
num_kv_attention_heads_per_layer):
logger.warning(
f"Non-uniform KV heads per layer detected, using max ({num_kv_attention_heads}) for LoRA. "
"This code-path is currently untested and not validated. May fail!"
"Defining LORA with per-layer KV heads is not supported for LORA, using the max number of KV heads per layer"
)
num_kv_attention_heads = max(num_kv_attention_heads_per_layer)
else:
# all layers have the same number of KV heads
num_kv_attention_heads = num_kv_attention_heads_per_layer[0]
lora_modules = LoraModule.create_lora_modules(
lora_module_names=lora_config.lora_target_modules,

View File

@ -350,7 +350,6 @@ def test_llama_7b_lora_config_overrides_peft_cache_config():
# TODO smor: currently Nemotron-Super-49B-v1 with LoRA memory consumption is overly high
# https://jirasw.nvidia.com/browse/TRTLLM-5045
@pytest.mark.skip(reason="https://nvbugs/5401210")
@skip_gpu_memory_less_than_138gb
def test_nemotron_nas_lora() -> None:
lora_config = LoraConfig(lora_dir=[