mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
feat: TRTLLM-5574 Add phi-4-multimodal pytorch-backend support (#5644)
Signed-off-by: Wanli Jiang <35160485+Wanli-Jiang@users.noreply.github.com>
This commit is contained in:
parent
e09e409dfb
commit
2d2b8bae32
@ -145,7 +145,7 @@ def parse_arguments():
|
||||
return args
|
||||
|
||||
|
||||
def setup_llm(args):
|
||||
def setup_llm(args, **kwargs):
|
||||
kv_cache_config = KvCacheConfig(
|
||||
enable_block_reuse=not args.disable_kv_cache_reuse,
|
||||
free_gpu_memory_fraction=args.kv_cache_fraction,
|
||||
@ -222,7 +222,9 @@ def setup_llm(args):
|
||||
speculative_config=spec_config,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
gather_generation_logits=args.return_generation_logits,
|
||||
max_beam_width=args.max_beam_width)
|
||||
max_beam_width=args.max_beam_width,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=args.max_tokens,
|
||||
|
||||
@ -7,24 +7,56 @@ from quickstart_advanced import add_llm_args, setup_llm
|
||||
from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS,
|
||||
default_multimodal_input_loader)
|
||||
|
||||
example_images = [
|
||||
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
||||
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
|
||||
]
|
||||
example_image_prompts = [
|
||||
"Describe the natural environment in the image.",
|
||||
"Describe the object and the weather condition in the image.",
|
||||
"Describe the traffic condition on the road in the image.",
|
||||
]
|
||||
example_videos = [
|
||||
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
|
||||
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
|
||||
]
|
||||
example_video_prompts = [
|
||||
"Tell me what you see in the video briefly.",
|
||||
"Describe the scene in the video briefly.",
|
||||
]
|
||||
example_medias_and_prompts = {
|
||||
"image": {
|
||||
"media": [
|
||||
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
||||
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
|
||||
],
|
||||
"prompt": [
|
||||
"Describe the natural environment in the image.",
|
||||
"Describe the object and the weather condition in the image.",
|
||||
"Describe the traffic condition on the road in the image.",
|
||||
]
|
||||
},
|
||||
"video": {
|
||||
"media": [
|
||||
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
|
||||
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
|
||||
],
|
||||
"prompt": [
|
||||
"Tell me what you see in the video briefly.",
|
||||
"Describe the scene in the video briefly.",
|
||||
]
|
||||
},
|
||||
"audio": {
|
||||
"media": [
|
||||
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_the_traffic_sign_in_the_image.wav",
|
||||
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav",
|
||||
],
|
||||
"prompt": [
|
||||
"Transcribe the audio clip into text, please don't add other text.",
|
||||
"Transcribe the audio clip into text, please don't add other text.",
|
||||
]
|
||||
},
|
||||
"image_audio": {
|
||||
"media": [
|
||||
[
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
||||
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
|
||||
],
|
||||
[
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
||||
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
|
||||
],
|
||||
],
|
||||
"prompt": [
|
||||
"Describe the scene in the image briefly.",
|
||||
"",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def add_multimodal_args(parser):
|
||||
@ -34,7 +66,7 @@ def add_multimodal_args(parser):
|
||||
help="Model type.")
|
||||
parser.add_argument("--modality",
|
||||
type=str,
|
||||
choices=["image", "video"],
|
||||
choices=["image", "video", "audio", "image_audio"],
|
||||
default="image",
|
||||
help="Media type.")
|
||||
parser.add_argument("--media",
|
||||
@ -53,11 +85,24 @@ def add_multimodal_args(parser):
|
||||
return parser
|
||||
|
||||
|
||||
def add_lora_args(parser):
|
||||
parser.add_argument("--load_lora",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to load the LoRA model.")
|
||||
parser.add_argument("--auto_model_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The auto model name in TRTLLM repo.")
|
||||
return parser
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Multimodal models with the PyTorch workflow.")
|
||||
parser = add_llm_args(parser)
|
||||
parser = add_multimodal_args(parser)
|
||||
parser = add_lora_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.disable_kv_cache_reuse = True # kv cache reuse does not work for multimodal, force overwrite
|
||||
@ -71,11 +116,19 @@ def main():
|
||||
args = parse_arguments()
|
||||
# set prompts and media to example prompts and images if they are not provided
|
||||
if args.prompt is None:
|
||||
args.prompt = example_image_prompts if args.modality == "image" else example_video_prompts
|
||||
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
|
||||
if args.media is None:
|
||||
args.media = example_images if args.modality == "image" else example_videos
|
||||
args.media = example_medias_and_prompts[args.modality]["media"]
|
||||
|
||||
llm, sampling_params = setup_llm(args)
|
||||
lora_config = None
|
||||
if args.load_lora:
|
||||
assert args.auto_model_name is not None, "Please provide the auto model name to load LoRA config."
|
||||
import importlib
|
||||
models_module = importlib.import_module('tensorrt_llm._torch.models')
|
||||
model_class = getattr(models_module, args.auto_model_name)
|
||||
lora_config = model_class.lora_config(args.model_dir)
|
||||
|
||||
llm, sampling_params = setup_llm(args, lora_config=lora_config)
|
||||
|
||||
image_format = args.image_format
|
||||
if args.model_type is not None:
|
||||
@ -96,7 +149,16 @@ def main():
|
||||
num_frames=args.num_frames,
|
||||
device=device)
|
||||
|
||||
outputs = llm.generate(inputs, sampling_params)
|
||||
lora_request = None
|
||||
if args.load_lora:
|
||||
lora_request = model_class.lora_request(len(inputs), args.modality,
|
||||
llm._hf_model_dir)
|
||||
|
||||
outputs = llm.generate(
|
||||
inputs,
|
||||
sampling_params,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
prompt = args.prompt[i]
|
||||
|
||||
@ -59,3 +59,4 @@ ninja
|
||||
etcd3
|
||||
blake3
|
||||
llguidance==0.7.29
|
||||
soundfile
|
||||
|
||||
@ -351,6 +351,8 @@ class RopeParams:
|
||||
beta_slow: int = 1
|
||||
mscale: float = 1.0
|
||||
mscale_all_dim: float = 0.0
|
||||
short_factor: Optional[Tuple[float]] = None
|
||||
long_factor: Optional[Tuple[float]] = None
|
||||
|
||||
@staticmethod
|
||||
def from_config(config) -> "RopeParams":
|
||||
@ -386,12 +388,18 @@ class RopeParams:
|
||||
"low_freq_factor", 1.0)
|
||||
rope_params.high_freq_factor = rope_scaling.get(
|
||||
"high_freq_factor", 4.0)
|
||||
rope_params.original_max_positions = rope_scaling.get(
|
||||
"original_max_position_embeddings", 1024)
|
||||
rope_params.original_max_positions = getattr(
|
||||
config,
|
||||
"original_max_position_embeddings", None) or rope_scaling.get(
|
||||
"original_max_position_embeddings", None) or 1024
|
||||
rope_params.beta_fast = rope_scaling.get("beta_fast", 32)
|
||||
rope_params.beta_slow = rope_scaling.get("beta_slow", 1)
|
||||
rope_params.mscale = rope_scaling.get("mscale", 1.0)
|
||||
rope_params.mscale_all_dim = rope_scaling.get("mscale_all_dim", 0.0)
|
||||
if "short_factor" in rope_scaling:
|
||||
rope_params.short_factor = tuple(rope_scaling["short_factor"])
|
||||
if "long_factor" in rope_scaling:
|
||||
rope_params.long_factor = tuple(rope_scaling["long_factor"])
|
||||
# Workaround for DeepSeek V3 Lite since its rope_scaling is null in config.json.
|
||||
elif config.model_type == "deepseek_v3":
|
||||
rope_params.scale_type = RotaryScalingType.yarn
|
||||
@ -428,7 +436,14 @@ class RopeParams:
|
||||
self.mscale_all_dim,
|
||||
)
|
||||
elif self.scale_type == RotaryScalingType.longrope:
|
||||
raise NotImplementedError("Long RoPE is not supported.")
|
||||
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_long_rope_for_attention_plugin(
|
||||
num_pos=self.max_positions,
|
||||
dim=self.dim,
|
||||
theta=self.theta,
|
||||
original_max_pos=self.original_max_positions,
|
||||
short_factor=self.short_factor,
|
||||
long_factor=self.long_factor,
|
||||
)
|
||||
else:
|
||||
rope_inv_freq, rope_cos_sin = RopeEmbeddingUtils.create_sinusoidal_positions_for_attention_plugin(
|
||||
self.max_positions,
|
||||
|
||||
@ -15,6 +15,8 @@ from .modeling_mixtral import MixtralForCausalLM
|
||||
from .modeling_nemotron import NemotronForCausalLM
|
||||
from .modeling_nemotron_h import NemotronHForCausalLM
|
||||
from .modeling_nemotron_nas import NemotronNASForCausalLM
|
||||
from .modeling_phi3 import Phi3ForCausalLM
|
||||
from .modeling_phi4mm import Phi4MMForCausalLM
|
||||
from .modeling_qwen import (Qwen2ForCausalLM, Qwen2ForProcessRewardModel,
|
||||
Qwen2ForRewardModel)
|
||||
from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
|
||||
@ -42,6 +44,8 @@ __all__ = [
|
||||
"NemotronForCausalLM",
|
||||
"NemotronHForCausalLM",
|
||||
"NemotronNASForCausalLM",
|
||||
"Phi3ForCausalLM",
|
||||
"Phi4MMForCausalLM",
|
||||
"Qwen2ForCausalLM",
|
||||
"Qwen2ForProcessRewardModel",
|
||||
"Qwen2ForRewardModel",
|
||||
|
||||
@ -64,6 +64,7 @@ def fuse_input_embeds(
|
||||
mm_token_mask = input_ids >= vocab_size
|
||||
text_token_mask = input_ids < vocab_size
|
||||
else:
|
||||
mm_token_ids = mm_token_ids.to(input_ids.device)
|
||||
mm_token_mask = torch.isin(input_ids, mm_token_ids)
|
||||
text_token_mask = ~mm_token_mask
|
||||
text_token_indices = torch.where(text_token_mask)[0]
|
||||
|
||||
249
tensorrt_llm/_torch/models/modeling_phi3.py
Normal file
249
tensorrt_llm/_torch/models/modeling_phi3.py
Normal file
@ -0,0 +1,249 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import Phi3Config
|
||||
|
||||
from tensorrt_llm._torch.attention_backend import AttentionMetadata
|
||||
from tensorrt_llm._torch.attention_backend.interface import (
|
||||
PositionalEmbeddingParams, RopeParams)
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_utils import (DecoderModel,
|
||||
DecoderModelForCausalLM,
|
||||
register_auto_model)
|
||||
from tensorrt_llm._torch.modules.attention import Attention
|
||||
from tensorrt_llm._torch.modules.decoder_layer import DecoderLayer
|
||||
from tensorrt_llm._torch.modules.embedding import Embedding
|
||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||
from tensorrt_llm._torch.modules.linear import TensorParallelMode
|
||||
from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
|
||||
class Phi3Attention(Attention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Phi3Config],
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
|
||||
rope_params = RopeParams.from_config(config)
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=config.attention_bias,
|
||||
pos_embd_params=PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.rope_gpt_neox,
|
||||
rope=rope_params,
|
||||
),
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
|
||||
class Phi3DecoderLayer(DecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Phi3Config],
|
||||
layer_idx: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = Phi3Attention(model_config, layer_idx=layer_idx)
|
||||
|
||||
self.mlp = GatedMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.LongTensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor],
|
||||
lora_params=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
position_ids=None,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
lora_params=lora_params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Fully connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states, **kwargs)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Phi3Model(DecoderModel):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[Phi3Config]):
|
||||
super().__init__(model_config)
|
||||
config = self.model_config.pretrained_config
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
||||
self.embed_tokens = Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
dtype=config.torch_dtype,
|
||||
mapping=model_config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gather_output=True,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
Phi3DecoderLayer(
|
||||
model_config,
|
||||
layer_idx,
|
||||
) for layer_idx in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_params=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
residual = None
|
||||
for decoder_layer in self.layers:
|
||||
hidden_states, residual = decoder_layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
residual=residual,
|
||||
attn_metadata=attn_metadata,
|
||||
lora_params=lora_params,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@register_auto_model("Phi3ForCausalLM")
|
||||
class Phi3ForCausalLM(DecoderModelForCausalLM[Phi3Model, Phi3Config]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Phi3Config],
|
||||
):
|
||||
super().__init__(Phi3Model(model_config),
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
|
||||
def load_weights(self, weights: dict):
|
||||
self.model_config.mapping.tp_size
|
||||
hidden_size = self.config.hidden_size
|
||||
num_heads = self.config.num_attention_heads
|
||||
num_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = hidden_size // num_heads
|
||||
|
||||
def filter_weights(prefix: str, weights: dict):
|
||||
result = {}
|
||||
for k, v in weights.items():
|
||||
if k.startswith(prefix):
|
||||
new_k = k[len(prefix) + 1:]
|
||||
result[new_k] = v
|
||||
return result
|
||||
|
||||
for name, module in tqdm(list(self.named_modules()),
|
||||
desc="Loading weights"):
|
||||
if len(module._parameters) > 0:
|
||||
# skip load weights if tie word embeddings is enabled and layer is lm_head
|
||||
if self.config.tie_word_embeddings and name.startswith(
|
||||
'lm_head'):
|
||||
continue
|
||||
|
||||
module_weights = filter_weights(name, weights)
|
||||
if hasattr(module, 'load_weights'):
|
||||
if "self_attn.qkv_proj" in name:
|
||||
# The weights need to be split correctly before sharding to support tp_size >1.
|
||||
qkv_weight = module_weights['weight'][:]
|
||||
q_weight = qkv_weight[:hidden_size, :]
|
||||
k_weight = qkv_weight[hidden_size:hidden_size +
|
||||
num_kv_heads * head_dim, :]
|
||||
v_weight = qkv_weight[hidden_size +
|
||||
num_kv_heads * head_dim:, :]
|
||||
module.load_weights(weights=[
|
||||
{
|
||||
'weight': q_weight
|
||||
},
|
||||
{
|
||||
'weight': k_weight
|
||||
},
|
||||
{
|
||||
'weight': v_weight
|
||||
},
|
||||
])
|
||||
elif "mlp.gate_up_proj" in name:
|
||||
# The weights need to be split correctly before sharding to support tp_size >1.
|
||||
intermediate_size = self.config.intermediate_size
|
||||
gate_up_weight = module_weights['weight'][:]
|
||||
gate_weight = gate_up_weight[:intermediate_size, :]
|
||||
up_weight = gate_up_weight[intermediate_size:, :]
|
||||
module.load_weights(weights=[
|
||||
{
|
||||
'weight': gate_weight
|
||||
},
|
||||
{
|
||||
'weight': up_weight
|
||||
},
|
||||
])
|
||||
else:
|
||||
module.load_weights(weights=[module_weights])
|
||||
else:
|
||||
for n, p in module._parameters.items():
|
||||
if p is not None:
|
||||
p.data.copy_(module_weights[n][:])
|
||||
286
tensorrt_llm/_torch/models/modeling_phi4mm.py
Normal file
286
tensorrt_llm/_torch/models/modeling_phi4mm.py
Normal file
@ -0,0 +1,286 @@
|
||||
# Plan for phi4-mm model support.
|
||||
# (done) step 1: support legacy inference pipeline for phi4-mm model.
|
||||
# (todo) step 2: refactor the inference pipeline to use AGGREGATE mode (https://github.com/NVIDIA/TensorRT-LLM/pull/5522).
|
||||
|
||||
import copy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from PIL import Image
|
||||
|
||||
from ...executor.request import LoRARequest
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...logger import logger
|
||||
from ...lora_manager import LoraConfig
|
||||
from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..model_config import ModelConfig
|
||||
from .modeling_auto import AutoModelForCausalLM
|
||||
from .modeling_multimodal_utils import fuse_input_embeds
|
||||
from .modeling_utils import register_auto_model
|
||||
|
||||
# Special tokens
|
||||
_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>'
|
||||
_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>'
|
||||
|
||||
|
||||
# Create a PreTrainedModel class for transformers=4.53.1 upgrade.
|
||||
# Core idea is to provide `prepare_inputs_for_generation` method from `GenerationMixin`.
|
||||
class NewPreTrainedModel(transformers.modeling_utils.PreTrainedModel,
|
||||
transformers.generation.GenerationMixin):
|
||||
pass
|
||||
|
||||
|
||||
class Phi4MMInputProcessor(InputProcessor):
|
||||
|
||||
def __init__(self,
|
||||
model_path: str,
|
||||
model_config: transformers.PretrainedConfig,
|
||||
tokenizer: transformers.AutoTokenizer,
|
||||
trust_remote_code: bool = True):
|
||||
assert trust_remote_code, "trust_remote_code must be True for Phi4MM"
|
||||
|
||||
self.model_config = model_config
|
||||
self.device = 'cuda'
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.use_fast = True
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=self.use_fast)
|
||||
|
||||
self.processor = transformers.AutoProcessor.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=trust_remote_code,
|
||||
use_fast=self.use_fast)
|
||||
|
||||
# Build pure-pytorch model architecture for multimodal encoder.
|
||||
# Model weights are also loaded here.
|
||||
OldPreTrainedModel = transformers.modeling_utils.PreTrainedModel
|
||||
transformers.modeling_utils.PreTrainedModel = NewPreTrainedModel
|
||||
# TODO: Make separate Phi4VisionEncoder and Phi4AudioEncoder, and move them to LLM-side.
|
||||
ref_phi4mm_model = transformers.AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
trust_remote_code=True,
|
||||
# Flash_attn_2 only supports bf16 or fp16 and set in HF config.
|
||||
torch_dtype='auto',
|
||||
_attn_implementation='flash_attention_2',
|
||||
).eval()
|
||||
transformers.modeling_utils.PreTrainedModel = OldPreTrainedModel
|
||||
self.phi4mm_modal_encoder = ref_phi4mm_model.model.embed_tokens_extend.to(
|
||||
self.device)
|
||||
# Required by Phi4MMImageAudioEmbedding.
|
||||
# See link: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L701
|
||||
self.phi4mm_wte = ref_phi4mm_model.model.embed_tokens.to(self.device)
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
text_prompt, mm_data, mm_processor_kwargs = inputs.get("prompt"), \
|
||||
inputs.get("multi_modal_data", {}), inputs.get("mm_processor_kwargs", {})
|
||||
images = mm_data.get("image", None)
|
||||
audios = mm_data.get("audio", None)
|
||||
|
||||
if images is not None:
|
||||
if isinstance(images[0], torch.Tensor):
|
||||
# Convert normalized tensors (0-1) to PIL images (0-255).
|
||||
images = [
|
||||
Image.fromarray((image.permute(1, 2, 0) * 255).to(
|
||||
torch.uint8).cpu().numpy()) for image in images
|
||||
]
|
||||
|
||||
# Preprocessing for multimodal data.
|
||||
inputs = self.processor(text=[text_prompt],
|
||||
images=images,
|
||||
audios=audios,
|
||||
return_tensors='pt').to(self.device)
|
||||
|
||||
# Set audio_projection_mode according to the modality.
|
||||
# Ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L2103
|
||||
if images is not None:
|
||||
audio_projection_mode = 'vision'
|
||||
elif audios is not None:
|
||||
audio_projection_mode = 'speech'
|
||||
else:
|
||||
audio_projection_mode = 'speech'
|
||||
|
||||
# Processing with Phi4MMImageAudioEmbedding.
|
||||
mm_features = self.phi4mm_modal_encoder(
|
||||
input_ids=inputs['input_ids'],
|
||||
input_embeds=None,
|
||||
input_image_embeds=inputs['input_image_embeds'],
|
||||
input_audio_embeds=inputs['input_audio_embeds'],
|
||||
image_sizes=inputs['image_sizes'],
|
||||
image_attention_mask=inputs['image_attention_mask'],
|
||||
audio_embed_sizes=inputs['audio_embed_sizes'],
|
||||
audio_attention_mask=inputs['audio_attention_mask'],
|
||||
audio_projection_mode=audio_projection_mode,
|
||||
wte=self.phi4mm_wte,
|
||||
)
|
||||
|
||||
# Postprocessing to get multimodal-only embeddings.
|
||||
image_token_mask = inputs['input_ids'] == _IMAGE_SPECIAL_TOKEN_ID
|
||||
audio_token_mask = inputs['input_ids'] == _AUDIO_SPECIAL_TOKEN_ID
|
||||
mm_token_mask = image_token_mask | audio_token_mask
|
||||
mm_features = mm_features[mm_token_mask]
|
||||
|
||||
multimodal_data = {}
|
||||
multimodal_data["multimodal_embedding"] = mm_features
|
||||
|
||||
return inputs['input_ids'][0].to(torch.int32).tolist(), {
|
||||
"multimodal_data": multimodal_data,
|
||||
}
|
||||
|
||||
|
||||
@register_auto_model("Phi4MMForCausalLM")
|
||||
@register_input_processor(Phi4MMInputProcessor, model_type="phi4mm")
|
||||
class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
|
||||
_supports_flash_attn_2 = True
|
||||
MM_TOKEN_IDS = torch.tensor(
|
||||
[_IMAGE_SPECIAL_TOKEN_ID, _AUDIO_SPECIAL_TOKEN_ID])
|
||||
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
|
||||
config = model_config.pretrained_config
|
||||
super().__init__(config)
|
||||
|
||||
self.model_config = model_config
|
||||
if hasattr(self, "llm"):
|
||||
return
|
||||
|
||||
# We use Phi3ForCausalLM as the language model.
|
||||
llm_model_config = copy.deepcopy(model_config)
|
||||
llm_model_config.pretrained_config.architectures = ["Phi3ForCausalLM"]
|
||||
# Only build the language model architecture without loading weights.
|
||||
self.llm = AutoModelForCausalLM.from_config(llm_model_config)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.model_dtype = getattr(config, "torch_dtype", torch.float16)
|
||||
logger.info(f"{self.dtype=} {self.model_dtype=}")
|
||||
self.post_config()
|
||||
self.is_loaded = True
|
||||
|
||||
def load_weights(self, weights):
|
||||
# Filter out non-language model weights.
|
||||
weights = {
|
||||
k: v
|
||||
for k, v in weights.items()
|
||||
if not k.startswith('model.embed_tokens_extend')
|
||||
}
|
||||
# Filter out LoRA weights.
|
||||
# LoRA weights will be loaded by LoraManager.
|
||||
weights = {k: v for k, v in weights.items() if '.lora_' not in k}
|
||||
# Rename base layer weights.
|
||||
updated_weights = {}
|
||||
for k in weights.keys():
|
||||
if 'base_layer.weight' in k:
|
||||
new_k = k.replace('base_layer.weight', 'weight')
|
||||
updated_weights[new_k] = weights[k]
|
||||
else:
|
||||
updated_weights[k] = weights[k]
|
||||
weights = updated_weights
|
||||
|
||||
self.llm.load_weights(weights)
|
||||
|
||||
def infer_max_seq_len(self) -> int:
|
||||
return self.llm.infer_max_seq_len()
|
||||
|
||||
def post_config(self):
|
||||
# use llm.config as config for pytorch model engine
|
||||
self.config = self.llm.config
|
||||
self.model_config.pretrained_config = self.llm.config
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
input_embeds: Optional[torch.Tensor] = None,
|
||||
return_context_logits: bool = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
VLM forward logic with inflight batching support.
|
||||
"""
|
||||
num_context_requests, num_generation_requests = attn_metadata.num_contexts, attn_metadata.num_generations
|
||||
logger.debug(
|
||||
f"num_context_requests: {num_context_requests}, num_generation_requests: {num_generation_requests}"
|
||||
)
|
||||
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
mm_embedding = [
|
||||
multimodal_param.multimodal_data["multimodal_embedding"]
|
||||
for multimodal_param in multimodal_params
|
||||
]
|
||||
input_ids, input_embeds = fuse_input_embeds(
|
||||
self.llm.model.embed_tokens,
|
||||
input_ids,
|
||||
mm_embedding,
|
||||
mm_token_ids=self.MM_TOKEN_IDS,
|
||||
)
|
||||
|
||||
output_prob = self.llm.forward(
|
||||
attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=input_embeds,
|
||||
return_context_logits=return_context_logits,
|
||||
lora_params=kwargs.get("lora_params", None),
|
||||
)
|
||||
|
||||
logger.debug(f'output shape: {output_prob.shape}')
|
||||
return output_prob
|
||||
|
||||
@staticmethod
|
||||
def lora_config(model_dir: str):
|
||||
_lora_config = LoraConfig(
|
||||
lora_dir=[
|
||||
f"{model_dir}/vision-lora",
|
||||
f"{model_dir}/speech-lora",
|
||||
],
|
||||
lora_target_modules=[
|
||||
"attn_qkv",
|
||||
"attn_dense",
|
||||
"mlp_h_to_4h",
|
||||
"mlp_4h_to_h",
|
||||
],
|
||||
trtllm_modules_to_hf_modules={
|
||||
"attn_qkv": "qkv_proj",
|
||||
"attn_dense": "o_proj",
|
||||
"mlp_h_to_4h": "gate_up_proj",
|
||||
"mlp_4h_to_h": "down_proj",
|
||||
},
|
||||
max_lora_rank=320, # Max rank for Phi4MM.
|
||||
)
|
||||
return _lora_config
|
||||
|
||||
@staticmethod
|
||||
def lora_request(num_requests: int, modality: str, base_model_dir: str):
|
||||
# Prepare LoRA requests for different modalities.
|
||||
# Ref: https://huggingface.co/microsoft/Phi-4-multimodal-instruct/blob/main/modeling_phi4mm.py#L2103
|
||||
lora_request = None
|
||||
if modality == "image" or modality == "image_audio":
|
||||
lora_request = [
|
||||
LoRARequest(
|
||||
lora_name=f"vision-lora-{i}",
|
||||
lora_int_id=i,
|
||||
lora_path=f"{base_model_dir}/vision-lora",
|
||||
) for i in range(num_requests)
|
||||
]
|
||||
elif modality == "audio":
|
||||
lora_request = [
|
||||
LoRARequest(
|
||||
lora_name=f"speech-lora-{i}",
|
||||
lora_int_id=i,
|
||||
lora_path=f"{base_model_dir}/speech-lora",
|
||||
) for i in range(num_requests)
|
||||
]
|
||||
|
||||
return lora_request
|
||||
@ -4834,6 +4834,48 @@ class RopeEmbeddingUtils:
|
||||
True), _compute_sinusoidal_positions(
|
||||
scaling_long_factors, False, True), short_mscale
|
||||
|
||||
@staticmethod
|
||||
def create_sinusoidal_positions_long_rope_for_attention_plugin(
|
||||
num_pos: int,
|
||||
dim: int,
|
||||
theta: float,
|
||||
original_max_pos: int,
|
||||
short_factor: List[float],
|
||||
long_factor: List[float],
|
||||
dtype=np.float32):
|
||||
short_factor = np.array(short_factor, dtype=np.float32)
|
||||
long_factor = np.array(long_factor, dtype=np.float32)
|
||||
|
||||
inv_freq = 1.0 / (theta**(np.arange(0, dim, 2, dtype=np.float32) / dim))
|
||||
|
||||
# Short part
|
||||
inv_freq_short = inv_freq / short_factor
|
||||
t_short = np.arange(np.min([num_pos, original_max_pos]),
|
||||
dtype=np.float32)
|
||||
freqs_short = np.einsum("i,j->ij", t_short, inv_freq_short)
|
||||
|
||||
# Long part
|
||||
inv_freq_long = inv_freq / long_factor
|
||||
t_long = np.arange(np.max([0, num_pos - original_max_pos]),
|
||||
dtype=np.float32) + original_max_pos
|
||||
freqs_long = np.einsum("i,j->ij", t_long, inv_freq_long)
|
||||
|
||||
freqs = np.concatenate([freqs_short, freqs_long], axis=0)
|
||||
sinusoid_inp = freqs.astype(np.float32)[..., np.newaxis]
|
||||
|
||||
# Apply scaling
|
||||
scale = num_pos / original_max_pos
|
||||
scaling_factor = np.sqrt(1.0 + np.log(scale) / np.log(original_max_pos))
|
||||
|
||||
# fuse cos/sin into float2 (cos, sin).
|
||||
concat = np.concatenate(
|
||||
(np.cos(sinusoid_inp) * scaling_factor,
|
||||
np.sin(sinusoid_inp) * scaling_factor),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
return None, concat.reshape(1, -1).astype(dtype)
|
||||
|
||||
@staticmethod
|
||||
def create_fake_weight(dim: int, dtype=np.half):
|
||||
return np.random.rand(dim).astype(dtype)
|
||||
|
||||
@ -5,8 +5,9 @@ from .registry import (ExtraProcessedInputs, InputProcessor,
|
||||
register_input_processor)
|
||||
from .utils import (ALL_SUPPORTED_MULTIMODAL_MODELS, ConversationMessage,
|
||||
MultimodalData, MultimodalDataTracker,
|
||||
add_multimodal_placeholders, async_load_image,
|
||||
async_load_video, default_multimodal_input_loader,
|
||||
add_multimodal_placeholders, async_load_audio,
|
||||
async_load_image, async_load_video,
|
||||
default_multimodal_input_loader,
|
||||
encode_base64_content_from_url, load_image, load_video)
|
||||
|
||||
__all__ = [
|
||||
@ -24,6 +25,7 @@ __all__ = [
|
||||
"MultimodalDataTracker",
|
||||
"MultimodalData",
|
||||
"MultimodalInput",
|
||||
"async_load_audio",
|
||||
"async_load_image",
|
||||
"async_load_video",
|
||||
"add_multimodal_placeholders",
|
||||
|
||||
@ -5,12 +5,13 @@ import tempfile
|
||||
from collections import defaultdict
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Coroutine, Dict, List, Optional, TypedDict, Union
|
||||
from typing import Any, Coroutine, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
import requests
|
||||
import soundfile
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import ToTensor
|
||||
@ -159,6 +160,35 @@ async def async_load_video(
|
||||
return load_video(video_path, num_frames, format, device)
|
||||
|
||||
|
||||
def load_audio(
|
||||
audio: str,
|
||||
format: str = "pt",
|
||||
device: str = "cuda",
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
parsed_url = urlparse(audio)
|
||||
if parsed_url.scheme in ["http", "https"]:
|
||||
audio = requests.get(audio, stream=True, timeout=10)
|
||||
audio = BytesIO(audio.content)
|
||||
|
||||
audio = soundfile.read(audio)
|
||||
return audio
|
||||
|
||||
|
||||
async def async_load_audio(
|
||||
audio: str,
|
||||
format: str = "pt",
|
||||
device: str = "cuda",
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
parsed_url = urlparse(audio)
|
||||
if parsed_url.scheme in ["http", "https"]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(audio) as response:
|
||||
audio = BytesIO(await response.content.read())
|
||||
|
||||
audio = soundfile.read(audio)
|
||||
return audio
|
||||
|
||||
|
||||
# Copied from https://github.com/vllm-project/vllm/blob/main/examples/online_serving/openai_chat_completion_client_for_multimodal.py#L38
|
||||
def encode_base64_content_from_url(content_url: str) -> str:
|
||||
"""Encode a content retrieved from a remote url to base64 format."""
|
||||
@ -186,19 +216,24 @@ SUPPORTED_LLAVA_IMAGE_MODEL_GROUP = ["llava_llama", "llava_next"]
|
||||
SUPPORTED_LLAVA_VIDEO_MODEL_GROUP = ["llava_llama"]
|
||||
SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP = ["mistral3"]
|
||||
SUPPORTED_HYPERCLOVAX_MODEL_GROUP = ["hyperclovax_vlm"]
|
||||
SUPPORTED_PHI_MODEL_GROUP = ["phi4mm"]
|
||||
|
||||
ALL_SUPPORTED_IMAGE_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
|
||||
+ SUPPORTED_LLAMA_MODEL_GROUP \
|
||||
+ SUPPORTED_LLAVA_IMAGE_MODEL_GROUP \
|
||||
+ SUPPORTED_HYPERCLOVAX_MODEL_GROUP \
|
||||
+ SUPPORTED_GEMMA_MODEL_GROUP \
|
||||
+ SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP
|
||||
+ SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP \
|
||||
+ SUPPORTED_PHI_MODEL_GROUP
|
||||
|
||||
ALL_SUPPORTED_VIDEO_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
|
||||
+ SUPPORTED_LLAVA_VIDEO_MODEL_GROUP
|
||||
|
||||
ALL_SUPPORTED_AUDIO_MODELS = SUPPORTED_PHI_MODEL_GROUP
|
||||
|
||||
ALL_SUPPORTED_MULTIMODAL_MODELS = list(set(ALL_SUPPORTED_IMAGE_MODELS) \
|
||||
| set(ALL_SUPPORTED_VIDEO_MODELS))
|
||||
| set(ALL_SUPPORTED_VIDEO_MODELS) \
|
||||
| set(ALL_SUPPORTED_AUDIO_MODELS))
|
||||
|
||||
HF_CHAT_TEMPLATE_EXCEPTIONS = ["llava_llama"]
|
||||
PLACEHOLDER_EXCEPTIONS = ["llava_next"]
|
||||
@ -223,6 +258,7 @@ PLACEHOLDER_PLACEMENT_MAP = {
|
||||
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
|
||||
# src/mistral_common/tokens/tokenizers/base.py#L326
|
||||
"mistral3": MultimodalPlaceholderPlacement.AFTER_TEXT,
|
||||
"phi4mm": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
}
|
||||
assert len(PLACEHOLDER_PLACEMENT_MAP) == len(ALL_SUPPORTED_MULTIMODAL_MODELS)
|
||||
|
||||
@ -235,7 +271,7 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
|
||||
Args:
|
||||
model_type: The type of the multimodal model.
|
||||
modality: The modality of the data.
|
||||
current_count: The number of multimodal data already added. Currently not used.
|
||||
current_count: The number of multimodal data already added.
|
||||
|
||||
"""
|
||||
|
||||
@ -257,6 +293,8 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
|
||||
# Ref: https://github.com/mistralai/mistral-common/blob/26a6bb3a07ee0b78a3808f2797f23e1d28514b93/
|
||||
# src/mistral_common/tokens/tokenizers/base.py#L60
|
||||
return "[IMG]"
|
||||
elif model_type in SUPPORTED_PHI_MODEL_GROUP:
|
||||
return f"<|image_{current_count}|>"
|
||||
raise TypeError(
|
||||
f"For image modality, only {ALL_SUPPORTED_IMAGE_MODELS} are supported but got {model_type}"
|
||||
)
|
||||
@ -268,6 +306,9 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
|
||||
raise TypeError(
|
||||
f"For video modality, only {ALL_SUPPORTED_VIDEO_MODELS} are supported but got {model_type}"
|
||||
)
|
||||
elif modality == "audio":
|
||||
if model_type in SUPPORTED_PHI_MODEL_GROUP:
|
||||
return f"<|audio_{current_count}|>"
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
|
||||
@ -343,7 +384,10 @@ def add_multimodal_placeholders(model_type: str, text_prompt: str,
|
||||
case MultimodalPlaceholderPlacement.AFTER_TEXT:
|
||||
parts.append(text_prompt)
|
||||
parts.extend(placeholders)
|
||||
return "\n".join(parts)
|
||||
if model_type == "phi4mm":
|
||||
return "".join(parts)
|
||||
else:
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def resolve_hf_chat_template(
|
||||
@ -458,6 +502,34 @@ def default_multimodal_input_loader(
|
||||
format=image_data_format,
|
||||
device=device)) for i in media
|
||||
]
|
||||
elif modality == "audio":
|
||||
mm_data = [
|
||||
MultimodalData(modality=modality,
|
||||
data=load_audio(i, device=device)) for i in media
|
||||
]
|
||||
elif modality == "image_audio":
|
||||
# Use different load_xxx functions to match the modality.
|
||||
mm_data = []
|
||||
for m in media:
|
||||
data = None
|
||||
_modal = None
|
||||
if _modal is None:
|
||||
try:
|
||||
data = load_image(m,
|
||||
format=image_data_format,
|
||||
device=device)
|
||||
_modal = "image"
|
||||
except Exception:
|
||||
pass
|
||||
if _modal is None:
|
||||
try:
|
||||
data = load_audio(m, device=device)
|
||||
_modal = "audio"
|
||||
except Exception:
|
||||
pass
|
||||
if _modal is None:
|
||||
raise ValueError(f"Unknown matching modality: {modality}")
|
||||
mm_data.append(MultimodalData(modality=_modal, data=data))
|
||||
else:
|
||||
raise ValueError(f"Unknown modality: {modality}")
|
||||
return ConversationMessage(role="user", content=prompt, media=mm_data)
|
||||
|
||||
@ -2,7 +2,8 @@ from functools import partial
|
||||
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
|
||||
Optional, Tuple, TypeAlias, TypedDict, Union, cast)
|
||||
|
||||
from openai.types.chat import ChatCompletionContentPartImageParam
|
||||
from openai.types.chat import (ChatCompletionContentPartImageParam,
|
||||
ChatCompletionContentPartInputAudioParam)
|
||||
from openai.types.chat import \
|
||||
ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam
|
||||
from openai.types.chat import (ChatCompletionContentPartTextParam,
|
||||
@ -12,8 +13,8 @@ from typing_extensions import Required
|
||||
|
||||
from tensorrt_llm.inputs import (ConversationMessage, MultimodalData,
|
||||
MultimodalDataTracker,
|
||||
add_multimodal_placeholders, async_load_image,
|
||||
async_load_video)
|
||||
add_multimodal_placeholders, async_load_audio,
|
||||
async_load_image, async_load_video)
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
@ -33,12 +34,16 @@ ChatCompletionContentPartParam: TypeAlias = Union[
|
||||
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartVideoParam,
|
||||
str]
|
||||
|
||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = ["text", "image_url", "video_url"]
|
||||
# TODO: Add "input_audio" to support byte_encoded audio input.
|
||||
VALID_MESSAGE_CONTENT_MM_PART_TYPES = [
|
||||
"text", "image_url", "video_url", "audio_url"
|
||||
]
|
||||
|
||||
# Parser Functions
|
||||
_TextParser = partial(cast, ChatCompletionContentPartTextParam)
|
||||
_ImageParser = partial(cast, ChatCompletionContentPartImageParam)
|
||||
_VideoParser = partial(cast, ChatCompletionContentPartVideoParam)
|
||||
_AudioParser = partial(cast, ChatCompletionContentPartInputAudioParam)
|
||||
|
||||
MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
|
||||
str, dict[str, str]]]] = {
|
||||
@ -48,6 +53,8 @@ MM_PARSER_MAP: dict[str, Callable[[ChatCompletionContentPartParam], Union[
|
||||
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
|
||||
"video_url":
|
||||
lambda part: _VideoParser(part).get("video_url", {}).get("url", None),
|
||||
"audio_url":
|
||||
lambda part: _AudioParser(part).get("audio_url", {}).get("url", None),
|
||||
}
|
||||
|
||||
|
||||
@ -74,7 +81,7 @@ def parse_chat_message_content_part(
|
||||
|
||||
part_type, content = _parse_chat_message_content_mm_part(part)
|
||||
|
||||
# if part_type is text/image_url/video_url but content is None, log a warning and skip
|
||||
# if part_type is text/image_url/video_url/audio_url but content is None, log a warning and skip
|
||||
if part_type in VALID_MESSAGE_CONTENT_MM_PART_TYPES and content is None:
|
||||
logger.warning(
|
||||
"Skipping multimodal part '%s' (type: '%s') with empty / unparsable content.",
|
||||
@ -108,6 +115,18 @@ def parse_chat_message_content_part(
|
||||
|
||||
return MultimodalData(modality="video", data=load_video_async())
|
||||
|
||||
if part_type == "audio_url":
|
||||
str_content = cast(str, content)
|
||||
|
||||
async def load_audio_async():
|
||||
try:
|
||||
return await async_load_audio(str_content)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load audio: {str(e)}")
|
||||
return None
|
||||
|
||||
return MultimodalData(modality="audio", data=load_audio_async())
|
||||
|
||||
raise NotImplementedError(f"Unknown part type: {part_type}")
|
||||
|
||||
|
||||
|
||||
@ -115,3 +115,5 @@ mistralai/Ministral-8B-Instruct-2410:
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 78.35
|
||||
microsoft/Phi-4-multimodal-instruct:
|
||||
- accuracy: 81.19
|
||||
|
||||
@ -199,3 +199,5 @@ mistralai/Ministral-8B-Instruct-2410:
|
||||
- quant_algo: FP8
|
||||
kv_cache_quant_algo: FP8
|
||||
accuracy: 65.96
|
||||
microsoft/Phi-4-multimodal-instruct:
|
||||
- accuracy: 69.69
|
||||
|
||||
@ -1847,3 +1847,16 @@ class TestBielik11BInstruct(LlmapiAccuracyTestHarness):
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
|
||||
class TestPhi4MM(LlmapiAccuracyTestHarness):
|
||||
# phi4-mm can also support text input.
|
||||
MODEL_NAME = "microsoft/Phi-4-multimodal-instruct"
|
||||
MODEL_PATH = f"{llm_models_root()}/multimodals/Phi-4-multimodal-instruct"
|
||||
|
||||
def test_auto_dtype(self):
|
||||
with LLM(self.MODEL_PATH) as llm:
|
||||
task = MMLU(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
task = GSM8K(self.MODEL_NAME)
|
||||
task.evaluate(llm)
|
||||
|
||||
@ -1544,7 +1544,25 @@ def test_build_time_benchmark_sanity(llm_root, llm_venv):
|
||||
])
|
||||
|
||||
|
||||
### Pivot-To-Python examples
|
||||
### PyTorch examples
|
||||
|
||||
|
||||
def parse_output(text):
|
||||
results = []
|
||||
text_lists = re.split(r"\[\d+\] Prompt:", text)
|
||||
for item in text_lists:
|
||||
item = item.replace(os.linesep, "")
|
||||
while True:
|
||||
match = re.search(r"(Generated text: \'(.*?)\')", item,
|
||||
re.MULTILINE)
|
||||
if match is None:
|
||||
break
|
||||
_, end = match.span(1)
|
||||
results.append(match.group(2))
|
||||
item = item[end:]
|
||||
return results
|
||||
|
||||
|
||||
def test_ptp_quickstart(llm_root, llm_venv):
|
||||
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
|
||||
|
||||
@ -2101,21 +2119,6 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
|
||||
|
||||
output = llm_venv.run_cmd(cmd, caller=check_output)
|
||||
|
||||
def parse_output(text):
|
||||
results = []
|
||||
text_lists = re.split(r"\[\d+\] Prompt:", text)
|
||||
for item in text_lists:
|
||||
item = item.replace(os.linesep, "")
|
||||
while True:
|
||||
match = re.search(r"(Generated text: \'(.*?)\')", item,
|
||||
re.MULTILINE)
|
||||
if match is None:
|
||||
break
|
||||
_, end = match.span(1)
|
||||
results.append(match.group(2))
|
||||
item = item[end:]
|
||||
return results
|
||||
|
||||
match_ratio = 4.0 / 5
|
||||
if model_name == "qwen2-vl-7b-instruct" and modality == "image":
|
||||
match_ratio = 4.0 / 6
|
||||
@ -2182,6 +2185,92 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
|
||||
_check_mem_usage(running_log, [peak, 0, 0, 0])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modality", ["image", "audio", "image_audio"])
|
||||
def test_ptp_quickstart_multimodal_phi4mm(llm_root, llm_venv, modality):
|
||||
model_name = "Phi-4-multimodal-instruct"
|
||||
model_path = "multimodals/Phi-4-multimodal-instruct"
|
||||
|
||||
example_root = Path(os.path.join(llm_root, "examples", "llm-api"))
|
||||
test_data_root = Path(
|
||||
os.path.join(llm_models_root(), "multimodals", "test_data"))
|
||||
audio_data_root = Path(
|
||||
os.path.join(llm_models_root(), "multimodals",
|
||||
"Phi-4-multimodal-instruct", "examples"))
|
||||
print(f"Accuracy test {model_name} {modality} mode with example inputs.")
|
||||
accuracy_inputs = {
|
||||
"image": {
|
||||
"prompt": [
|
||||
"Describe the object and the weather condition in the image.",
|
||||
"Describe the traffic condition on the road in the image.",
|
||||
],
|
||||
"media": [
|
||||
str(test_data_root / "inpaint.png"),
|
||||
str(test_data_root / "61.jpg"),
|
||||
],
|
||||
},
|
||||
"audio": {
|
||||
"prompt": [
|
||||
"Transcribe the audio clip into text, please don't add other text.",
|
||||
"Transcribe the audio clip into text, please don't add other text.",
|
||||
],
|
||||
"media": [
|
||||
str(audio_data_root /
|
||||
"what_is_the_traffic_sign_in_the_image.wav"),
|
||||
str(audio_data_root / "what_is_shown_in_this_image.wav"),
|
||||
],
|
||||
},
|
||||
"image_audio": {
|
||||
"prompt": [
|
||||
"",
|
||||
],
|
||||
"media": [
|
||||
str(test_data_root / "inpaint.png"),
|
||||
str(audio_data_root / "what_is_shown_in_this_image.wav"),
|
||||
],
|
||||
}
|
||||
}
|
||||
expected_keywords = {
|
||||
"image": [
|
||||
["clear", "sunny", "sky", "image", "object"],
|
||||
["road", "car", "lane", "strip", "bus"],
|
||||
],
|
||||
"audio": [
|
||||
["what", "is", "the", "traffic", "sign", "in", "image"],
|
||||
["what", "is", "shown", "in", "this", "image"],
|
||||
],
|
||||
"image_audio": [
|
||||
["Half", "Dome", "Park", "natural", "image"],
|
||||
],
|
||||
}
|
||||
|
||||
cmd = [
|
||||
str(example_root / "quickstart_multimodal.py"),
|
||||
"--model_dir",
|
||||
f"{llm_models_root()}/{model_path}",
|
||||
"--modality",
|
||||
modality,
|
||||
"--prompt",
|
||||
*accuracy_inputs[modality]["prompt"],
|
||||
"--media",
|
||||
*accuracy_inputs[modality]["media"],
|
||||
"--load_lora",
|
||||
"--auto_model_name",
|
||||
"Phi4MMForCausalLM",
|
||||
]
|
||||
output = llm_venv.run_cmd(cmd, caller=check_output)
|
||||
|
||||
match_ratio = 0.6
|
||||
for prompt_output, prompt_keywords in zip(parse_output(output),
|
||||
expected_keywords[modality]):
|
||||
matches = [
|
||||
keyword in prompt_output.lower() for keyword in prompt_keywords
|
||||
]
|
||||
obs_match_ratio = 1. * sum(matches) / len(matches)
|
||||
assert obs_match_ratio >= match_ratio, f"Incorrect output!\nGenerated \"{prompt_output}\"\nExpected keywords \"{prompt_keywords}\"\n Matched keywords: {matches}\n Observed match ratio {obs_match_ratio} below threshold {match_ratio}"
|
||||
|
||||
print("All answers are correct!")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
("BertForSequenceClassification", "bert/bert-base-uncased-yelp-polarity"),
|
||||
])
|
||||
|
||||
@ -486,6 +486,7 @@ accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestBielik11BInstruct::test_fp8
|
||||
accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_auto_dtype
|
||||
accuracy/test_llm_api_pytorch.py::TestMinistral8BInstruct::test_fp8
|
||||
accuracy/test_llm_api_pytorch.py::TestPhi4MM::test_auto_dtype
|
||||
|
||||
test_e2e.py::test_llama_e2e[use_cpp_session-remove_input_padding-]
|
||||
test_e2e.py::test_llama_e2e[use_py_session-remove_input_padding-]
|
||||
@ -528,6 +529,9 @@ test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistr
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[mistral-small-3.1-24b-instruct-Mistral-Small-3.1-24B-Instruct-2503-image-False]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-False]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal[gemma-3-27b-it-gemma/gemma-3-27b-it-image-True]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]
|
||||
test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
|
||||
test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
test_e2e.py::test_ptp_star_attention_example[Llama3.1-8B-BF16-llama-3.1-model/Meta-Llama-3.1-8B]
|
||||
|
||||
@ -15,6 +15,7 @@ l0_a30:
|
||||
tests:
|
||||
# ------------- PyTorch tests ---------------
|
||||
- unittest/_torch/modeling -k "modeling_nemotron_nas"
|
||||
- unittest/_torch/modeling -k "modeling_phi3"
|
||||
- unittest/_torch/modeling -k "modeling_qwen"
|
||||
- unittest/_torch/modeling -k "modeling_qwen_moe"
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
|
||||
@ -28,6 +28,9 @@ l0_l40s:
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-image-True]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-False]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal[qwen2.5-vl-7b-instruct-Qwen2.5-VL-7B-Instruct-video-True]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[audio]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image]
|
||||
- test_e2e.py::test_ptp_quickstart_multimodal_phi4mm[image_audio]
|
||||
- test_e2e.py::test_ptp_quickstart_bert[VANILLA-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
- test_e2e.py::test_ptp_quickstart_bert[TRTLLM-BertForSequenceClassification-bert/bert-base-uncased-yelp-polarity]
|
||||
- condition:
|
||||
|
||||
352
tests/unittest/_torch/modeling/test_modeling_phi3.py
Normal file
352
tests/unittest/_torch/modeling/test_modeling_phi3.py
Normal file
@ -0,0 +1,352 @@
|
||||
import unittest
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from transformers import Phi3Config
|
||||
from transformers import Phi3ForCausalLM as HFPhi3ForCausalLM
|
||||
from utils.util import default_dtype
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_phi3 import Phi3ForCausalLM
|
||||
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
PHI3_MINI_4K_CONFIG = {
|
||||
"_name_or_path": "Phi-3-mini-4k-instruct",
|
||||
"architectures": ["Phi3ForCausalLM"],
|
||||
"attention_dropout": 0.0,
|
||||
"auto_map": {
|
||||
"AutoConfig": "configuration_phi3.Phi3Config",
|
||||
"AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM"
|
||||
},
|
||||
"bos_token_id": 1,
|
||||
"embd_pdrop": 0.0,
|
||||
"eos_token_id": 32000,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 3072,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 8192,
|
||||
"max_position_embeddings": 4096,
|
||||
"model_type": "phi3",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 32,
|
||||
"num_key_value_heads": 32,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"pad_token_id": 32000,
|
||||
"resid_pdrop": 0.0,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": None,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 2047,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.40.2",
|
||||
"use_cache": True,
|
||||
"attention_bias": False,
|
||||
"vocab_size": 32064
|
||||
}
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class Scenario:
|
||||
backend: str
|
||||
use_cuda_graph: bool = False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"backend:{self.backend.lower()}-use_cuda_graph:{self.use_cuda_graph}"
|
||||
|
||||
|
||||
def reduce_phi3_config(mem_for_full_model: int,
|
||||
config_dict: dict[str, Any],
|
||||
default_num_layers: int = 32):
|
||||
_, total_mem = torch.cuda.mem_get_info()
|
||||
# scale model down if gpu memory is low
|
||||
if total_mem < mem_for_full_model:
|
||||
model_fraction = total_mem / mem_for_full_model
|
||||
num_layers = int(config_dict["num_hidden_layers"] * model_fraction)
|
||||
num_layers = min(num_layers, default_num_layers)
|
||||
config_dict["num_hidden_layers"] = num_layers
|
||||
|
||||
|
||||
class TestPhi3(unittest.TestCase):
|
||||
|
||||
def test_phi3_sanity(self):
|
||||
config_dict = deepcopy(PHI3_MINI_4K_CONFIG)
|
||||
# 8B * sizeof(float16) plus some extra for activations
|
||||
mem_for_full_model = (2 + 1) * 8 * 2**(30)
|
||||
reduce_phi3_config(mem_for_full_model, config_dict)
|
||||
if config_dict["num_hidden_layers"] <= 0:
|
||||
self.skipTest("Insufficient memory for a single Phi3 layer")
|
||||
phi3_config = Phi3Config.from_dict(config_dict)
|
||||
dtype = phi3_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
model_config = ModelConfig(pretrained_config=phi3_config)
|
||||
phi3 = Phi3ForCausalLM(model_config).to(device)
|
||||
|
||||
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
|
||||
context_sequence_lengths = [3, 2, 1]
|
||||
sequence_lengths = context_sequence_lengths + [1, 1]
|
||||
past_seen_tokens = [0, 0, 0, 62, 75]
|
||||
request_ids = list(range(len(sequence_lengths)))
|
||||
token_nums = (torch.tensor(past_seen_tokens) +
|
||||
torch.tensor(sequence_lengths)).tolist()
|
||||
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
|
||||
|
||||
num_blocks = 100
|
||||
tokens_per_block = 128
|
||||
head_dim = phi3.config.hidden_size // phi3.config.num_attention_heads
|
||||
num_layers = phi3.config.num_hidden_layers
|
||||
num_kv_heads = phi3.config.num_key_value_heads
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = len(context_sequence_lengths) + 2
|
||||
|
||||
if dtype == torch.half:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
||||
tokens_per_block)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
|
||||
num_contexts=len(context_sequence_lengths),
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
max_num_requests=len(context_sequence_lengths) + 2,
|
||||
max_num_tokens=8192,
|
||||
)
|
||||
|
||||
position_ids = []
|
||||
for i, tokens in enumerate(past_seen_tokens):
|
||||
seq_len = context_sequence_lengths[i] if i < len(
|
||||
context_sequence_lengths) else 1
|
||||
position_id = torch.arange(tokens,
|
||||
tokens + seq_len,
|
||||
device=input_ids.device)
|
||||
position_ids.append(position_id)
|
||||
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0)
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = phi3.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
self.assertEqual(len(past_seen_tokens), logits.shape[0])
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = phi3.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata,
|
||||
return_context_logits=True)
|
||||
self.assertEqual(input_ids.shape, logits.shape[:-1])
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@torch.no_grad()
|
||||
def test_phi3_allclose_to_hf(self) -> None:
|
||||
"""
|
||||
Compare output to HF
|
||||
"""
|
||||
scenario = Scenario(backend="TRTLLM")
|
||||
backend = scenario.backend
|
||||
metadata_cls = get_attention_backend(backend).Metadata
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
config_dict = deepcopy(PHI3_MINI_4K_CONFIG)
|
||||
# 8B * sizeof(float16) plus some extra for activations
|
||||
# times 2, since we'll need 2 of these
|
||||
mem_for_full_model = (2 + 1) * 8 * 2**(30) * 4
|
||||
reduce_phi3_config(mem_for_full_model, config_dict)
|
||||
if config_dict["num_hidden_layers"] <= 0:
|
||||
self.skipTest("Insufficient memory for a single Phi3 layer")
|
||||
phi3_config = Phi3Config.from_dict(config_dict)
|
||||
dtype = phi3_config.torch_dtype
|
||||
device = torch.device('cuda')
|
||||
|
||||
with torch.device(device), default_dtype(dtype):
|
||||
hf_phi3 = HFPhi3ForCausalLM(phi3_config).eval()
|
||||
|
||||
model_config = ModelConfig(pretrained_config=phi3_config,
|
||||
attn_backend=backend)
|
||||
|
||||
phi3 = Phi3ForCausalLM(model_config).to(dtype).to(device)
|
||||
phi3.load_weights(hf_phi3.state_dict())
|
||||
|
||||
num_blocks = 1
|
||||
tokens_per_block = 128
|
||||
head_dim = phi3.config.hidden_size // phi3.config.num_attention_heads
|
||||
num_layers = phi3.config.num_hidden_layers
|
||||
num_kv_heads = phi3.config.num_key_value_heads
|
||||
max_seq_len = num_blocks * tokens_per_block
|
||||
batch_size = 1
|
||||
|
||||
if dtype == torch.half:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
||||
elif dtype == torch.bfloat16:
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
||||
else:
|
||||
raise ValueError("Invalid dtype")
|
||||
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
||||
tokens_per_block)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
)
|
||||
|
||||
# context
|
||||
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
|
||||
num_cached_tokens_per_seq = [0]
|
||||
request_ids = [1]
|
||||
token_nums = [input_ids.size(-1)]
|
||||
prompt_lens = [input_ids.size(-1)]
|
||||
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=1,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
# Note: no CUDA graphs for prefill, the graph runner is built for
|
||||
# decoding only.
|
||||
position_ids = [torch.arange(0, input_ids.size(-1))]
|
||||
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
logits = phi3.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_phi3.forward(input_ids=input_ids.unsqueeze(0),
|
||||
position_ids=position_ids,
|
||||
use_cache=True)
|
||||
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
|
||||
# gen
|
||||
gen_input_ids = torch.tensor([600], dtype=torch.int, device=device)
|
||||
|
||||
num_cached_tokens_per_seq = [input_ids.size(-1)]
|
||||
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
|
||||
num_contexts=0,
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=num_cached_tokens_per_seq,
|
||||
),
|
||||
max_num_requests=1,
|
||||
max_num_tokens=8192,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
gen_position_ids = [
|
||||
torch.arange(input_ids.size(-1),
|
||||
input_ids.size(-1) + gen_input_ids.size(-1))
|
||||
]
|
||||
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
|
||||
|
||||
def run_forward(input_ids, position_ids, attn_metadata):
|
||||
attn_metadata.prepare()
|
||||
if not scenario.use_cuda_graph:
|
||||
return phi3.forward(input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
else:
|
||||
graph_runner = DecodingCUDAGraphRunner(
|
||||
attn_metadata.max_num_requests, "cuda", attn_metadata)
|
||||
graph_runner.capture(lambda inputs: phi3.forward(**inputs))
|
||||
|
||||
for _ in range(2):
|
||||
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
|
||||
# in prepare().
|
||||
attn_metadata.prepare()
|
||||
logits = graph_runner.run({
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"attn_metadata": attn_metadata,
|
||||
})
|
||||
return logits
|
||||
|
||||
if scenario.use_cuda_graph:
|
||||
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
|
||||
|
||||
with torch.inference_mode():
|
||||
logits = run_forward(input_ids=gen_input_ids,
|
||||
position_ids=gen_position_ids,
|
||||
attn_metadata=attn_metadata)
|
||||
ref = hf_phi3.forward(input_ids=gen_input_ids.unsqueeze(0),
|
||||
position_ids=gen_position_ids,
|
||||
past_key_values=ref.past_key_values,
|
||||
use_cache=True)
|
||||
|
||||
torch.testing.assert_close(logits,
|
||||
ref.logits[:, -1].float(),
|
||||
atol=0.4,
|
||||
rtol=0.4)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
Loading…
Reference in New Issue
Block a user