mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-18 16:55:08 +08:00
phi4 fp4
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com> Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
parent
12c6f0769d
commit
50ba4193f4
@ -225,18 +225,37 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
|
||||
|
||||
# Get the scale factor for the fused QKV projection
|
||||
qkv_scale = module_weights.get('weight_scale', None)
|
||||
input_scale = module_weights.get('input_scale', None)
|
||||
weight_scale_2 = module_weights.get(
|
||||
'weight_scale_2', None)
|
||||
|
||||
q_dict = {'weight': q_weight}
|
||||
if qkv_scale is not None:
|
||||
q_dict['weight_scale'] = qkv_scale
|
||||
q_dict['weight_scale'] = qkv_scale[:hidden_size, :]
|
||||
if input_scale is not None:
|
||||
q_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
q_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
k_dict = {'weight': k_weight}
|
||||
if qkv_scale is not None:
|
||||
k_dict['weight_scale'] = qkv_scale # Use same scale
|
||||
k_dict['weight_scale'] = qkv_scale[
|
||||
hidden_size:hidden_size +
|
||||
num_kv_heads * head_dim, :]
|
||||
if input_scale is not None:
|
||||
k_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
k_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
v_dict = {'weight': v_weight}
|
||||
if qkv_scale is not None:
|
||||
v_dict['weight_scale'] = qkv_scale # Use same scale
|
||||
v_dict['weight_scale'] = qkv_scale[hidden_size +
|
||||
num_kv_heads *
|
||||
head_dim:, :]
|
||||
if input_scale is not None:
|
||||
v_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
v_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
module.load_weights(weights=[q_dict, k_dict, v_dict])
|
||||
elif "mlp.gate_up_proj" in name:
|
||||
@ -248,14 +267,28 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
|
||||
|
||||
# Get the scale factors if they exist
|
||||
gate_up_scale = module_weights.get('weight_scale', None)
|
||||
input_scale = module_weights.get('input_scale', None)
|
||||
weight_scale_2 = module_weights.get(
|
||||
'weight_scale_2', None)
|
||||
|
||||
gate_dict = {'weight': gate_weight}
|
||||
if gate_up_scale is not None:
|
||||
gate_dict['weight_scale'] = gate_up_scale
|
||||
gate_dict[
|
||||
'weight_scale'] = gate_up_scale[:
|
||||
intermediate_size, :]
|
||||
if input_scale is not None:
|
||||
gate_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
gate_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
up_dict = {'weight': up_weight}
|
||||
if gate_up_scale is not None:
|
||||
up_dict['weight_scale'] = gate_up_scale
|
||||
up_dict['weight_scale'] = gate_up_scale[
|
||||
intermediate_size:, :]
|
||||
if input_scale is not None:
|
||||
up_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
up_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
module.load_weights(weights=[gate_dict, up_dict])
|
||||
else:
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
@ -64,10 +65,10 @@ def _load_phi4mm_classes(local_path):
|
||||
# Add parent folder to sys.path to enable relative import.
|
||||
original_sys_path = sys.path.copy()
|
||||
package_folder = Path(local_path)
|
||||
package_name = package_folder.name
|
||||
parent_folder = str(package_folder.parent)
|
||||
if parent_folder not in sys.path:
|
||||
sys.path.insert(0, parent_folder)
|
||||
|
||||
try:
|
||||
# Import Phi4MMConfig from configuration_phi4mm.py.
|
||||
config_path = os.path.join(local_path, 'configuration_phi4mm.py')
|
||||
@ -87,8 +88,7 @@ def _load_phi4mm_classes(local_path):
|
||||
# `Phi-4-multimodal-instruct` as the package name to avoid relative import errors.
|
||||
# `hf_modeling_phi4mm` as the module name to avoid name conflicts.
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"Phi-4-multimodal-instruct.hf_modeling_phi4mm",
|
||||
modeling_phi4mm_path)
|
||||
f"{package_name}.hf_modeling_phi4mm", modeling_phi4mm_path)
|
||||
hf_modeling_phi4mm = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(hf_modeling_phi4mm)
|
||||
Phi4MMAudioEmbedding = hf_modeling_phi4mm.Phi4MMAudioEmbedding
|
||||
@ -515,13 +515,17 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
# Load weights into HFPhi4MultimodalEncoder.
|
||||
if not _is_disagg():
|
||||
filtered_weights = {}
|
||||
pattern = r"(audio_embed.encoder.encoders..*.conv.glu.b[12]|image_embed.glb_GN|image_embed.img_processor.head.probe)"
|
||||
for k, v in weights.items():
|
||||
if k.startswith("model.embed_tokens."):
|
||||
new_k = k.replace("model.embed_tokens.", "embed_tokens.")
|
||||
filtered_weights[new_k] = v
|
||||
elif k.startswith("model.embed_tokens_extend."):
|
||||
new_k = k.replace("model.embed_tokens_extend.", "")
|
||||
filtered_weights[new_k] = v
|
||||
new_k = k.replace("model.embed_tokeqns_extend.", "")
|
||||
if re.match(pattern, new_k):
|
||||
filtered_weights[new_k] = v.unsqueeze(0)
|
||||
else:
|
||||
filtered_weights[new_k] = v
|
||||
self.hf_phi4mm_model.load_state_dict(filtered_weights, strict=True)
|
||||
|
||||
# Filter out non-language model weights.
|
||||
@ -539,6 +543,15 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
if 'base_layer.weight' in k:
|
||||
new_k = k.replace('base_layer.weight', 'weight')
|
||||
updated_weights[new_k] = weights[k]
|
||||
elif 'base_layer.input_scale' in k:
|
||||
new_k = k.replace('base_layer.input_scale', 'input_scale')
|
||||
updated_weights[new_k] = weights[k]
|
||||
elif 'base_layer.weight_scale' in k:
|
||||
new_k = k.replace('base_layer.weight_scale', 'weight_scale')
|
||||
updated_weights[new_k] = weights[k]
|
||||
elif 'base_layer.weight_scale_2' in k:
|
||||
new_k = k.replace('base_layer.weight_scale_2', 'weight_scale_2')
|
||||
updated_weights[new_k] = weights[k]
|
||||
else:
|
||||
updated_weights[k] = weights[k]
|
||||
weights = updated_weights
|
||||
|
||||
Loading…
Reference in New Issue
Block a user