[https://nvbugs/5803813][fix] Fix llama 4 min latency (#10724)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2026-01-16 17:14:02 -05:00 committed by Yanchao Lu
parent 93e7ae73ea
commit f02948d956

View File

@ -88,9 +88,10 @@ class Llama4MinLatencyLinear(Linear):
self.enable_trtllm_gen = enable_trtllm_gen
self.position_ids = None
def load_weights(self, weights: List[Dict]):
def load_weights(self, weights: List[Dict], allow_partial_loading: bool):
super().load_weights(weights)
super().load_weights(weights,
allow_partial_loading=allow_partial_loading)
# After loading weights, calculate the combined scale (input_scale * weight_scale) for special kernels and
# trtllm-gen kernels.
@ -384,11 +385,6 @@ class Llama4MinLatencyAttention(Llama4Attention):
and self.floor_scale == 8192.0 \
and self.attn_scale == 0.1
qkv_shard_indices_mapping = {
"q": (0, self.q_size),
"k": (self.q_size, self.kv_size),
"v": (self.q_size + self.kv_size, self.kv_size),
}
# When min-latency QKV gemm is enabled, override qkv_proj.
self.qkv_proj = Llama4MinLatencyLinear(
self.hidden_size,
@ -405,7 +401,6 @@ class Llama4MinLatencyAttention(Llama4Attention):
enable_fused_gemm_attn_scaling=self.
enable_fused_gemm_attn_scaling,
enable_trtllm_gen=True,
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping,
)
def _forward_nope(