feat: TRTLLM-5574 Add phi-4-multimodal pytorch-backend support (#5644)

Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
Wanli Jiang 2025-07-17 06:30:58 +08:00 committed by GitHub
parent e09e409dfb
commit 2d2b8bae32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 1277 additions and 56 deletions

View File

@ -145,7 +145,7 @@ def parse_arguments():
return args
def setup_llm(args):
def setup_llm(args, **kwargs):
kv_cache_config = KvCacheConfig(
enable_block_reuse=not args.disable_kv_cache_reuse,
free_gpu_memory_fraction=args.kv_cache_fraction,
@ -222,7 +222,9 @@ def setup_llm(args):
speculative_config=spec_config,
trust_remote_code=args.trust_remote_code,
gather_generation_logits=args.return_generation_logits,
max_beam_width=args.max_beam_width)
max_beam_width=args.max_beam_width,
**kwargs,
)
sampling_params = SamplingParams(
max_tokens=args.max_tokens,

View File

@ -7,24 +7,56 @@ from quickstart_advanced import add_llm_args, setup_llm
from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS,
default_multimodal_input_loader)
example_images = [
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
]
example_image_prompts = [
"Describe the natural environment in the image.",
"Describe the object and the weather condition in the image.",
"Describe the traffic condition on the road in the image.",
]
example_videos = [
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
]
example_video_prompts = [
"Tell me what you see in the video briefly.",
"Describe the scene in the video briefly.",
]
example_medias_and_prompts = {
"image": {
"media": [
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
],
"prompt": [
"Describe the natural environment in the image.",
"Describe the object and the weather condition in the image.",
"Describe the traffic condition on the road in the image.",
]
},
"video": {
"media": [
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
],
"prompt": [
"Tell me what you see in the video briefly.",
"Describe the scene in the video briefly.",
]
},
"audio": {
"media": [
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_the_traffic_sign_in_the_image.wav",
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav",
],
"prompt": [
"Transcribe the audio clip into text, please don't add other text.",
"Transcribe the audio clip into text, please don't add other text.",
]
},
"image_audio": {
"media": [
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
],
[
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
],
],
"prompt": [
"Describe the scene in the image briefly.",
"",
]
}
}
def add_multimodal_args(parser):
@ -34,7 +66,7 @@ def add_multimodal_args(parser):
help="Model type.")
parser.add_argument("--modality",
type=str,
choices=["image", "video"],
choices=["image", "video", "audio", "image_audio"],
default="image",
help="Media type.")
parser.add_argument("--media",
@ -53,11 +85,24 @@ def add_multimodal_args(parser):
return parser
def add_lora_args(parser):
parser.add_argument("--load_lora",
default=False,
action='store_true',
help="Whether to load the LoRA model.")
parser.add_argument("--auto_model_name",
type=str,
default=None,
help="The auto model name in TRTLLM repo.")
return parser
def parse_arguments():
parser = argparse.ArgumentParser(
description="Multimodal models with the PyTorch workflow.")
parser = add_llm_args(parser)
parser = add_multimodal_args(parser)
parser = add_lora_args(parser)
args = parser.parse_args()
args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite
@ -71,11 +116,19 @@ def main():
args = parse_arguments()
# set prompts and media to example prompts and images if they are not provided
if args.prompt is None:
args.prompt = example_image_prompts if args.modality == "image" else example_video_prompts
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
if args.media is None:
args.media = example_images if args.modality == "image" else example_videos
args.media = example_medias_and_prompts[args.modality]["media"]
llm, sampling_params = setup_llm(args)
lora_config = None
if args.load_lora:
assert args.auto_model_name is not None, "Please provide the auto model name to load LoRA config."
import importlib
models_module = importlib.import_module('tensorrt_llm._torch.models')
model_class = getattr(models_module, args.auto_model_name)
lora_config = model_class.lora_config(args.model_dir)
llm, sampling_params = setup_llm(args, lora_config=lora_config)
image_format = args.image_format
if args.model_type is not None:
@ -96,7 +149,16 @@ def main():
num_frames=args.num_frames,
device=device)
outputs = llm.generate(inputs, sampling_params)
lora_request = None
if args.load_lora:
lora_request = model_class.lora_request(len(inputs), args.modality,
llm._hf_model_dir)
outputs = llm.generate(
inputs,
sampling_params,
lora_request=lora_request,
)
for i, output in enumerate(outputs):
prompt = args.prompt[i]

View File

@ -59,3 +59,4 @@ ninja
etcd3
blake3
llguidance==0.7.29
soundfile

View File

@ -351,6 +351,8 @@ class RopeParams:
beta_slow: int = 1
mscale: float = 1.0
mscale_all_dim: float = 0.0
short_factor: Optional[Tuple[float]] = None
long_factor: Optional[Tuple[float]] = None
@staticmethod
def from_config(config) -> "RopeParams":
@ -386,12 +388,18 @@ class RopeParams:
"low_freq_factor", 1.0)
rope_params.high_freq_factor = rope_scaling.get(
"high_freq_factor", 4.0)
rope_params.original_max_positions = rope_scaling.get(
"original_max_position_embeddings", 1024)
rope_params.original_max_positions = getattr(
config,
"original_max_position_embeddings", None) or rope_scaling.get(
"original_max_position_embeddings", None) or 1024
rope_params.beta_fast = rope_scaling.get("beta_fast", 32)
rope_params.beta_slow = rope_scaling.get("beta_slow", 1)
rope_params.mscale = rope_scaling.get("mscale", 1.0)
rope_params.mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
if "short_factor" in rope_scaling:
rope_params.short_factor = tuple(rope_scaling["short_factor"])
if "long_factor" in rope_scaling:
rope_params.long_factor = tuple(rope_scaling["long_factor"])
# Workaround for DeepSeek V3 Lite since its rope_scaling is null in config.json.
elif config.model_type == "deepseek_v3":
rope_params.scale_type = RotaryScalingType.yarn
@ -428,7 +436,14 @@ class RopeParams:
self.mscale_all_dim,
)
elif self.scale_type == RotaryScalingType.longrope:
raise NotImplementedError("Long RoPE is not supported.")
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope_for_attention_plugin(
num_pos=self.max_positions,
dim=self.dim,
theta=self.theta,
original_max_pos=self.original_max_positions,
short_factor=self.short_factor,
long_factor=self.long_factor,
)
else:
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
self.max_positions,

View File

@ -15,6 +15,8 @@ from .modeling_mixtral import MixtralForCausalLM
from .modeling_nemotron import NemotronForCausalLM
from .modeling_nemotron_h import NemotronHForCausalLM
from .modeling_nemotron_nas import NemotronNASForCausalLM
from .modeling_phi3 import Phi3ForCausalLM
from .modeling_phi4mm import Phi4MMForCausalLM
from .modeling_qwen import (Qwen2ForCausalLM, Qwen2ForProcessRewardModel,
Qwen2ForRewardModel)
from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
@ -42,6 +44,8 @@ __all__ = [
"NemotronForCausalLM",
"NemotronHForCausalLM",
"NemotronNASForCausalLM",
"Phi3ForCausalLM",
"Phi4MMForCausalLM",
"Qwen2ForCausalLM",
"Qwen2ForProcessRewardModel",
"Qwen2ForRewardModel",

View File

@ -64,6 +64,7 @@ def fuse_input_embeds(
mm_token_mask = input_ids >= vocab_size
text_token_mask = input_ids < vocab_size
else:
mm_token_ids = mm_token_ids.to(input_ids.device)
mm_token_mask = torch.isin(input_ids, mm_token_ids)
text_token_mask = ~mm_token_mask
text_token_indices = torch.where(text_token_mask)[0]

View File

@ -0,0 +1,249 @@
from typing import Optional, Tuple
import torch
from torch import nn
from tqdm import tqdm
from transformers import Phi3Config
from tensorrt_llm._torch.attention_backend import AttentionMetadata
from tensorrt_llm._torch.attention_backend.interface import (
PositionalEmbeddingParams, RopeParams)
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
DecoderModelForCausalLM,
register_auto_model)
from tensorrt_llm._torch.modules.attention import Attention
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
from tensorrt_llm._torch.modules.embedding import Embedding
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
from tensorrt_llm._torch.modules.linear import TensorParallelMode
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm.functional import PositionEmbeddingType
class Phi3Attention(Attention):
def __init__(
self,
model_config: ModelConfig[Phi3Config],
layer_idx: Optional[int] = None,
):
config = model_config.pretrained_config
rope_params = RopeParams.from_config(config)
super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
pos_embd_params=PositionalEmbeddingParams(
type=PositionEmbeddingType.rope_gpt_neox,
rope=rope_params,
),
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
)
class Phi3DecoderLayer(DecoderLayer):
def __init__(
self,
model_config: ModelConfig[Phi3Config],
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
super().__init__()
config = model_config.pretrained_config
self.layer_idx = layer_idx
self.self_attn = Phi3Attention(model_config, layer_idx=layer_idx)
self.mlp = GatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=False,
dtype=config.torch_dtype,
config=model_config,
)
self.input_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
self.post_attention_layernorm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
def forward(
self,
position_ids: torch.LongTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
lora_params=None,
**kwargs,
) -> torch.Tensor:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
position_ids=None,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
lora_params=lora_params,
**kwargs,
)
# Fully connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states, **kwargs)
return hidden_states, residual
class Phi3Model(DecoderModel):
def __init__(self, model_config: ModelConfig[Phi3Config]):
super().__init__(model_config)
config = self.model_config.pretrained_config
self.padding_idx = config.pad_token_id
self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)
self.layers = nn.ModuleList([
Phi3DecoderLayer(
model_config,
layer_idx,
) for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(
hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype,
)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
lora_params=None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
residual = None
for decoder_layer in self.layers:
hidden_states, residual = decoder_layer(
hidden_states=hidden_states,
position_ids=position_ids,
residual=residual,
attn_metadata=attn_metadata,
lora_params=lora_params,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@register_auto_model("Phi3ForCausalLM")
class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
def __init__(
self,
model_config: ModelConfig[Phi3Config],
):
super().__init__(Phi3Model(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
def load_weights(self, weights: dict):
self.model_config.mapping.tp_size
hidden_size = self.config.hidden_size
num_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = hidden_size // num_heads
def filter_weights(prefix: str, weights: dict):
result = {}
for k, v in weights.items():
if k.startswith(prefix):
new_k = k[len(prefix) + 1:]
result[new_k] = v
return result
for name, module in tqdm(list(self.named_modules()),
desc="Loading weights"):
if len(module._parameters) > 0:
# skip load weights if tie word embeddings is enabled and layer is lm_head
if self.config.tie_word_embeddings and name.startswith(
'lm_head'):
continue
module_weights = filter_weights(name, weights)
if hasattr(module, 'load_weights'):
if "self_attn.qkv_proj" in name:
# The weights need to be split correctly before sharding to support tp_size >1.
qkv_weight = module_weights['weight'][:]
q_weight = qkv_weight[:hidden_size, :]
k_weight = qkv_weight[hidden_size:hidden_size +
num_kv_heads * head_dim, :]
v_weight = qkv_weight[hidden_size +
num_kv_heads * head_dim:, :]
module.load_weights(weights=[
{
'weight': q_weight
},
{
'weight': k_weight
},
{
'weight': v_weight
},
])
elif "mlp.gate_up_proj" in name:
# The weights need to be split correctly before sharding to support tp_size >1.
intermediate_size = self.config.intermediate_size
gate_up_weight = module_weights['weight'][:]
gate_weight = gate_up_weight[:intermediate_size, :]
up_weight = gate_up_weight[intermediate_size:, :]
module.load_weights(weights=[
{
'weight': gate_weight
},
{
'weight': up_weight
},
])
else:
module.load_weights(weights=[module_weights])
else:
for n, p in module._parameters.items():
if p is not None:
p.data.copy_(module_weights[n][:])

View File

@ -0,0 +1,286 @@
# Plan for phi4-mm model support.
# (done) step 1: support legacy inference pipeline for phi4-mm model.
# (todo) step 2: refactor the inference pipeline to use AGGREGATE mode (https://github.com/NVIDIA/TensorRT-LLM/pull/5522).
import copy
from typing import List, Optional, Tuple
import torch
import transformers
from PIL import Image
from ...executor.request import LoRARequest
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
register_input_processor)
from ...logger import logger
from ...lora_manager import LoraConfig
from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
from ..model_config import ModelConfig
from .modeling_auto import AutoModelForCausalLM
from .modeling_multimodal_utils import fuse_input_embeds
from .modeling_utils import register_auto_model
# Special tokens
_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>'
_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>'
# Create a PreTrainedModel class for transformers=4.53.1 upgrade.
# Core idea is to provide `prepare_inputs_for_generation` method from `GenerationMixin`.
class NewPreTrainedModel(transformers.modeling_utils.PreTrainedModel,
transformers.generation.GenerationMixin):
pass
class Phi4MMInputProcessor(InputProcessor):
def __init__(self,
model_path: str,
model_config: transformers.PretrainedConfig,
tokenizer: transformers.AutoTokenizer,
trust_remote_code: bool = True):
assert trust_remote_code, "trust_remote_code must be True for Phi4MM"
self.model_config = model_config
self.device = 'cuda'
self.tokenizer = tokenizer
self.use_fast = True
if self.tokenizer is None:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
use_fast=self.use_fast)
self.processor = transformers.AutoProcessor.from_pretrained(
model_path,
trust_remote_code=trust_remote_code,
use_fast=self.use_fast)
# Build pure-pytorch model architecture for multimodal encoder.
# Model weights are also loaded here.
OldPreTrainedModel = transformers.modeling_utils.PreTrainedModel
transformers.modeling_utils.PreTrainedModel = NewPreTrainedModel
# TODO: Make separate Phi4VisionEncoder and Phi4AudioEncoder, and move them to LLM-side.
ref_phi4mm_model = transformers.AutoModelForCausalLM.from_pretrained(
model_path,
trust_remote_code=True,
# Flash_attn_2 only supports bf16 or fp16 and set in HF config.
torch_dtype='auto',
_attn_implementation='flash_attention_2',
).eval()
transformers.modeling_utils.PreTrainedModel = OldPreTrainedModel
self.phi4mm_modal_encoder = ref_phi4mm_model.model.embed_tokens_extend.to(
self.device)
# Required by Phi4MMImageAudioEmbedding.
# See link: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L701
self.phi4mm_wte = ref_phi4mm_model.model.embed_tokens.to(self.device)
@torch.inference_mode()
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
text_prompt, mm_data, mm_processor_kwargs = inputs.get("prompt"), \
inputs.get("multi_modal_data", {}), inputs.get("mm_processor_kwargs", {})
images = mm_data.get("image", None)
audios = mm_data.get("audio", None)
if images is not None:
if isinstance(images[0], torch.Tensor):
# Convert normalized tensors (0-1) to PIL images (0-255).
images = [
Image.fromarray((image.permute(1, 2, 0) * 255).to(
torch.uint8).cpu().numpy()) for image in images
]
# Preprocessing for multimodal data.
inputs = self.processor(text=[text_prompt],
images=images,
audios=audios,
return_tensors='pt').to(self.device)
# Set audio_projection_mode according to the modality.
# Ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L2103
if images is not None:
audio_projection_mode = 'vision'
elif audios is not None:
audio_projection_mode = 'speech'
else:
audio_projection_mode = 'speech'
# Processing with Phi4MMImageAudioEmbedding.
mm_features = self.phi4mm_modal_encoder(
input_ids=inputs['input_ids'],
input_embeds=None,
input_image_embeds=inputs['input_image_embeds'],
input_audio_embeds=inputs['input_audio_embeds'],
image_sizes=inputs['image_sizes'],
image_attention_mask=inputs['image_attention_mask'],
audio_embed_sizes=inputs['audio_embed_sizes'],
audio_attention_mask=inputs['audio_attention_mask'],
audio_projection_mode=audio_projection_mode,
wte=self.phi4mm_wte,
)
# Postprocessing to get multimodal-only embeddings.
image_token_mask = inputs['input_ids'] == _IMAGE_SPECIAL_TOKEN_ID
audio_token_mask = inputs['input_ids'] == _AUDIO_SPECIAL_TOKEN_ID
mm_token_mask = image_token_mask | audio_token_mask
mm_features = mm_features[mm_token_mask]
multimodal_data = {}
multimodal_data["multimodal_embedding"] = mm_features
return inputs['input_ids'][0].to(torch.int32).tolist(), {
"multimodal_data": multimodal_data,
}
@register_auto_model("Phi4MMForCausalLM")
@register_input_processor(Phi4MMInputProcessor, model_type="phi4mm")
class Phi4MMForCausalLM(transformers.PreTrainedModel):
_supports_flash_attn_2 = True
MM_TOKEN_IDS = torch.tensor(
[_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID])
def __init__(self, model_config: ModelConfig):
config = model_config.pretrained_config
super().__init__(config)
self.model_config = model_config
if hasattr(self, "llm"):
return
# We use Phi3ForCausalLM as the language model.
llm_model_config = copy.deepcopy(model_config)
llm_model_config.pretrained_config.architectures = ["Phi3ForCausalLM"]
# Only build the language model architecture without loading weights.
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
self.vocab_size = config.vocab_size
self.model_dtype = getattr(config, "torch_dtype", torch.float16)
logger.info(f"{self.dtype=} {self.model_dtype=}")
self.post_config()
self.is_loaded = True
def load_weights(self, weights):
# Filter out non-language model weights.
weights = {
k: v
for k, v in weights.items()
if not k.startswith('model.embed_tokens_extend')
}
# Filter out LoRA weights.
# LoRA weights will be loaded by LoraManager.
weights = {k: v for k, v in weights.items() if '.lora_' not in k}
# Rename base layer weights.
updated_weights = {}
for k in weights.keys():
if 'base_layer.weight' in k:
new_k = k.replace('base_layer.weight', 'weight')
updated_weights[new_k] = weights[k]
else:
updated_weights[k] = weights[k]
weights = updated_weights
self.llm.load_weights(weights)
def infer_max_seq_len(self) -> int:
return self.llm.infer_max_seq_len()
def post_config(self):
# use llm.config as config for pytorch model engine
self.config = self.llm.config
self.model_config.pretrained_config = self.llm.config
@torch.inference_mode()
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
input_embeds: Optional[torch.Tensor] = None,
return_context_logits: bool = False,
**kwargs,
) -> torch.Tensor:
"""
VLM forward logic with inflight batching support.
"""
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
logger.debug(
f"num_context_requests: {num_context_requests}, num_generation_requests: {num_generation_requests}"
)
multimodal_params = kwargs.get("multimodal_params", [])
mm_embedding = [
multimodal_param.multimodal_data["multimodal_embedding"]
for multimodal_param in multimodal_params
]
input_ids, input_embeds = fuse_input_embeds(
self.llm.model.embed_tokens,
input_ids,
mm_embedding,
mm_token_ids=self.MM_TOKEN_IDS,
)
output_prob = self.llm.forward(
attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=input_embeds,
return_context_logits=return_context_logits,
lora_params=kwargs.get("lora_params", None),
)
logger.debug(f'output shape: {output_prob.shape}')
return output_prob
@staticmethod
def lora_config(model_dir: str):
_lora_config = LoraConfig(
lora_dir=[
f"{model_dir}/vision-lora",
f"{model_dir}/speech-lora",
],
lora_target_modules=[
"attn_qkv",
"attn_dense",
"mlp_h_to_4h",
"mlp_4h_to_h",
],
trtllm_modules_to_hf_modules={
"attn_qkv": "qkv_proj",
"attn_dense": "o_proj",
"mlp_h_to_4h": "gate_up_proj",
"mlp_4h_to_h": "down_proj",
},
max_lora_rank=320, # Max rank for Phi4MM.
)
return _lora_config
@staticmethod
def lora_request(num_requests: int, modality: str, base_model_dir: str):
# Prepare LoRA requests for different modalities.
# Ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L2103
lora_request = None
if modality == "image" or modality == "image_audio":
lora_request = [
LoRARequest(
lora_name=f"vision-lora-{i}",
lora_int_id=i,
lora_path=f"{base_model_dir}/vision-lora",
) for i in range(num_requests)
]
elif modality == "audio":
lora_request = [
LoRARequest(
lora_name=f"speech-lora-{i}",
lora_int_id=i,
lora_path=f"{base_model_dir}/speech-lora",
) for i in range(num_requests)
]
return lora_request

View File

@ -4834,6 +4834,48 @@ class RopeEmbeddingUtils:
True), _compute_sinusoidal_positions(
scaling_long_factors, False, True), short_mscale
@staticmethod
def create_sinusoidal_positions_long_rope_for_attention_plugin(
num_pos: int,
dim: int,
theta: float,
original_max_pos: int,
short_factor: List[float],
long_factor: List[float],
dtype=np.float32):
short_factor = np.array(short_factor, dtype=np.float32)
long_factor = np.array(long_factor, dtype=np.float32)
inv_freq = 1.0 / (theta**(np.arange(0, dim, 2, dtype=np.float32) / dim))
# Short part
inv_freq_short = inv_freq / short_factor
t_short = np.arange(np.min([num_pos, original_max_pos]),
dtype=np.float32)
freqs_short = np.einsum("i,j->ij", t_short, inv_freq_short)
# Long part
inv_freq_long = inv_freq / long_factor
t_long = np.arange(np.max([0, num_pos - original_max_pos]),
dtype=np.float32) + original_max_pos
freqs_long = np.einsum("i,j->ij", t_long, inv_freq_long)
freqs = np.concatenate([freqs_short, freqs_long], axis=0)
sinusoid_inp = freqs.astype(np.float32)[..., np.newaxis]
# Apply scaling
scale = num_pos / original_max_pos
scaling_factor = np.sqrt(1.0 + np.log(scale) / np.log(original_max_pos))
# fuse cos/sin into float2 (cos, sin).
concat = np.concatenate(
(np.cos(sinusoid_inp) * scaling_factor,
np.sin(sinusoid_inp) * scaling_factor),
axis=-1,
)
return None, concat.reshape(1, -1).astype(dtype)
@staticmethod
def create_fake_weight(dim: int, dtype=np.half):
return np.random.rand(dim).astype(dtype)

View File

@ -5,8 +5,9 @@ from .registry import (ExtraProcessedInputs, InputProcessor,
register_input_processor)
from .utils import (ALL_SUPPORTED_MULTIMODAL_MODELS, ConversationMessage,
MultimodalData, MultimodalDataTracker,
add_multimodal_placeholders, async_load_image,
async_load_video, default_multimodal_input_loader,
add_multimodal_placeholders, async_load_audio,
async_load_image, async_load_video,
default_multimodal_input_loader,
encode_base64_content_from_url, load_image, load_video)
__all__ = [
@ -24,6 +25,7 @@ __all__ = [
"MultimodalDataTracker",
"MultimodalData",
"MultimodalInput",
"async_load_audio",
"async_load_image",
"async_load_video",
"add_multimodal_placeholders",

View File

@ -5,12 +5,13 @@ import tempfile
from collections import defaultdict
from io import BytesIO
from pathlib import Path
from typing import Any, Coroutine, Dict, List, Optional, TypedDict, Union
from typing import Any, Coroutine, Dict, List, Optional, Tuple, TypedDict, Union
from urllib.parse import urlparse
import aiohttp
import numpy as np
import requests
import soundfile
import torch
from PIL import Image
from torchvision.transforms import ToTensor
@ -159,6 +160,35 @@ async def async_load_video(
return load_video(video_path, num_frames, format, device)
def load_audio(
audio: str,
format: str = "pt",
device: str = "cuda",
) -> Tuple[np.ndarray, int]:
parsed_url = urlparse(audio)
if parsed_url.scheme in ["http", "https"]:
audio = requests.get(audio, stream=True, timeout=10)
audio = BytesIO(audio.content)
audio = soundfile.read(audio)
return audio
async def async_load_audio(
audio: str,
format: str = "pt",
device: str = "cuda",
) -> Tuple[np.ndarray, int]:
parsed_url = urlparse(audio)
if parsed_url.scheme in ["http", "https"]:
async with aiohttp.ClientSession() as session:
async with session.get(audio) as response:
audio = BytesIO(await response.content.read())
audio = soundfile.read(audio)
return audio
# Copied from https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_client_for_multimodal.py#L38
def encode_base64_content_from_url(content_url: str) -> str:
"""Encode a content retrieved from a remote url to base64 format."""
@ -186,19 +216,24 @@ SUPPORTED_LLAVA_IMAGE_MODEL_GROUP = ["llava_llama", "llava_next"]
SUPPORTED_LLAVA_VIDEO_MODEL_GROUP = ["llava_llama"]
SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP = ["mistral3"]
SUPPORTED_HYPERCLOVAX_MODEL_GROUP = ["hyperclovax_vlm"]
SUPPORTED_PHI_MODEL_GROUP = ["phi4mm"]
ALL_SUPPORTED_IMAGE_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
+ SUPPORTED_LLAMA_MODEL_GROUP \
+ SUPPORTED_LLAVA_IMAGE_MODEL_GROUP \
+ SUPPORTED_HYPERCLOVAX_MODEL_GROUP \
+ SUPPORTED_GEMMA_MODEL_GROUP \
+ SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP
+ SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP \
+ SUPPORTED_PHI_MODEL_GROUP
ALL_SUPPORTED_VIDEO_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
+ SUPPORTED_LLAVA_VIDEO_MODEL_GROUP
ALL_SUPPORTED_AUDIO_MODELS = SUPPORTED_PHI_MODEL_GROUP
ALL_SUPPORTED_MULTIMODAL_MODELS = list(set(ALL_SUPPORTED_IMAGE_MODELS) \
| set(ALL_SUPPORTED_VIDEO_MODELS))
| set(ALL_SUPPORTED_VIDEO_MODELS) \
| set(ALL_SUPPORTED_AUDIO_MODELS))
HF_CHAT_TEMPLATE_EXCEPTIONS = ["llava_llama"]
PLACEHOLDER_EXCEPTIONS = ["llava_next"]
@ -223,6 +258,7 @@ PLACEHOLDER_PLACEMENT_MAP = {
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
# src/mistral_common/tokens/tokenizers/base.py#L326
"mistral3": MultimodalPlaceholderPlacement.AFTER_TEXT,
"phi4mm": MultimodalPlaceholderPlacement.BEFORE_TEXT,
}
assert len(PLACEHOLDER_PLACEMENT_MAP) == len(ALL_SUPPORTED_MULTIMODAL_MODELS)
@ -235,7 +271,7 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
Args:
model_type: The type of the multimodal model.
modality: The modality of the data.
current_count: The number of multimodal data already added. Currently not used.
current_count: The number of multimodal data already added.
"""
@ -257,6 +293,8 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
# Ref: https://github.com/mistralai/mistral-common/blob/26a6bb3a07ee0b78a3808f2797f23e1d28514b93/
# src/mistral_common/tokens/tokenizers/base.py#L60
return "[IMG]"
elif model_type in SUPPORTED_PHI_MODEL_GROUP:
return f"<|image_{current_count}|>"
raise TypeError(
f"For image modality, only {ALL_SUPPORTED_IMAGE_MODELS} are supported but got {model_type}"
)
@ -268,6 +306,9 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
raise TypeError(
f"For video modality, only {ALL_SUPPORTED_VIDEO_MODELS} are supported but got {model_type}"
)
elif modality == "audio":
if model_type in SUPPORTED_PHI_MODEL_GROUP:
return f"<|audio_{current_count}|>"
raise TypeError(f"Unknown modality: {modality}")
@ -343,7 +384,10 @@ def add_multimodal_placeholders(model_type: str, text_prompt: str,
case MultimodalPlaceholderPlacement.AFTER_TEXT:
parts.append(text_prompt)
parts.extend(placeholders)
return "\n".join(parts)
if model_type == "phi4mm":
return "".join(parts)
else:
return "\n".join(parts)
def resolve_hf_chat_template(
@ -458,6 +502,34 @@ def default_multimodal_input_loader(
format=image_data_format,
device=device)) for i in media
]
elif modality == "audio":
mm_data = [
MultimodalData(modality=modality,
data=load_audio(i, device=device)) for i in media
]
elif modality == "image_audio":
# Use different load_xxx functions to match the modality.
mm_data = []
for m in media:
data = None
_modal = None
if _modal is None:
try:
data = load_image(m,
format=image_data_format,
device=device)
_modal = "image"
except Exception:
pass
if _modal is None:
try:
data = load_audio(m, device=device)
_modal = "audio"
except Exception:
pass
if _modal is None:
raise ValueError(f"Unknown matching modality: {modality}")
mm_data.append(MultimodalData(modality=_modal, data=data))
else:
raise ValueError(f"Unknown modality: {modality}")
return ConversationMessage(role="user", content=prompt, media=mm_data)

View File

@ -2,7 +2,8 @@ from functools import partial
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
from openai.types.chat import ChatCompletionContentPartImageParam
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartInputAudioParam)
from openai.types.chat import \
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
from openai.types.chat import (ChatCompletionContentPartTextParam,
@ -12,8 +13,8 @@ from typing_extensions import Required
from tensorrt_llm.inputs import (ConversationMessage, MultimodalData,
MultimodalDataTracker,
add_multimodal_placeholders, async_load_image,
async_load_video)
add_multimodal_placeholders, async_load_audio,
async_load_image, async_load_video)
from tensorrt_llm.logger import logger
@ -33,12 +34,16 @@ ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
str]
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ["text", "image_url", "video_url"]
# TODO: Add "input_audio" to support byte_encoded audio input.
VALID_MESSAGE_CONTENT_MM_PART_TYPES = [
"text", "image_url", "video_url", "audio_url"
]
# Parser Functions
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
_AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
str, dict[str, str]]]] = {
@ -48,6 +53,8 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"video_url":
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
"audio_url":
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
}
@ -74,7 +81,7 @@ def parse_chat_message_content_part(
part_type, content = _parse_chat_message_content_mm_part(part)
# if part_type is text/image_url/video_url but content is None, log a warning and skip
# if part_type is text/image_url/video_url/audio_url but content is None, log a warning and skip
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
logger.warning(
"Skipping multimodal part '%s' (type: '%s') with empty / unparsable content.",
@ -108,6 +115,18 @@ def parse_chat_message_content_part(
return MultimodalData(modality="video", data=load_video_async())
if part_type == "audio_url":
str_content = cast(str, content)
async def load_audio_async():
try:
return await async_load_audio(str_content)
except Exception as e:
logger.error(f"Failed to load audio: {str(e)}")
return None
return MultimodalData(modality="audio", data=load_audio_async())
raise NotImplementedError(f"Unknown part type: {part_type}")

View File

@ -115,3 +115,5 @@ mistralai/Ministral-8B-Instruct-2410:
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 78.35
microsoft/Phi-4-multimodal-instruct:
- accuracy: 81.19

View File

@ -199,3 +199,5 @@ mistralai/Ministral-8B-Instruct-2410:
- quant_algo: FP8
kv_cache_quant_algo: FP8
accuracy: 65.96
microsoft/Phi-4-multimodal-instruct:
- accuracy: 69.69

View File

@ -1847,3 +1847,16 @@ class TestBielik11BInstruct(LlmapiAccuracyTestHarness):
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)
class TestPhi4MM(LlmapiAccuracyTestHarness):
# phi4-mm can also support text input.
MODEL_NAME = "microsoft/Phi-4-multimodal-instruct"
MODEL_PATH = f"{llm_models_root()}/multimodals/Phi-4-multimodal-instruct"
def test_auto_dtype(self):
with LLM(self.MODEL_PATH) as llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

View File

@ -1544,7 +1544,25 @@ def test_build_time_benchmark_sanity(llm_root, llm_venv):
])
### Pivot-To-Python examples
### PyTorch examples
def parse_output(text):
results = []
text_lists = re.split(r"\[\d+\] Prompt:", text)
for item in text_lists:
item = item.replace(os.linesep, "")
while True:
match = re.search(r"(Generated text: \'(.*?)\')", item,
re.MULTILINE)
if match is None:
break
_, end = match.span(1)
results.append(match.group(2))
item = item[end:]
return results
def test_ptp_quickstart(llm_root, llm_venv):
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
@ -2101,21 +2119,6 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
output = llm_venv.run_cmd(cmd, caller=check_output)
def parse_output(text):
results = []
text_lists = re.split(r"\[\d+\] Prompt:", text)
for item in text_lists:
item = item.replace(os.linesep, "")
while True:
match = re.search(r"(Generated text: \'(.*?)\')", item,
re.MULTILINE)
if match is None:
break
_, end = match.span(1)
results.append(match.group(2))
item = item[end:]
return results
match_ratio = 4.0 / 5
if model_name == "qwen2-vl-7b-instruct" and modality == "image":
match_ratio = 4.0 / 6
@ -2182,6 +2185,92 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
_check_mem_usage(running_log, [peak, 0, 0, 0])
@pytest.mark.parametrize("modality", ["image", "audio", "image_audio"])
def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
model_name = "Phi-4-multimodal-instruct"
model_path = "multimodals/Phi-4-multimodal-instruct"
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
test_data_root = Path(
os.path.join(llm_models_root(), "multimodals", "test_data"))
audio_data_root = Path(
os.path.join(llm_models_root(), "multimodals",
"Phi-4-multimodal-instruct", "examples"))
print(f"Accuracy test {model_name} {modality} mode with example inputs.")
accuracy_inputs = {
"image": {
"prompt": [
"Describe the object and the weather condition in the image.",
"Describe the traffic condition on the road in the image.",
],
"media": [
str(test_data_root / "inpaint.png"),
str(test_data_root / "61.jpg"),
],
},
"audio": {
"prompt": [
"Transcribe the audio clip into text, please don't add other text.",
"Transcribe the audio clip into text, please don't add other text.",
],
"media": [
str(audio_data_root /
"what_is_the_traffic_sign_in_the_image.wav"),
str(audio_data_root / "what_is_shown_in_this_image.wav"),
],
},
"image_audio": {
"prompt": [
"",
],
"media": [
str(test_data_root / "inpaint.png"),
str(audio_data_root / "what_is_shown_in_this_image.wav"),
],
}
}
expected_keywords = {
"image": [
["clear", "sunny", "sky", "image", "object"],
["road", "car", "lane", "strip", "bus"],
],
"audio": [
["what", "is", "the", "traffic", "sign", "in", "image"],
["what", "is", "shown", "in", "this", "image"],
],
"image_audio": [
["Half", "Dome", "Park", "natural", "image"],
],
}
cmd = [
str(example_root / "quickstart_multimodal.py"),
"--model_dir",
f"{llm_models_root()}/{model_path}",
"--modality",
modality,
"--prompt",
*accuracy_inputs[modality]["prompt"],
"--media",
*accuracy_inputs[modality]["media"],
"--load_lora",
"--auto_model_name",
"Phi4MMForCausalLM",
]
output = llm_venv.run_cmd(cmd, caller=check_output)
match_ratio = 0.6
for prompt_output, prompt_keywords in zip(parse_output(output),
expected_keywords[modality]):
matches = [
keyword in prompt_output.lower() for keyword in prompt_keywords
]
obs_match_ratio = 1. * sum(matches) / len(matches)
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"
print("All answers are correct!")
@pytest.mark.parametrize("model_name,model_path", [
("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"),
])

View File

@ -486,6 +486,7 @@ accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_fp8
accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_auto_dtype
accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype
test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-]
test_e2e.py::test_llama_e2e[use_py_session-remove_input_padding-]
@ -528,6 +529,9 @@ test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistr
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-False]
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]

View File

@ -15,6 +15,7 @@ l0_a30:
tests:
# ------------- PyTorch tests ---------------
- unittest/_torch/modeling -k "modeling_nemotron_nas"
- unittest/_torch/modeling -k "modeling_phi3"
- unittest/_torch/modeling -k "modeling_qwen"
- unittest/_torch/modeling -k "modeling_qwen_moe"
- unittest/_torch/auto_deploy/unit/singlegpu

View File

@ -28,6 +28,9 @@ l0_l40s:
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image-True]
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-False]
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-True]
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio]
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
- test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
- test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
- condition:

View File

@ -0,0 +1,352 @@
import unittest
from copy import deepcopy
from dataclasses import dataclass
from typing import Any
import torch
from transformers import Phi3Config
from transformers import Phi3ForCausalLM as HFPhi3ForCausalLM
from utils.util import default_dtype
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_phi3 import Phi3ForCausalLM
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
PHI3_MINI_4K_CONFIG = {
"_name_or_path": "Phi-3-mini-4k-instruct",
"architectures": ["Phi3ForCausalLM"],
"attention_dropout": 0.0,
"auto_map": {
"AutoConfig": "configuration_phi3.Phi3Config",
"AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM"
},
"bos_token_id": 1,
"embd_pdrop": 0.0,
"eos_token_id": 32000,
"hidden_act": "silu",
"hidden_size": 3072,
"initializer_range": 0.02,
"intermediate_size": 8192,
"max_position_embeddings": 4096,
"model_type": "phi3",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 32,
"original_max_position_embeddings": 4096,
"pad_token_id": 32000,
"resid_pdrop": 0.0,
"rms_norm_eps": 1e-05,
"rope_scaling": None,
"rope_theta": 10000.0,
"sliding_window": 2047,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.2",
"use_cache": True,
"attention_bias": False,
"vocab_size": 32064
}
@dataclass(repr=False)
class Scenario:
backend: str
use_cuda_graph: bool = False
def __repr__(self) -> str:
return f"backend:{self.backend.lower()}-use_cuda_graph:{self.use_cuda_graph}"
def reduce_phi3_config(mem_for_full_model: int,
config_dict: dict[str, Any],
default_num_layers: int = 32):
_, total_mem = torch.cuda.mem_get_info()
# scale model down if gpu memory is low
if total_mem < mem_for_full_model:
model_fraction = total_mem / mem_for_full_model
num_layers = int(config_dict["num_hidden_layers"] * model_fraction)
num_layers = min(num_layers, default_num_layers)
config_dict["num_hidden_layers"] = num_layers
class TestPhi3(unittest.TestCase):
def test_phi3_sanity(self):
config_dict = deepcopy(PHI3_MINI_4K_CONFIG)
# 8B * sizeof(float16) plus some extra for activations
mem_for_full_model = (2 + 1) * 8 * 2**(30)
reduce_phi3_config(mem_for_full_model, config_dict)
if config_dict["num_hidden_layers"] <= 0:
self.skipTest("Insufficient memory for a single Phi3 layer")
phi3_config = Phi3Config.from_dict(config_dict)
dtype = phi3_config.torch_dtype
device = torch.device('cuda')
with torch.device(device), default_dtype(dtype):
model_config = ModelConfig(pretrained_config=phi3_config)
phi3 = Phi3ForCausalLM(model_config).to(device)
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int,
device=device)
context_sequence_lengths = [3, 2, 1]
sequence_lengths = context_sequence_lengths + [1, 1]
past_seen_tokens = [0, 0, 0, 62, 75]
request_ids = list(range(len(sequence_lengths)))
token_nums = (torch.tensor(past_seen_tokens) +
torch.tensor(sequence_lengths)).tolist()
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
num_blocks = 100
tokens_per_block = 128
head_dim = phi3.config.hidden_size // phi3.config.num_attention_heads
num_layers = phi3.config.num_hidden_layers
num_kv_heads = phi3.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = len(context_sequence_lengths) + 2
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
attn_metadata = metadata_cls(
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
num_contexts=len(context_sequence_lengths),
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=past_seen_tokens,
),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
max_num_requests=len(context_sequence_lengths) + 2,
max_num_tokens=8192,
)
position_ids = []
for i, tokens in enumerate(past_seen_tokens):
seq_len = context_sequence_lengths[i] if i < len(
context_sequence_lengths) else 1
position_id = torch.arange(tokens,
tokens + seq_len,
device=input_ids.device)
position_ids.append(position_id)
position_ids = torch.cat(position_ids).unsqueeze(0)
with torch.inference_mode():
attn_metadata.prepare()
logits = phi3.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
self.assertEqual(len(past_seen_tokens), logits.shape[0])
with torch.inference_mode():
attn_metadata.prepare()
logits = phi3.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
return_context_logits=True)
self.assertEqual(input_ids.shape, logits.shape[:-1])
kv_cache_manager.shutdown()
@torch.no_grad()
def test_phi3_allclose_to_hf(self) -> None:
"""
Compare output to HF
"""
scenario = Scenario(backend="TRTLLM")
backend = scenario.backend
metadata_cls = get_attention_backend(backend).Metadata
torch.random.manual_seed(0)
config_dict = deepcopy(PHI3_MINI_4K_CONFIG)
# 8B * sizeof(float16) plus some extra for activations
# times 2, since we'll need 2 of these
mem_for_full_model = (2 + 1) * 8 * 2**(30) * 4
reduce_phi3_config(mem_for_full_model, config_dict)
if config_dict["num_hidden_layers"] <= 0:
self.skipTest("Insufficient memory for a single Phi3 layer")
phi3_config = Phi3Config.from_dict(config_dict)
dtype = phi3_config.torch_dtype
device = torch.device('cuda')
with torch.device(device), default_dtype(dtype):
hf_phi3 = HFPhi3ForCausalLM(phi3_config).eval()
model_config = ModelConfig(pretrained_config=phi3_config,
attn_backend=backend)
phi3 = Phi3ForCausalLM(model_config).to(dtype).to(device)
phi3.load_weights(hf_phi3.state_dict())
num_blocks = 1
tokens_per_block = 128
head_dim = phi3.config.hidden_size // phi3.config.num_attention_heads
num_layers = phi3.config.num_hidden_layers
num_kv_heads = phi3.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = 1
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
max_seq_len=max_seq_len,
max_batch_size=batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
)
# context
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int,
device=device)
num_cached_tokens_per_seq = [0]
request_ids = [1]
token_nums = [input_ids.size(-1)]
prompt_lens = [input_ids.size(-1)]
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
attn_metadata = metadata_cls(
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
num_contexts=1,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
# Note: no CUDA graphs for prefill, the graph runner is built for
# decoding only.
position_ids = [torch.arange(0, input_ids.size(-1))]
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = phi3.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
ref = hf_phi3.forward(input_ids=input_ids.unsqueeze(0),
position_ids=position_ids,
use_cache=True)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4)
# gen
gen_input_ids = torch.tensor([600], dtype=torch.int, device=device)
num_cached_tokens_per_seq = [input_ids.size(-1)]
attn_metadata = metadata_cls(
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
num_contexts=0,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
)
gen_position_ids = [
torch.arange(input_ids.size(-1),
input_ids.size(-1) + gen_input_ids.size(-1))
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
return phi3.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: phi3.forward(**inputs))
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
return logits
if scenario.use_cuda_graph:
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
with torch.inference_mode():
logits = run_forward(input_ids=gen_input_ids,
position_ids=gen_position_ids,
attn_metadata=attn_metadata)
ref = hf_phi3.forward(input_ids=gen_input_ids.unsqueeze(0),
position_ids=gen_position_ids,
past_key_values=ref.past_key_values,
use_cache=True)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.4,
rtol=0.4)
kv_cache_manager.shutdown()