feat: support llama4 nope layers; support FP8 checkpoint loading; (#3382)

* Enable NOPE, Fix a rotary embedding bug for gptj_stype_rope, Address PR comment, Properly skip the rotary_embdding for Llama4 ROPE layers

* Add support for FP8 checkpoint, Fix ckpt weighting loading for FP8

* Temporarily disable min_latency_mode for llama4

---------

Co-authored-by: Yilin Fan <yilinf@nvidia.com>
Co-authored-by: Sharan Chetlur <116769508+schetlur-nv@users.noreply.github.com>
This commit is contained in:
Zhihan Jiang 2025-04-10 10:16:42 -07:00 committed by GitHub
parent a6a2ae6cc1
commit 8300218d21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 49 additions and 18 deletions

View File

@ -55,12 +55,11 @@ class LlamaAttention(Attention):
):
config = model_config.pretrained_config
# TODO: do we need to use NoPE in any versions of LLaMA4?
nope_layer_interval = None
if nope_layer_interval:
is_nope_layer = (layer_idx + 1) % nope_layer_interval == 0
else:
is_nope_layer = False
# Note the convention no_rope_layers[layer_idx] == 0 means nope_layer
self.is_llama4 = config.model_type == "llama4_text"
is_nope_layer = False
if self.is_llama4:
is_nope_layer = config.no_rope_layers[layer_idx] == 0
use_rope = not is_nope_layer
if use_rope and model_config.fuse_pos_embd:
@ -206,6 +205,10 @@ class LlamaDecoderLayer(DecoderLayer):
self.fusion_config.POST_MLP_FUSION = False
self.is_llama4 = config.model_type == "llama4_text"
self.is_nope_layer = False
if self.is_llama4:
self.is_nope_layer = config.no_rope_layers[layer_idx] == 0
self.self_attn = LlamaAttention(
model_config,
layer_idx=layer_idx,
@ -286,13 +289,20 @@ class LlamaDecoderLayer(DecoderLayer):
# Only enable min-latency mode on Blackwell
# TODO: Remove it after we fix crash on Hopper
major, minor = torch.cuda.get_device_capability()
is_blackwell = (major * 10 + minor) >= 100
min_latency_mode = hidden_states.size(
0
) <= 128 and self.fusion_config.POST_MOE_FUSION and is_blackwell and self.is_quanted
# major, minor = torch.cuda.get_device_capability()
# is_blackwell = (major * 10 + minor) >= 100
# min_latency_mode = hidden_states.size(
# 0
# ) <= 128 and self.fusion_config.POST_MOE_FUSION and is_blackwell and self.is_quanted
# Temporarily disable min-latency mode for Llama4
min_latency_mode = False
# Self Attention
# For NOPE layers (like in Llama4), the position_ids needs to be set to None, so the rotary embedding will not be applied.
if self.is_nope_layer:
position_ids = None
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
@ -573,6 +583,8 @@ class Llama4ForConditionalGeneration(DecoderModelForCausalLM[LlamaModel,
if key.startswith("language_model."):
new_key = key[len("language_model."):]
new_weights[new_key] = tensor
else:
new_weights[key] = tensor
super().load_weights(new_weights)

View File

@ -812,9 +812,18 @@ class FusedMoE(nn.Module):
dtype=torch.float32)
tmp_fc2_input_scale = torch.empty(self.num_experts, dtype=torch.float32)
for expert_id in range(self.num_experts):
w1_input_scale = weights[f"{expert_id}.w1.input_scale"]
w3_input_scale = weights[f"{expert_id}.w3.input_scale"]
w2_input_scale = weights[f"{expert_id}.w2.input_scale"]
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_input_scale = weights[f"{expert_id}.w1.input_scale"]
w3_input_scale = weights[f"{expert_id}.w3.input_scale"]
w2_input_scale = weights[f"{expert_id}.w2.input_scale"]
elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_input_scale = weights[f"gate_up_proj_input_scale"]
w3_input_scale = weights[f"gate_up_proj_input_scale"]
w2_input_scale = weights[f"down_proj_input_scale"]
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
)
load_expert_fc31_input_scale_fp8_qdq(
w1_input_scale, w3_input_scale, tmp_fc31_input_scale[expert_id])
@ -878,9 +887,18 @@ class FusedMoE(nn.Module):
dst_w2_weight_scale.copy_(w2_weight_scale[...].reshape([]))
for expert_id in range(self.expert_start, self.expert_end):
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
w1_weight_scale = weights[f"gate_up_proj_weight_scale"]
w3_weight_scale = weights[f"gate_up_proj_weight_scale"]
w2_weight_scale = weights[f"down_proj_weight_scale"]
else:
raise NotImplementedError(
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
)
expert_idx = expert_id - self.expert_start

View File

@ -181,7 +181,8 @@ class RotaryEmbedding(nn.Module):
if use_gptj_style_rope:
raise ValueError(
"Must have flashinfer installed to use gptj style RoPE")
"gptj style RoPE has to go through flashinfer route for correct results."
)
# it is assumed all targets are of the same rank
q_or_k = targets[0]
remove_input_padding = (len(q_or_k.size()) == 2)