mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: llama4 input processor (#3383)
Signed-off-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com> Signed-off-by: Haohang Huang <31998628+symphonylyh@users.noreply.github.com> Co-authored-by: Alexandre Milesi <30204471+milesial@users.noreply.github.com> Co-authored-by: Haohang Huang <31998628+symphonylyh@users.noreply.github.com>
This commit is contained in:
parent
d7472231f9
commit
362a8272f8
@ -51,7 +51,7 @@ python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --mo
|
||||
| `LlavaLlamaModel` | VILA | `Efficient-Large-Model/NVILA-8B` | L + V |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | `llava-hf/llava-v1.6-mistral-7b-hf` | L + V |
|
||||
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA | `meta-llama/Meta-Llama-3.1-70B` | L |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 Scout/Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 Scout/Maverick | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct` | L + V |
|
||||
| `MistralForCausalLM` | Mistral | `mistralai/Mistral-7B-v0.1` | L |
|
||||
| `MixtralForCausalLM` | Mixtral | `mistralai/Mixtral-8x7B-v0.1` | L |
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | `meta-llama/Llama-3.2-11B-Vision` | L |
|
||||
|
||||
@ -26,7 +26,7 @@ torchvision
|
||||
nvidia-modelopt[torch]~=0.27.0
|
||||
nvidia-nccl-cu12
|
||||
nvidia-cuda-nvrtc-cu12
|
||||
transformers==4.51.0
|
||||
transformers~=4.51.1
|
||||
pydantic>=2.9.1
|
||||
pillow==10.3.0
|
||||
wheel<=0.45.1
|
||||
|
||||
@ -1,15 +1,22 @@
|
||||
import copy
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from torch import nn
|
||||
from transformers import Llama4Config, LlamaConfig
|
||||
from transformers import (AutoProcessor, Llama4Config, Llama4VisionModel,
|
||||
LlamaConfig)
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
|
||||
|
||||
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
AllReduceParams, DeepseekAllReduce)
|
||||
from tensorrt_llm._torch.pipeline_interface import PipelineInterface
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import (PositionalEmbeddingParams,
|
||||
PredefinedAttentionMask, RopeParams)
|
||||
@ -26,6 +33,7 @@ from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..modules.rotary_embedding import RotaryEmbedding
|
||||
from ..speculative import Eagle3SpecMetadata, SpecMetadata
|
||||
from .modeling_multimodal_utils import fuse_input_embeds
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
EagerFusionConfig, MissingLayer,
|
||||
register_auto_model, support_pp,
|
||||
@ -829,15 +837,13 @@ class LlamaForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]):
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
|
||||
|
||||
@register_auto_model("Llama4ForConditionalGeneration")
|
||||
class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model,
|
||||
Llama4Config]):
|
||||
@register_auto_model("Llama4ForCausalLM")
|
||||
class Llama4ForCausalLM(DecoderModelForCausalLM[LlamaModel, Llama4Config]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[Llama4Config],
|
||||
):
|
||||
# TODO: figure out a better way to handle multimodality.
|
||||
model_config = copy.copy(model_config)
|
||||
architectures = model_config.pretrained_config.architectures
|
||||
model_config.pretrained_config = model_config.pretrained_config.text_config
|
||||
@ -876,6 +882,82 @@ class Llama4ForConditionalGeneration(DecoderModelForCausalLM[Llama4Model,
|
||||
idx + 1].input_layernorm
|
||||
|
||||
|
||||
class Llama4InputProcessor(InputProcessor):
|
||||
|
||||
def __init__(self, model_path, model_config, tokenizer):
|
||||
self.processor = AutoProcessor.from_pretrained(model_path,
|
||||
use_fast=True)
|
||||
self.model_config = model_config
|
||||
self.tokenizer = tokenizer
|
||||
self.vocab_size = model_config.text_config.vocab_size
|
||||
self.image_token_index = model_config.image_token_index
|
||||
|
||||
self.encoder = nn.ModuleDict({
|
||||
"vision_model":
|
||||
Llama4VisionModel(model_config.vision_config),
|
||||
"multi_modal_projector":
|
||||
Llama4MultiModalProjector(model_config)
|
||||
}).cuda()
|
||||
load_sharded_checkpoint(self.encoder, model_path, strict=False)
|
||||
|
||||
@torch.inference_mode()
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
text_prompt, mm_data = inputs.get("prompt"), inputs.get(
|
||||
"multi_modal_data")
|
||||
images, do_rescale = None, True
|
||||
|
||||
if mm_data and mm_data.get("image"):
|
||||
images = mm_data["image"]
|
||||
img_type = type(mm_data["image"][0])
|
||||
do_rescale = (img_type == Image)
|
||||
assert all(isinstance(img, img_type) for img in mm_data["image"])
|
||||
|
||||
# preprocess images and insert image tokens
|
||||
processed = self.processor(text=text_prompt,
|
||||
images=images,
|
||||
return_tensors="pt",
|
||||
device="cuda",
|
||||
do_rescale=do_rescale,
|
||||
add_special_tokens=False)
|
||||
if images:
|
||||
token_ids, pixel_values = processed["input_ids"].squeeze(
|
||||
), processed["pixel_values"]
|
||||
mm_embeds = self.encoder.vision_model(
|
||||
pixel_values.float().cuda()).last_hidden_state.flatten(0, 1)
|
||||
mm_embeds = self.encoder.multi_modal_projector(mm_embeds)
|
||||
# for fuse_input_embeds
|
||||
token_ids[token_ids == self.image_token_index] = self.vocab_size + 1
|
||||
return token_ids.tolist(), {
|
||||
"prompt_tuning_config": [mm_embeds, None, None]
|
||||
}
|
||||
else:
|
||||
return processed["input_ids"].squeeze().tolist(), {}
|
||||
|
||||
|
||||
@register_auto_model("Llama4ForConditionalGeneration")
|
||||
@register_input_processor(Llama4InputProcessor)
|
||||
class Llama4ForConditionalGeneration(Llama4ForCausalLM):
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_context_logits: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
mm_embed = kwargs.get("multi_modal_data", [])
|
||||
input_ids, inputs_embeds = fuse_input_embeds(self.model.embed_tokens,
|
||||
input_ids, mm_embed)
|
||||
logits = super().forward(attn_metadata, input_ids, position_ids,
|
||||
inputs_embeds, return_context_logits)
|
||||
return logits
|
||||
|
||||
|
||||
@register_auto_model("MistralForCausalLM")
|
||||
class MistralForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]):
|
||||
|
||||
|
||||
@ -221,7 +221,8 @@ class LlavaNextModel(PreTrainedModel):
|
||||
mm_embed
|
||||
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
|
||||
|
||||
input_ids, inputs_embeds = fuse_input_embeds(self, input_ids, mm_embed)
|
||||
input_ids, inputs_embeds = fuse_input_embeds(
|
||||
self.llm.model.embed_tokens, input_ids, mm_embed)
|
||||
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
|
||||
inputs_embeds, return_context_logits)
|
||||
return logits
|
||||
|
||||
@ -25,9 +25,11 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Normalize, Resize, ToTensor
|
||||
|
||||
from tensorrt_llm._torch.modules.embedding import Embedding
|
||||
|
||||
|
||||
def fuse_input_embeds(
|
||||
model,
|
||||
embedding_layer: Embedding,
|
||||
input_ids: torch.LongTensor,
|
||||
mm_embeds: List[torch.Tensor],
|
||||
) -> Tuple[Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
|
||||
@ -44,20 +46,24 @@ def fuse_input_embeds(
|
||||
if len(mm_embeds) == 0:
|
||||
return input_ids, None
|
||||
|
||||
vocab_size = embedding_layer.num_embeddings
|
||||
mm_embed = torch.cat(mm_embeds, dim=0)
|
||||
|
||||
text_token_indices = torch.where(input_ids < vocab_size)[0]
|
||||
mm_token_indices = torch.where(input_ids >= vocab_size)[0]
|
||||
|
||||
text_embed = embedding_layer(input_ids[text_token_indices])
|
||||
input_embeds = torch.empty(input_ids.shape[0],
|
||||
mm_embed.shape[-1],
|
||||
device=input_ids.device,
|
||||
dtype=model.model_dtype)
|
||||
device=text_embed.device,
|
||||
dtype=text_embed.dtype)
|
||||
|
||||
text_token_indices = torch.where(input_ids < model.vocab_size)[0]
|
||||
mm_token_indices = torch.where(input_ids >= model.vocab_size)[0]
|
||||
input_embeds[text_token_indices, :] = text_embed.to(
|
||||
dtype=input_embeds.dtype, device=input_embeds.device)
|
||||
input_embeds[mm_token_indices, :] = mm_embed.to(dtype=input_embeds.dtype,
|
||||
device=input_embeds.device)
|
||||
|
||||
text_embed = model.llm.model.embed_tokens(input_ids[text_token_indices])
|
||||
input_embeds[text_token_indices, :] = text_embed.to(model.model_dtype)
|
||||
input_embeds[mm_token_indices, :] = mm_embed.to(model.model_dtype)
|
||||
|
||||
return None, input_embeds.to(model.dtype)
|
||||
return None, input_embeds
|
||||
|
||||
|
||||
#region VILA utils
|
||||
|
||||
@ -417,7 +417,8 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
assert mm_embed == [] or len(
|
||||
mm_embed) == num_context_requests, error_msg
|
||||
|
||||
input_ids, input_embeds = fuse_input_embeds(self, input_ids, mm_embed)
|
||||
input_ids, input_embeds = fuse_input_embeds(self.llm.model.embed_tokens,
|
||||
input_ids, mm_embed)
|
||||
|
||||
mrope_config = kwargs.get("mrope_config", {})
|
||||
if mrope_config:
|
||||
|
||||
@ -1167,7 +1167,8 @@ class VilaModel(PreTrainedModel):
|
||||
mm_embed
|
||||
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
|
||||
|
||||
input_ids, inputs_embeds = fuse_input_embeds(self, input_ids, mm_embed)
|
||||
input_ids, inputs_embeds = fuse_input_embeds(
|
||||
self.llm.model.embed_tokens, input_ids, mm_embed)
|
||||
logits = self.llm.forward(attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
||||
@ -59,6 +59,7 @@ class LMHead(Linear):
|
||||
local_in_features -= self.padding_size
|
||||
self.in_features = local_in_features
|
||||
self.out_features = local_out_features
|
||||
self.num_embeddings = num_embeddings
|
||||
|
||||
weight_shape = (self.out_features, self.in_features)
|
||||
self.weight = Parameter(torch.empty(weight_shape, dtype=dtype))
|
||||
|
||||
@ -2,7 +2,7 @@ from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
|
||||
from .registry import (ExtraProcessedInputs, InputProcessor,
|
||||
create_input_processor, register_input_processor)
|
||||
from .utils import (INPUT_FORMATTER_MAP, default_image_loader,
|
||||
default_video_loader, format_llava_next_input,
|
||||
default_video_loader, format_generic_input,
|
||||
format_qwen2_vl_input, format_vila_input, load_image,
|
||||
load_video)
|
||||
|
||||
@ -11,5 +11,5 @@ __all__ = [
|
||||
"InputProcessor", "create_input_processor", "register_input_processor",
|
||||
"ExtraProcessedInputs", "load_image", "load_video", "INPUT_FORMATTER_MAP",
|
||||
"default_image_loader", "default_video_loader", "format_vila_input",
|
||||
"format_llava_next_input", "format_qwen2_vl_input"
|
||||
"format_generic_input", "format_qwen2_vl_input"
|
||||
]
|
||||
|
||||
@ -104,7 +104,7 @@ def format_vila_input(model_dir, inputs):
|
||||
return inputs
|
||||
|
||||
|
||||
def format_llava_next_input(model_dir, inputs):
|
||||
def format_generic_input(model_dir, inputs):
|
||||
"""
|
||||
This function formats the input for the Llava Next VL model.
|
||||
|
||||
@ -122,15 +122,16 @@ def format_llava_next_input(model_dir, inputs):
|
||||
def apply_template(prompt, multimodal_data):
|
||||
conversation = [
|
||||
{
|
||||
"role": "user",
|
||||
"role":
|
||||
"user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
*[{
|
||||
"type": "image"
|
||||
},
|
||||
} for _ in multimodal_data["image"]],
|
||||
],
|
||||
},
|
||||
]
|
||||
@ -220,7 +221,8 @@ def default_video_loader(prompts, videos, image_data_format="pt", num_frames=8):
|
||||
|
||||
INPUT_FORMATTER_MAP = {
|
||||
"llava_llama": format_vila_input,
|
||||
"llava_next": format_llava_next_input,
|
||||
"llava_next": format_generic_input,
|
||||
"qwen2_vl": format_qwen2_vl_input,
|
||||
"qwen2_5_vl": format_qwen2_vl_input,
|
||||
"llama4": format_generic_input,
|
||||
}
|
||||
|
||||
@ -10,6 +10,9 @@ from typing import Any, List, Literal, Optional, Sequence, Union
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from tensorrt_llm.inputs.data import TextPrompt
|
||||
from tensorrt_llm.inputs.registry import DefaultInputProcessor
|
||||
|
||||
from .. import bindings as tllm
|
||||
from .._utils import global_mpi_rank, nvtx_range_debug
|
||||
from ..bindings import executor as tllm
|
||||
@ -309,6 +312,16 @@ class LLM:
|
||||
if queries is not None:
|
||||
queries = prompt_inputs(queries)
|
||||
|
||||
if not inputs.get("prompt") and inputs.get(
|
||||
"prompt_token_ids") and not isinstance(self.input_processor,
|
||||
DefaultInputProcessor):
|
||||
# VLMs need to process/tokenize the prompt in their own way
|
||||
prompt = self.tokenizer.decode(inputs['prompt_token_ids'])
|
||||
inputs = TextPrompt(
|
||||
prompt=prompt,
|
||||
multi_modal_data=inputs.get("multi_modal_data"),
|
||||
mm_processor_kwargs=inputs.get("mm_processor_kwargs"))
|
||||
|
||||
query_token_ids = None
|
||||
prompt_tuning_config = None
|
||||
mrope_config = None
|
||||
|
||||
@ -603,7 +603,8 @@ class TestVila(unittest.TestCase):
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
images = [torch.rand(196, 2560, dtype=dtype, device=device)]
|
||||
input_ids, input_embeds = fuse_input_embeds(model, input_ids, images)
|
||||
input_ids, input_embeds = fuse_input_embeds(
|
||||
model.llm.model.embed_tokens, input_ids, images)
|
||||
self.assertIsNone(input_ids)
|
||||
self.assertEqual(list(input_embeds.shape), [233, 2560])
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
@ -19,12 +20,17 @@ from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
@pytest.mark.parametrize("use_cuda_graph", [True, False],
|
||||
ids=["enable_graph", "disable_graph"])
|
||||
def test_llama4(model_name, backend, tp_size, use_cuda_graph):
|
||||
prompts = [
|
||||
"The president of the United States is",
|
||||
]
|
||||
prompts = [{
|
||||
"prompt": "The president of the United States is"
|
||||
}, {
|
||||
"prompt": "<|image|>This image is of color",
|
||||
"multi_modal_data": {
|
||||
"image": [torch.ones(3, 1024, 1024)]
|
||||
}
|
||||
}]
|
||||
|
||||
expected_outputs = [
|
||||
" the head of state and head of government of the",
|
||||
" the head of state and head of government of the", " solid white"
|
||||
]
|
||||
|
||||
pytorch_config = PyTorchConfig(attn_backend=backend,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user