From 50ba4193f417ed67084a6782221ae93c219a42af Mon Sep 17 00:00:00 2001 From: Pamela <179191831+pamelap-nvidia@users.noreply.github.com> Date: Wed, 3 Sep 2025 15:21:13 +0000 Subject: [PATCH] phi4 fp4 Signed-off-by: Pamela <179191831+pamelap-nvidia@users.noreply.github.com> Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com> --- tensorrt_llm/_torch/models/modeling_phi3.py | 43 ++++++++++++++++--- tensorrt_llm/_torch/models/modeling_phi4mm.py | 23 +++++++--- 2 files changed, 56 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_phi3.py b/tensorrt_llm/_torch/models/modeling_phi3.py index 272cd41e5b..39bb4cc143 100644 --- a/tensorrt_llm/_torch/models/modeling_phi3.py +++ b/tensorrt_llm/_torch/models/modeling_phi3.py @@ -225,18 +225,37 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]): # Get the scale factor for the fused QKV projection qkv_scale = module_weights.get('weight_scale', None) + input_scale = module_weights.get('input_scale', None) + weight_scale_2 = module_weights.get( + 'weight_scale_2', None) q_dict = {'weight': q_weight} if qkv_scale is not None: - q_dict['weight_scale'] = qkv_scale + q_dict['weight_scale'] = qkv_scale[:hidden_size, :] + if input_scale is not None: + q_dict['input_scale'] = input_scale + if weight_scale_2 is not None: + q_dict['weight_scale_2'] = weight_scale_2 k_dict = {'weight': k_weight} if qkv_scale is not None: - k_dict['weight_scale'] = qkv_scale # Use same scale + k_dict['weight_scale'] = qkv_scale[ + hidden_size:hidden_size + + num_kv_heads * head_dim, :] + if input_scale is not None: + k_dict['input_scale'] = input_scale + if weight_scale_2 is not None: + k_dict['weight_scale_2'] = weight_scale_2 v_dict = {'weight': v_weight} if qkv_scale is not None: - v_dict['weight_scale'] = qkv_scale # Use same scale + v_dict['weight_scale'] = qkv_scale[hidden_size + + num_kv_heads * + head_dim:, :] + if input_scale is not None: + v_dict['input_scale'] = input_scale + if weight_scale_2 is not None: + v_dict['weight_scale_2'] = weight_scale_2 module.load_weights(weights=[q_dict, k_dict, v_dict]) elif "mlp.gate_up_proj" in name: @@ -248,14 +267,28 @@ class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]): # Get the scale factors if they exist gate_up_scale = module_weights.get('weight_scale', None) + input_scale = module_weights.get('input_scale', None) + weight_scale_2 = module_weights.get( + 'weight_scale_2', None) gate_dict = {'weight': gate_weight} if gate_up_scale is not None: - gate_dict['weight_scale'] = gate_up_scale + gate_dict[ + 'weight_scale'] = gate_up_scale[: + intermediate_size, :] + if input_scale is not None: + gate_dict['input_scale'] = input_scale + if weight_scale_2 is not None: + gate_dict['weight_scale_2'] = weight_scale_2 up_dict = {'weight': up_weight} if gate_up_scale is not None: - up_dict['weight_scale'] = gate_up_scale + up_dict['weight_scale'] = gate_up_scale[ + intermediate_size:, :] + if input_scale is not None: + up_dict['input_scale'] = input_scale + if weight_scale_2 is not None: + up_dict['weight_scale_2'] = weight_scale_2 module.load_weights(weights=[gate_dict, up_dict]) else: diff --git a/tensorrt_llm/_torch/models/modeling_phi4mm.py b/tensorrt_llm/_torch/models/modeling_phi4mm.py index bc449e1da5..403848aa19 100644 --- a/tensorrt_llm/_torch/models/modeling_phi4mm.py +++ b/tensorrt_llm/_torch/models/modeling_phi4mm.py @@ -8,6 +8,7 @@ import copy import importlib import os +import re import sys from pathlib import Path from typing import List, Optional, Tuple @@ -64,10 +65,10 @@ def _load_phi4mm_classes(local_path): # Add parent folder to sys.path to enable relative import. original_sys_path = sys.path.copy() package_folder = Path(local_path) + package_name = package_folder.name parent_folder = str(package_folder.parent) if parent_folder not in sys.path: sys.path.insert(0, parent_folder) - try: # Import Phi4MMConfig from configuration_phi4mm.py. config_path = os.path.join(local_path, 'configuration_phi4mm.py') @@ -87,8 +88,7 @@ def _load_phi4mm_classes(local_path): # `Phi-4-multimodal-instruct` as the package name to avoid relative import errors. # `hf_modeling_phi4mm` as the module name to avoid name conflicts. spec = importlib.util.spec_from_file_location( - "Phi-4-multimodal-instruct.hf_modeling_phi4mm", - modeling_phi4mm_path) + f"{package_name}.hf_modeling_phi4mm", modeling_phi4mm_path) hf_modeling_phi4mm = importlib.util.module_from_spec(spec) spec.loader.exec_module(hf_modeling_phi4mm) Phi4MMAudioEmbedding = hf_modeling_phi4mm.Phi4MMAudioEmbedding @@ -515,13 +515,17 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): # Load weights into HFPhi4MultimodalEncoder. if not _is_disagg(): filtered_weights = {} + pattern = r"(audio_embed.encoder.encoders..*.conv.glu.b[12]|image_embed.glb_GN|image_embed.img_processor.head.probe)" for k, v in weights.items(): if k.startswith("model.embed_tokens."): new_k = k.replace("model.embed_tokens.", "embed_tokens.") filtered_weights[new_k] = v elif k.startswith("model.embed_tokens_extend."): - new_k = k.replace("model.embed_tokens_extend.", "") - filtered_weights[new_k] = v + new_k = k.replace("model.embed_tokeqns_extend.", "") + if re.match(pattern, new_k): + filtered_weights[new_k] = v.unsqueeze(0) + else: + filtered_weights[new_k] = v self.hf_phi4mm_model.load_state_dict(filtered_weights, strict=True) # Filter out non-language model weights. @@ -539,6 +543,15 @@ class Phi4MMForCausalLM(transformers.PreTrainedModel): if 'base_layer.weight' in k: new_k = k.replace('base_layer.weight', 'weight') updated_weights[new_k] = weights[k] + elif 'base_layer.input_scale' in k: + new_k = k.replace('base_layer.input_scale', 'input_scale') + updated_weights[new_k] = weights[k] + elif 'base_layer.weight_scale' in k: + new_k = k.replace('base_layer.weight_scale', 'weight_scale') + updated_weights[new_k] = weights[k] + elif 'base_layer.weight_scale_2' in k: + new_k = k.replace('base_layer.weight_scale_2', 'weight_scale_2') + updated_weights[new_k] = weights[k] else: updated_weights[k] = weights[k] weights = updated_weights