From f02948d956473e5374e160c1e789538efc76df23 Mon Sep 17 00:00:00 2001 From: Mike Iovine Date: Fri, 16 Jan 2026 17:14:02 -0500 Subject: [PATCH] [https://nvbugs/5803813][fix] Fix llama 4 min latency (#10724) Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- .../_torch/models/modeling_llama_min_latency.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_llama_min_latency.py b/tensorrt_llm/_torch/models/modeling_llama_min_latency.py index 9af2d99f0b..ae3d1601fb 100644 --- a/tensorrt_llm/_torch/models/modeling_llama_min_latency.py +++ b/tensorrt_llm/_torch/models/modeling_llama_min_latency.py @@ -88,9 +88,10 @@ class Llama4MinLatencyLinear(Linear): self.enable_trtllm_gen = enable_trtllm_gen self.position_ids = None - def load_weights(self, weights: List[Dict]): + def load_weights(self, weights: List[Dict], allow_partial_loading: bool): - super().load_weights(weights) + super().load_weights(weights, + allow_partial_loading=allow_partial_loading) # After loading weights, calculate the combined scale (input_scale * weight_scale) for special kernels and # trtllm-gen kernels. @@ -384,11 +385,6 @@ class Llama4MinLatencyAttention(Llama4Attention): and self.floor_scale == 8192.0 \ and self.attn_scale == 0.1 - qkv_shard_indices_mapping = { - "q": (0, self.q_size), - "k": (self.q_size, self.kv_size), - "v": (self.q_size + self.kv_size, self.kv_size), - } # When min-latency QKV gemm is enabled, override qkv_proj. self.qkv_proj = Llama4MinLatencyLinear( self.hidden_size, @@ -405,7 +401,6 @@ class Llama4MinLatencyAttention(Llama4Attention): enable_fused_gemm_attn_scaling=self. enable_fused_gemm_attn_scaling, enable_trtllm_gen=True, - fused_weight_shard_indices_mapping=qkv_shard_indices_mapping, ) def _forward_nope(