[https://nvbugs/5670469][fix] Filter 0s and choose min of kv_head for Nemotron model (#10206)

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
Faraz 2026-01-04 19:42:53 -05:00 committed by GitHub
parent e2f5455533
commit 8e2065b4d9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1070,7 +1070,12 @@ class PyTorchModelEngine(ModelEngine):
num_attention_heads = getattr(config, 'num_attention_heads', None)
num_key_value_heads = getattr(config, 'num_key_value_heads', None)
if num_attention_heads is not None and num_key_value_heads is not None:
# Calculate the number of attention heads per KV head (GQA ratio)
if isinstance(num_key_value_heads, (list, tuple)):
# Filter out invalid KV heads, default to 0 if no valid KV heads are found
num_key_value_heads = min(
(kv for kv in num_key_value_heads if kv and kv > 0), default=0)
if num_attention_heads and num_key_value_heads:
num_heads_per_kv = num_attention_heads // num_key_value_heads
else:
num_heads_per_kv = 1