revert WAR for shapes

Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com>
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
Pamela 2025-09-09 00:14:42 +00:00 committed by Faraz Khoubsirat
parent 1a219f1ed5
commit 1010e15bb5
No known key found for this signature in database
GPG Key ID: 15733A5323348457
2 changed files with 66 additions and 74 deletions

View File

@ -217,45 +217,45 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
if "self_attn.qkv_proj" in name:
# The weights need to be split correctly before sharding to support tp_size >1.
qkv_weight = module_weights['weight'][:]
q_weight = qkv_weight[:hidden_size, :]
k_weight = qkv_weight[hidden_size:hidden_size +
num_kv_heads * head_dim, :]
v_weight = qkv_weight[hidden_size +
num_kv_heads * head_dim:, :]
qk_split_index = hidden_size
kv_split_index = hidden_size + num_kv_heads * head_dim
q_dict = {'weight': qkv_weight[:qk_split_index, :]}
k_dict = {
'weight':
qkv_weight[qk_split_index:kv_split_index, :]
}
v_dict = {'weight': qkv_weight[kv_split_index:, :]}
# Get the scale factor for the fused QKV projection
qkv_scale = module_weights.get('weight_scale', None)
if qkv_scale is not None:
if qkv_scale.shape[0] == qkv_weight.shape[0]:
q_dict[
'weight_scale'] = qkv_scale[:
qk_split_index, :]
k_dict['weight_scale'] = qkv_scale[
qk_split_index:kv_split_index, :]
v_dict['weight_scale'] = qkv_scale[
kv_split_index:, :]
else: # use same scale
q_dict['weight_scale'] = qkv_scale
k_dict['weight_scale'] = qkv_scale
v_dict['weight_scale'] = qkv_scale
input_scale = module_weights.get('input_scale', None)
if input_scale is not None:
q_dict['input_scale'] = input_scale
k_dict['input_scale'] = input_scale
v_dict['input_scale'] = input_scale
weight_scale_2 = module_weights.get(
'weight_scale_2', None)
q_dict = {'weight': q_weight}
if qkv_scale is not None:
q_dict['weight_scale'] = qkv_scale[:hidden_size, :]
if input_scale is not None:
q_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
q_dict['weight_scale_2'] = weight_scale_2
k_dict = {'weight': k_weight}
if qkv_scale is not None:
k_dict['weight_scale'] = qkv_scale[
hidden_size:hidden_size +
num_kv_heads * head_dim, :]
if input_scale is not None:
k_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
k_dict['weight_scale_2'] = weight_scale_2
v_dict = {'weight': v_weight}
if qkv_scale is not None:
v_dict['weight_scale'] = qkv_scale[hidden_size +
num_kv_heads *
head_dim:, :]
if input_scale is not None:
v_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
v_dict['weight_scale_2'] = weight_scale_2
if weight_scale_2 is not None:
q_dict['weight_scale_2'] = weight_scale_2
k_dict['weight_scale_2'] = weight_scale_2
v_dict['weight_scale_2'] = weight_scale_2
module.load_weights(weights=[q_dict, k_dict, v_dict])
elif "mlp.gate_up_proj" in name:
@ -265,30 +265,33 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
gate_weight = gate_up_weight[:intermediate_size, :]
up_weight = gate_up_weight[intermediate_size:, :]
gate_dict = {'weight': gate_weight}
up_dict = {'weight': up_weight}
# Get the scale factors if they exist
gate_up_scale = module_weights.get('weight_scale', None)
if gate_up_scale is not None:
if gate_up_scale.shape[0] == gate_up_weight.shape[
0]:
gate_dict[
'weight_scale'] = gate_up_scale[:
intermediate_size, :]
up_dict['weight_scale'] = gate_up_scale[
intermediate_size:, :]
else: # use same scale
gate_dict['weight_scale'] = gate_up_scale
up_dict['weight_scale'] = gate_up_scale
input_scale = module_weights.get('input_scale', None)
if input_scale is not None:
gate_dict['input_scale'] = input_scale
up_dict['input_scale'] = input_scale
weight_scale_2 = module_weights.get(
'weight_scale_2', None)
gate_dict = {'weight': gate_weight}
if gate_up_scale is not None:
gate_dict[
'weight_scale'] = gate_up_scale[:
intermediate_size, :]
if input_scale is not None:
gate_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
gate_dict['weight_scale_2'] = weight_scale_2
up_dict = {'weight': up_weight}
if gate_up_scale is not None:
up_dict['weight_scale'] = gate_up_scale[
intermediate_size:, :]
if input_scale is not None:
up_dict['input_scale'] = input_scale
if weight_scale_2 is not None:
up_dict['weight_scale_2'] = weight_scale_2
if weight_scale_2 is not None:
gate_dict['weight_scale_2'] = weight_scale_2
up_dict['weight_scale_2'] = weight_scale_2
module.load_weights(weights=[gate_dict, up_dict])
else:

View File

@ -8,7 +8,6 @@
import copy
import importlib
import os
import re
import sys
from pathlib import Path
from typing import List, Optional, Tuple
@ -404,7 +403,6 @@ class Phi4MMInputProcessor(InputProcessor):
self.tokenizer = tokenizer
self.use_fast = True
model_path = "microsoft/Phi-4-multimodal-instruct"
if self.tokenizer is None:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path,
@ -516,17 +514,13 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
# Load weights into HFPhi4MultimodalEncoder.
if not _is_disagg():
filtered_weights = {}
pattern = r"(audio_embed.encoder.encoders..*.conv.glu.b[12]|image_embed.glb_GN|image_embed.img_processor.head.probe)"
for k, v in weights.items():
if k.startswith("model.embed_tokens."):
new_k = k.replace("model.embed_tokens.", "embed_tokens.")
filtered_weights[new_k] = v
elif k.startswith("model.embed_tokens_extend."):
new_k = k.replace("model.embed_tokeqns_extend.", "")
if re.match(pattern, new_k):
filtered_weights[new_k] = v.unsqueeze(0)
else:
filtered_weights[new_k] = v
new_k = k.replace("model.embed_tokens_extend.", "")
filtered_weights[new_k] = v
self.hf_phi4mm_model.load_state_dict(filtered_weights, strict=True)
# Filter out non-language model weights.
@ -540,21 +534,16 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
weights = {k: v for k, v in weights.items() if '.lora_' not in k}
# Rename base layer weights.
updated_weights = {}
base_layers = [
'weight', 'input_scale', 'weight_scale', 'weight_scale_2'
]
for k in weights.keys():
if 'base_layer.weight' in k:
new_k = k.replace('base_layer.weight', 'weight')
updated_weights[new_k] = weights[k]
elif 'base_layer.input_scale' in k:
new_k = k.replace('base_layer.input_scale', 'input_scale')
updated_weights[new_k] = weights[k]
elif 'base_layer.weight_scale' in k:
new_k = k.replace('base_layer.weight_scale', 'weight_scale')
updated_weights[new_k] = weights[k]
elif 'base_layer.weight_scale_2' in k:
new_k = k.replace('base_layer.weight_scale_2', 'weight_scale_2')
updated_weights[new_k] = weights[k]
else:
updated_weights[k] = weights[k]
new_k = k
for layer in base_layers:
if f'base_layer.{layer}' in k:
new_k = k.replace(f'base_layer.{layer}', layer)
break
updated_weights[new_k] = weights[k]
weights = updated_weights
self.llm.load_weights(weights)