mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-19 01:05:12 +08:00
feat: support llama4 nope layers; support FP8 checkpoint loading; (#3382)
* Enable NOPE, Fix a rotary embedding bug for gptj_stype_rope, Address PR comment, Properly skip the rotary_embdding for Llama4 ROPE layers * Add support for FP8 checkpoint, Fix ckpt weighting loading for FP8 * Temporarily disable min_latency_mode for llama4 --------- Co-authored-by: Yilin Fan <yilinf@nvidia.com> Co-authored-by: Sharan Chetlur <116769508+schetlur-nv@users.noreply.github.com>
This commit is contained in:
parent
a6a2ae6cc1
commit
8300218d21
@ -55,12 +55,11 @@ class LlamaAttention(Attention):
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
|
||||
# TODO: do we need to use NoPE in any versions of LLaMA4?
|
||||
nope_layer_interval = None
|
||||
if nope_layer_interval:
|
||||
is_nope_layer = (layer_idx + 1) % nope_layer_interval == 0
|
||||
else:
|
||||
is_nope_layer = False
|
||||
# Note the convention no_rope_layers[layer_idx] == 0 means nope_layer
|
||||
self.is_llama4 = config.model_type == "llama4_text"
|
||||
is_nope_layer = False
|
||||
if self.is_llama4:
|
||||
is_nope_layer = config.no_rope_layers[layer_idx] == 0
|
||||
|
||||
use_rope = not is_nope_layer
|
||||
if use_rope and model_config.fuse_pos_embd:
|
||||
@ -206,6 +205,10 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
self.fusion_config.POST_MLP_FUSION = False
|
||||
|
||||
self.is_llama4 = config.model_type == "llama4_text"
|
||||
self.is_nope_layer = False
|
||||
if self.is_llama4:
|
||||
self.is_nope_layer = config.no_rope_layers[layer_idx] == 0
|
||||
|
||||
self.self_attn = LlamaAttention(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
@ -286,13 +289,20 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
|
||||
# Only enable min-latency mode on Blackwell
|
||||
# TODO: Remove it after we fix crash on Hopper
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
is_blackwell = (major * 10 + minor) >= 100
|
||||
min_latency_mode = hidden_states.size(
|
||||
0
|
||||
) <= 128 and self.fusion_config.POST_MOE_FUSION and is_blackwell and self.is_quanted
|
||||
# major, minor = torch.cuda.get_device_capability()
|
||||
# is_blackwell = (major * 10 + minor) >= 100
|
||||
# min_latency_mode = hidden_states.size(
|
||||
# 0
|
||||
# ) <= 128 and self.fusion_config.POST_MOE_FUSION and is_blackwell and self.is_quanted
|
||||
|
||||
# Temporarily disable min-latency mode for Llama4
|
||||
min_latency_mode = False
|
||||
|
||||
# Self Attention
|
||||
# For NOPE layers (like in Llama4), the position_ids needs to be set to None, so the rotary embedding will not be applied.
|
||||
if self.is_nope_layer:
|
||||
position_ids = None
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
@ -573,6 +583,8 @@ class Llama4ForConditionalGeneration(DecoderModelForCausalLM[LlamaModel,
|
||||
if key.startswith("language_model."):
|
||||
new_key = key[len("language_model."):]
|
||||
new_weights[new_key] = tensor
|
||||
else:
|
||||
new_weights[key] = tensor
|
||||
|
||||
super().load_weights(new_weights)
|
||||
|
||||
|
||||
@ -812,9 +812,18 @@ class FusedMoE(nn.Module):
|
||||
dtype=torch.float32)
|
||||
tmp_fc2_input_scale = torch.empty(self.num_experts, dtype=torch.float32)
|
||||
for expert_id in range(self.num_experts):
|
||||
w1_input_scale = weights[f"{expert_id}.w1.input_scale"]
|
||||
w3_input_scale = weights[f"{expert_id}.w3.input_scale"]
|
||||
w2_input_scale = weights[f"{expert_id}.w2.input_scale"]
|
||||
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w1_input_scale = weights[f"{expert_id}.w1.input_scale"]
|
||||
w3_input_scale = weights[f"{expert_id}.w3.input_scale"]
|
||||
w2_input_scale = weights[f"{expert_id}.w2.input_scale"]
|
||||
elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
w1_input_scale = weights[f"gate_up_proj_input_scale"]
|
||||
w3_input_scale = weights[f"gate_up_proj_input_scale"]
|
||||
w2_input_scale = weights[f"down_proj_input_scale"]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
|
||||
)
|
||||
|
||||
load_expert_fc31_input_scale_fp8_qdq(
|
||||
w1_input_scale, w3_input_scale, tmp_fc31_input_scale[expert_id])
|
||||
@ -878,9 +887,18 @@ class FusedMoE(nn.Module):
|
||||
dst_w2_weight_scale.copy_(w2_weight_scale[...].reshape([]))
|
||||
|
||||
for expert_id in range(self.expert_start, self.expert_end):
|
||||
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
|
||||
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
|
||||
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
|
||||
if self.weight_loading_mode == MoEWeightLoadingMode.VANILLA:
|
||||
w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"]
|
||||
w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"]
|
||||
w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"]
|
||||
elif self.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ:
|
||||
w1_weight_scale = weights[f"gate_up_proj_weight_scale"]
|
||||
w3_weight_scale = weights[f"gate_up_proj_weight_scale"]
|
||||
w2_weight_scale = weights[f"down_proj_weight_scale"]
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unknown weight loading mode in MoE: {self.weight_loading_mode}"
|
||||
)
|
||||
|
||||
expert_idx = expert_id - self.expert_start
|
||||
|
||||
|
||||
@ -181,7 +181,8 @@ class RotaryEmbedding(nn.Module):
|
||||
|
||||
if use_gptj_style_rope:
|
||||
raise ValueError(
|
||||
"Must have flashinfer installed to use gptj style RoPE")
|
||||
"gptj style RoPE has to go through flashinfer route for correct results."
|
||||
)
|
||||
# it is assumed all targets are of the same rank
|
||||
q_or_k = targets[0]
|
||||
remove_input_padding = (len(q_or_k.size()) == 2)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user