feat: llama4 input processor (#3383)

Signed-off-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com>
Signed-off-by: Haohang Huang <31998628+symphonylyh@users.noreply.github.com>
Co-authored-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com>
Co-authored-by: Haohang Huang <31998628+symphonylyh@users.noreply.github.com>
This commit is contained in:
milesial 2025-04-25 16:47:14 -07:00 committed by GitHub
parent d7472231f9
commit 362a8272f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 147 additions and 33 deletions

View File

@ -51,7 +51,7 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
| `LlavaLlamaModel` | VILA | `Efficient-Large-Model/NVILA-8B` | L + V |
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | `llava-hf/llava-v1.6-mistral-7b-hf` | L + V |
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA | `meta-llama/Meta-Llama-3.1-70B` | L |
| `Llama4ForConditionalGeneration` | Llama 4 Scout/Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L |
| `Llama4ForConditionalGeneration` | Llama 4 Scout/Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L + V |
| `MistralForCausalLM` | Mistral | `mistralai/Mistral-7B-v0.1` | L |
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L |
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L |

View File

@ -26,7 +26,7 @@ torchvision
nvidia-modelopt[torch]~=0.27.0
nvidia-nccl-cu12
nvidia-cuda-nvrtc-cu12
transformers==4.51.0
transformers~=4.51.1
pydantic>=2.9.1
pillow==10.3.0
wheel<=0.45.1

View File

@ -1,15 +1,22 @@
import copy
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from PIL.Image import Image
from torch import nn
from transformers import Llama4Config, LlamaConfig
from transformers import (AutoProcessor, Llama4Config, Llama4VisionModel,
LlamaConfig)
from transformers.modeling_utils import load_sharded_checkpoint
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, DeepseekAllReduce)
from tensorrt_llm._torch.pipeline_interface import PipelineInterface
from tensorrt_llm.functional import PositionEmbeddingType
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
register_input_processor)
from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import (PositionalEmbeddingParams,
PredefinedAttentionMask, RopeParams)
@ -26,6 +33,7 @@ from ..modules.multi_stream_utils import maybe_execute_in_parallel
from ..modules.rms_norm import RMSNorm
from ..modules.rotary_embedding import RotaryEmbedding
from ..speculative import Eagle3SpecMetadata, SpecMetadata
from .modeling_multimodal_utils import fuse_input_embeds
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
EagerFusionConfig, MissingLayer,
register_auto_model, support_pp,
@ -829,15 +837,13 @@ class LlamaForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]):
vocab_size=model_config.pretrained_config.vocab_size)
@register_auto_model("Llama4ForConditionalGeneration")
class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model,
Llama4Config]):
@register_auto_model("Llama4ForCausalLM")
class Llama4ForCausalLM(DecoderModelForCausalLM[LlamaModel, Llama4Config]):
def __init__(
self,
model_config: ModelConfig[Llama4Config],
):
# TODO: figure out a better way to handle multimodality.
model_config = copy.copy(model_config)
architectures = model_config.pretrained_config.architectures
model_config.pretrained_config = model_config.pretrained_config.text_config
@ -876,6 +882,82 @@ class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model,
idx + 1].input_layernorm
class Llama4InputProcessor(InputProcessor):
def __init__(self, model_path, model_config, tokenizer):
self.processor = AutoProcessor.from_pretrained(model_path,
use_fast=True)
self.model_config = model_config
self.tokenizer = tokenizer
self.vocab_size = model_config.text_config.vocab_size
self.image_token_index = model_config.image_token_index
self.encoder = nn.ModuleDict({
"vision_model":
Llama4VisionModel(model_config.vision_config),
"multi_modal_projector":
Llama4MultiModalProjector(model_config)
}).cuda()
load_sharded_checkpoint(self.encoder, model_path, strict=False)
@torch.inference_mode()
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
text_prompt, mm_data = inputs.get("prompt"), inputs.get(
"multi_modal_data")
images, do_rescale = None, True
if mm_data and mm_data.get("image"):
images = mm_data["image"]
img_type = type(mm_data["image"][0])
do_rescale = (img_type == Image)
assert all(isinstance(img, img_type) for img in mm_data["image"])
# preprocess images and insert image tokens
processed = self.processor(text=text_prompt,
images=images,
return_tensors="pt",
device="cuda",
do_rescale=do_rescale,
add_special_tokens=False)
if images:
token_ids, pixel_values = processed["input_ids"].squeeze(
), processed["pixel_values"]
mm_embeds = self.encoder.vision_model(
pixel_values.float().cuda()).last_hidden_state.flatten(0, 1)
mm_embeds = self.encoder.multi_modal_projector(mm_embeds)
# for fuse_input_embeds
token_ids[token_ids == self.image_token_index] = self.vocab_size + 1
return token_ids.tolist(), {
"prompt_tuning_config": [mm_embeds, None, None]
}
else:
return processed["input_ids"].squeeze().tolist(), {}
@register_auto_model("Llama4ForConditionalGeneration")
@register_input_processor(Llama4InputProcessor)
class Llama4ForConditionalGeneration(Llama4ForCausalLM):
@torch.inference_mode()
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: Optional[bool] = False,
**kwargs,
) -> torch.Tensor:
mm_embed = kwargs.get("multi_modal_data", [])
input_ids, inputs_embeds = fuse_input_embeds(self.model.embed_tokens,
input_ids, mm_embed)
logits = super().forward(attn_metadata, input_ids, position_ids,
inputs_embeds, return_context_logits)
return logits
@register_auto_model("MistralForCausalLM")
class MistralForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]):

View File

@ -221,7 +221,8 @@ class LlavaNextModel(PreTrainedModel):
mm_embed
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
input_ids, inputs_embeds = fuse_input_embeds(self, input_ids, mm_embed)
input_ids, inputs_embeds = fuse_input_embeds(
self.llm.model.embed_tokens, input_ids, mm_embed)
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
inputs_embeds, return_context_logits)
return logits

View File

@ -25,9 +25,11 @@ from einops import rearrange
from PIL import Image
from torchvision.transforms import Normalize, Resize, ToTensor
from tensorrt_llm._torch.modules.embedding import Embedding
def fuse_input_embeds(
model,
embedding_layer: Embedding,
input_ids: torch.LongTensor,
mm_embeds: List[torch.Tensor],
) -> Tuple[Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
@ -44,20 +46,24 @@ def fuse_input_embeds(
if len(mm_embeds) == 0:
return input_ids, None
vocab_size = embedding_layer.num_embeddings
mm_embed = torch.cat(mm_embeds, dim=0)
text_token_indices = torch.where(input_ids < vocab_size)[0]
mm_token_indices = torch.where(input_ids >= vocab_size)[0]
text_embed = embedding_layer(input_ids[text_token_indices])
input_embeds = torch.empty(input_ids.shape[0],
mm_embed.shape[-1],
device=input_ids.device,
dtype=model.model_dtype)
device=text_embed.device,
dtype=text_embed.dtype)
text_token_indices = torch.where(input_ids < model.vocab_size)[0]
mm_token_indices = torch.where(input_ids >= model.vocab_size)[0]
input_embeds[text_token_indices, :] = text_embed.to(
dtype=input_embeds.dtype, device=input_embeds.device)
input_embeds[mm_token_indices, :] = mm_embed.to(dtype=input_embeds.dtype,
device=input_embeds.device)
text_embed = model.llm.model.embed_tokens(input_ids[text_token_indices])
input_embeds[text_token_indices, :] = text_embed.to(model.model_dtype)
input_embeds[mm_token_indices, :] = mm_embed.to(model.model_dtype)
return None, input_embeds.to(model.dtype)
return None, input_embeds
#region VILA utils

View File

@ -417,7 +417,8 @@ class Qwen2VLModelBase(PreTrainedModel):
assert mm_embed == [] or len(
mm_embed) == num_context_requests, error_msg
input_ids, input_embeds = fuse_input_embeds(self, input_ids, mm_embed)
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
input_ids, mm_embed)
mrope_config = kwargs.get("mrope_config", {})
if mrope_config:

View File

@ -1167,7 +1167,8 @@ class VilaModel(PreTrainedModel):
mm_embed
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
input_ids, inputs_embeds = fuse_input_embeds(self, input_ids, mm_embed)
input_ids, inputs_embeds = fuse_input_embeds(
self.llm.model.embed_tokens, input_ids, mm_embed)
logits = self.llm.forward(attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,

View File

@ -59,6 +59,7 @@ class LMHead(Linear):
local_in_features -= self.padding_size
self.in_features = local_in_features
self.out_features = local_out_features
self.num_embeddings = num_embeddings
weight_shape = (self.out_features, self.in_features)
self.weight = Parameter(torch.empty(weight_shape, dtype=dtype))

View File

@ -2,7 +2,7 @@ from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
from .registry import (ExtraProcessedInputs, InputProcessor,
create_input_processor, register_input_processor)
from .utils import (INPUT_FORMATTER_MAP, default_image_loader,
default_video_loader, format_llava_next_input,
default_video_loader, format_generic_input,
format_qwen2_vl_input, format_vila_input, load_image,
load_video)
@ -11,5 +11,5 @@ __all__ = [
"InputProcessor", "create_input_processor", "register_input_processor",
"ExtraProcessedInputs", "load_image", "load_video", "INPUT_FORMATTER_MAP",
"default_image_loader", "default_video_loader", "format_vila_input",
"format_llava_next_input", "format_qwen2_vl_input"
"format_generic_input", "format_qwen2_vl_input"
]

View File

@ -104,7 +104,7 @@ def format_vila_input(model_dir, inputs):
return inputs
def format_llava_next_input(model_dir, inputs):
def format_generic_input(model_dir, inputs):
"""
This function formats the input for the Llava Next VL model.
@ -122,15 +122,16 @@ def format_llava_next_input(model_dir, inputs):
def apply_template(prompt, multimodal_data):
conversation = [
{
"role": "user",
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
{
*[{
"type": "image"
},
} for _ in multimodal_data["image"]],
],
},
]
@ -220,7 +221,8 @@ def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8):
INPUT_FORMATTER_MAP = {
"llava_llama": format_vila_input,
"llava_next": format_llava_next_input,
"llava_next": format_generic_input,
"qwen2_vl": format_qwen2_vl_input,
"qwen2_5_vl": format_qwen2_vl_input,
"llama4": format_generic_input,
}

View File

@ -10,6 +10,9 @@ from typing import Any, List, Literal, Optional, Sequence, Union
from tqdm import tqdm
from transformers import PreTrainedTokenizerBase
from tensorrt_llm.inputs.data import TextPrompt
from tensorrt_llm.inputs.registry import DefaultInputProcessor
from .. import bindings as tllm
from .._utils import global_mpi_rank, nvtx_range_debug
from ..bindings import executor as tllm
@ -309,6 +312,16 @@ class LLM:
if queries is not None:
queries = prompt_inputs(queries)
if not inputs.get("prompt") and inputs.get(
"prompt_token_ids") and not isinstance(self.input_processor,
DefaultInputProcessor):
# VLMs need to process/tokenize the prompt in their own way
prompt = self.tokenizer.decode(inputs['prompt_token_ids'])
inputs = TextPrompt(
prompt=prompt,
multi_modal_data=inputs.get("multi_modal_data"),
mm_processor_kwargs=inputs.get("mm_processor_kwargs"))
query_token_ids = None
prompt_tuning_config = None
mrope_config = None

View File

@ -603,7 +603,8 @@ class TestVila(unittest.TestCase):
device=device,
dtype=torch.int)
images = [torch.rand(196, 2560, dtype=dtype, device=device)]
input_ids, input_embeds = fuse_input_embeds(model, input_ids, images)
input_ids, input_embeds = fuse_input_embeds(
model.llm.model.embed_tokens, input_ids, images)
self.assertIsNone(input_ids)
self.assertEqual(list(input_embeds.shape), [233, 2560])

View File

@ -2,6 +2,7 @@
from difflib import SequenceMatcher
import pytest
import torch
from utils.llm_data import llm_models_root
from tensorrt_llm import SamplingParams
@ -19,12 +20,17 @@ from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
@pytest.mark.parametrize("use_cuda_graph", [True, False],
ids=["enable_graph", "disable_graph"])
def test_llama4(model_name, backend, tp_size, use_cuda_graph):
prompts = [
"The president of the United States is",
]
prompts = [{
"prompt": "The president of the United States is"
}, {
"prompt": "<|image|>This image is of color",
"multi_modal_data": {
"image": [torch.ones(3, 1024, 1024)]
}
}]
expected_outputs = [
" the head of state and head of government of the",
" the head of state and head of government of the", " solid white"
]
pytorch_config = PyTorchConfig(attn_backend=backend,