diff --git a/tensorrt_llm/_torch/models/modeling_llama_min_latency.py b/tensorrt_llm/_torch/models/modeling_llama_min_latency.py index 9af2d99f0b..ae3d1601fb 100644 --- a/tensorrt_llm/_torch/models/modeling_llama_min_latency.py +++ b/tensorrt_llm/_torch/models/modeling_llama_min_latency.py @@ -88,9 +88,10 @@ class Llama4MinLatencyLinear(Linear): self.enable_trtllm_gen = enable_trtllm_gen self.position_ids = None - def load_weights(self, weights: List[Dict]): + def load_weights(self, weights: List[Dict], allow_partial_loading: bool): - super().load_weights(weights) + super().load_weights(weights, + allow_partial_loading=allow_partial_loading) # After loading weights, calculate the combined scale (input_scale * weight_scale) for special kernels and # trtllm-gen kernels. @@ -384,11 +385,6 @@ class Llama4MinLatencyAttention(Llama4Attention): and self.floor_scale == 8192.0 \ and self.attn_scale == 0.1 - qkv_shard_indices_mapping = { - "q": (0, self.q_size), - "k": (self.q_size, self.kv_size), - "v": (self.q_size + self.kv_size, self.kv_size), - } # When min-latency QKV gemm is enabled, override qkv_proj. self.qkv_proj = Llama4MinLatencyLinear( self.hidden_size, @@ -405,7 +401,6 @@ class Llama4MinLatencyAttention(Llama4Attention): enable_fused_gemm_attn_scaling=self. enable_fused_gemm_attn_scaling, enable_trtllm_gen=True, - fused_weight_shard_indices_mapping=qkv_shard_indices_mapping, ) def _forward_nope(