From 362a8272f80f07fc2056c5f6130a64e51e7f0615 Mon Sep 17 00:00:00 2001 From: milesial Date: Fri, 25 Apr 2025 16:47:14 -0700 Subject: [PATCH] feat: llama4 input processor (#3383) Signed-off-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com> Signed-off-by: Haohang Huang <31998628+symphonylyh@users.noreply.github.com> Co-authored-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com> Co-authored-by: Haohang Huang <31998628+symphonylyh@users.noreply.github.com> --- examples/pytorch/README.md | 2 +- requirements.txt | 2 +- tensorrt_llm/_torch/models/modeling_llama.py | 94 +++++++++++++++++-- .../_torch/models/modeling_llava_next.py | 3 +- .../models/modeling_multimodal_utils.py | 26 +++-- .../_torch/models/modeling_qwen2vl.py | 3 +- tensorrt_llm/_torch/models/modeling_vila.py | 3 +- tensorrt_llm/_torch/modules/embedding.py | 1 + tensorrt_llm/inputs/__init__.py | 4 +- tensorrt_llm/inputs/utils.py | 12 ++- tensorrt_llm/llmapi/llm.py | 13 +++ .../_torch/modeling/test_modeling_vila.py | 3 +- .../_torch/multi_gpu_modeling/test_llama4.py | 14 ++- 13 files changed, 147 insertions(+), 33 deletions(-) diff --git a/examples/pytorch/README.md b/examples/pytorch/README.md index 5441cbd29f..37fadec45d 100644 --- a/examples/pytorch/README.md +++ b/examples/pytorch/README.md @@ -51,7 +51,7 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo | `LlavaLlamaModel` | VILA | `Efficient-Large-Model/NVILA-8B` | L + V | | `LlavaNextForConditionalGeneration` | LLaVA-NeXT | `llava-hf/llava-v1.6-mistral-7b-hf` | L + V | | `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA | `meta-llama/Meta-Llama-3.1-70B` | L | -| `Llama4ForConditionalGeneration` | Llama 4 Scout/Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L | +| `Llama4ForConditionalGeneration` | Llama 4 Scout/Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L + V | | `MistralForCausalLM` | Mistral | `mistralai/Mistral-7B-v0.1` | L | | `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L | | `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L | diff --git a/requirements.txt b/requirements.txt index b352fdd54c..ba1bb6542e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,7 +26,7 @@ torchvision nvidia-modelopt[torch]~=0.27.0 nvidia-nccl-cu12 nvidia-cuda-nvrtc-cu12 -transformers==4.51.0 +transformers~=4.51.1 pydantic>=2.9.1 pillow==10.3.0 wheel<=0.45.1 diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index eab023a3be..eaf5afa3c1 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -1,15 +1,22 @@ import copy -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch +from PIL.Image import Image from torch import nn -from transformers import Llama4Config, LlamaConfig +from transformers import (AutoProcessor, Llama4Config, Llama4VisionModel, + LlamaConfig) +from transformers.modeling_utils import load_sharded_checkpoint +from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, DeepseekAllReduce) from tensorrt_llm._torch.pipeline_interface import PipelineInterface from tensorrt_llm.functional import PositionEmbeddingType +from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, + register_input_processor) +from ...sampling_params import SamplingParams from ..attention_backend import AttentionMetadata from ..attention_backend.interface import (PositionalEmbeddingParams, PredefinedAttentionMask, RopeParams) @@ -26,6 +33,7 @@ from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from ..modules.rotary_embedding import RotaryEmbedding from ..speculative import Eagle3SpecMetadata, SpecMetadata +from .modeling_multimodal_utils import fuse_input_embeds from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, EagerFusionConfig, MissingLayer, register_auto_model, support_pp, @@ -829,15 +837,13 @@ class LlamaForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]): vocab_size=model_config.pretrained_config.vocab_size) -@register_auto_model("Llama4ForConditionalGeneration") -class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model, - Llama4Config]): +@register_auto_model("Llama4ForCausalLM") +class Llama4ForCausalLM(DecoderModelForCausalLM[LlamaModel, Llama4Config]): def __init__( self, model_config: ModelConfig[Llama4Config], ): - # TODO: figure out a better way to handle multimodality. model_config = copy.copy(model_config) architectures = model_config.pretrained_config.architectures model_config.pretrained_config = model_config.pretrained_config.text_config @@ -876,6 +882,82 @@ class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model, idx + 1].input_layernorm +class Llama4InputProcessor(InputProcessor): + + def __init__(self, model_path, model_config, tokenizer): + self.processor = AutoProcessor.from_pretrained(model_path, + use_fast=True) + self.model_config = model_config + self.tokenizer = tokenizer + self.vocab_size = model_config.text_config.vocab_size + self.image_token_index = model_config.image_token_index + + self.encoder = nn.ModuleDict({ + "vision_model": + Llama4VisionModel(model_config.vision_config), + "multi_modal_projector": + Llama4MultiModalProjector(model_config) + }).cuda() + load_sharded_checkpoint(self.encoder, model_path, strict=False) + + @torch.inference_mode() + def __call__( + self, inputs: TextPrompt, sampling_params: SamplingParams + ) -> Tuple[List[int], Optional[ExtraProcessedInputs]]: + text_prompt, mm_data = inputs.get("prompt"), inputs.get( + "multi_modal_data") + images, do_rescale = None, True + + if mm_data and mm_data.get("image"): + images = mm_data["image"] + img_type = type(mm_data["image"][0]) + do_rescale = (img_type == Image) + assert all(isinstance(img, img_type) for img in mm_data["image"]) + + # preprocess images and insert image tokens + processed = self.processor(text=text_prompt, + images=images, + return_tensors="pt", + device="cuda", + do_rescale=do_rescale, + add_special_tokens=False) + if images: + token_ids, pixel_values = processed["input_ids"].squeeze( + ), processed["pixel_values"] + mm_embeds = self.encoder.vision_model( + pixel_values.float().cuda()).last_hidden_state.flatten(0, 1) + mm_embeds = self.encoder.multi_modal_projector(mm_embeds) + # for fuse_input_embeds + token_ids[token_ids == self.image_token_index] = self.vocab_size + 1 + return token_ids.tolist(), { + "prompt_tuning_config": [mm_embeds, None, None] + } + else: + return processed["input_ids"].squeeze().tolist(), {} + + +@register_auto_model("Llama4ForConditionalGeneration") +@register_input_processor(Llama4InputProcessor) +class Llama4ForConditionalGeneration(Llama4ForCausalLM): + + @torch.inference_mode() + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + return_context_logits: Optional[bool] = False, + **kwargs, + ) -> torch.Tensor: + mm_embed = kwargs.get("multi_modal_data", []) + input_ids, inputs_embeds = fuse_input_embeds(self.model.embed_tokens, + input_ids, mm_embed) + logits = super().forward(attn_metadata, input_ids, position_ids, + inputs_embeds, return_context_logits) + return logits + + @register_auto_model("MistralForCausalLM") class MistralForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]): diff --git a/tensorrt_llm/_torch/models/modeling_llava_next.py b/tensorrt_llm/_torch/models/modeling_llava_next.py index e50a8835b7..67b0b7042f 100644 --- a/tensorrt_llm/_torch/models/modeling_llava_next.py +++ b/tensorrt_llm/_torch/models/modeling_llava_next.py @@ -221,7 +221,8 @@ class LlavaNextModel(PreTrainedModel): mm_embed ) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests" - input_ids, inputs_embeds = fuse_input_embeds(self, input_ids, mm_embed) + input_ids, inputs_embeds = fuse_input_embeds( + self.llm.model.embed_tokens, input_ids, mm_embed) logits = self.llm.forward(attn_metadata, input_ids, position_ids, inputs_embeds, return_context_logits) return logits diff --git a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py index 3c185ff6ff..620ca65f5c 100644 --- a/tensorrt_llm/_torch/models/modeling_multimodal_utils.py +++ b/tensorrt_llm/_torch/models/modeling_multimodal_utils.py @@ -25,9 +25,11 @@ from einops import rearrange from PIL import Image from torchvision.transforms import Normalize, Resize, ToTensor +from tensorrt_llm._torch.modules.embedding import Embedding + def fuse_input_embeds( - model, + embedding_layer: Embedding, input_ids: torch.LongTensor, mm_embeds: List[torch.Tensor], ) -> Tuple[Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: @@ -44,20 +46,24 @@ def fuse_input_embeds( if len(mm_embeds) == 0: return input_ids, None + vocab_size = embedding_layer.num_embeddings mm_embed = torch.cat(mm_embeds, dim=0) + + text_token_indices = torch.where(input_ids < vocab_size)[0] + mm_token_indices = torch.where(input_ids >= vocab_size)[0] + + text_embed = embedding_layer(input_ids[text_token_indices]) input_embeds = torch.empty(input_ids.shape[0], mm_embed.shape[-1], - device=input_ids.device, - dtype=model.model_dtype) + device=text_embed.device, + dtype=text_embed.dtype) - text_token_indices = torch.where(input_ids < model.vocab_size)[0] - mm_token_indices = torch.where(input_ids >= model.vocab_size)[0] + input_embeds[text_token_indices, :] = text_embed.to( + dtype=input_embeds.dtype, device=input_embeds.device) + input_embeds[mm_token_indices, :] = mm_embed.to(dtype=input_embeds.dtype, + device=input_embeds.device) - text_embed = model.llm.model.embed_tokens(input_ids[text_token_indices]) - input_embeds[text_token_indices, :] = text_embed.to(model.model_dtype) - input_embeds[mm_token_indices, :] = mm_embed.to(model.model_dtype) - - return None, input_embeds.to(model.dtype) + return None, input_embeds #region VILA utils diff --git a/tensorrt_llm/_torch/models/modeling_qwen2vl.py b/tensorrt_llm/_torch/models/modeling_qwen2vl.py index 792907c915..aa93a8614d 100644 --- a/tensorrt_llm/_torch/models/modeling_qwen2vl.py +++ b/tensorrt_llm/_torch/models/modeling_qwen2vl.py @@ -417,7 +417,8 @@ class Qwen2VLModelBase(PreTrainedModel): assert mm_embed == [] or len( mm_embed) == num_context_requests, error_msg - input_ids, input_embeds = fuse_input_embeds(self, input_ids, mm_embed) + input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens, + input_ids, mm_embed) mrope_config = kwargs.get("mrope_config", {}) if mrope_config: diff --git a/tensorrt_llm/_torch/models/modeling_vila.py b/tensorrt_llm/_torch/models/modeling_vila.py index 058b2b893a..3ed73b30df 100644 --- a/tensorrt_llm/_torch/models/modeling_vila.py +++ b/tensorrt_llm/_torch/models/modeling_vila.py @@ -1167,7 +1167,8 @@ class VilaModel(PreTrainedModel): mm_embed ) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests" - input_ids, inputs_embeds = fuse_input_embeds(self, input_ids, mm_embed) + input_ids, inputs_embeds = fuse_input_embeds( + self.llm.model.embed_tokens, input_ids, mm_embed) logits = self.llm.forward(attn_metadata=attn_metadata, input_ids=input_ids, position_ids=position_ids, diff --git a/tensorrt_llm/_torch/modules/embedding.py b/tensorrt_llm/_torch/modules/embedding.py index 15cc50ab1f..fa31aedc95 100644 --- a/tensorrt_llm/_torch/modules/embedding.py +++ b/tensorrt_llm/_torch/modules/embedding.py @@ -59,6 +59,7 @@ class LMHead(Linear): local_in_features -= self.padding_size self.in_features = local_in_features self.out_features = local_out_features + self.num_embeddings = num_embeddings weight_shape = (self.out_features, self.in_features) self.weight = Parameter(torch.empty(weight_shape, dtype=dtype)) diff --git a/tensorrt_llm/inputs/__init__.py b/tensorrt_llm/inputs/__init__.py index fd9887ddb1..a9afd9184a 100644 --- a/tensorrt_llm/inputs/__init__.py +++ b/tensorrt_llm/inputs/__init__.py @@ -2,7 +2,7 @@ from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs from .registry import (ExtraProcessedInputs, InputProcessor, create_input_processor, register_input_processor) from .utils import (INPUT_FORMATTER_MAP, default_image_loader, - default_video_loader, format_llava_next_input, + default_video_loader, format_generic_input, format_qwen2_vl_input, format_vila_input, load_image, load_video) @@ -11,5 +11,5 @@ __all__ = [ "InputProcessor", "create_input_processor", "register_input_processor", "ExtraProcessedInputs", "load_image", "load_video", "INPUT_FORMATTER_MAP", "default_image_loader", "default_video_loader", "format_vila_input", - "format_llava_next_input", "format_qwen2_vl_input" + "format_generic_input", "format_qwen2_vl_input" ] diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 863bbcfeff..0da8ec6175 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -104,7 +104,7 @@ def format_vila_input(model_dir, inputs): return inputs -def format_llava_next_input(model_dir, inputs): +def format_generic_input(model_dir, inputs): """ This function formats the input for the Llava Next VL model. @@ -122,15 +122,16 @@ def format_llava_next_input(model_dir, inputs): def apply_template(prompt, multimodal_data): conversation = [ { - "role": "user", + "role": + "user", "content": [ { "type": "text", "text": prompt }, - { + *[{ "type": "image" - }, + } for _ in multimodal_data["image"]], ], }, ] @@ -220,7 +221,8 @@ def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8): INPUT_FORMATTER_MAP = { "llava_llama": format_vila_input, - "llava_next": format_llava_next_input, + "llava_next": format_generic_input, "qwen2_vl": format_qwen2_vl_input, "qwen2_5_vl": format_qwen2_vl_input, + "llama4": format_generic_input, } diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index c3c6adfa26..d21abfe5e1 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -10,6 +10,9 @@ from typing import Any, List, Literal, Optional, Sequence, Union from tqdm import tqdm from transformers import PreTrainedTokenizerBase +from tensorrt_llm.inputs.data import TextPrompt +from tensorrt_llm.inputs.registry import DefaultInputProcessor + from .. import bindings as tllm from .._utils import global_mpi_rank, nvtx_range_debug from ..bindings import executor as tllm @@ -309,6 +312,16 @@ class LLM: if queries is not None: queries = prompt_inputs(queries) + if not inputs.get("prompt") and inputs.get( + "prompt_token_ids") and not isinstance(self.input_processor, + DefaultInputProcessor): + # VLMs need to process/tokenize the prompt in their own way + prompt = self.tokenizer.decode(inputs['prompt_token_ids']) + inputs = TextPrompt( + prompt=prompt, + multi_modal_data=inputs.get("multi_modal_data"), + mm_processor_kwargs=inputs.get("mm_processor_kwargs")) + query_token_ids = None prompt_tuning_config = None mrope_config = None diff --git a/tests/unittest/_torch/modeling/test_modeling_vila.py b/tests/unittest/_torch/modeling/test_modeling_vila.py index fdd3245fad..3343fd9255 100644 --- a/tests/unittest/_torch/modeling/test_modeling_vila.py +++ b/tests/unittest/_torch/modeling/test_modeling_vila.py @@ -603,7 +603,8 @@ class TestVila(unittest.TestCase): device=device, dtype=torch.int) images = [torch.rand(196, 2560, dtype=dtype, device=device)] - input_ids, input_embeds = fuse_input_embeds(model, input_ids, images) + input_ids, input_embeds = fuse_input_embeds( + model.llm.model.embed_tokens, input_ids, images) self.assertIsNone(input_ids) self.assertEqual(list(input_embeds.shape), [233, 2560]) diff --git a/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py b/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py index cb379480cb..366c27dbac 100644 --- a/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py +++ b/tests/unittest/_torch/multi_gpu_modeling/test_llama4.py @@ -2,6 +2,7 @@ from difflib import SequenceMatcher import pytest +import torch from utils.llm_data import llm_models_root from tensorrt_llm import SamplingParams @@ -19,12 +20,17 @@ from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig @pytest.mark.parametrize("use_cuda_graph", [True, False], ids=["enable_graph", "disable_graph"]) def test_llama4(model_name, backend, tp_size, use_cuda_graph): - prompts = [ - "The president of the United States is", - ] + prompts = [{ + "prompt": "The president of the United States is" + }, { + "prompt": "<|image|>This image is of color", + "multi_modal_data": { + "image": [torch.ones(3, 1024, 1024)] + } + }] expected_outputs = [ - " the head of state and head of government of the", + " the head of state and head of government of the", " solid white" ] pytorch_config = PyTorchConfig(attn_backend=backend,