mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
from typing import Any, Dict, List, Optional, Protocol, Tuple, Type, TypeVar
|
|
|
|
from torch import nn
|
|
|
|
from ..sampling_params import SamplingParams
|
|
from .data import TextPrompt
|
|
|
|
N = TypeVar("N", bound=Type[nn.Module])
|
|
|
|
ExtraProcessedInputs = Dict[str, Any]
|
|
|
|
|
|
class InputProcessor(Protocol):
|
|
"""Protocol for InputProcessor classes."""
|
|
|
|
tokenizer: any
|
|
model_config: any
|
|
|
|
def __call__(
|
|
self, inputs: TextPrompt, sampling_params: SamplingParams
|
|
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
|
"""Process the inputs to the model."""
|
|
...
|
|
|
|
|
|
class DefaultInputProcessor(InputProcessor):
|
|
"""Preprocess the inputs to the model."""
|
|
|
|
def __init__(self, model_config, tokenizer) -> None:
|
|
self.tokenizer = tokenizer
|
|
self.model_config = model_config
|
|
|
|
def __call__(
|
|
self, inputs: TextPrompt, sampling_params: SamplingParams
|
|
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
|
|
"""The default input processor handles only tokenization."""
|
|
if self.tokenizer is None:
|
|
raise ValueError("tokenizer is required to tokenize string prompt")
|
|
if sampling_params.truncate_prompt_tokens is None:
|
|
token_ids = self.tokenizer.encode(
|
|
inputs["prompt"],
|
|
add_special_tokens=sampling_params.add_special_tokens)
|
|
else:
|
|
token_ids = self.tokenizer.encode(
|
|
inputs["prompt"],
|
|
add_special_tokens=sampling_params.add_special_tokens,
|
|
truncation=True,
|
|
max_length=sampling_params.truncate_prompt_tokens)
|
|
return token_ids, None
|
|
|
|
|
|
class InputProcessorRegistry:
|
|
|
|
def __init__(self) -> None:
|
|
self._input_processors_cls_by_model_type: Dict[
|
|
Type[nn.Module], Type[InputProcessor]] = {}
|
|
|
|
|
|
INPUT_PROCESSOR_REGISTRY = InputProcessorRegistry()
|
|
|
|
|
|
def register_input_processor(processor_cls: Type[InputProcessor]):
|
|
"""
|
|
Register an input processor to a model class.
|
|
"""
|
|
|
|
def wrapper(model_cls: N) -> N:
|
|
INPUT_PROCESSOR_REGISTRY._input_processors_cls_by_model_type[
|
|
model_cls] = processor_cls
|
|
|
|
return model_cls
|
|
|
|
return wrapper
|
|
|
|
|
|
def create_input_processor(model_path_or_dir: str, tokenizer):
|
|
"""
|
|
Create an input processor for a specific model.
|
|
"""
|
|
from tensorrt_llm._torch.model_config import ModelConfig
|
|
from tensorrt_llm._torch.models import get_model_architecture
|
|
|
|
model_config = None
|
|
try:
|
|
config = ModelConfig.from_pretrained(model_path_or_dir,
|
|
trust_remote_code=True)
|
|
model_config = config.pretrained_config
|
|
except ValueError:
|
|
config = None
|
|
|
|
if model_config is not None:
|
|
try:
|
|
model_cls, _ = get_model_architecture(model_config)
|
|
input_processor_cls = INPUT_PROCESSOR_REGISTRY._input_processors_cls_by_model_type \
|
|
.get(model_cls)
|
|
except RuntimeError: # unregistered model
|
|
input_processor_cls = None
|
|
if input_processor_cls is not None:
|
|
return input_processor_cls(model_config, tokenizer)
|
|
|
|
return DefaultInputProcessor(None, tokenizer)
|