mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Refactoring input prep to allow out-of-tree models (#6497)
Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
parent
bd9a6dd9ab
commit
2923eb88a1
52
examples/llm-api/out_of_tree_example/readme.md
Normal file
52
examples/llm-api/out_of_tree_example/readme.md
Normal file
@ -0,0 +1,52 @@
|
||||
# Out-of-tree Model Development
|
||||
The file `modeling_opt.py` shows an example of how a custom model can be defined using TRT-LLM APIs without modifying the source code of TRT-LLM.
|
||||
|
||||
The file `main.py` shows how to run inference for such custom models using the LLM API.
|
||||
|
||||
|
||||
## Out-of-tree Multimodal Models
|
||||
|
||||
For multimodal models, TRT-LLM provides `quickstart_multimodal.py` to quickly run a multimodal model that is defined within TRT-LLM. `trtllm-bench` can be used for benchmarking such models.
|
||||
However, the following sections describe how to use those tools for out-of-tree models.
|
||||
|
||||
### Pre-requisite
|
||||
To use an out-of-tree model with the quickstart example and trtllm-bench, you need to prepare the model definition files similar to a python module.
|
||||
Consider the following file structure as an example:
|
||||
```
|
||||
modeling_custom_phi
|
||||
|-- __init__.py
|
||||
|-- configuration.py
|
||||
|-- modeling_custom_phi.py
|
||||
|-- encoder
|
||||
|-- __init__.py
|
||||
|-- configuration.py
|
||||
|-- modeling_encoder.py
|
||||
````
|
||||
The files `__init__.py` should be populated with the right imports for the custom model. For example, the `modeling_custom_phi/__init__.py` can contain something like:
|
||||
```
|
||||
from .modeling_custom_phi import MyVLMForConditionalGeneration
|
||||
from . import encoder
|
||||
```
|
||||
|
||||
### Quickstart Example
|
||||
|
||||
Once the model definition files are prepared as a python module (as described above), you can use the `--custom_module_dirs` flag in `quickstart_multimodal.py` to load your model and run inference.
|
||||
|
||||
```
|
||||
python3 quickstart_multimodal.py --model_dir ./model_ckpt --modality image --max_tokens 10 --prompt "Describe the image." --media ./demo_lower.png --image_format pil --custom_module_dirs ../modeling_custom_phi
|
||||
```
|
||||
|
||||
### Benchmarking
|
||||
|
||||
Similar to the quickstart example, you can use the same CLI argument with `trtllm-bench` to benchmark a custom model.
|
||||
|
||||
Prepare the dataset:
|
||||
```
|
||||
python ./benchmarks/cpp/prepare_dataset.py --tokenizer ./model_ckpt --stdout dataset --dataset-name lmms-lab/MMMU --dataset-split test --dataset-image-key image --dataset-prompt-key "question" --num-requests 100 --output-len-dist 128,5 > mm_data.jsonl
|
||||
```
|
||||
|
||||
|
||||
Run the benchmark:
|
||||
```
|
||||
trtllm-bench --model ./model_ckpt --model_path ./model_ckpt throughput --dataset mm_data.jsonl --backend pytorch --num_requests 100 --max_batch_size 4 --modality image --streaming --custom_module_dirs ../modeling_custom_phi
|
||||
```
|
||||
@ -4,8 +4,9 @@ import os
|
||||
|
||||
from quickstart_advanced import add_llm_args, setup_llm
|
||||
|
||||
from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS,
|
||||
default_multimodal_input_loader)
|
||||
from tensorrt_llm.inputs import default_multimodal_input_loader
|
||||
from tensorrt_llm.inputs.registry import MULTIMODAL_PLACEHOLDER_REGISTRY
|
||||
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir
|
||||
|
||||
example_medias_and_prompts = {
|
||||
"image": {
|
||||
@ -79,10 +80,11 @@ example_medias_and_prompts = {
|
||||
|
||||
|
||||
def add_multimodal_args(parser):
|
||||
parser.add_argument("--model_type",
|
||||
type=str,
|
||||
choices=ALL_SUPPORTED_MULTIMODAL_MODELS,
|
||||
help="Model type.")
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
choices=MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(),
|
||||
help="Model type as specified in the HuggingFace model config.")
|
||||
parser.add_argument("--modality",
|
||||
type=str,
|
||||
choices=[
|
||||
@ -90,7 +92,7 @@ def add_multimodal_args(parser):
|
||||
"multiple_image", "mixture_text_image"
|
||||
],
|
||||
default="image",
|
||||
help="Media type.")
|
||||
help="Media type being used for inference.")
|
||||
parser.add_argument("--media",
|
||||
type=str,
|
||||
nargs="+",
|
||||
@ -108,6 +110,18 @@ def add_multimodal_args(parser):
|
||||
type=str,
|
||||
default="cpu",
|
||||
help="The device to have the input on.")
|
||||
parser.add_argument(
|
||||
"--custom_module_dirs",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help=
|
||||
("Paths to an out-of-tree model directory which should be imported."
|
||||
" This is useful to load a custom model. The directory should have a structure like:"
|
||||
" <model_name>"
|
||||
" ├── __init__.py"
|
||||
" ├── <model_name>.py"
|
||||
" └── <sub_dirs>"))
|
||||
return parser
|
||||
|
||||
|
||||
@ -140,6 +154,15 @@ def parse_arguments():
|
||||
|
||||
def main():
|
||||
args = parse_arguments()
|
||||
if args.custom_module_dirs is not None:
|
||||
for custom_module_dir in args.custom_module_dirs:
|
||||
try:
|
||||
import_custom_module_from_dir(custom_module_dir)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Failed to import custom module from {custom_module_dir}: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
lora_config = None
|
||||
if args.load_lora:
|
||||
@ -159,8 +182,11 @@ def main():
|
||||
model_type = args.model_type
|
||||
else:
|
||||
model_type = json.load(
|
||||
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
|
||||
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"
|
||||
open(os.path.join(str(llm._hf_model_dir),
|
||||
'config.json')))['model_type']
|
||||
assert model_type in MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(), \
|
||||
f"Unsupported model_type: {model_type} found!\n" \
|
||||
f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}"
|
||||
|
||||
# set prompts and media to example prompts and images if they are not provided
|
||||
if args.prompt is None:
|
||||
@ -168,7 +194,7 @@ def main():
|
||||
if args.media is None:
|
||||
args.media = example_medias_and_prompts[args.modality]["media"]
|
||||
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
|
||||
model_dir=llm._hf_model_dir,
|
||||
model_dir=str(llm._hf_model_dir),
|
||||
model_type=model_type,
|
||||
modality=args.modality,
|
||||
prompts=args.prompt,
|
||||
|
||||
@ -10,7 +10,9 @@ from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
|
||||
BaseWeightMapper
|
||||
|
||||
from ..._utils import nvtx_range
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...logger import logger
|
||||
from ...sampling_params import SamplingParams
|
||||
@ -137,7 +139,13 @@ class Gemma3MultiModalProjector(torch.nn.Module):
|
||||
|
||||
|
||||
@register_auto_model("Gemma3ForConditionalGeneration")
|
||||
@register_input_processor(Gemma3InputProcessor, model_type="gemma3")
|
||||
@register_input_processor(
|
||||
Gemma3InputProcessor,
|
||||
model_type="gemma3",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={"image": "<start_of_image>"},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
))
|
||||
class Gemma3VLM(PreTrainedModel):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[Gemma3Config]):
|
||||
|
||||
@ -15,7 +15,9 @@ from transformers.models.auto import CONFIG_MAPPING
|
||||
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...logger import logger
|
||||
from ...sampling_params import SamplingParams
|
||||
@ -961,7 +963,23 @@ class HCXVisionModel:
|
||||
|
||||
|
||||
@register_auto_model("HCXVisionForCausalLM")
|
||||
@register_input_processor(HCXVisionInputProcessor, model_type="hyperclovax_vlm")
|
||||
@register_input_processor(
|
||||
HCXVisionInputProcessor,
|
||||
model_type="hyperclovax_vlm",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
"image":
|
||||
('<im_end>\n<|im_start|>user (mime) \n'
|
||||
'{"type": "image/jpeg", "filename": ""}<|im_end|>\n'
|
||||
'<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n'
|
||||
'<|im_start|>image/aux\n'
|
||||
'다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 '
|
||||
'keyword와 bbox 위치입니다.bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 '
|
||||
'형태입니다. 참고하여 답변하세요. '
|
||||
'{"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}')
|
||||
},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.AFTER_TEXT,
|
||||
))
|
||||
class HCXVisionForCausalLM(PreTrainedModel):
|
||||
|
||||
def __init__(self, model_config: ModelConfig):
|
||||
|
||||
@ -20,7 +20,9 @@ from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_manager import HfLoraLoader
|
||||
from tensorrt_llm.models.convert_utils import split_matrix_tp
|
||||
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...sampling_params import SamplingParams
|
||||
from ..attention_backend import AttentionMetadata
|
||||
@ -1173,7 +1175,13 @@ class Llama4InputProcessor(InputProcessor):
|
||||
|
||||
|
||||
@register_auto_model("Llama4ForConditionalGeneration")
|
||||
@register_input_processor(Llama4InputProcessor, model_type="llama4")
|
||||
@register_input_processor(
|
||||
Llama4InputProcessor,
|
||||
model_type="llama4",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={"image": "<|image|>"},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
))
|
||||
class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
|
||||
Llama4Config]):
|
||||
|
||||
|
||||
@ -14,7 +14,9 @@ from transformers.models.llava_next.modeling_llava_next import (
|
||||
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...llmapi.utils import download_hf_model
|
||||
from ...logger import logger
|
||||
@ -263,7 +265,13 @@ class LlavaNextVisionModel(nn.Module):
|
||||
|
||||
|
||||
@register_auto_model("LlavaNextForConditionalGeneration")
|
||||
@register_input_processor(LlavaNextInputProcessor, model_type="llava_next")
|
||||
@register_input_processor(
|
||||
LlavaNextInputProcessor,
|
||||
model_type="llava_next",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={"image": "<image>"},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
))
|
||||
class LlavaNextModel(PreTrainedModel):
|
||||
config_class = LlavaNextConfig
|
||||
|
||||
|
||||
@ -29,7 +29,9 @@ from tensorrt_llm._torch.modules.rms_norm import RMSNorm
|
||||
from tensorrt_llm._torch.speculative import SpecMetadata
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
TextPrompt, register_input_processor)
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
from tensorrt_llm.llmapi import SamplingParams
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
@ -269,8 +271,20 @@ class Mistral3InputProcessor(InputProcessor):
|
||||
|
||||
|
||||
@register_auto_model("Mistral3ForConditionalGeneration")
|
||||
# The below informs the registry which input registry to create for this in `tensorrt_llm/llmapi/llm.py`.
|
||||
@register_input_processor(Mistral3InputProcessor, model_type="mistral3")
|
||||
@register_input_processor(
|
||||
Mistral3InputProcessor,
|
||||
model_type="mistral3",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
"image": "[IMG]",
|
||||
},
|
||||
# NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
|
||||
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
|
||||
# src/mistral_common/tokens/tokenizers/base.py#L326
|
||||
# However, accuracy tests show that the model generates higher quality output when the image
|
||||
# precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
))
|
||||
class Mistral3VLM(PreTrainedModel):
|
||||
"""Mistral3VLM implementation for TRTLLM.
|
||||
|
||||
|
||||
@ -19,7 +19,9 @@ from PIL import Image
|
||||
from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
|
||||
from ...executor.request import LoRARequest
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...logger import logger
|
||||
from ...lora_helper import LoraConfig
|
||||
@ -461,7 +463,17 @@ class Phi4MMInputProcessor(InputProcessor):
|
||||
|
||||
|
||||
@register_auto_model("Phi4MMForCausalLM")
|
||||
@register_input_processor(Phi4MMInputProcessor, model_type="phi4mm")
|
||||
@register_input_processor(
|
||||
Phi4MMInputProcessor,
|
||||
model_type="phi4mm",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
"image": "<|image_{0}|>",
|
||||
"audio": "<|audio_{0}|>",
|
||||
},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
placeholders_separator="",
|
||||
))
|
||||
class Phi4MMForCausalLM(transformers.PreTrainedModel):
|
||||
|
||||
_supports_flash_attn_2 = True
|
||||
|
||||
@ -12,7 +12,9 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams
|
||||
|
||||
from ..._utils import nvtx_range_debug
|
||||
from ...functional import RopeEmbeddingUtils, RotaryScalingType
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...logger import logger
|
||||
from ...sampling_params import SamplingParams
|
||||
@ -645,7 +647,16 @@ class Qwen2VLModelBase(PreTrainedModel):
|
||||
|
||||
|
||||
@register_auto_model("Qwen2VLForConditionalGeneration")
|
||||
@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_vl")
|
||||
@register_input_processor(
|
||||
Qwen2VLInputProcessorBase,
|
||||
model_type="qwen2_vl",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
"image": "<|vision_start|><|image_pad|><|vision_end|>",
|
||||
"video": "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
))
|
||||
class Qwen2VLModel(Qwen2VLModelBase):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
|
||||
@ -657,7 +668,14 @@ class Qwen2VLModel(Qwen2VLModelBase):
|
||||
|
||||
|
||||
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
|
||||
@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_5_vl")
|
||||
@register_input_processor(
|
||||
Qwen2VLInputProcessorBase,
|
||||
model_type="qwen2_5_vl",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
"image": "<|vision_start|><|image_pad|><|vision_end|>",
|
||||
"video": "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
}))
|
||||
class Qwen2_5_VLModel(Qwen2VLModelBase):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
|
||||
|
||||
@ -35,7 +35,9 @@ from transformers import (AutoConfig, AutoImageProcessor, AutoModel,
|
||||
PreTrainedModel)
|
||||
|
||||
from ..._utils import nvtx_range
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, TextPrompt,
|
||||
register_input_processor)
|
||||
from ...logger import logger
|
||||
from ...sampling_params import SamplingParams
|
||||
@ -1118,7 +1120,16 @@ class VilaInputProcessor(InputProcessor):
|
||||
|
||||
|
||||
@register_auto_model(VilaConfig.model_architecture)
|
||||
@register_input_processor(VilaInputProcessor, model_type="llava_llama")
|
||||
@register_input_processor(
|
||||
VilaInputProcessor,
|
||||
model_type="llava_llama",
|
||||
placeholder_metadata=MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
"image": "<image>",
|
||||
"video": "<vila/video>"
|
||||
},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
))
|
||||
class VilaModel(PreTrainedModel):
|
||||
config_class = VilaConfig
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download
|
||||
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
|
||||
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
|
||||
from tensorrt_llm.bench.build.build import get_model_config
|
||||
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir
|
||||
|
||||
# isort: off
|
||||
from tensorrt_llm.bench.benchmark.utils.general import (
|
||||
@ -49,6 +50,16 @@ from tensorrt_llm.sampling_params import SamplingParams
|
||||
type=click.Choice(ALL_SUPPORTED_BACKENDS),
|
||||
default="pytorch",
|
||||
help="The backend to use when running benchmarking.")
|
||||
@optgroup.option(
|
||||
"--custom_module_dirs",
|
||||
type=click.Path(exists=True,
|
||||
readable=True,
|
||||
path_type=Path,
|
||||
resolve_path=True),
|
||||
default=None,
|
||||
multiple=True,
|
||||
help="Paths to custom module directories to import.",
|
||||
)
|
||||
@optgroup.option(
|
||||
"--extra_llm_api_options",
|
||||
type=str,
|
||||
@ -104,6 +115,16 @@ from tensorrt_llm.sampling_params import SamplingParams
|
||||
required=False,
|
||||
help="Pass in a dataset file for parsing instead of stdin.",
|
||||
)
|
||||
# For text models, tokenizer initialization is not needed when loading the model since the dataset is already tokenized.
|
||||
# For this reason, we skip tokenizer initialization by default.
|
||||
# However, for VLM models, tokenizer initialization is needed inside the model since the dataset contains texts and
|
||||
# raw media data. We cannot skip tokenizer initialization in this case.
|
||||
@optgroup.option(
|
||||
"--no_skip_tokenizer_init",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Do not skip tokenizer initialization when loading the model.",
|
||||
)
|
||||
@optgroup.option(
|
||||
"--eos_id",
|
||||
type=int,
|
||||
@ -118,6 +139,18 @@ from tensorrt_llm.sampling_params import SamplingParams
|
||||
default=None,
|
||||
help="Modality of the multimodal requests.",
|
||||
)
|
||||
@optgroup.option(
|
||||
"--image_data_format",
|
||||
type=click.Choice(["pt", "pil"]),
|
||||
default="pt",
|
||||
help="Format of the image data for multimodal models.",
|
||||
)
|
||||
@optgroup.option(
|
||||
"--data_device",
|
||||
type=click.Choice(["cuda", "cpu"]),
|
||||
default="cuda",
|
||||
help="Device to load the multimodal data on.",
|
||||
)
|
||||
@optgroup.option(
|
||||
"--max_input_len",
|
||||
type=int,
|
||||
@ -262,7 +295,17 @@ def throughput_command(
|
||||
logger.info("Preparing to run throughput benchmark...")
|
||||
# Parameters from CLI
|
||||
# Model, experiment, and engine params
|
||||
custom_module_dirs: list[Path] = params.pop("custom_module_dirs", [])
|
||||
for custom_module_dir in custom_module_dirs:
|
||||
try:
|
||||
import_custom_module_from_dir(custom_module_dir)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to import custom module from {custom_module_dir}: {e}")
|
||||
raise e
|
||||
|
||||
dataset_path: Path = params.get("dataset")
|
||||
no_skip_tokenizer_init: bool = params.get("no_skip_tokenizer_init", False)
|
||||
eos_id: int = params.get("eos_id")
|
||||
warmup: int = params.get("warmup")
|
||||
num_requests: int = params.get("num_requests")
|
||||
@ -274,6 +317,8 @@ def throughput_command(
|
||||
backend: str = params.get("backend")
|
||||
modality: str = params.get("modality")
|
||||
max_input_len: int = params.get("max_input_len")
|
||||
image_data_format: str = params.get("image_data_format", "pt")
|
||||
data_device: str = params.get("data_device", "cpu")
|
||||
model_type = get_model_config(model, checkpoint_path).model_type
|
||||
|
||||
# Reporting options
|
||||
@ -286,7 +331,7 @@ def throughput_command(
|
||||
# Runtime kwargs and option tracking.
|
||||
kwargs = {}
|
||||
|
||||
# Initialize the HF tokenizer for the specified model.
|
||||
# Initialize the HF tokenizer for the specified model. This is only used for data preparation.
|
||||
tokenizer = initialize_tokenizer(checkpoint_path)
|
||||
|
||||
# Dataset Loading and Preparation
|
||||
@ -298,6 +343,8 @@ def throughput_command(
|
||||
model_dir=checkpoint_path,
|
||||
model_type=model_type,
|
||||
modality=modality,
|
||||
image_data_format=image_data_format,
|
||||
data_device=data_device,
|
||||
max_input_seq_len_for_multimodal=max_input_len)
|
||||
metadata.dataset_path = dataset_path
|
||||
params["target_input_len"] = params.get(
|
||||
@ -392,6 +439,7 @@ def throughput_command(
|
||||
logger.info("Setting up throughput benchmark.")
|
||||
kwargs = kwargs | runtime_config.get_llm_args()
|
||||
kwargs['backend'] = backend
|
||||
kwargs['skip_tokenizer_init'] = not no_skip_tokenizer_init
|
||||
|
||||
if backend == "pytorch" and iteration_log is not None:
|
||||
kwargs["enable_iter_perf_stats"] = True
|
||||
|
||||
@ -41,6 +41,8 @@ def create_dataset_from_stream(
|
||||
model_dir: str = None,
|
||||
model_type: str = None,
|
||||
modality: str = None,
|
||||
image_data_format: str = "pt",
|
||||
data_device: str = "cpu",
|
||||
max_input_seq_len_for_multimodal: int = 4096,
|
||||
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
|
||||
"""Generate metadata and a list of requests to drive benchmarking.
|
||||
@ -130,7 +132,9 @@ def create_dataset_from_stream(
|
||||
model_type=model_type,
|
||||
modality=modality,
|
||||
prompts=prompts,
|
||||
media=media_paths) # list of dicts
|
||||
media=media_paths, # list of dicts
|
||||
image_data_format=image_data_format,
|
||||
device=data_device)
|
||||
|
||||
all_isl = []
|
||||
all_seq_len = []
|
||||
|
||||
@ -1,16 +1,23 @@
|
||||
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
|
||||
from .multimodal import MultimodalInput
|
||||
from .registry import (ExtraProcessedInputs, InputProcessor,
|
||||
create_input_processor, create_input_processor_with_hash,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement, create_input_processor,
|
||||
create_input_processor_with_hash,
|
||||
register_input_processor)
|
||||
from .utils import (ALL_SUPPORTED_MULTIMODAL_MODELS, ConversationMessage,
|
||||
MultimodalData, MultimodalDataTracker,
|
||||
from .utils import (ALL_SUPPORTED_AUDIO_MODELS, ALL_SUPPORTED_IMAGE_MODELS,
|
||||
ALL_SUPPORTED_MULTIMODAL_MODELS, ALL_SUPPORTED_VIDEO_MODELS,
|
||||
ConversationMessage, MultimodalData, MultimodalDataTracker,
|
||||
add_multimodal_placeholders, async_load_audio,
|
||||
async_load_image, async_load_video,
|
||||
default_multimodal_input_loader,
|
||||
encode_base64_content_from_url, load_image, load_video)
|
||||
|
||||
__all__ = [
|
||||
"ALL_SUPPORTED_MULTIMODAL_MODELS",
|
||||
"ALL_SUPPORTED_IMAGE_MODELS",
|
||||
"ALL_SUPPORTED_VIDEO_MODELS",
|
||||
"ALL_SUPPORTED_AUDIO_MODELS",
|
||||
"PromptInputs",
|
||||
"prompt_inputs",
|
||||
"TextPrompt",
|
||||
@ -20,7 +27,8 @@ __all__ = [
|
||||
"create_input_processor_with_hash",
|
||||
"register_input_processor",
|
||||
"ExtraProcessedInputs",
|
||||
"ALL_SUPPORTED_MULTIMODAL_MODELS",
|
||||
"MultimodalPlaceholderMetadata",
|
||||
"MultimodalPlaceholderPlacement",
|
||||
"ConversationMessage",
|
||||
"MultimodalDataTracker",
|
||||
"MultimodalData",
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Any, Callable, Dict, List, Optional, Protocol, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
@ -10,7 +12,6 @@ from .data import TextPrompt
|
||||
from .multimodal import (MultimodalInput, apply_mm_hashes, default_hasher,
|
||||
find_mm_token_lengths, find_mm_token_positions,
|
||||
hexdigest_to_int32, validate_mm_inputs)
|
||||
from .utils import ALL_SUPPORTED_MULTIMODAL_MODELS
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
@ -32,6 +33,7 @@ class InputProcessor(Protocol):
|
||||
model_path: any
|
||||
model_config: any
|
||||
tokenizer: any
|
||||
multimodal_hashing_supported: Optional[bool] = None
|
||||
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
@ -50,6 +52,7 @@ class DefaultInputProcessor(InputProcessor):
|
||||
self.tokenizer = tokenizer
|
||||
self.model_config = model_config
|
||||
self.model_path = model_path
|
||||
self.multimodal_hashing_supported = None
|
||||
|
||||
def __call__(
|
||||
self, inputs: TextPrompt, sampling_params: SamplingParams
|
||||
@ -108,6 +111,127 @@ class DefaultInputProcessor(InputProcessor):
|
||||
return token_ids, None
|
||||
|
||||
|
||||
class MultimodalPlaceholderPlacement(enum.Enum):
|
||||
"""
|
||||
The placement of the multimodal placeholder in the prompt. Valid values are:
|
||||
- BEFORE_TEXT: the placeholders are placed before the text prompt.
|
||||
- AFTER_TEXT: the placeholders are placed after the text prompt.
|
||||
"""
|
||||
INVALID = -1
|
||||
BEFORE_TEXT = 0
|
||||
AFTER_TEXT = 1
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MultimodalPlaceholderMetadata:
|
||||
"""
|
||||
Metadata for the multimodal placeholder. It has 3 components:
|
||||
- placeholder_map:
|
||||
A mapping from modality to placeholder string.
|
||||
Modality can be "image", "video", "audio", etc.
|
||||
- placeholder_placement:
|
||||
The placement of the placeholders, e.g. before or after the text prompt.
|
||||
- placeholders_separator:
|
||||
The separator between the placeholders, e.g. some models use "\n" to separate the placeholders.
|
||||
"""
|
||||
placeholder_map: Dict[str, str] = field(default_factory=dict)
|
||||
placeholder_placement: MultimodalPlaceholderPlacement = MultimodalPlaceholderPlacement.AFTER_TEXT
|
||||
placeholders_separator: str = "\n"
|
||||
|
||||
|
||||
class MultimodalPlaceholderRegistry:
|
||||
"""
|
||||
Registry for the multimodal models to keep track of the placeholder information.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._multimodal_placeholder_by_model_type: Dict[
|
||||
str, MultimodalPlaceholderMetadata] = {}
|
||||
|
||||
def __str__(self) -> str:
|
||||
s = ""
|
||||
for model_type, placeholder_metadata in self._multimodal_placeholder_by_model_type.items(
|
||||
):
|
||||
s += "-" * 100 + "\n"
|
||||
s += f"Model type: {model_type}\n"
|
||||
s += f"Placeholder map: {placeholder_metadata.placeholder_map}\n"
|
||||
s += f"Placeholder placement: {placeholder_metadata.placeholder_placement}\n"
|
||||
s += f"Placeholders separator: \"{placeholder_metadata.placeholders_separator}\"\n"
|
||||
s += "-" * 80 + "\n"
|
||||
return s
|
||||
|
||||
def set_placeholder_metadata(
|
||||
self, model_type: str,
|
||||
placeholder_metadata: MultimodalPlaceholderMetadata):
|
||||
self._multimodal_placeholder_by_model_type[
|
||||
model_type] = placeholder_metadata
|
||||
|
||||
def remove_placeholder_metadata(self, model_type: str):
|
||||
if model_type not in self._multimodal_placeholder_by_model_type:
|
||||
raise ValueError(f"Model type '{model_type}' is not registered")
|
||||
del self._multimodal_placeholder_by_model_type[model_type]
|
||||
|
||||
def is_valid(self, model_type: str, modality: str) -> bool:
|
||||
return model_type in self._multimodal_placeholder_by_model_type and \
|
||||
modality in self._multimodal_placeholder_by_model_type[model_type].placeholder_map
|
||||
|
||||
def get_placeholder_metadata(
|
||||
self, model_type: str) -> MultimodalPlaceholderMetadata:
|
||||
if model_type not in self._multimodal_placeholder_by_model_type:
|
||||
raise ValueError(
|
||||
f"Model type {model_type} is not registered in MultimodalPlaceholderRegistry"
|
||||
)
|
||||
return self._multimodal_placeholder_by_model_type[model_type]
|
||||
|
||||
def get_placeholder(self, model_type: str, modality: str) -> str:
|
||||
if not self.is_valid(model_type, modality):
|
||||
raise ValueError(
|
||||
f"Model type '{model_type}' with modality '{modality}' is not registered."
|
||||
)
|
||||
return self._multimodal_placeholder_by_model_type[
|
||||
model_type].placeholder_map[modality]
|
||||
|
||||
def get_placeholder_placement(
|
||||
self, model_type: str) -> MultimodalPlaceholderPlacement:
|
||||
if model_type not in self._multimodal_placeholder_by_model_type:
|
||||
raise ValueError(f"Model type '{model_type}' is not registered")
|
||||
return self._multimodal_placeholder_by_model_type[
|
||||
model_type].placeholder_placement
|
||||
|
||||
def get_placeholders_separator(self, model_type: str) -> str:
|
||||
if model_type not in self._multimodal_placeholder_by_model_type:
|
||||
raise ValueError(f"Model type '{model_type}' is not registered")
|
||||
return self._multimodal_placeholder_by_model_type[
|
||||
model_type].placeholders_separator
|
||||
|
||||
def get_registered_image_model_types(self) -> Tuple[str, ...]:
|
||||
return (
|
||||
model_type
|
||||
for model_type in self._multimodal_placeholder_by_model_type
|
||||
if "image" in self.
|
||||
_multimodal_placeholder_by_model_type[model_type].placeholder_map)
|
||||
|
||||
def get_registered_video_model_types(self) -> Tuple[str, ...]:
|
||||
return (
|
||||
model_type
|
||||
for model_type in self._multimodal_placeholder_by_model_type
|
||||
if "video" in self.
|
||||
_multimodal_placeholder_by_model_type[model_type].placeholder_map)
|
||||
|
||||
def get_registered_audio_model_types(self) -> Tuple[str, ...]:
|
||||
return (
|
||||
model_type
|
||||
for model_type in self._multimodal_placeholder_by_model_type
|
||||
if "audio" in self.
|
||||
_multimodal_placeholder_by_model_type[model_type].placeholder_map)
|
||||
|
||||
def get_registered_model_types(self) -> Tuple[str, ...]:
|
||||
return tuple(self._multimodal_placeholder_by_model_type.keys())
|
||||
|
||||
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY = MultimodalPlaceholderRegistry()
|
||||
|
||||
|
||||
class InputProcessorRegistry:
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -118,9 +242,10 @@ class InputProcessorRegistry:
|
||||
INPUT_PROCESSOR_REGISTRY = InputProcessorRegistry()
|
||||
|
||||
|
||||
def register_input_processor(processor_cls: Type[InputProcessor],
|
||||
model_type: str,
|
||||
out_of_tree: bool = False):
|
||||
def register_input_processor(
|
||||
processor_cls: Type[InputProcessor],
|
||||
model_type: str,
|
||||
placeholder_metadata: MultimodalPlaceholderMetadata = None):
|
||||
"""
|
||||
Register an input processor to a model class.
|
||||
NOTE:
|
||||
@ -128,17 +253,18 @@ def register_input_processor(processor_cls: Type[InputProcessor],
|
||||
the model type only for that.
|
||||
2. If this is used for other models in the future, this logic needs to be
|
||||
updated e.g. adding another version of this API without the model_type.
|
||||
3. If the model is not in the tree, user needs to set out_of_tree to True
|
||||
to bypass the model type check and provide their own input preparation.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
INPUT_PROCESSOR_REGISTRY._input_processors_cls_by_model_type[
|
||||
model_cls] = processor_cls
|
||||
if not out_of_tree:
|
||||
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, \
|
||||
f"Model type {model_type} not in {ALL_SUPPORTED_MULTIMODAL_MODELS}.\n" \
|
||||
"Please see the tensorrt_llm/inputs/utils.py file for more information."
|
||||
if placeholder_metadata is None:
|
||||
raise ValueError(
|
||||
f"A valid placeholder_metadata must be provided but got {placeholder_metadata}"
|
||||
)
|
||||
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
|
||||
model_type, placeholder_metadata)
|
||||
|
||||
return model_cls
|
||||
|
||||
@ -192,41 +318,88 @@ def create_input_processor_with_hash(
|
||||
A wrapped processor that modifies prompts before processing.
|
||||
"""
|
||||
|
||||
def multimodal_hashing_process(
|
||||
inputs: TextPrompt, sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
"""
|
||||
Process the multinmodal hashing for media tokens if possible.
|
||||
"""
|
||||
assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support."
|
||||
mm_data = inputs['multi_modal_data']
|
||||
num_mm_tokens = find_mm_token_lengths(mm_data, input_processor)
|
||||
if len(num_mm_tokens) > 0:
|
||||
mm_hashes = apply_mm_hashes(mm_data, hash_lib)
|
||||
prompt_token_ids, extra_processed_inputs = input_processor(
|
||||
inputs, sampling_params)
|
||||
start_positions = find_mm_token_positions(
|
||||
input_ids=prompt_token_ids, # token sequence
|
||||
num_mm_tokens=
|
||||
num_mm_tokens, # list of lengths of each chunk of visual tokens
|
||||
vocab_size=input_processor.model_config.vocab_size,
|
||||
)
|
||||
# flatten the hashes from dict to a single list
|
||||
mm_hashes = [h for hashes in mm_hashes.values() for h in hashes]
|
||||
validate_mm_inputs(prompt_token_ids, mm_hashes, start_positions,
|
||||
num_mm_tokens)
|
||||
mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes
|
||||
] # nested list w/ multiple int32 per hash
|
||||
|
||||
extra_processed_inputs[
|
||||
"multimodal_input"] = MultimodalInput.from_components(
|
||||
mm_hashes_int32, start_positions, num_mm_tokens)
|
||||
return prompt_token_ids, extra_processed_inputs
|
||||
return [], None
|
||||
|
||||
def input_processor_wrapper(
|
||||
inputs: TextPrompt, sampling_params: SamplingParams
|
||||
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
||||
try:
|
||||
assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support."
|
||||
mm_data = inputs['multi_modal_data']
|
||||
num_mm_tokens = find_mm_token_lengths(mm_data, input_processor)
|
||||
if len(num_mm_tokens) > 0:
|
||||
mm_hashes = apply_mm_hashes(mm_data, hash_lib)
|
||||
prompt_token_ids, extra_processed_inputs = input_processor(
|
||||
inputs, sampling_params)
|
||||
start_positions = find_mm_token_positions(
|
||||
input_ids=prompt_token_ids, # token sequence
|
||||
num_mm_tokens=
|
||||
num_mm_tokens, # list of lengths of each chunk of visual tokens
|
||||
vocab_size=input_processor.model_config.vocab_size,
|
||||
)
|
||||
# flatten the hashes from dict to a single list
|
||||
mm_hashes = [h for hashes in mm_hashes.values() for h in hashes]
|
||||
validate_mm_inputs(prompt_token_ids, mm_hashes, start_positions,
|
||||
num_mm_tokens)
|
||||
mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes
|
||||
] # nested list w/ multiple int32 per hash
|
||||
try_multimodal_hashing = False # only used for first time
|
||||
use_multimodal_hashing = False # used for subsequent calls
|
||||
modalities = list(set(inputs['multi_modal_data'].keys())
|
||||
) if 'multi_modal_data' in inputs else []
|
||||
if len(modalities) > 0:
|
||||
# NOTE: tensorrt_llm/inputs/multimodal.py:find_mm_token_lengths only supports image data for now
|
||||
if len(modalities) == 1 and modalities[0] == "image":
|
||||
# only try multimodal hashing if the inputs only contain image data
|
||||
if input_processor.multimodal_hashing_supported is not None:
|
||||
use_multimodal_hashing = input_processor.multimodal_hashing_supported
|
||||
else:
|
||||
# we need to try the multimodal hashing for the first time to determine if it is supported
|
||||
try_multimodal_hashing = True
|
||||
|
||||
extra_processed_inputs[
|
||||
"multimodal_input"] = MultimodalInput.from_components(
|
||||
mm_hashes_int32, start_positions, num_mm_tokens)
|
||||
if try_multimodal_hashing or use_multimodal_hashing:
|
||||
try:
|
||||
prompt_token_ids, extra_processed_inputs = multimodal_hashing_process(
|
||||
inputs, sampling_params)
|
||||
if try_multimodal_hashing:
|
||||
# if trying for first time, set the flag to True
|
||||
input_processor.multimodal_hashing_supported = True
|
||||
return prompt_token_ids, extra_processed_inputs
|
||||
else:
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.warning(f"Multimodal hashing failed: {e}.")
|
||||
if try_multimodal_hashing:
|
||||
# if trying for first time, fall back to basic input processor
|
||||
# and set the flag to False so that we don't try again
|
||||
input_processor.multimodal_hashing_supported = False
|
||||
logger.warning("Falling back to basic input processor.")
|
||||
try:
|
||||
return input_processor(inputs, sampling_params)
|
||||
except Exception as e2:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.warning(f"Basic input processor failed: {e}.")
|
||||
raise e2
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
try:
|
||||
return input_processor(inputs, sampling_params)
|
||||
except Exception as e:
|
||||
# Fall back to basic input processor if multimodal processing fails
|
||||
logger.warning(
|
||||
f"Multimodal hashing failed: {e}. Falling back to basic input processor."
|
||||
)
|
||||
return input_processor(inputs, sampling_params)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.warning(f"Basic input processor failed: {e}.")
|
||||
raise e
|
||||
|
||||
return input_processor_wrapper
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import enum
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from io import BytesIO
|
||||
@ -18,6 +17,8 @@ from torchvision.transforms import ToTensor
|
||||
from transformers import AutoProcessor, ProcessorMixin
|
||||
from transformers.utils import logging
|
||||
|
||||
from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY,
|
||||
MultimodalPlaceholderPlacement)
|
||||
from tensorrt_llm.llmapi.llm_utils import ModelLoader
|
||||
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, TransformersTokenizer
|
||||
|
||||
@ -209,60 +210,25 @@ NOTE:
|
||||
placeholder for the model needs to be added in retrieve_multimodal_placeholder().
|
||||
"""
|
||||
|
||||
SUPPORTED_QWEN_MODEL_GROUP = ["qwen2_vl", "qwen2_5_vl"]
|
||||
SUPPORTED_GEMMA_MODEL_GROUP = ["gemma3"]
|
||||
SUPPORTED_LLAMA_MODEL_GROUP = ["mllama", "llama4"]
|
||||
SUPPORTED_LLAVA_IMAGE_MODEL_GROUP = ["llava_llama", "llava_next"]
|
||||
SUPPORTED_LLAVA_VIDEO_MODEL_GROUP = ["llava_llama"]
|
||||
SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP = ["mistral3"]
|
||||
SUPPORTED_HYPERCLOVAX_MODEL_GROUP = ["hyperclovax_vlm"]
|
||||
SUPPORTED_PHI_MODEL_GROUP = ["phi4mm"]
|
||||
|
||||
ALL_SUPPORTED_IMAGE_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
|
||||
+ SUPPORTED_LLAMA_MODEL_GROUP \
|
||||
+ SUPPORTED_LLAVA_IMAGE_MODEL_GROUP \
|
||||
+ SUPPORTED_HYPERCLOVAX_MODEL_GROUP \
|
||||
+ SUPPORTED_GEMMA_MODEL_GROUP \
|
||||
+ SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP \
|
||||
+ SUPPORTED_PHI_MODEL_GROUP
|
||||
|
||||
ALL_SUPPORTED_VIDEO_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
|
||||
+ SUPPORTED_LLAVA_VIDEO_MODEL_GROUP
|
||||
|
||||
ALL_SUPPORTED_AUDIO_MODELS = SUPPORTED_PHI_MODEL_GROUP
|
||||
|
||||
ALL_SUPPORTED_MULTIMODAL_MODELS = list(set(ALL_SUPPORTED_IMAGE_MODELS) \
|
||||
| set(ALL_SUPPORTED_VIDEO_MODELS) \
|
||||
| set(ALL_SUPPORTED_AUDIO_MODELS))
|
||||
|
||||
HF_CHAT_TEMPLATE_EXCEPTIONS = ["llava_llama"]
|
||||
PLACEHOLDER_EXCEPTIONS = ["llava_next"]
|
||||
|
||||
|
||||
class MultimodalPlaceholderPlacement(enum.Enum):
|
||||
INVALID = -1
|
||||
BEFORE_TEXT = 0
|
||||
AFTER_TEXT = 1
|
||||
# Helpers to always get the latest supported multimodal model types from the registry
|
||||
def ALL_SUPPORTED_MULTIMODAL_MODELS():
|
||||
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()
|
||||
|
||||
|
||||
PLACEHOLDER_PLACEMENT_MAP = {
|
||||
"qwen2_vl": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
"qwen2_5_vl": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
"llava_llama": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
"llava_next": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
"llama4": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
"mllama": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
"hyperclovax_vlm": MultimodalPlaceholderPlacement.AFTER_TEXT,
|
||||
"gemma3": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
# NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
|
||||
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
|
||||
# src/mistral_common/tokens/tokenizers/base.py#L326
|
||||
# However, accuracy tests show that the model generates higher quality output when the image
|
||||
# precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
|
||||
"mistral3": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
"phi4mm": MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
}
|
||||
assert len(PLACEHOLDER_PLACEMENT_MAP) == len(ALL_SUPPORTED_MULTIMODAL_MODELS)
|
||||
def ALL_SUPPORTED_IMAGE_MODELS():
|
||||
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_image_model_types()
|
||||
|
||||
|
||||
def ALL_SUPPORTED_VIDEO_MODELS():
|
||||
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_video_model_types()
|
||||
|
||||
|
||||
def ALL_SUPPORTED_AUDIO_MODELS():
|
||||
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_audio_model_types()
|
||||
|
||||
|
||||
def retrieve_multimodal_placeholder(model_type: str, modality: str,
|
||||
@ -276,41 +242,16 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
|
||||
current_count: The number of multimodal data already added.
|
||||
|
||||
"""
|
||||
|
||||
if modality == "image":
|
||||
if model_type in SUPPORTED_QWEN_MODEL_GROUP:
|
||||
return "<|vision_start|><|image_pad|><|vision_end|>"
|
||||
elif model_type in SUPPORTED_LLAMA_MODEL_GROUP:
|
||||
return "<|image|>"
|
||||
elif model_type in SUPPORTED_LLAVA_IMAGE_MODEL_GROUP:
|
||||
return "<image>"
|
||||
elif model_type in SUPPORTED_GEMMA_MODEL_GROUP:
|
||||
return "<start_of_image>"
|
||||
elif model_type in SUPPORTED_HYPERCLOVAX_MODEL_GROUP:
|
||||
return '<im_end>\n<|im_start|>user (mime) \n{"type": "image/jpeg", "filename": ""}<|im_end|>\n' + \
|
||||
'<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n' + \
|
||||
'<|im_start|>image/aux\n다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 keyword와 bbox 위치입니다.' + \
|
||||
'bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 형태입니다. 참고하여 답변하세요. {"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}'
|
||||
elif model_type in SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP:
|
||||
# Ref: https://github.com/mistralai/mistral-common/blob/26a6bb3a07ee0b78a3808f2797f23e1d28514b93/
|
||||
# src/mistral_common/tokens/tokenizers/base.py#L60
|
||||
return "[IMG]"
|
||||
elif model_type in SUPPORTED_PHI_MODEL_GROUP:
|
||||
return f"<|image_{current_count}|>"
|
||||
raise TypeError(
|
||||
f"For image modality, only {ALL_SUPPORTED_IMAGE_MODELS} are supported but got {model_type}"
|
||||
)
|
||||
elif modality == "video":
|
||||
if model_type in SUPPORTED_QWEN_MODEL_GROUP:
|
||||
return "<|vision_start|><|video_pad|><|vision_end|>"
|
||||
elif model_type in SUPPORTED_LLAVA_VIDEO_MODEL_GROUP:
|
||||
return "<vila/video>"
|
||||
raise TypeError(
|
||||
f"For video modality, only {ALL_SUPPORTED_VIDEO_MODELS} are supported but got {model_type}"
|
||||
)
|
||||
elif modality == "audio":
|
||||
if model_type in SUPPORTED_PHI_MODEL_GROUP:
|
||||
return f"<|audio_{current_count}|>"
|
||||
if MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(model_type, modality):
|
||||
"""
|
||||
The placeholder is a string with a single placeholder for the current count.
|
||||
- For example, if the placeholder is "<|image_{0}|>", and the current count is 1,
|
||||
the placeholder will be "<|image_1|>".
|
||||
- However, if the placeholder is "<|image|>", the current count would be ignored.
|
||||
In this case, the placeholder would be "<|image|>".
|
||||
"""
|
||||
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder(
|
||||
model_type, modality).format(current_count)
|
||||
raise TypeError(f"Unknown modality: {modality}")
|
||||
|
||||
|
||||
@ -379,17 +320,15 @@ def add_multimodal_placeholders(model_type: str, text_prompt: str,
|
||||
for placeholder in mm_placeholder_counts:
|
||||
placeholders.extend([placeholder] * mm_placeholder_counts[placeholder])
|
||||
parts = []
|
||||
match PLACEHOLDER_PLACEMENT_MAP[model_type]:
|
||||
match MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder_placement(model_type):
|
||||
case MultimodalPlaceholderPlacement.BEFORE_TEXT:
|
||||
parts.extend(placeholders)
|
||||
parts.append(text_prompt)
|
||||
case MultimodalPlaceholderPlacement.AFTER_TEXT:
|
||||
parts.append(text_prompt)
|
||||
parts.extend(placeholders)
|
||||
if model_type == "phi4mm":
|
||||
return "".join(parts)
|
||||
else:
|
||||
return "\n".join(parts)
|
||||
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholders_separator(
|
||||
model_type).join(parts)
|
||||
|
||||
|
||||
def resolve_hf_chat_template(
|
||||
|
||||
132
tensorrt_llm/tools/importlib_utils.py
Normal file
132
tensorrt_llm/tools/importlib_utils.py
Normal file
@ -0,0 +1,132 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.util
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
def import_custom_module_from_file(
|
||||
custom_module_path: Union[str, Path]) -> Optional[ModuleType]:
|
||||
"""Import a custom module from a single file.
|
||||
|
||||
Args:
|
||||
custom_module_path (Union[str, Path]): The path to the custom module file.
|
||||
|
||||
Returns:
|
||||
The imported module object.
|
||||
|
||||
Raises:
|
||||
ImportError: If the module cannot be imported.
|
||||
"""
|
||||
if isinstance(custom_module_path, str):
|
||||
custom_module_path = Path(custom_module_path)
|
||||
print(f"Importing custom module from file: {custom_module_path}")
|
||||
|
||||
# Import single Python file
|
||||
module = None
|
||||
spec = importlib.util.spec_from_file_location(custom_module_path.stem,
|
||||
str(custom_module_path))
|
||||
if spec is not None:
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
if spec.loader is not None:
|
||||
spec.loader.exec_module(module)
|
||||
print(
|
||||
f"Successfully imported custom module from file: {custom_module_path}"
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
f"Failed to import custom module from {custom_module_path}")
|
||||
else:
|
||||
raise ImportError(
|
||||
f"Failed to import custom module from {custom_module_path}")
|
||||
return module
|
||||
|
||||
|
||||
def import_custom_module_from_dir(
|
||||
custom_module_path: Union[str, Path]) -> Optional[ModuleType]:
|
||||
"""Import a custom module from a directory.
|
||||
|
||||
Args:
|
||||
custom_module_path (Union[str, Path]): The path to the custom module directory.
|
||||
|
||||
Returns:
|
||||
The imported module object.
|
||||
|
||||
Raises:
|
||||
ImportError: If the module cannot be imported.
|
||||
|
||||
Note:
|
||||
This function will add the parent directory of the custom module directory to sys.path.
|
||||
This is useful for importing modules that are not in the current working directory.
|
||||
"""
|
||||
if isinstance(custom_module_path, str):
|
||||
custom_module_path = Path(custom_module_path)
|
||||
print(f"Importing custom module from directory: {custom_module_path}")
|
||||
|
||||
# Import directory as a package
|
||||
# Add the parent directory to sys.path so we can import the package
|
||||
import sys
|
||||
parent_dir = str(custom_module_path.parent)
|
||||
if parent_dir not in sys.path:
|
||||
sys.path.insert(0, parent_dir)
|
||||
|
||||
# Import the package
|
||||
module = None
|
||||
package_name = custom_module_path.name
|
||||
try:
|
||||
module = importlib.import_module(package_name)
|
||||
print(
|
||||
f"Successfully imported custom module from directory: {custom_module_path}"
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
f"Failed to import package {package_name} from {custom_module_path}: {e}"
|
||||
)
|
||||
return module
|
||||
|
||||
|
||||
def import_custom_module(
|
||||
custom_module_path: Union[str, Path]) -> Optional[ModuleType]:
|
||||
"""Import a custom module from a file or directory.
|
||||
|
||||
Args:
|
||||
custom_module_path (Union[str, Path]): The path to the custom module file or directory.
|
||||
|
||||
Returns:
|
||||
The imported module object.
|
||||
|
||||
Raises:
|
||||
ImportError: If the module cannot be imported.
|
||||
FileNotFoundError: If the custom module path does not exist.
|
||||
"""
|
||||
if isinstance(custom_module_path, str):
|
||||
custom_module_path = Path(custom_module_path)
|
||||
print(f"Importing custom module from: {custom_module_path}")
|
||||
|
||||
if custom_module_path.exists():
|
||||
if custom_module_path.is_file():
|
||||
return import_custom_module_from_file(custom_module_path)
|
||||
elif custom_module_path.is_dir():
|
||||
return import_custom_module_from_dir(custom_module_path)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Custom module path {custom_module_path} is neither a file nor a directory."
|
||||
)
|
||||
else:
|
||||
raise FileNotFoundError(
|
||||
f"Custom module path {custom_module_path} does not exist.")
|
||||
return None
|
||||
106
tests/unittest/others/test_multimodal_registry.py
Normal file
106
tests/unittest/others/test_multimodal_registry.py
Normal file
@ -0,0 +1,106 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY,
|
||||
MultimodalPlaceholderMetadata,
|
||||
MultimodalPlaceholderPlacement)
|
||||
|
||||
|
||||
class TestMultimodalPlaceholderRegistry(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.model_type = "test_model_type"
|
||||
self.placeholder_metadata = MultimodalPlaceholderMetadata(
|
||||
placeholder_map={
|
||||
"image": "IMAGE_PLACEHOLDER",
|
||||
"video": "VIDEO_PLACEHOLDER"
|
||||
},
|
||||
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
|
||||
placeholders_separator="\n")
|
||||
|
||||
def test_new_registration(self):
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
|
||||
self.model_type, self.placeholder_metadata)
|
||||
self.assertEqual(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder_metadata(
|
||||
self.model_type), self.placeholder_metadata)
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata(
|
||||
self.model_type)
|
||||
|
||||
def test_registered_model_types(self):
|
||||
pre_reg_model_types = list(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types())
|
||||
|
||||
# register the model type
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
|
||||
self.model_type, self.placeholder_metadata)
|
||||
|
||||
post_reg_model_types = list(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types())
|
||||
self.assertEqual(
|
||||
len(pre_reg_model_types) + 1, len(post_reg_model_types))
|
||||
self.assertIn(self.model_type, post_reg_model_types)
|
||||
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata(
|
||||
self.model_type)
|
||||
|
||||
def test_validity(self):
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
|
||||
self.model_type, self.placeholder_metadata)
|
||||
self.assertTrue(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(self.model_type, "image"))
|
||||
self.assertTrue(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(self.model_type, "video"))
|
||||
self.assertFalse(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(self.model_type, "audio"))
|
||||
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata(
|
||||
self.model_type)
|
||||
|
||||
def test_model_types_per_modality(self):
|
||||
pre_reg_image_model_types = list(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_image_model_types())
|
||||
pre_reg_video_model_types = list(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_video_model_types())
|
||||
pre_reg_audio_model_types = list(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_audio_model_types())
|
||||
|
||||
# register the model type for image and video
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
|
||||
self.model_type, self.placeholder_metadata)
|
||||
|
||||
post_reg_image_model_types = list(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_image_model_types())
|
||||
post_reg_video_model_types = list(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_video_model_types())
|
||||
post_reg_audio_model_types = list(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_audio_model_types())
|
||||
self.assertEqual(
|
||||
len(pre_reg_image_model_types) + 1, len(post_reg_image_model_types))
|
||||
self.assertEqual(
|
||||
len(pre_reg_video_model_types) + 1, len(post_reg_video_model_types))
|
||||
self.assertEqual(len(pre_reg_audio_model_types),
|
||||
len(post_reg_audio_model_types))
|
||||
self.assertIn(self.model_type, post_reg_image_model_types)
|
||||
self.assertIn(self.model_type, post_reg_video_model_types)
|
||||
self.assertNotIn(self.model_type, post_reg_audio_model_types)
|
||||
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata(
|
||||
self.model_type)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user