Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
Pamela 2025-09-03 15:21:13 +00:00 committed by Faraz Khoubsirat
parent 12c6f0769d
commit 50ba4193f4
No known key found for this signature in database
GPG Key ID: 15733A5323348457
2 changed files with 56 additions and 10 deletions

View File

@ -225,18 +225,37 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
# Get the scale factor for the fused QKV projection
qkv_scale = module_weights.get('weight_scale', None)
input_scale = module_weights.get('input_scale', None)
weight_scale_2 = module_weights.get(
'weight_scale_2', None)
q_dict = {'weight': q_weight}
if qkv_scale is not None:
q_dict['weight_scale'] = qkv_scale
q_dict['weight_scale'] = qkv_scale[:hidden_size, :]
if input_scale is not None:
q_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
q_dict['weight_scale_2'] = weight_scale_2
k_dict = {'weight': k_weight}
if qkv_scale is not None:
k_dict['weight_scale'] = qkv_scale # Use same scale
k_dict['weight_scale'] = qkv_scale[
hidden_size:hidden_size +
num_kv_heads * head_dim, :]
if input_scale is not None:
k_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
k_dict['weight_scale_2'] = weight_scale_2
v_dict = {'weight': v_weight}
if qkv_scale is not None:
v_dict['weight_scale'] = qkv_scale # Use same scale
v_dict['weight_scale'] = qkv_scale[hidden_size +
num_kv_heads *
head_dim:, :]
if input_scale is not None:
v_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
v_dict['weight_scale_2'] = weight_scale_2
module.load_weights(weights=[q_dict, k_dict, v_dict])
elif "mlp.gate_up_proj" in name:
@ -248,14 +267,28 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
# Get the scale factors if they exist
gate_up_scale = module_weights.get('weight_scale', None)
input_scale = module_weights.get('input_scale', None)
weight_scale_2 = module_weights.get(
'weight_scale_2', None)
gate_dict = {'weight': gate_weight}
if gate_up_scale is not None:
gate_dict['weight_scale'] = gate_up_scale
gate_dict[
'weight_scale'] = gate_up_scale[:
intermediate_size, :]
if input_scale is not None:
gate_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
gate_dict['weight_scale_2'] = weight_scale_2
up_dict = {'weight': up_weight}
if gate_up_scale is not None:
up_dict['weight_scale'] = gate_up_scale
up_dict['weight_scale'] = gate_up_scale[
intermediate_size:, :]
if input_scale is not None:
up_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
up_dict['weight_scale_2'] = weight_scale_2
module.load_weights(weights=[gate_dict, up_dict])
else:

View File

@ -8,6 +8,7 @@
import copy
import importlib
import os
import re
import sys
from pathlib import Path
from typing import List, Optional, Tuple
@ -64,10 +65,10 @@ def _load_phi4mm_classes(local_path):
# Add parent folder to sys.path to enable relative import.
original_sys_path = sys.path.copy()
package_folder = Path(local_path)
package_name = package_folder.name
parent_folder = str(package_folder.parent)
if parent_folder not in sys.path:
sys.path.insert(0, parent_folder)
try:
# Import Phi4MMConfig from configuration_phi4mm.py.
config_path = os.path.join(local_path, 'configuration_phi4mm.py')
@ -87,8 +88,7 @@ def _load_phi4mm_classes(local_path):
# `Phi-4-multimodal-instruct` as the package name to avoid relative import errors.
# `hf_modeling_phi4mm` as the module name to avoid name conflicts.
spec = importlib.util.spec_from_file_location(
"Phi-4-multimodal-instruct.hf_modeling_phi4mm",
modeling_phi4mm_path)
f"{package_name}.hf_modeling_phi4mm", modeling_phi4mm_path)
hf_modeling_phi4mm = importlib.util.module_from_spec(spec)
spec.loader.exec_module(hf_modeling_phi4mm)
Phi4MMAudioEmbedding = hf_modeling_phi4mm.Phi4MMAudioEmbedding
@ -515,13 +515,17 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
# Load weights into HFPhi4MultimodalEncoder.
if not _is_disagg():
filtered_weights = {}
pattern = r"(audio_embed.encoder.encoders..*.conv.glu.b[12]|image_embed.glb_GN|image_embed.img_processor.head.probe)"
for k, v in weights.items():
if k.startswith("model.embed_tokens."):
new_k = k.replace("model.embed_tokens.", "embed_tokens.")
filtered_weights[new_k] = v
elif k.startswith("model.embed_tokens_extend."):
new_k = k.replace("model.embed_tokens_extend.", "")
filtered_weights[new_k] = v
new_k = k.replace("model.embed_tokeqns_extend.", "")
if re.match(pattern, new_k):
filtered_weights[new_k] = v.unsqueeze(0)
else:
filtered_weights[new_k] = v
self.hf_phi4mm_model.load_state_dict(filtered_weights, strict=True)
# Filter out non-language model weights.
@ -539,6 +543,15 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
if 'base_layer.weight' in k:
new_k = k.replace('base_layer.weight', 'weight')
updated_weights[new_k] = weights[k]
elif 'base_layer.input_scale' in k:
new_k = k.replace('base_layer.input_scale', 'input_scale')
updated_weights[new_k] = weights[k]
elif 'base_layer.weight_scale' in k:
new_k = k.replace('base_layer.weight_scale', 'weight_scale')
updated_weights[new_k] = weights[k]
elif 'base_layer.weight_scale_2' in k:
new_k = k.replace('base_layer.weight_scale_2', 'weight_scale_2')
updated_weights[new_k] = weights[k]
else:
updated_weights[k] = weights[k]
weights = updated_weights