diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index c47cff9e8c..d39fdaff73 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -55,12 +55,11 @@ class LlamaAttention(Attention): ): config = model_config.pretrained_config - # TODO: do we need to use NoPE in any versions of LLaMA4? - nope_layer_interval = None - if nope_layer_interval: - is_nope_layer = (layer_idx + 1) % nope_layer_interval == 0 - else: - is_nope_layer = False + # Note the convention no_rope_layers[layer_idx] == 0 means nope_layer + self.is_llama4 = config.model_type == "llama4_text" + is_nope_layer = False + if self.is_llama4: + is_nope_layer = config.no_rope_layers[layer_idx] == 0 use_rope = not is_nope_layer if use_rope and model_config.fuse_pos_embd: @@ -206,6 +205,10 @@ class LlamaDecoderLayer(DecoderLayer): self.fusion_config.POST_MLP_FUSION = False self.is_llama4 = config.model_type == "llama4_text" + self.is_nope_layer = False + if self.is_llama4: + self.is_nope_layer = config.no_rope_layers[layer_idx] == 0 + self.self_attn = LlamaAttention( model_config, layer_idx=layer_idx, @@ -286,13 +289,20 @@ class LlamaDecoderLayer(DecoderLayer): # Only enable min-latency mode on Blackwell # TODO: Remove it after we fix crash on Hopper - major, minor = torch.cuda.get_device_capability() - is_blackwell = (major * 10 + minor) >= 100 - min_latency_mode = hidden_states.size( - 0 - ) <= 128 and self.fusion_config.POST_MOE_FUSION and is_blackwell and self.is_quanted + # major, minor = torch.cuda.get_device_capability() + # is_blackwell = (major * 10 + minor) >= 100 + # min_latency_mode = hidden_states.size( + # 0 + # ) <= 128 and self.fusion_config.POST_MOE_FUSION and is_blackwell and self.is_quanted + + # Temporarily disable min-latency mode for Llama4 + min_latency_mode = False # Self Attention + # For NOPE layers (like in Llama4), the position_ids needs to be set to None, so the rotary embedding will not be applied. + if self.is_nope_layer: + position_ids = None + hidden_states = self.self_attn( position_ids=position_ids, hidden_states=hidden_states, @@ -573,6 +583,8 @@ class Llama4ForConditionalGeneration(DecoderModelForCausalLM[LlamaModel, if key.startswith("language_model."): new_key = key[len("language_model."):] new_weights[new_key] = tensor + else: + new_weights[key] = tensor super().load_weights(new_weights) diff --git a/tensorrt_llm/_torch/modules/fused_moe.py b/tensorrt_llm/_torch/modules/fused_moe.py index 65077a4bf9..5991e89cd8 100755 --- a/tensorrt_llm/_torch/modules/fused_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe.py @@ -812,9 +812,18 @@ class FusedMoE(nn.Module): dtype=torch.float32) tmp_fc2_input_scale = torch.empty(self.num_experts, dtype=torch.float32) for expert_id in range(self.num_experts): - w1_input_scale = weights[f"{expert_id}.w1.input_scale"] - w3_input_scale = weights[f"{expert_id}.w3.input_scale"] - w2_input_scale = weights[f"{expert_id}.w2.input_scale"] + if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_input_scale = weights[f"{expert_id}.w1.input_scale"] + w3_input_scale = weights[f"{expert_id}.w3.input_scale"] + w2_input_scale = weights[f"{expert_id}.w2.input_scale"] + elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_input_scale = weights[f"gate_up_proj_input_scale"] + w3_input_scale = weights[f"gate_up_proj_input_scale"] + w2_input_scale = weights[f"down_proj_input_scale"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + ) load_expert_fc31_input_scale_fp8_qdq( w1_input_scale, w3_input_scale, tmp_fc31_input_scale[expert_id]) @@ -878,9 +887,18 @@ class FusedMoE(nn.Module): dst_w2_weight_scale.copy_(w2_weight_scale[...].reshape([])) for expert_id in range(self.expert_start, self.expert_end): - w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] - w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] - w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] + w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] + w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_weight_scale = weights[f"gate_up_proj_weight_scale"] + w3_weight_scale = weights[f"gate_up_proj_weight_scale"] + w2_weight_scale = weights[f"down_proj_weight_scale"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {self.weight_loading_mode}" + ) expert_idx = expert_id - self.expert_start diff --git a/tensorrt_llm/_torch/modules/rotary_embedding.py b/tensorrt_llm/_torch/modules/rotary_embedding.py index c2bc194b15..2101e0732b 100644 --- a/tensorrt_llm/_torch/modules/rotary_embedding.py +++ b/tensorrt_llm/_torch/modules/rotary_embedding.py @@ -181,7 +181,8 @@ class RotaryEmbedding(nn.Module): if use_gptj_style_rope: raise ValueError( - "Must have flashinfer installed to use gptj style RoPE") + "gptj style RoPE has to go through flashinfer route for correct results." + ) # it is assumed all targets are of the same rank q_or_k = targets[0] remove_input_padding = (len(q_or_k.size()) == 2)