TensorRT-LLMs/tensorrt_llm/inputs/data.py
Chang Liu 26901e4aa0
[TRTLLM-10612][feat] Initial support of AIGV models in TRTLLM (#11462)
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com>
Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com>
Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com>
Co-authored-by: Freddy Qi <junq@nvidia.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
2026-02-14 06:11:11 +08:00

165 lines
5.4 KiB
Python

# Adapt from
# https://github.com/vllm-project/vllm/blob/2e33fe419186c65a18da6668972d61d7bbc31564/vllm/inputs/data.py
from typing import Any, Dict, List, Sequence, Union
from typing_extensions import NotRequired, TypedDict
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired[Dict[str, Any]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
multi_modal_uuids: NotRequired[Dict[str, List[Any]]]
"""
Optional user-provided UUIDs for multimodal items.
Structure mirrors multi_modal_data: {"image": ["uuid1", None, "uuid3"]}.
When a UUID is provided for an item, it will be returned in KV cache events
instead of the computed content hash. Use None to fall back to content
hashing for specific items.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
input processor for mm input processing.
"""
query: NotRequired[str]
"""The query input text for star attention."""
class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt."""
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired[Dict[str, Any]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
multi_modal_uuids: NotRequired[Dict[str, List[Any]]]
"""
Optional user-provided UUIDs for multimodal items.
Structure mirrors multi_modal_data: {"image": ["uuid1", None, "uuid3"]}.
When a UUID is provided for an item, it will be returned in KV cache events
instead of the computed content hash. Use None to fall back to content
hashing for specific items.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
input processor for mm input processing.
"""
query_token_ids: NotRequired[List[int]]
"""The query input token IDs for star attention."""
PromptInputs = Union[str, List[int], TextPrompt, TokensPrompt]
def prompt_inputs(inputs: PromptInputs, ) -> Union[TextPrompt, TokensPrompt]:
if isinstance(inputs, str):
prompt_inputs = TextPrompt(prompt=inputs)
elif isinstance(inputs, list):
assert isinstance(inputs[0], int)
prompt_inputs = TokensPrompt(prompt_token_ids=inputs)
elif isinstance(inputs, dict):
assert inputs.get("prompt") is not None \
or inputs.get("prompt_token_ids") is not None
return inputs
else:
raise TypeError(
f"Invalid type of inputs for llm.generate: {type(inputs)}")
return prompt_inputs
class VisualGenTextPrompt(TypedDict):
prompt: str
negative_prompt: NotRequired[str]
class VisualGenTokensPrompt(TypedDict):
prompt_token_ids: List[int]
negative_prompt_token_ids: NotRequired[List[int]]
VisualGenPromptInputs = Union[
str,
List[int],
VisualGenTextPrompt,
VisualGenTokensPrompt,
]
VisualGenInputs = Union[
VisualGenPromptInputs,
Sequence[VisualGenPromptInputs],
]
def visual_gen_inputs(
inputs: "VisualGenPromptInputs",
) -> Union["VisualGenTextPrompt", "VisualGenTokensPrompt"]:
# str -> text prompt
if isinstance(inputs, str):
return VisualGenTextPrompt(prompt=inputs)
# list[int] -> token prompt
if isinstance(inputs, list):
if len(inputs) == 0:
raise ValueError("`inputs` token list cannot be empty.")
if not all(isinstance(t, int) for t in inputs):
raise TypeError(
"`inputs` list must contain only ints when used as token IDs.")
return VisualGenTokensPrompt(prompt_token_ids=inputs)
# dict form
if isinstance(inputs, dict):
has_prompt = "prompt" in inputs
has_prompt_token_ids = "prompt_token_ids" in inputs
if has_prompt == has_prompt_token_ids:
raise ValueError(
"VisualGen prompt dict must contain exactly one of "
"`prompt` or `prompt_token_ids`.")
if has_prompt:
prompt = inputs.get("prompt")
if not isinstance(prompt, str) or prompt == "":
raise TypeError("`prompt` must be a non-empty string.")
if "negative_prompt" in inputs and not isinstance(
inputs["negative_prompt"], str):
raise TypeError("`negative_prompt` must be a string.")
return inputs # VisualGenTextPrompt
token_ids = inputs.get("prompt_token_ids")
if not isinstance(token_ids, list) or len(token_ids) == 0:
raise TypeError("`prompt_token_ids` must be a non-empty list[int].")
if not all(isinstance(t, int) for t in token_ids):
raise TypeError("`prompt_token_ids` must contain only ints.")
if "negative_prompt_token_ids" in inputs:
neg_ids = inputs["negative_prompt_token_ids"]
if not isinstance(neg_ids, list) or not all(
isinstance(t, int) for t in neg_ids):
raise TypeError(
"`negative_prompt_token_ids` must be a list[int].")
return inputs # VisualGenTokensPrompt
raise TypeError(
"Invalid `inputs` for VisualGen.generate. "
"Expected one of: str, list[int], VisualGenTextPrompt, VisualGenTokensPrompt."
)