mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
revert WAR for shapes
Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com> Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
parent
1a219f1ed5
commit
1010e15bb5
@ -217,45 +217,45 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
|
||||
if "self_attn.qkv_proj" in name:
|
||||
# The weights need to be split correctly before sharding to support tp_size >1.
|
||||
qkv_weight = module_weights['weight'][:]
|
||||
q_weight = qkv_weight[:hidden_size, :]
|
||||
k_weight = qkv_weight[hidden_size:hidden_size +
|
||||
num_kv_heads * head_dim, :]
|
||||
v_weight = qkv_weight[hidden_size +
|
||||
num_kv_heads * head_dim:, :]
|
||||
qk_split_index = hidden_size
|
||||
kv_split_index = hidden_size + num_kv_heads * head_dim
|
||||
|
||||
q_dict = {'weight': qkv_weight[:qk_split_index, :]}
|
||||
k_dict = {
|
||||
'weight':
|
||||
qkv_weight[qk_split_index:kv_split_index, :]
|
||||
}
|
||||
v_dict = {'weight': qkv_weight[kv_split_index:, :]}
|
||||
|
||||
# Get the scale factor for the fused QKV projection
|
||||
qkv_scale = module_weights.get('weight_scale', None)
|
||||
|
||||
if qkv_scale is not None:
|
||||
if qkv_scale.shape[0] == qkv_weight.shape[0]:
|
||||
q_dict[
|
||||
'weight_scale'] = qkv_scale[:
|
||||
qk_split_index, :]
|
||||
k_dict['weight_scale'] = qkv_scale[
|
||||
qk_split_index:kv_split_index, :]
|
||||
v_dict['weight_scale'] = qkv_scale[
|
||||
kv_split_index:, :]
|
||||
else: # use same scale
|
||||
q_dict['weight_scale'] = qkv_scale
|
||||
k_dict['weight_scale'] = qkv_scale
|
||||
v_dict['weight_scale'] = qkv_scale
|
||||
|
||||
input_scale = module_weights.get('input_scale', None)
|
||||
if input_scale is not None:
|
||||
q_dict['input_scale'] = input_scale
|
||||
k_dict['input_scale'] = input_scale
|
||||
v_dict['input_scale'] = input_scale
|
||||
|
||||
weight_scale_2 = module_weights.get(
|
||||
'weight_scale_2', None)
|
||||
|
||||
q_dict = {'weight': q_weight}
|
||||
if qkv_scale is not None:
|
||||
q_dict['weight_scale'] = qkv_scale[:hidden_size, :]
|
||||
if input_scale is not None:
|
||||
q_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
q_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
k_dict = {'weight': k_weight}
|
||||
if qkv_scale is not None:
|
||||
k_dict['weight_scale'] = qkv_scale[
|
||||
hidden_size:hidden_size +
|
||||
num_kv_heads * head_dim, :]
|
||||
if input_scale is not None:
|
||||
k_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
k_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
v_dict = {'weight': v_weight}
|
||||
if qkv_scale is not None:
|
||||
v_dict['weight_scale'] = qkv_scale[hidden_size +
|
||||
num_kv_heads *
|
||||
head_dim:, :]
|
||||
if input_scale is not None:
|
||||
v_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
v_dict['weight_scale_2'] = weight_scale_2
|
||||
if weight_scale_2 is not None:
|
||||
q_dict['weight_scale_2'] = weight_scale_2
|
||||
k_dict['weight_scale_2'] = weight_scale_2
|
||||
v_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
module.load_weights(weights=[q_dict, k_dict, v_dict])
|
||||
elif "mlp.gate_up_proj" in name:
|
||||
@ -265,30 +265,33 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
|
||||
gate_weight = gate_up_weight[:intermediate_size, :]
|
||||
up_weight = gate_up_weight[intermediate_size:, :]
|
||||
|
||||
gate_dict = {'weight': gate_weight}
|
||||
up_dict = {'weight': up_weight}
|
||||
|
||||
# Get the scale factors if they exist
|
||||
gate_up_scale = module_weights.get('weight_scale', None)
|
||||
if gate_up_scale is not None:
|
||||
if gate_up_scale.shape[0] == gate_up_weight.shape[
|
||||
0]:
|
||||
gate_dict[
|
||||
'weight_scale'] = gate_up_scale[:
|
||||
intermediate_size, :]
|
||||
up_dict['weight_scale'] = gate_up_scale[
|
||||
intermediate_size:, :]
|
||||
else: # use same scale
|
||||
gate_dict['weight_scale'] = gate_up_scale
|
||||
up_dict['weight_scale'] = gate_up_scale
|
||||
|
||||
input_scale = module_weights.get('input_scale', None)
|
||||
if input_scale is not None:
|
||||
gate_dict['input_scale'] = input_scale
|
||||
up_dict['input_scale'] = input_scale
|
||||
|
||||
weight_scale_2 = module_weights.get(
|
||||
'weight_scale_2', None)
|
||||
|
||||
gate_dict = {'weight': gate_weight}
|
||||
if gate_up_scale is not None:
|
||||
gate_dict[
|
||||
'weight_scale'] = gate_up_scale[:
|
||||
intermediate_size, :]
|
||||
if input_scale is not None:
|
||||
gate_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
gate_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
up_dict = {'weight': up_weight}
|
||||
if gate_up_scale is not None:
|
||||
up_dict['weight_scale'] = gate_up_scale[
|
||||
intermediate_size:, :]
|
||||
if input_scale is not None:
|
||||
up_dict['input_scale'] = input_scale
|
||||
if weight_scale_2 is not None:
|
||||
up_dict['weight_scale_2'] = weight_scale_2
|
||||
if weight_scale_2 is not None:
|
||||
gate_dict['weight_scale_2'] = weight_scale_2
|
||||
up_dict['weight_scale_2'] = weight_scale_2
|
||||
|
||||
module.load_weights(weights=[gate_dict, up_dict])
|
||||
else:
|
||||
|
||||
@ -8,7 +8,6 @@
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
@ -404,7 +403,6 @@ class Phi4MMInputProcessor(InputProcessor):
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.use_fast = True
|
||||
model_path = "microsoft/Phi-4-multimodal-instruct"
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
@ -516,17 +514,13 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
# Load weights into HFPhi4MultimodalEncoder.
|
||||
if not _is_disagg():
|
||||
filtered_weights = {}
|
||||
pattern = r"(audio_embed.encoder.encoders..*.conv.glu.b[12]|image_embed.glb_GN|image_embed.img_processor.head.probe)"
|
||||
for k, v in weights.items():
|
||||
if k.startswith("model.embed_tokens."):
|
||||
new_k = k.replace("model.embed_tokens.", "embed_tokens.")
|
||||
filtered_weights[new_k] = v
|
||||
elif k.startswith("model.embed_tokens_extend."):
|
||||
new_k = k.replace("model.embed_tokeqns_extend.", "")
|
||||
if re.match(pattern, new_k):
|
||||
filtered_weights[new_k] = v.unsqueeze(0)
|
||||
else:
|
||||
filtered_weights[new_k] = v
|
||||
new_k = k.replace("model.embed_tokens_extend.", "")
|
||||
filtered_weights[new_k] = v
|
||||
self.hf_phi4mm_model.load_state_dict(filtered_weights, strict=True)
|
||||
|
||||
# Filter out non-language model weights.
|
||||
@ -540,21 +534,16 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
weights = {k: v for k, v in weights.items() if '.lora_' not in k}
|
||||
# Rename base layer weights.
|
||||
updated_weights = {}
|
||||
base_layers = [
|
||||
'weight', 'input_scale', 'weight_scale', 'weight_scale_2'
|
||||
]
|
||||
for k in weights.keys():
|
||||
if 'base_layer.weight' in k:
|
||||
new_k = k.replace('base_layer.weight', 'weight')
|
||||
updated_weights[new_k] = weights[k]
|
||||
elif 'base_layer.input_scale' in k:
|
||||
new_k = k.replace('base_layer.input_scale', 'input_scale')
|
||||
updated_weights[new_k] = weights[k]
|
||||
elif 'base_layer.weight_scale' in k:
|
||||
new_k = k.replace('base_layer.weight_scale', 'weight_scale')
|
||||
updated_weights[new_k] = weights[k]
|
||||
elif 'base_layer.weight_scale_2' in k:
|
||||
new_k = k.replace('base_layer.weight_scale_2', 'weight_scale_2')
|
||||
updated_weights[new_k] = weights[k]
|
||||
else:
|
||||
updated_weights[k] = weights[k]
|
||||
new_k = k
|
||||
for layer in base_layers:
|
||||
if f'base_layer.{layer}' in k:
|
||||
new_k = k.replace(f'base_layer.{layer}', layer)
|
||||
break
|
||||
updated_weights[new_k] = weights[k]
|
||||
weights = updated_weights
|
||||
self.llm.load_weights(weights)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user