TensorRT-LLMs/tensorrt_llm/inputs/data.py
Dan Blanaru 16d2467ea8 Update TensorRT-LLM (#2755)
* Update TensorRT-LLM

---------

Co-authored-by: Denis Kayshev <topenkoff@gmail.com>
Co-authored-by: akhoroshev <arthoroshev@gmail.com>
Co-authored-by: Patrick Reiter Horn <patrick.horn@gmail.com>

Update
2025-02-11 03:01:00 +00:00

67 lines
2.0 KiB
Python

# Adapt from
# https://github.com/vllm-project/vllm/blob/2e33fe419186c65a18da6668972d61d7bbc31564/vllm/inputs/data.py
from typing import Any, Dict, List, Union
from typing_extensions import NotRequired, TypedDict
class TextPrompt(TypedDict):
"""Schema for a text prompt."""
prompt: str
"""The input text to be tokenized before passing to the model."""
multi_modal_data: NotRequired[Dict[str, Any]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
input processor for mm input processing.
"""
class TokensPrompt(TypedDict):
"""Schema for a tokenized prompt."""
prompt_token_ids: List[int]
"""A list of token IDs to pass to the model."""
multi_modal_data: NotRequired[Dict[str, Any]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
mm_processor_kwargs: NotRequired[Dict[str, Any]]
"""
Optional multi-modal processor kwargs to be forwarded to the
input processor for mm input processing.
"""
PromptInputs = Union[str, List[int], TextPrompt, TokensPrompt]
def prompt_inputs(inputs: PromptInputs, ) -> Union[TextPrompt, TokensPrompt]:
if isinstance(inputs, str):
prompt_inputs = TextPrompt(prompt=inputs)
elif isinstance(inputs, list):
assert isinstance(inputs[0], int)
prompt_inputs = TokensPrompt(prompt_token_ids=inputs)
elif isinstance(inputs, dict):
assert inputs.get("prompt") is not None \
or inputs.get("prompt_token_ids") is not None
assert inputs.get("multi_modal_data") is None or inputs.get(
"mm_processor_kwargs"
) is not None, "For multimodal inputs, must specify both input data and processor kwargs."
return inputs
else:
raise TypeError(
f"Invalid type of inputs for llm.generate: {type(inputs)}")
return prompt_inputs