mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
# Adapt from
|
|
# https://github.com/vllm-project/vllm/blob/2e33fe419186c65a18da6668972d61d7bbc31564/vllm/inputs/data.py
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from typing_extensions import NotRequired, TypedDict
|
|
|
|
|
|
class TextPrompt(TypedDict):
|
|
"""Schema for a text prompt."""
|
|
|
|
prompt: str
|
|
"""The input text to be tokenized before passing to the model."""
|
|
|
|
multi_modal_data: NotRequired[Dict[str, Any]]
|
|
"""
|
|
Optional multi-modal data to pass to the model,
|
|
if the model supports it.
|
|
"""
|
|
|
|
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
|
"""
|
|
Optional multi-modal processor kwargs to be forwarded to the
|
|
input processor for mm input processing.
|
|
"""
|
|
|
|
|
|
class TokensPrompt(TypedDict):
|
|
"""Schema for a tokenized prompt."""
|
|
|
|
prompt_token_ids: List[int]
|
|
"""A list of token IDs to pass to the model."""
|
|
|
|
multi_modal_data: NotRequired[Dict[str, Any]]
|
|
"""
|
|
Optional multi-modal data to pass to the model,
|
|
if the model supports it.
|
|
"""
|
|
|
|
mm_processor_kwargs: NotRequired[Dict[str, Any]]
|
|
"""
|
|
Optional multi-modal processor kwargs to be forwarded to the
|
|
input processor for mm input processing.
|
|
"""
|
|
|
|
|
|
PromptInputs = Union[str, List[int], TextPrompt, TokensPrompt]
|
|
|
|
|
|
def prompt_inputs(
|
|
inputs: PromptInputs,
|
|
multi_modal_data: Optional[Dict[str, Any]] = None,
|
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
|
) -> Union[TextPrompt, TokensPrompt]:
|
|
if isinstance(inputs, str):
|
|
prompt_inputs = TextPrompt(prompt=inputs)
|
|
elif isinstance(inputs, list):
|
|
assert isinstance(inputs[0], int)
|
|
prompt_inputs = TokensPrompt(prompt_token_ids=inputs)
|
|
elif isinstance(inputs, dict):
|
|
assert inputs.get("prompt") is not None \
|
|
or inputs.get("prompt_token_ids") is not None
|
|
return inputs
|
|
else:
|
|
raise TypeError(
|
|
f"Invalid type of inputs for llm.generate: {type(inputs)}")
|
|
|
|
if multi_modal_data is not None:
|
|
prompt_inputs["multi_modal_data"] = multi_modal_data
|
|
if mm_processor_kwargs is not None:
|
|
prompt_inputs["mm_processor_kwargs"] = mm_processor_kwargs
|
|
|
|
return prompt_inputs
|