diff --git a/tensorrt_llm/_torch/models/modeling_phi3.py b/tensorrt_llm/_torch/models/modeling_phi3.py index 39bb4cc143..0fc2fec816 100644 --- a/tensorrt_llm/_torch/models/modeling_phi3.py +++ b/tensorrt_llm/_torch/models/modeling_phi3.py @@ -217,45 +217,45 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]): if "self_attn.qkv_proj" in name: # The weights need to be split correctly before sharding to support tp_size >1. qkv_weight = module_weights['weight'][:] - q_weight = qkv_weight[:hidden_size, :] - k_weight = qkv_weight[hidden_size:hidden_size + - num_kv_heads * head_dim, :] - v_weight = qkv_weight[hidden_size + - num_kv_heads * head_dim:, :] + qk_split_index = hidden_size + kv_split_index = hidden_size + num_kv_heads * head_dim + + q_dict = {'weight': qkv_weight[:qk_split_index, :]} + k_dict = { + 'weight': + qkv_weight[qk_split_index:kv_split_index, :] + } + v_dict = {'weight': qkv_weight[kv_split_index:, :]} # Get the scale factor for the fused QKV projection qkv_scale = module_weights.get('weight_scale', None) + + if qkv_scale is not None: + if qkv_scale.shape[0] == qkv_weight.shape[0]: + q_dict[ + 'weight_scale'] = qkv_scale[: + qk_split_index, :] + k_dict['weight_scale'] = qkv_scale[ + qk_split_index:kv_split_index, :] + v_dict['weight_scale'] = qkv_scale[ + kv_split_index:, :] + else: # use same scale + q_dict['weight_scale'] = qkv_scale + k_dict['weight_scale'] = qkv_scale + v_dict['weight_scale'] = qkv_scale + input_scale = module_weights.get('input_scale', None) + if input_scale is not None: + q_dict['input_scale'] = input_scale + k_dict['input_scale'] = input_scale + v_dict['input_scale'] = input_scale + weight_scale_2 = module_weights.get( 'weight_scale_2', None) - - q_dict = {'weight': q_weight} - if qkv_scale is not None: - q_dict['weight_scale'] = qkv_scale[:hidden_size, :] - if input_scale is not None: - q_dict['input_scale'] = input_scale - if weight_scale_2 is not None: - q_dict['weight_scale_2'] = weight_scale_2 - - k_dict = {'weight': k_weight} - if qkv_scale is not None: - k_dict['weight_scale'] = qkv_scale[ - hidden_size:hidden_size + - num_kv_heads * head_dim, :] - if input_scale is not None: - k_dict['input_scale'] = input_scale - if weight_scale_2 is not None: - k_dict['weight_scale_2'] = weight_scale_2 - - v_dict = {'weight': v_weight} - if qkv_scale is not None: - v_dict['weight_scale'] = qkv_scale[hidden_size + - num_kv_heads * - head_dim:, :] - if input_scale is not None: - v_dict['input_scale'] = input_scale - if weight_scale_2 is not None: - v_dict['weight_scale_2'] = weight_scale_2 + if weight_scale_2 is not None: + q_dict['weight_scale_2'] = weight_scale_2 + k_dict['weight_scale_2'] = weight_scale_2 + v_dict['weight_scale_2'] = weight_scale_2 module.load_weights(weights=[q_dict, k_dict, v_dict]) elif "mlp.gate_up_proj" in name: @@ -265,30 +265,33 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]): gate_weight = gate_up_weight[:intermediate_size, :] up_weight = gate_up_weight[intermediate_size:, :] + gate_dict = {'weight': gate_weight} + up_dict = {'weight': up_weight} + # Get the scale factors if they exist gate_up_scale = module_weights.get('weight_scale', None) + if gate_up_scale is not None: + if gate_up_scale.shape[0] == gate_up_weight.shape[ + 0]: + gate_dict[ + 'weight_scale'] = gate_up_scale[: + intermediate_size, :] + up_dict['weight_scale'] = gate_up_scale[ + intermediate_size:, :] + else: # use same scale + gate_dict['weight_scale'] = gate_up_scale + up_dict['weight_scale'] = gate_up_scale + input_scale = module_weights.get('input_scale', None) + if input_scale is not None: + gate_dict['input_scale'] = input_scale + up_dict['input_scale'] = input_scale + weight_scale_2 = module_weights.get( 'weight_scale_2', None) - - gate_dict = {'weight': gate_weight} - if gate_up_scale is not None: - gate_dict[ - 'weight_scale'] = gate_up_scale[: - intermediate_size, :] - if input_scale is not None: - gate_dict['input_scale'] = input_scale - if weight_scale_2 is not None: - gate_dict['weight_scale_2'] = weight_scale_2 - - up_dict = {'weight': up_weight} - if gate_up_scale is not None: - up_dict['weight_scale'] = gate_up_scale[ - intermediate_size:, :] - if input_scale is not None: - up_dict['input_scale'] = input_scale - if weight_scale_2 is not None: - up_dict['weight_scale_2'] = weight_scale_2 + if weight_scale_2 is not None: + gate_dict['weight_scale_2'] = weight_scale_2 + up_dict['weight_scale_2'] = weight_scale_2 module.load_weights(weights=[gate_dict, up_dict]) else: diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index a83649bc08..80afe29ac5 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -8,7 +8,6 @@ import copy import importlib import os -import re import sys from pathlib import Path from typing import List, Optional, Tuple @@ -404,7 +403,6 @@ class Phi4MMInputProcessor(InputProcessor): self.tokenizer = tokenizer self.use_fast = True - model_path = "microsoft/Phi-4-multimodal-instruct" if self.tokenizer is None: self.tokenizer = transformers.AutoTokenizer.from_pretrained( model_path, @@ -516,17 +514,13 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): # Load weights into HFPhi4MultimodalEncoder. if not _is_disagg(): filtered_weights = {} - pattern = r"(audio_embed.encoder.encoders..*.conv.glu.b[12]|image_embed.glb_GN|image_embed.img_processor.head.probe)" for k, v in weights.items(): if k.startswith("model.embed_tokens."): new_k = k.replace("model.embed_tokens.", "embed_tokens.") filtered_weights[new_k] = v elif k.startswith("model.embed_tokens_extend."): - new_k = k.replace("model.embed_tokeqns_extend.", "") - if re.match(pattern, new_k): - filtered_weights[new_k] = v.unsqueeze(0) - else: - filtered_weights[new_k] = v + new_k = k.replace("model.embed_tokens_extend.", "") + filtered_weights[new_k] = v self.hf_phi4mm_model.load_state_dict(filtered_weights, strict=True) # Filter out non-language model weights. @@ -540,21 +534,16 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): weights = {k: v for k, v in weights.items() if '.lora_' not in k} # Rename base layer weights. updated_weights = {} + base_layers = [ + 'weight', 'input_scale', 'weight_scale', 'weight_scale_2' + ] for k in weights.keys(): - if 'base_layer.weight' in k: - new_k = k.replace('base_layer.weight', 'weight') - updated_weights[new_k] = weights[k] - elif 'base_layer.input_scale' in k: - new_k = k.replace('base_layer.input_scale', 'input_scale') - updated_weights[new_k] = weights[k] - elif 'base_layer.weight_scale' in k: - new_k = k.replace('base_layer.weight_scale', 'weight_scale') - updated_weights[new_k] = weights[k] - elif 'base_layer.weight_scale_2' in k: - new_k = k.replace('base_layer.weight_scale_2', 'weight_scale_2') - updated_weights[new_k] = weights[k] - else: - updated_weights[k] = weights[k] + new_k = k + for layer in base_layers: + if f'base_layer.{layer}' in k: + new_k = k.replace(f'base_layer.{layer}', layer) + break + updated_weights[new_k] = weights[k] weights = updated_weights self.llm.load_weights(weights)