[None][fix] Refactoring input prep to allow out-of-tree models (#6497)

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
This commit is contained in:
rakib-hasan 2025-08-12 17:29:10 -07:00 committed by GitHub
parent bd9a6dd9ab
commit 2923eb88a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 748 additions and 163 deletions

View File

@ -0,0 +1,52 @@
# Out-of-tree Model Development
The file `modeling_opt.py` shows an example of how a custom model can be defined using TRT-LLM APIs without modifying the source code of TRT-LLM.
The file `main.py` shows how to run inference for such custom models using the LLM API.
## Out-of-tree Multimodal Models
For multimodal models, TRT-LLM provides `quickstart_multimodal.py` to quickly run a multimodal model that is defined within TRT-LLM. `trtllm-bench` can be used for benchmarking such models.
However, the following sections describe how to use those tools for out-of-tree models.
### Pre-requisite
To use an out-of-tree model with the quickstart example and trtllm-bench, you need to prepare the model definition files similar to a python module.
Consider the following file structure as an example:
```
modeling_custom_phi
|-- __init__.py
|-- configuration.py
|-- modeling_custom_phi.py
|-- encoder
|-- __init__.py
|-- configuration.py
|-- modeling_encoder.py
````
The files `__init__.py` should be populated with the right imports for the custom model. For example, the `modeling_custom_phi/__init__.py` can contain something like:
```
from .modeling_custom_phi import MyVLMForConditionalGeneration
from . import encoder
```
### Quickstart Example
Once the model definition files are prepared as a python module (as described above), you can use the `--custom_module_dirs` flag in `quickstart_multimodal.py` to load your model and run inference.
```
python3 quickstart_multimodal.py --model_dir ./model_ckpt --modality image --max_tokens 10 --prompt "Describe the image." --media ./demo_lower.png --image_format pil --custom_module_dirs ../modeling_custom_phi
```
### Benchmarking
Similar to the quickstart example, you can use the same CLI argument with `trtllm-bench` to benchmark a custom model.
Prepare the dataset:
```
python ./benchmarks/cpp/prepare_dataset.py --tokenizer ./model_ckpt --stdout dataset --dataset-name lmms-lab/MMMU --dataset-split test --dataset-image-key image --dataset-prompt-key "question" --num-requests 100 --output-len-dist 128,5 > mm_data.jsonl
```
Run the benchmark:
```
trtllm-bench --model ./model_ckpt --model_path ./model_ckpt throughput --dataset mm_data.jsonl --backend pytorch --num_requests 100 --max_batch_size 4 --modality image --streaming --custom_module_dirs ../modeling_custom_phi
```

View File

@ -4,8 +4,9 @@ import os
from quickstart_advanced import add_llm_args, setup_llm
from tensorrt_llm.inputs import (ALL_SUPPORTED_MULTIMODAL_MODELS,
default_multimodal_input_loader)
from tensorrt_llm.inputs import default_multimodal_input_loader
from tensorrt_llm.inputs.registry import MULTIMODAL_PLACEHOLDER_REGISTRY
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir
example_medias_and_prompts = {
"image": {
@ -79,10 +80,11 @@ example_medias_and_prompts = {
def add_multimodal_args(parser):
parser.add_argument("--model_type",
type=str,
choices=ALL_SUPPORTED_MULTIMODAL_MODELS,
help="Model type.")
parser.add_argument(
"--model_type",
type=str,
choices=MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(),
help="Model type as specified in the HuggingFace model config.")
parser.add_argument("--modality",
type=str,
choices=[
@ -90,7 +92,7 @@ def add_multimodal_args(parser):
"multiple_image", "mixture_text_image"
],
default="image",
help="Media type.")
help="Media type being used for inference.")
parser.add_argument("--media",
type=str,
nargs="+",
@ -108,6 +110,18 @@ def add_multimodal_args(parser):
type=str,
default="cpu",
help="The device to have the input on.")
parser.add_argument(
"--custom_module_dirs",
type=str,
nargs="+",
default=None,
help=
("Paths to an out-of-tree model directory which should be imported."
" This is useful to load a custom model. The directory should have a structure like:"
" <model_name>"
" ├── __init__.py"
" ├── <model_name>.py"
" └── <sub_dirs>"))
return parser
@ -140,6 +154,15 @@ def parse_arguments():
def main():
args = parse_arguments()
if args.custom_module_dirs is not None:
for custom_module_dir in args.custom_module_dirs:
try:
import_custom_module_from_dir(custom_module_dir)
except Exception as e:
print(
f"Failed to import custom module from {custom_module_dir}: {e}"
)
raise e
lora_config = None
if args.load_lora:
@ -159,8 +182,11 @@ def main():
model_type = args.model_type
else:
model_type = json.load(
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, f"Unsupported model_type: {model_type}"
open(os.path.join(str(llm._hf_model_dir),
'config.json')))['model_type']
assert model_type in MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(), \
f"Unsupported model_type: {model_type} found!\n" \
f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}"
# set prompts and media to example prompts and images if they are not provided
if args.prompt is None:
@ -168,7 +194,7 @@ def main():
if args.media is None:
args.media = example_medias_and_prompts[args.modality]["media"]
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
model_dir=llm._hf_model_dir,
model_dir=str(llm._hf_model_dir),
model_type=model_type,
modality=args.modality,
prompts=args.prompt,

View File

@ -10,7 +10,9 @@ from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
BaseWeightMapper
from ..._utils import nvtx_range
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...sampling_params import SamplingParams
@ -137,7 +139,13 @@ class Gemma3MultiModalProjector(torch.nn.Module):
@register_auto_model("Gemma3ForConditionalGeneration")
@register_input_processor(Gemma3InputProcessor, model_type="gemma3")
@register_input_processor(
Gemma3InputProcessor,
model_type="gemma3",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={"image": "<start_of_image>"},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class Gemma3VLM(PreTrainedModel):
def __init__(self, model_config: ModelConfig[Gemma3Config]):

View File

@ -15,7 +15,9 @@ from transformers.models.auto import CONFIG_MAPPING
from tensorrt_llm.inputs.multimodal import MultimodalParams
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...sampling_params import SamplingParams
@ -961,7 +963,23 @@ class HCXVisionModel:
@register_auto_model("HCXVisionForCausalLM")
@register_input_processor(HCXVisionInputProcessor, model_type="hyperclovax_vlm")
@register_input_processor(
HCXVisionInputProcessor,
model_type="hyperclovax_vlm",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image":
('<im_end>\n<|im_start|>user (mime) \n'
'{"type": "image/jpeg", "filename": ""}<|im_end|>\n'
'<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n'
'<|im_start|>image/aux\n'
'다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 '
'keyword와 bbox 위치입니다.bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 '
'형태입니다. 참고하여 답변하세요. '
'{"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}')
},
placeholder_placement=MultimodalPlaceholderPlacement.AFTER_TEXT,
))
class HCXVisionForCausalLM(PreTrainedModel):
def __init__(self, model_config: ModelConfig):

View File

@ -20,7 +20,9 @@ from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import HfLoraLoader
from tensorrt_llm.models.convert_utils import split_matrix_tp
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...sampling_params import SamplingParams
from ..attention_backend import AttentionMetadata
@ -1173,7 +1175,13 @@ class Llama4InputProcessor(InputProcessor):
@register_auto_model("Llama4ForConditionalGeneration")
@register_input_processor(Llama4InputProcessor, model_type="llama4")
@register_input_processor(
Llama4InputProcessor,
model_type="llama4",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={"image": "<|image|>"},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
Llama4Config]):

View File

@ -14,7 +14,9 @@ from transformers.models.llava_next.modeling_llava_next import (
from tensorrt_llm.inputs.multimodal import MultimodalParams
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...llmapi.utils import download_hf_model
from ...logger import logger
@ -263,7 +265,13 @@ class LlavaNextVisionModel(nn.Module):
@register_auto_model("LlavaNextForConditionalGeneration")
@register_input_processor(LlavaNextInputProcessor, model_type="llava_next")
@register_input_processor(
LlavaNextInputProcessor,
model_type="llava_next",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={"image": "<image>"},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class LlavaNextModel(PreTrainedModel):
config_class = LlavaNextConfig

View File

@ -29,7 +29,9 @@ from tensorrt_llm._torch.modules.rms_norm import RMSNorm
from tensorrt_llm._torch.speculative import SpecMetadata
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.inputs import (ExtraProcessedInputs, InputProcessor,
TextPrompt, register_input_processor)
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.logger import logger
@ -269,8 +271,20 @@ class Mistral3InputProcessor(InputProcessor):
@register_auto_model("Mistral3ForConditionalGeneration")
# The below informs the registry which input registry to create for this in `tensorrt_llm/llmapi/llm.py`.
@register_input_processor(Mistral3InputProcessor, model_type="mistral3")
@register_input_processor(
Mistral3InputProcessor,
model_type="mistral3",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "[IMG]",
},
# NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
# src/mistral_common/tokens/tokenizers/base.py#L326
# However, accuracy tests show that the model generates higher quality output when the image
# precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class Mistral3VLM(PreTrainedModel):
"""Mistral3VLM implementation for TRTLLM.

View File

@ -19,7 +19,9 @@ from PIL import Image
from tensorrt_llm.inputs.multimodal import MultimodalParams
from ...executor.request import LoRARequest
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...lora_helper import LoraConfig
@ -461,7 +463,17 @@ class Phi4MMInputProcessor(InputProcessor):
@register_auto_model("Phi4MMForCausalLM")
@register_input_processor(Phi4MMInputProcessor, model_type="phi4mm")
@register_input_processor(
Phi4MMInputProcessor,
model_type="phi4mm",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<|image_{0}|>",
"audio": "<|audio_{0}|>",
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
placeholders_separator="",
))
class Phi4MMForCausalLM(transformers.PreTrainedModel):
_supports_flash_attn_2 = True

View File

@ -12,7 +12,9 @@ from tensorrt_llm.inputs.multimodal import MultimodalParams
from ..._utils import nvtx_range_debug
from ...functional import RopeEmbeddingUtils, RotaryScalingType
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...sampling_params import SamplingParams
@ -645,7 +647,16 @@ class Qwen2VLModelBase(PreTrainedModel):
@register_auto_model("Qwen2VLForConditionalGeneration")
@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_vl")
@register_input_processor(
Qwen2VLInputProcessorBase,
model_type="qwen2_vl",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<|vision_start|><|image_pad|><|vision_end|>",
"video": "<|vision_start|><|video_pad|><|vision_end|>"
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class Qwen2VLModel(Qwen2VLModelBase):
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
@ -657,7 +668,14 @@ class Qwen2VLModel(Qwen2VLModelBase):
@register_auto_model("Qwen2_5_VLForConditionalGeneration")
@register_input_processor(Qwen2VLInputProcessorBase, model_type="qwen2_5_vl")
@register_input_processor(
Qwen2VLInputProcessorBase,
model_type="qwen2_5_vl",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<|vision_start|><|image_pad|><|vision_end|>",
"video": "<|vision_start|><|video_pad|><|vision_end|>"
}))
class Qwen2_5_VLModel(Qwen2VLModelBase):
def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,

View File

@ -35,7 +35,9 @@ from transformers import (AutoConfig, AutoImageProcessor, AutoModel,
PreTrainedModel)
from ..._utils import nvtx_range
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
from ...inputs import (ExtraProcessedInputs, InputProcessor,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, TextPrompt,
register_input_processor)
from ...logger import logger
from ...sampling_params import SamplingParams
@ -1118,7 +1120,16 @@ class VilaInputProcessor(InputProcessor):
@register_auto_model(VilaConfig.model_architecture)
@register_input_processor(VilaInputProcessor, model_type="llava_llama")
@register_input_processor(
VilaInputProcessor,
model_type="llava_llama",
placeholder_metadata=MultimodalPlaceholderMetadata(
placeholder_map={
"image": "<image>",
"video": "<vila/video>"
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
))
class VilaModel(PreTrainedModel):
config_class = VilaConfig

View File

@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download
from tensorrt_llm.bench.benchmark.utils.asynchronous import async_benchmark
from tensorrt_llm.bench.benchmark.utils.processes import IterationWriter
from tensorrt_llm.bench.build.build import get_model_config
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir
# isort: off
from tensorrt_llm.bench.benchmark.utils.general import (
@ -49,6 +50,16 @@ from tensorrt_llm.sampling_params import SamplingParams
type=click.Choice(ALL_SUPPORTED_BACKENDS),
default="pytorch",
help="The backend to use when running benchmarking.")
@optgroup.option(
"--custom_module_dirs",
type=click.Path(exists=True,
readable=True,
path_type=Path,
resolve_path=True),
default=None,
multiple=True,
help="Paths to custom module directories to import.",
)
@optgroup.option(
"--extra_llm_api_options",
type=str,
@ -104,6 +115,16 @@ from tensorrt_llm.sampling_params import SamplingParams
required=False,
help="Pass in a dataset file for parsing instead of stdin.",
)
# For text models, tokenizer initialization is not needed when loading the model since the dataset is already tokenized.
# For this reason, we skip tokenizer initialization by default.
# However, for VLM models, tokenizer initialization is needed inside the model since the dataset contains texts and
# raw media data. We cannot skip tokenizer initialization in this case.
@optgroup.option(
"--no_skip_tokenizer_init",
is_flag=True,
default=False,
help="Do not skip tokenizer initialization when loading the model.",
)
@optgroup.option(
"--eos_id",
type=int,
@ -118,6 +139,18 @@ from tensorrt_llm.sampling_params import SamplingParams
default=None,
help="Modality of the multimodal requests.",
)
@optgroup.option(
"--image_data_format",
type=click.Choice(["pt", "pil"]),
default="pt",
help="Format of the image data for multimodal models.",
)
@optgroup.option(
"--data_device",
type=click.Choice(["cuda", "cpu"]),
default="cuda",
help="Device to load the multimodal data on.",
)
@optgroup.option(
"--max_input_len",
type=int,
@ -262,7 +295,17 @@ def throughput_command(
logger.info("Preparing to run throughput benchmark...")
# Parameters from CLI
# Model, experiment, and engine params
custom_module_dirs: list[Path] = params.pop("custom_module_dirs", [])
for custom_module_dir in custom_module_dirs:
try:
import_custom_module_from_dir(custom_module_dir)
except Exception as e:
logger.error(
f"Failed to import custom module from {custom_module_dir}: {e}")
raise e
dataset_path: Path = params.get("dataset")
no_skip_tokenizer_init: bool = params.get("no_skip_tokenizer_init", False)
eos_id: int = params.get("eos_id")
warmup: int = params.get("warmup")
num_requests: int = params.get("num_requests")
@ -274,6 +317,8 @@ def throughput_command(
backend: str = params.get("backend")
modality: str = params.get("modality")
max_input_len: int = params.get("max_input_len")
image_data_format: str = params.get("image_data_format", "pt")
data_device: str = params.get("data_device", "cpu")
model_type = get_model_config(model, checkpoint_path).model_type
# Reporting options
@ -286,7 +331,7 @@ def throughput_command(
# Runtime kwargs and option tracking.
kwargs = {}
# Initialize the HF tokenizer for the specified model.
# Initialize the HF tokenizer for the specified model. This is only used for data preparation.
tokenizer = initialize_tokenizer(checkpoint_path)
# Dataset Loading and Preparation
@ -298,6 +343,8 @@ def throughput_command(
model_dir=checkpoint_path,
model_type=model_type,
modality=modality,
image_data_format=image_data_format,
data_device=data_device,
max_input_seq_len_for_multimodal=max_input_len)
metadata.dataset_path = dataset_path
params["target_input_len"] = params.get(
@ -392,6 +439,7 @@ def throughput_command(
logger.info("Setting up throughput benchmark.")
kwargs = kwargs | runtime_config.get_llm_args()
kwargs['backend'] = backend
kwargs['skip_tokenizer_init'] = not no_skip_tokenizer_init
if backend == "pytorch" and iteration_log is not None:
kwargs["enable_iter_perf_stats"] = True

View File

@ -41,6 +41,8 @@ def create_dataset_from_stream(
model_dir: str = None,
model_type: str = None,
modality: str = None,
image_data_format: str = "pt",
data_device: str = "cpu",
max_input_seq_len_for_multimodal: int = 4096,
) -> Tuple[DatasetMetadata, List[InferenceRequest]]:
"""Generate metadata and a list of requests to drive benchmarking.
@ -130,7 +132,9 @@ def create_dataset_from_stream(
model_type=model_type,
modality=modality,
prompts=prompts,
media=media_paths) # list of dicts
media=media_paths, # list of dicts
image_data_format=image_data_format,
device=data_device)
all_isl = []
all_seq_len = []

View File

@ -1,16 +1,23 @@
from .data import PromptInputs, TextPrompt, TokensPrompt, prompt_inputs
from .multimodal import MultimodalInput
from .registry import (ExtraProcessedInputs, InputProcessor,
create_input_processor, create_input_processor_with_hash,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement, create_input_processor,
create_input_processor_with_hash,
register_input_processor)
from .utils import (ALL_SUPPORTED_MULTIMODAL_MODELS, ConversationMessage,
MultimodalData, MultimodalDataTracker,
from .utils import (ALL_SUPPORTED_AUDIO_MODELS, ALL_SUPPORTED_IMAGE_MODELS,
ALL_SUPPORTED_MULTIMODAL_MODELS, ALL_SUPPORTED_VIDEO_MODELS,
ConversationMessage, MultimodalData, MultimodalDataTracker,
add_multimodal_placeholders, async_load_audio,
async_load_image, async_load_video,
default_multimodal_input_loader,
encode_base64_content_from_url, load_image, load_video)
__all__ = [
"ALL_SUPPORTED_MULTIMODAL_MODELS",
"ALL_SUPPORTED_IMAGE_MODELS",
"ALL_SUPPORTED_VIDEO_MODELS",
"ALL_SUPPORTED_AUDIO_MODELS",
"PromptInputs",
"prompt_inputs",
"TextPrompt",
@ -20,7 +27,8 @@ __all__ = [
"create_input_processor_with_hash",
"register_input_processor",
"ExtraProcessedInputs",
"ALL_SUPPORTED_MULTIMODAL_MODELS",
"MultimodalPlaceholderMetadata",
"MultimodalPlaceholderPlacement",
"ConversationMessage",
"MultimodalDataTracker",
"MultimodalData",

View File

@ -1,3 +1,5 @@
import enum
from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, List, Optional, Protocol, Tuple, Type,
TypeVar)
@ -10,7 +12,6 @@ from .data import TextPrompt
from .multimodal import (MultimodalInput, apply_mm_hashes, default_hasher,
find_mm_token_lengths, find_mm_token_positions,
hexdigest_to_int32, validate_mm_inputs)
from .utils import ALL_SUPPORTED_MULTIMODAL_MODELS
N = TypeVar("N", bound=Type[nn.Module])
@ -32,6 +33,7 @@ class InputProcessor(Protocol):
model_path: any
model_config: any
tokenizer: any
multimodal_hashing_supported: Optional[bool] = None
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
@ -50,6 +52,7 @@ class DefaultInputProcessor(InputProcessor):
self.tokenizer = tokenizer
self.model_config = model_config
self.model_path = model_path
self.multimodal_hashing_supported = None
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
@ -108,6 +111,127 @@ class DefaultInputProcessor(InputProcessor):
return token_ids, None
class MultimodalPlaceholderPlacement(enum.Enum):
"""
The placement of the multimodal placeholder in the prompt. Valid values are:
- BEFORE_TEXT: the placeholders are placed before the text prompt.
- AFTER_TEXT: the placeholders are placed after the text prompt.
"""
INVALID = -1
BEFORE_TEXT = 0
AFTER_TEXT = 1
@dataclass(frozen=True)
class MultimodalPlaceholderMetadata:
"""
Metadata for the multimodal placeholder. It has 3 components:
- placeholder_map:
A mapping from modality to placeholder string.
Modality can be "image", "video", "audio", etc.
- placeholder_placement:
The placement of the placeholders, e.g. before or after the text prompt.
- placeholders_separator:
The separator between the placeholders, e.g. some models use "\n" to separate the placeholders.
"""
placeholder_map: Dict[str, str] = field(default_factory=dict)
placeholder_placement: MultimodalPlaceholderPlacement = MultimodalPlaceholderPlacement.AFTER_TEXT
placeholders_separator: str = "\n"
class MultimodalPlaceholderRegistry:
"""
Registry for the multimodal models to keep track of the placeholder information.
"""
def __init__(self) -> None:
self._multimodal_placeholder_by_model_type: Dict[
str, MultimodalPlaceholderMetadata] = {}
def __str__(self) -> str:
s = ""
for model_type, placeholder_metadata in self._multimodal_placeholder_by_model_type.items(
):
s += "-" * 100 + "\n"
s += f"Model type: {model_type}\n"
s += f"Placeholder map: {placeholder_metadata.placeholder_map}\n"
s += f"Placeholder placement: {placeholder_metadata.placeholder_placement}\n"
s += f"Placeholders separator: \"{placeholder_metadata.placeholders_separator}\"\n"
s += "-" * 80 + "\n"
return s
def set_placeholder_metadata(
self, model_type: str,
placeholder_metadata: MultimodalPlaceholderMetadata):
self._multimodal_placeholder_by_model_type[
model_type] = placeholder_metadata
def remove_placeholder_metadata(self, model_type: str):
if model_type not in self._multimodal_placeholder_by_model_type:
raise ValueError(f"Model type '{model_type}' is not registered")
del self._multimodal_placeholder_by_model_type[model_type]
def is_valid(self, model_type: str, modality: str) -> bool:
return model_type in self._multimodal_placeholder_by_model_type and \
modality in self._multimodal_placeholder_by_model_type[model_type].placeholder_map
def get_placeholder_metadata(
self, model_type: str) -> MultimodalPlaceholderMetadata:
if model_type not in self._multimodal_placeholder_by_model_type:
raise ValueError(
f"Model type {model_type} is not registered in MultimodalPlaceholderRegistry"
)
return self._multimodal_placeholder_by_model_type[model_type]
def get_placeholder(self, model_type: str, modality: str) -> str:
if not self.is_valid(model_type, modality):
raise ValueError(
f"Model type '{model_type}' with modality '{modality}' is not registered."
)
return self._multimodal_placeholder_by_model_type[
model_type].placeholder_map[modality]
def get_placeholder_placement(
self, model_type: str) -> MultimodalPlaceholderPlacement:
if model_type not in self._multimodal_placeholder_by_model_type:
raise ValueError(f"Model type '{model_type}' is not registered")
return self._multimodal_placeholder_by_model_type[
model_type].placeholder_placement
def get_placeholders_separator(self, model_type: str) -> str:
if model_type not in self._multimodal_placeholder_by_model_type:
raise ValueError(f"Model type '{model_type}' is not registered")
return self._multimodal_placeholder_by_model_type[
model_type].placeholders_separator
def get_registered_image_model_types(self) -> Tuple[str, ...]:
return (
model_type
for model_type in self._multimodal_placeholder_by_model_type
if "image" in self.
_multimodal_placeholder_by_model_type[model_type].placeholder_map)
def get_registered_video_model_types(self) -> Tuple[str, ...]:
return (
model_type
for model_type in self._multimodal_placeholder_by_model_type
if "video" in self.
_multimodal_placeholder_by_model_type[model_type].placeholder_map)
def get_registered_audio_model_types(self) -> Tuple[str, ...]:
return (
model_type
for model_type in self._multimodal_placeholder_by_model_type
if "audio" in self.
_multimodal_placeholder_by_model_type[model_type].placeholder_map)
def get_registered_model_types(self) -> Tuple[str, ...]:
return tuple(self._multimodal_placeholder_by_model_type.keys())
MULTIMODAL_PLACEHOLDER_REGISTRY = MultimodalPlaceholderRegistry()
class InputProcessorRegistry:
def __init__(self) -> None:
@ -118,9 +242,10 @@ class InputProcessorRegistry:
INPUT_PROCESSOR_REGISTRY = InputProcessorRegistry()
def register_input_processor(processor_cls: Type[InputProcessor],
model_type: str,
out_of_tree: bool = False):
def register_input_processor(
processor_cls: Type[InputProcessor],
model_type: str,
placeholder_metadata: MultimodalPlaceholderMetadata = None):
"""
Register an input processor to a model class.
NOTE:
@ -128,17 +253,18 @@ def register_input_processor(processor_cls: Type[InputProcessor],
the model type only for that.
2. If this is used for other models in the future, this logic needs to be
updated e.g. adding another version of this API without the model_type.
3. If the model is not in the tree, user needs to set out_of_tree to True
to bypass the model type check and provide their own input preparation.
"""
def wrapper(model_cls: N) -> N:
INPUT_PROCESSOR_REGISTRY._input_processors_cls_by_model_type[
model_cls] = processor_cls
if not out_of_tree:
assert model_type in ALL_SUPPORTED_MULTIMODAL_MODELS, \
f"Model type {model_type} not in {ALL_SUPPORTED_MULTIMODAL_MODELS}.\n" \
"Please see the tensorrt_llm/inputs/utils.py file for more information."
if placeholder_metadata is None:
raise ValueError(
f"A valid placeholder_metadata must be provided but got {placeholder_metadata}"
)
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
model_type, placeholder_metadata)
return model_cls
@ -192,41 +318,88 @@ def create_input_processor_with_hash(
A wrapped processor that modifies prompts before processing.
"""
def multimodal_hashing_process(
inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""
Process the multinmodal hashing for media tokens if possible.
"""
assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support."
mm_data = inputs['multi_modal_data']
num_mm_tokens = find_mm_token_lengths(mm_data, input_processor)
if len(num_mm_tokens) > 0:
mm_hashes = apply_mm_hashes(mm_data, hash_lib)
prompt_token_ids, extra_processed_inputs = input_processor(
inputs, sampling_params)
start_positions = find_mm_token_positions(
input_ids=prompt_token_ids, # token sequence
num_mm_tokens=
num_mm_tokens, # list of lengths of each chunk of visual tokens
vocab_size=input_processor.model_config.vocab_size,
)
# flatten the hashes from dict to a single list
mm_hashes = [h for hashes in mm_hashes.values() for h in hashes]
validate_mm_inputs(prompt_token_ids, mm_hashes, start_positions,
num_mm_tokens)
mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes
] # nested list w/ multiple int32 per hash
extra_processed_inputs[
"multimodal_input"] = MultimodalInput.from_components(
mm_hashes_int32, start_positions, num_mm_tokens)
return prompt_token_ids, extra_processed_inputs
return [], None
def input_processor_wrapper(
inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
try:
assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support."
mm_data = inputs['multi_modal_data']
num_mm_tokens = find_mm_token_lengths(mm_data, input_processor)
if len(num_mm_tokens) > 0:
mm_hashes = apply_mm_hashes(mm_data, hash_lib)
prompt_token_ids, extra_processed_inputs = input_processor(
inputs, sampling_params)
start_positions = find_mm_token_positions(
input_ids=prompt_token_ids, # token sequence
num_mm_tokens=
num_mm_tokens, # list of lengths of each chunk of visual tokens
vocab_size=input_processor.model_config.vocab_size,
)
# flatten the hashes from dict to a single list
mm_hashes = [h for hashes in mm_hashes.values() for h in hashes]
validate_mm_inputs(prompt_token_ids, mm_hashes, start_positions,
num_mm_tokens)
mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes
] # nested list w/ multiple int32 per hash
try_multimodal_hashing = False # only used for first time
use_multimodal_hashing = False # used for subsequent calls
modalities = list(set(inputs['multi_modal_data'].keys())
) if 'multi_modal_data' in inputs else []
if len(modalities) > 0:
# NOTE: tensorrt_llm/inputs/multimodal.py:find_mm_token_lengths only supports image data for now
if len(modalities) == 1 and modalities[0] == "image":
# only try multimodal hashing if the inputs only contain image data
if input_processor.multimodal_hashing_supported is not None:
use_multimodal_hashing = input_processor.multimodal_hashing_supported
else:
# we need to try the multimodal hashing for the first time to determine if it is supported
try_multimodal_hashing = True
extra_processed_inputs[
"multimodal_input"] = MultimodalInput.from_components(
mm_hashes_int32, start_positions, num_mm_tokens)
if try_multimodal_hashing or use_multimodal_hashing:
try:
prompt_token_ids, extra_processed_inputs = multimodal_hashing_process(
inputs, sampling_params)
if try_multimodal_hashing:
# if trying for first time, set the flag to True
input_processor.multimodal_hashing_supported = True
return prompt_token_ids, extra_processed_inputs
else:
except Exception as e:
import traceback
traceback.print_exc()
logger.warning(f"Multimodal hashing failed: {e}.")
if try_multimodal_hashing:
# if trying for first time, fall back to basic input processor
# and set the flag to False so that we don't try again
input_processor.multimodal_hashing_supported = False
logger.warning("Falling back to basic input processor.")
try:
return input_processor(inputs, sampling_params)
except Exception as e2:
import traceback
traceback.print_exc()
logger.warning(f"Basic input processor failed: {e}.")
raise e2
else:
raise e
else:
try:
return input_processor(inputs, sampling_params)
except Exception as e:
# Fall back to basic input processor if multimodal processing fails
logger.warning(
f"Multimodal hashing failed: {e}. Falling back to basic input processor."
)
return input_processor(inputs, sampling_params)
except Exception as e:
import traceback
traceback.print_exc()
logger.warning(f"Basic input processor failed: {e}.")
raise e
return input_processor_wrapper

View File

@ -1,6 +1,5 @@
import asyncio
import base64
import enum
import tempfile
from collections import defaultdict
from io import BytesIO
@ -18,6 +17,8 @@ from torchvision.transforms import ToTensor
from transformers import AutoProcessor, ProcessorMixin
from transformers.utils import logging
from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY,
MultimodalPlaceholderPlacement)
from tensorrt_llm.llmapi.llm_utils import ModelLoader
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, TransformersTokenizer
@ -209,60 +210,25 @@ NOTE:
placeholder for the model needs to be added in retrieve_multimodal_placeholder().
"""
SUPPORTED_QWEN_MODEL_GROUP = ["qwen2_vl", "qwen2_5_vl"]
SUPPORTED_GEMMA_MODEL_GROUP = ["gemma3"]
SUPPORTED_LLAMA_MODEL_GROUP = ["mllama", "llama4"]
SUPPORTED_LLAVA_IMAGE_MODEL_GROUP = ["llava_llama", "llava_next"]
SUPPORTED_LLAVA_VIDEO_MODEL_GROUP = ["llava_llama"]
SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP = ["mistral3"]
SUPPORTED_HYPERCLOVAX_MODEL_GROUP = ["hyperclovax_vlm"]
SUPPORTED_PHI_MODEL_GROUP = ["phi4mm"]
ALL_SUPPORTED_IMAGE_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
+ SUPPORTED_LLAMA_MODEL_GROUP \
+ SUPPORTED_LLAVA_IMAGE_MODEL_GROUP \
+ SUPPORTED_HYPERCLOVAX_MODEL_GROUP \
+ SUPPORTED_GEMMA_MODEL_GROUP \
+ SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP \
+ SUPPORTED_PHI_MODEL_GROUP
ALL_SUPPORTED_VIDEO_MODELS = SUPPORTED_QWEN_MODEL_GROUP \
+ SUPPORTED_LLAVA_VIDEO_MODEL_GROUP
ALL_SUPPORTED_AUDIO_MODELS = SUPPORTED_PHI_MODEL_GROUP
ALL_SUPPORTED_MULTIMODAL_MODELS = list(set(ALL_SUPPORTED_IMAGE_MODELS) \
| set(ALL_SUPPORTED_VIDEO_MODELS) \
| set(ALL_SUPPORTED_AUDIO_MODELS))
HF_CHAT_TEMPLATE_EXCEPTIONS = ["llava_llama"]
PLACEHOLDER_EXCEPTIONS = ["llava_next"]
class MultimodalPlaceholderPlacement(enum.Enum):
INVALID = -1
BEFORE_TEXT = 0
AFTER_TEXT = 1
# Helpers to always get the latest supported multimodal model types from the registry
def ALL_SUPPORTED_MULTIMODAL_MODELS():
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()
PLACEHOLDER_PLACEMENT_MAP = {
"qwen2_vl": MultimodalPlaceholderPlacement.BEFORE_TEXT,
"qwen2_5_vl": MultimodalPlaceholderPlacement.BEFORE_TEXT,
"llava_llama": MultimodalPlaceholderPlacement.BEFORE_TEXT,
"llava_next": MultimodalPlaceholderPlacement.BEFORE_TEXT,
"llama4": MultimodalPlaceholderPlacement.BEFORE_TEXT,
"mllama": MultimodalPlaceholderPlacement.BEFORE_TEXT,
"hyperclovax_vlm": MultimodalPlaceholderPlacement.AFTER_TEXT,
"gemma3": MultimodalPlaceholderPlacement.BEFORE_TEXT,
# NOTE: for mistral3 multimodal models, it does not strictly have to be before the text.
# Ref: https://github.com/mistralai/mistral-common/blob/039465db2bdc0486df36365c9bdb428188482a18/
# src/mistral_common/tokens/tokenizers/base.py#L326
# However, accuracy tests show that the model generates higher quality output when the image
# precedes the text (the relative difference can be as much as ~30% for both vLLM and TRT-LLM).
"mistral3": MultimodalPlaceholderPlacement.BEFORE_TEXT,
"phi4mm": MultimodalPlaceholderPlacement.BEFORE_TEXT,
}
assert len(PLACEHOLDER_PLACEMENT_MAP) == len(ALL_SUPPORTED_MULTIMODAL_MODELS)
def ALL_SUPPORTED_IMAGE_MODELS():
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_image_model_types()
def ALL_SUPPORTED_VIDEO_MODELS():
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_video_model_types()
def ALL_SUPPORTED_AUDIO_MODELS():
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_audio_model_types()
def retrieve_multimodal_placeholder(model_type: str, modality: str,
@ -276,41 +242,16 @@ def retrieve_multimodal_placeholder(model_type: str, modality: str,
current_count: The number of multimodal data already added.
"""
if modality == "image":
if model_type in SUPPORTED_QWEN_MODEL_GROUP:
return "<|vision_start|><|image_pad|><|vision_end|>"
elif model_type in SUPPORTED_LLAMA_MODEL_GROUP:
return "<|image|>"
elif model_type in SUPPORTED_LLAVA_IMAGE_MODEL_GROUP:
return "<image>"
elif model_type in SUPPORTED_GEMMA_MODEL_GROUP:
return "<start_of_image>"
elif model_type in SUPPORTED_HYPERCLOVAX_MODEL_GROUP:
return '<im_end>\n<|im_start|>user (mime) \n{"type": "image/jpeg", "filename": ""}<|im_end|>\n' + \
'<|im_start|>user (vector)\n<|dummy3|><|im_end|>\n' + \
'<|im_start|>image/aux\n다음 중 ocr은 사진에서 검출된 글자이고, lens_keyword는 사진에서 추출된 keyword와 bbox 위치입니다.' + \
'bbox는 0~1 사이로 정규화된 [x1, y1, x2, y2]의 형태입니다. 참고하여 답변하세요. {"ocr": "", "lens_keywords": "", "lens_local_keywords": ""}'
elif model_type in SUPPORTED_MISTRAL_IMAGE_MODEL_GROUP:
# Ref: https://github.com/mistralai/mistral-common/blob/26a6bb3a07ee0b78a3808f2797f23e1d28514b93/
# src/mistral_common/tokens/tokenizers/base.py#L60
return "[IMG]"
elif model_type in SUPPORTED_PHI_MODEL_GROUP:
return f"<|image_{current_count}|>"
raise TypeError(
f"For image modality, only {ALL_SUPPORTED_IMAGE_MODELS} are supported but got {model_type}"
)
elif modality == "video":
if model_type in SUPPORTED_QWEN_MODEL_GROUP:
return "<|vision_start|><|video_pad|><|vision_end|>"
elif model_type in SUPPORTED_LLAVA_VIDEO_MODEL_GROUP:
return "<vila/video>"
raise TypeError(
f"For video modality, only {ALL_SUPPORTED_VIDEO_MODELS} are supported but got {model_type}"
)
elif modality == "audio":
if model_type in SUPPORTED_PHI_MODEL_GROUP:
return f"<|audio_{current_count}|>"
if MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(model_type, modality):
"""
The placeholder is a string with a single placeholder for the current count.
- For example, if the placeholder is "<|image_{0}|>", and the current count is 1,
the placeholder will be "<|image_1|>".
- However, if the placeholder is "<|image|>", the current count would be ignored.
In this case, the placeholder would be "<|image|>".
"""
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder(
model_type, modality).format(current_count)
raise TypeError(f"Unknown modality: {modality}")
@ -379,17 +320,15 @@ def add_multimodal_placeholders(model_type: str, text_prompt: str,
for placeholder in mm_placeholder_counts:
placeholders.extend([placeholder] * mm_placeholder_counts[placeholder])
parts = []
match PLACEHOLDER_PLACEMENT_MAP[model_type]:
match MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder_placement(model_type):
case MultimodalPlaceholderPlacement.BEFORE_TEXT:
parts.extend(placeholders)
parts.append(text_prompt)
case MultimodalPlaceholderPlacement.AFTER_TEXT:
parts.append(text_prompt)
parts.extend(placeholders)
if model_type == "phi4mm":
return "".join(parts)
else:
return "\n".join(parts)
return MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholders_separator(
model_type).join(parts)
def resolve_hf_chat_template(

View File

@ -0,0 +1,132 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
from pathlib import Path
from types import ModuleType
from typing import Optional, Union
def import_custom_module_from_file(
custom_module_path: Union[str, Path]) -> Optional[ModuleType]:
"""Import a custom module from a single file.
Args:
custom_module_path (Union[str, Path]): The path to the custom module file.
Returns:
The imported module object.
Raises:
ImportError: If the module cannot be imported.
"""
if isinstance(custom_module_path, str):
custom_module_path = Path(custom_module_path)
print(f"Importing custom module from file: {custom_module_path}")
# Import single Python file
module = None
spec = importlib.util.spec_from_file_location(custom_module_path.stem,
str(custom_module_path))
if spec is not None:
module = importlib.util.module_from_spec(spec)
if spec.loader is not None:
spec.loader.exec_module(module)
print(
f"Successfully imported custom module from file: {custom_module_path}"
)
else:
raise ImportError(
f"Failed to import custom module from {custom_module_path}")
else:
raise ImportError(
f"Failed to import custom module from {custom_module_path}")
return module
def import_custom_module_from_dir(
custom_module_path: Union[str, Path]) -> Optional[ModuleType]:
"""Import a custom module from a directory.
Args:
custom_module_path (Union[str, Path]): The path to the custom module directory.
Returns:
The imported module object.
Raises:
ImportError: If the module cannot be imported.
Note:
This function will add the parent directory of the custom module directory to sys.path.
This is useful for importing modules that are not in the current working directory.
"""
if isinstance(custom_module_path, str):
custom_module_path = Path(custom_module_path)
print(f"Importing custom module from directory: {custom_module_path}")
# Import directory as a package
# Add the parent directory to sys.path so we can import the package
import sys
parent_dir = str(custom_module_path.parent)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
# Import the package
module = None
package_name = custom_module_path.name
try:
module = importlib.import_module(package_name)
print(
f"Successfully imported custom module from directory: {custom_module_path}"
)
except ImportError as e:
raise ImportError(
f"Failed to import package {package_name} from {custom_module_path}: {e}"
)
return module
def import_custom_module(
custom_module_path: Union[str, Path]) -> Optional[ModuleType]:
"""Import a custom module from a file or directory.
Args:
custom_module_path (Union[str, Path]): The path to the custom module file or directory.
Returns:
The imported module object.
Raises:
ImportError: If the module cannot be imported.
FileNotFoundError: If the custom module path does not exist.
"""
if isinstance(custom_module_path, str):
custom_module_path = Path(custom_module_path)
print(f"Importing custom module from: {custom_module_path}")
if custom_module_path.exists():
if custom_module_path.is_file():
return import_custom_module_from_file(custom_module_path)
elif custom_module_path.is_dir():
return import_custom_module_from_dir(custom_module_path)
else:
raise FileNotFoundError(
f"Custom module path {custom_module_path} is neither a file nor a directory."
)
else:
raise FileNotFoundError(
f"Custom module path {custom_module_path} does not exist.")
return None

View File

@ -0,0 +1,106 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY,
MultimodalPlaceholderMetadata,
MultimodalPlaceholderPlacement)
class TestMultimodalPlaceholderRegistry(unittest.TestCase):
def setUp(self):
self.model_type = "test_model_type"
self.placeholder_metadata = MultimodalPlaceholderMetadata(
placeholder_map={
"image": "IMAGE_PLACEHOLDER",
"video": "VIDEO_PLACEHOLDER"
},
placeholder_placement=MultimodalPlaceholderPlacement.BEFORE_TEXT,
placeholders_separator="\n")
def test_new_registration(self):
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
self.model_type, self.placeholder_metadata)
self.assertEqual(
MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder_metadata(
self.model_type), self.placeholder_metadata)
MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata(
self.model_type)
def test_registered_model_types(self):
pre_reg_model_types = list(
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types())
# register the model type
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
self.model_type, self.placeholder_metadata)
post_reg_model_types = list(
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types())
self.assertEqual(
len(pre_reg_model_types) + 1, len(post_reg_model_types))
self.assertIn(self.model_type, post_reg_model_types)
MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata(
self.model_type)
def test_validity(self):
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
self.model_type, self.placeholder_metadata)
self.assertTrue(
MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(self.model_type, "image"))
self.assertTrue(
MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(self.model_type, "video"))
self.assertFalse(
MULTIMODAL_PLACEHOLDER_REGISTRY.is_valid(self.model_type, "audio"))
MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata(
self.model_type)
def test_model_types_per_modality(self):
pre_reg_image_model_types = list(
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_image_model_types())
pre_reg_video_model_types = list(
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_video_model_types())
pre_reg_audio_model_types = list(
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_audio_model_types())
# register the model type for image and video
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
self.model_type, self.placeholder_metadata)
post_reg_image_model_types = list(
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_image_model_types())
post_reg_video_model_types = list(
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_video_model_types())
post_reg_audio_model_types = list(
MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_audio_model_types())
self.assertEqual(
len(pre_reg_image_model_types) + 1, len(post_reg_image_model_types))
self.assertEqual(
len(pre_reg_video_model_types) + 1, len(post_reg_video_model_types))
self.assertEqual(len(pre_reg_audio_model_types),
len(post_reg_audio_model_types))
self.assertIn(self.model_type, post_reg_image_model_types)
self.assertIn(self.model_type, post_reg_video_model_types)
self.assertNotIn(self.model_type, post_reg_audio_model_types)
MULTIMODAL_PLACEHOLDER_REGISTRY.remove_placeholder_metadata(
self.model_type)
if __name__ == "__main__":
unittest.main()