mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-17 00:04:57 +08:00
Signed-off-by: Chang Liu (Enterprise Products) <liuc@nvidia.com> Signed-off-by: Chang Liu <9713593+chang-l@users.noreply.github.com> Signed-off-by: Zhenhua Wang <zhenhuaw@nvidia.com> Co-authored-by: Freddy Qi <junq@nvidia.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Zhenhua Wang <zhenhuaw@nvidia.com>
165 lines
5.4 KiB
Python
165 lines
5.4 KiB
Python
# Adapt from
|
|
# https://github.com/vllm-project/vllm/blob/2e33fe419186c65a18da6668972d61d7bbc31564/vllm/inputs/data.py
|
|
from typing import Any, Dict, List, Sequence, Union
|
|
|
|
from typing_extensions import NotRequired, TypedDict
|
|
|
|
|
|
class TextPrompt(TypedDict):
|
|
"""Schema for a text prompt."""
|
|
|
|
prompt: str
|
|
"""The input text to be tokenized before passing to the model."""
|
|
|
|
multi_modal_data: NotRequired[Dict[str, Any]]
|
|
"""
|
|
Optional multi-modal data to pass to the model,
|
|
if the model supports it.
|
|
"""
|
|
|
|
multi_modal_uuids: NotRequired[Dict[str, List[Any]]]
|
|
"""
|
|
Optional user-provided UUIDs for multimodal items.
|
|
Structure mirrors multi_modal_data: {"image": ["uuid1", None, "uuid3"]}.
|
|
When a UUID is provided for an item, it will be returned in KV cache events
|
|
instead of the computed content hash. Use None to fall back to content
|
|
hashing for specific items.
|
|
"""
|
|
|
|
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
|
"""
|
|
Optional multi-modal processor kwargs to be forwarded to the
|
|
input processor for mm input processing.
|
|
"""
|
|
|
|
query: NotRequired[str]
|
|
"""The query input text for star attention."""
|
|
|
|
|
|
class TokensPrompt(TypedDict):
|
|
"""Schema for a tokenized prompt."""
|
|
|
|
prompt_token_ids: List[int]
|
|
"""A list of token IDs to pass to the model."""
|
|
|
|
multi_modal_data: NotRequired[Dict[str, Any]]
|
|
"""
|
|
Optional multi-modal data to pass to the model,
|
|
if the model supports it.
|
|
"""
|
|
|
|
multi_modal_uuids: NotRequired[Dict[str, List[Any]]]
|
|
"""
|
|
Optional user-provided UUIDs for multimodal items.
|
|
Structure mirrors multi_modal_data: {"image": ["uuid1", None, "uuid3"]}.
|
|
When a UUID is provided for an item, it will be returned in KV cache events
|
|
instead of the computed content hash. Use None to fall back to content
|
|
hashing for specific items.
|
|
"""
|
|
|
|
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
|
"""
|
|
Optional multi-modal processor kwargs to be forwarded to the
|
|
input processor for mm input processing.
|
|
"""
|
|
|
|
query_token_ids: NotRequired[List[int]]
|
|
"""The query input token IDs for star attention."""
|
|
|
|
|
|
PromptInputs = Union[str, List[int], TextPrompt, TokensPrompt]
|
|
|
|
|
|
def prompt_inputs(inputs: PromptInputs, ) -> Union[TextPrompt, TokensPrompt]:
|
|
if isinstance(inputs, str):
|
|
prompt_inputs = TextPrompt(prompt=inputs)
|
|
elif isinstance(inputs, list):
|
|
assert isinstance(inputs[0], int)
|
|
prompt_inputs = TokensPrompt(prompt_token_ids=inputs)
|
|
elif isinstance(inputs, dict):
|
|
assert inputs.get("prompt") is not None \
|
|
or inputs.get("prompt_token_ids") is not None
|
|
return inputs
|
|
else:
|
|
raise TypeError(
|
|
f"Invalid type of inputs for llm.generate: {type(inputs)}")
|
|
|
|
return prompt_inputs
|
|
|
|
|
|
class VisualGenTextPrompt(TypedDict):
|
|
prompt: str
|
|
negative_prompt: NotRequired[str]
|
|
|
|
|
|
class VisualGenTokensPrompt(TypedDict):
|
|
prompt_token_ids: List[int]
|
|
negative_prompt_token_ids: NotRequired[List[int]]
|
|
|
|
|
|
VisualGenPromptInputs = Union[
|
|
str,
|
|
List[int],
|
|
VisualGenTextPrompt,
|
|
VisualGenTokensPrompt,
|
|
]
|
|
|
|
VisualGenInputs = Union[
|
|
VisualGenPromptInputs,
|
|
Sequence[VisualGenPromptInputs],
|
|
]
|
|
|
|
|
|
def visual_gen_inputs(
|
|
inputs: "VisualGenPromptInputs",
|
|
) -> Union["VisualGenTextPrompt", "VisualGenTokensPrompt"]:
|
|
# str -> text prompt
|
|
if isinstance(inputs, str):
|
|
return VisualGenTextPrompt(prompt=inputs)
|
|
|
|
# list[int] -> token prompt
|
|
if isinstance(inputs, list):
|
|
if len(inputs) == 0:
|
|
raise ValueError("`inputs` token list cannot be empty.")
|
|
if not all(isinstance(t, int) for t in inputs):
|
|
raise TypeError(
|
|
"`inputs` list must contain only ints when used as token IDs.")
|
|
return VisualGenTokensPrompt(prompt_token_ids=inputs)
|
|
|
|
# dict form
|
|
if isinstance(inputs, dict):
|
|
has_prompt = "prompt" in inputs
|
|
has_prompt_token_ids = "prompt_token_ids" in inputs
|
|
|
|
if has_prompt == has_prompt_token_ids:
|
|
raise ValueError(
|
|
"VisualGen prompt dict must contain exactly one of "
|
|
"`prompt` or `prompt_token_ids`.")
|
|
|
|
if has_prompt:
|
|
prompt = inputs.get("prompt")
|
|
if not isinstance(prompt, str) or prompt == "":
|
|
raise TypeError("`prompt` must be a non-empty string.")
|
|
if "negative_prompt" in inputs and not isinstance(
|
|
inputs["negative_prompt"], str):
|
|
raise TypeError("`negative_prompt` must be a string.")
|
|
return inputs # VisualGenTextPrompt
|
|
|
|
token_ids = inputs.get("prompt_token_ids")
|
|
if not isinstance(token_ids, list) or len(token_ids) == 0:
|
|
raise TypeError("`prompt_token_ids` must be a non-empty list[int].")
|
|
if not all(isinstance(t, int) for t in token_ids):
|
|
raise TypeError("`prompt_token_ids` must contain only ints.")
|
|
if "negative_prompt_token_ids" in inputs:
|
|
neg_ids = inputs["negative_prompt_token_ids"]
|
|
if not isinstance(neg_ids, list) or not all(
|
|
isinstance(t, int) for t in neg_ids):
|
|
raise TypeError(
|
|
"`negative_prompt_token_ids` must be a list[int].")
|
|
return inputs # VisualGenTokensPrompt
|
|
|
|
raise TypeError(
|
|
"Invalid `inputs` for VisualGen.generate. "
|
|
"Expected one of: str, list[int], VisualGenTextPrompt, VisualGenTokensPrompt."
|
|
)
|