TensorRT-LLMs/tensorrt_llm/inputs/registry.py
Faraz fdbdbba540
[https://nvbugs/5752687][fix] Choose register model config over root config for VLM (#10553)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
2026-01-09 12:10:52 -05:00

762 lines
30 KiB
Python

import enum
import random
import traceback
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import (Any, Callable, Dict, List, Optional, Protocol, Tuple, Type,
TypeVar, Union)
import torch
from PIL import Image
from torch import Tensor, nn
from transformers import (AutoProcessor, PretrainedConfig,
PreTrainedTokenizerBase)
import tensorrt_llm
from .._utils import nvtx_range_debug
from ..logger import logger
from ..sampling_params import SamplingParams
from .data import TextPrompt
from .multimodal import (MultimodalInput, apply_mm_hashes, default_hasher,
find_mm_token_lengths, find_mm_token_positions,
hexdigest_to_int32, validate_mm_inputs)
N = TypeVar("N", bound=Type[nn.Module])
ExtraProcessedInputs = Dict[str, Any]
class InputProcessor(Protocol):
"""
Protocol for InputProcessor classes.
InputProcessor's functions are more relevant to multimodal use cases:
- Preprocess: extra steps to manipulate the prompts.
- Forward: the main logic to process the inputs. In multimodal cases, this may run a multimodal encoder model.
- Postprocess: extra steps to manipulate the outputs
Model-specific implementation should:
- Inherit this class and implement the forward() method.
- Register the inherited class to the model class using @register_input_processor(...)
"""
model_path: any
config: any
tokenizer: any
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
...
class DefaultInputProcessor(InputProcessor):
"""Preprocess the inputs to the model."""
def __init__(self,
model_path,
config,
tokenizer,
trust_remote_code: bool = True) -> None:
self.tokenizer = tokenizer
self.config = config
self.model_path = model_path
self.multimodal_hashing_supported = None
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""The default input processor handles only tokenization."""
if self.tokenizer is None:
raise ValueError("tokenizer is required to tokenize string prompt")
kwargs = {}
if sampling_params.truncate_prompt_tokens is not None:
kwargs = dict(truncation=True,
max_length=sampling_params.truncate_prompt_tokens)
toktoken_special_tokens = {
"<|startoftext|>",
"<|endoftext|>",
"<|reserved_200000|>",
"<|reserved_200001|>",
"<|return|>",
"<|constrain|>",
"<|reserved_200004|>",
"<|channel|>",
"<|start|>",
"<|end|>",
"<|message|>",
"<|reserved_200009|>",
"<|reserved_200010|>",
"<|reserved_200011|>",
"<|call|>",
"<|reserved_200013|>",
}
with nvtx_range_debug("tokenize prompt"):
try:
token_ids = self.tokenizer.encode(
inputs["prompt"],
add_special_tokens=sampling_params.add_special_tokens,
**kwargs)
except:
# Tiktoken path
token_ids = self.tokenizer.encode(
inputs["prompt"], allowed_special=toktoken_special_tokens)
if "query" in inputs:
with nvtx_range_debug("tokenize query"):
try:
query_token_ids = self.tokenizer.encode(
inputs["query"],
add_special_tokens=sampling_params.add_special_tokens,
**kwargs)
except:
# Tiktoken path
query_token_ids = self.tokenizer.encode(
inputs["query"],
allowed_special=toktoken_special_tokens)
return token_ids, {"query_token_ids": query_token_ids}
return token_ids, None
class BaseMultimodalInputProcessor(ABC):
"""
Base class for multimodal input processors with default implementations
of get_num_tokens_per_image and get_num_tokens_per_video methods.
This class provides default implementations that work with most AutoProcessor-based
models. Specific processors can override these methods if they need custom logic.
"""
def __init__(self,
model_path,
config,
tokenizer,
trust_remote_code: bool = True,
**kwargs) -> None:
super().__init__(**kwargs)
self._config = config
self._model_path = model_path
self._tokenizer = tokenizer
self._use_fast: bool = kwargs.get('use_fast', True)
self._multimodal_hashing_supported: Optional[bool] = None
def attach_multimodal_embeddings(
self,
inputs: TextPrompt,
multimodal_embedding: Dict[str, List[torch.Tensor]],
sampling_params: SamplingParams,
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""
Handle externally provided multimodal input embeddings.
While inputs["multi_modal_data"] is handled by __call__, this method is intended to process
inputs["multi_modal_embeddings"].
"""
raise NotImplementedError(
"Input processor does not support multimodal embedding input")
@property
@abstractmethod
def processor(self) -> AutoProcessor:
"""The HF AutoProcessor for this model."""
...
@property
@abstractmethod
def tokenizer(self) -> PreTrainedTokenizerBase:
"""The HF tokenizer for this model."""
return self._tokenizer
@property
@abstractmethod
def config(self) -> PretrainedConfig:
"""The HF pretrained config for this model."""
return self._config
@property
@abstractmethod
def dtype(self) -> torch.dtype:
"""The dtype for this model."""
...
@abstractmethod
def __call__(
self, inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
...
@property
def use_fast(self) -> bool:
"""
Whether to use fast tokenizer for AutoProcessor.
Default is True for most multimodal models.
"""
return self._use_fast
@property
def multimodal_hashing_supported(self) -> Optional[bool]:
"""
Whether multimodal hashing is supported for this processor.
Returns None if unknown (will be detected at runtime),
True if supported, False if not supported.
"""
return self._multimodal_hashing_supported
@multimodal_hashing_supported.setter
def multimodal_hashing_supported(self, value: Optional[bool]) -> None:
"""Set the multimodal hashing support status (used for runtime detection)."""
self._multimodal_hashing_supported = value
def get_vocab_size(self) -> Optional[int]:
"""Return the tokenizer/model vocabulary size if available; otherwise None.
Resolution order:
1) self.config.vocab_size
2) self.tokenizer.vocab_size
"""
# 1) Model config
if hasattr(self.config, 'vocab_size'):
return int(self.config.vocab_size)
# 2) Direct tokenizer on self
if hasattr(self.tokenizer, 'vocab_size'):
return int(self.tokenizer.vocab_size)
logger.debug(
f"Cannot determine vocab_size from {self.__class__.__name__}. "
"Please override this method to provide the vocabulary size. ")
return None
def get_mm_token_ids(self) -> Optional[Tensor]:
"""Return multimodal token IDs if available; otherwise None.
The token IDs filtered by this method should be contiguous for each multimodal item, i.e. special tokens if any should be included.
"""
if hasattr(self.processor, 'mm_token_ids'):
return self.processor.mm_token_ids
logger.debug(
f"Cannot find mm_token_ids in {self.__class__.__name__}.processor. "
"If needed, please override this method to return multimodal token ids. "
)
return None
def get_mm_special_token_ids(self) -> Optional[Tensor]:
"""
Return multimodal special token IDs if available; otherwise None.
Special tokens refer to multimodal-related tokens (e.g. <image_end>, <image_break>) that are not part
of the ViT output but come from text embeddings. Some VLMs
(e.g., Mistral3, LLaMA4) mix special tokens with multimodal tokens,
so they need to be returned separately.
"""
return getattr(self.processor, "mm_special_token_ids", None)
@property
def get_num_multimodal_tokens(self):
"""
Get the Hugging Face processor's '_get_num_multimodal_tokens' method.
"""
if hasattr(self.processor, '_get_num_multimodal_tokens'):
return self.processor._get_num_multimodal_tokens
else:
raise NotImplementedError(
f"get_num_multimodal_tokens not implemented for {self.__class__.__name__}. "
"Please override this method or ensure the processor has _get_num_multimodal_tokens method."
)
def get_num_tokens_per_image(
self,
*,
image: Image.Image,
**kwargs,
):
"""
Calculate the number of tokens generated for an image.
This (default) method delegates to the Hugging Face processor's '_get_num_multimodal_tokens' method.
Returns the token count for the given image.
Subclasses can override this method to provide custom logic to calculate the number of tokens.
"""
image_height = image.height
image_width = image.width
image_size = (image_height, image_width)
return self.get_num_multimodal_tokens([image_size],
**kwargs)["num_image_tokens"][0]
def get_num_tokens_per_video(
self,
*,
video: List[Image.Image],
**kwargs,
):
"""
Calculate the number of tokens generated for a video.
This (default) method delegates to the Hugging Face processor's '_get_num_multimodal_tokens' method.
Returns the token count for the given video.
Subclasses can override this method to provide custom logic to calculate the number of tokens.
"""
video_width = video[0].width
video_height = video[0].height
num_frames = len(video)
video_size = (num_frames, video_height, video_width)
try:
num_video_tokens = self.get_num_multimodal_tokens(
video_sizes=[video_size], **kwargs)["num_video_tokens"][0]
return num_video_tokens
except Exception:
# Fallback: treat video as sequence of frames
num_tokens_per_frame = self.get_num_tokens_per_image(image=video[0],
**kwargs)
temporal_patch_size = self.temporal_patch_size if hasattr(
self, 'temporal_patch_size') else 1
return num_tokens_per_frame * num_frames // temporal_patch_size
class BaseMultimodalDummyInputsBuilder(ABC):
"""
Base class for generating dummy inputs. Specially for profiling
"""
DEFAULT_IMAGE_MAX_DIM = 16384
DEFAULT_IMAGE_MIN_DIM = 128
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.image_max_dim = kwargs.get('image_max_dim',
self.DEFAULT_IMAGE_MAX_DIM)
self.image_min_dim = kwargs.get('image_min_dim',
self.DEFAULT_IMAGE_MIN_DIM)
@property
@abstractmethod
def tokenizer(self) -> PreTrainedTokenizerBase:
...
@property
@abstractmethod
def config(self) -> PretrainedConfig:
...
@property
@abstractmethod
def model_path(self) -> str:
...
def get_dummy_image(self, max_width: int, max_height: int) -> Image.Image:
image = Image.new("RGB", (max_width, max_height),
color=random.randint(0, 256))
return image
def get_dummy_prompt(self, input_seq_len: int):
# TODO(yechank): We use the max resolution as starting point and keep reducing the resolution until the prompt length is less than the input sequence length.
# Need to find better way to calculate the dummy prompt length as this iteration may not be efficient.
# Use the registered model_type from the decorator if available,
# otherwise fall back to HuggingFace config's model_type.
# This ensures consistency between placeholder registration and lookup.
registered_model_type = getattr(self.__class__,
'_registered_model_type', None)
config_model_type = self.config.model_type
model_type = registered_model_type or config_model_type
logger.debug(
f"[get_dummy_prompt] registered_model_type={registered_model_type}, "
f"config.model_type={config_model_type}, using model_type={model_type}"
)
while self.image_max_dim >= self.image_min_dim:
image = self.get_dummy_image(max_width=self.image_max_dim,
max_height=self.image_max_dim)
test_mm_prompt = tensorrt_llm.inputs.utils.default_multimodal_input_loader(
tokenizer=self.tokenizer,
model_dir=self.model_path,
model_type=model_type,
modality="image",
prompts=[""],
media=[[image]],
image_data_format="pt")[0]
prompt_token_ids_single_img, _ = self(test_mm_prompt, None)
if len(prompt_token_ids_single_img) <= input_seq_len:
return test_mm_prompt
# reduce img resolution
self.image_max_dim = self.image_max_dim >> 1
return None
class MultimodalPlaceholderPlacement(enum.Enum):
"""
The placement of the multimodal placeholder in the prompt. Valid values are:
- BEFORE_TEXT: the placeholders are placed before the text prompt.
- AFTER_TEXT: the placeholders are placed after the text prompt.
"""
INVALID = -1
BEFORE_TEXT = 0
AFTER_TEXT = 1
@dataclass(frozen=True)
class MultimodalPlaceholderMetadata:
"""
Metadata for the multimodal placeholder. It has 3 components:
- placeholder_map:
A mapping from modality to placeholder string.
Modality can be "image", "video", "audio", etc.
- placeholder_placement:
The placement of the placeholders, e.g. before or after the text prompt.
- placeholders_separator:
The separator between the placeholders, e.g. some models use "\n" to separate the placeholders.
"""
placeholder_map: Dict[str, str] = field(default_factory=dict)
placeholder_placement: MultimodalPlaceholderPlacement = MultimodalPlaceholderPlacement.AFTER_TEXT
placeholders_separator: str = "\n"
class MultimodalPlaceholderRegistry:
"""
Registry for the multimodal models to keep track of the placeholder information.
"""
def __init__(self) -> None:
self._multimodal_placeholder_by_model_type: Dict[
str, MultimodalPlaceholderMetadata] = {}
def __str__(self) -> str:
s = ""
for model_type, placeholder_metadata in self._multimodal_placeholder_by_model_type.items(
):
s += "-" * 100 + "\n"
s += f"Model type: {model_type}\n"
s += f"Placeholder map: {placeholder_metadata.placeholder_map}\n"
s += f"Placeholder placement: {placeholder_metadata.placeholder_placement}\n"
s += f"Placeholders separator: \"{placeholder_metadata.placeholders_separator}\"\n"
s += "-" * 80 + "\n"
return s
def set_placeholder_metadata(
self, model_type: str,
placeholder_metadata: MultimodalPlaceholderMetadata):
self._multimodal_placeholder_by_model_type[
model_type] = placeholder_metadata
def remove_placeholder_metadata(self, model_type: str):
if model_type not in self._multimodal_placeholder_by_model_type:
raise ValueError(f"Model type '{model_type}' is not registered")
del self._multimodal_placeholder_by_model_type[model_type]
def is_valid(self, model_type: str, modality: str) -> bool:
return model_type in self._multimodal_placeholder_by_model_type and \
modality in self._multimodal_placeholder_by_model_type[model_type].placeholder_map
def get_placeholder_metadata(
self, model_type: str) -> MultimodalPlaceholderMetadata:
if model_type not in self._multimodal_placeholder_by_model_type:
raise ValueError(
f"Model type {model_type} is not registered in MultimodalPlaceholderRegistry"
)
return self._multimodal_placeholder_by_model_type[model_type]
def get_placeholder(self, model_type: str, modality: str) -> str:
if not self.is_valid(model_type, modality):
raise ValueError(
f"Model type '{model_type}' with modality '{modality}' is not registered."
)
return self._multimodal_placeholder_by_model_type[
model_type].placeholder_map[modality]
def get_placeholder_placement(
self, model_type: str) -> MultimodalPlaceholderPlacement:
if model_type not in self._multimodal_placeholder_by_model_type:
raise ValueError(f"Model type '{model_type}' is not registered")
return self._multimodal_placeholder_by_model_type[
model_type].placeholder_placement
def get_placeholders_separator(self, model_type: str) -> str:
if model_type not in self._multimodal_placeholder_by_model_type:
raise ValueError(f"Model type '{model_type}' is not registered")
return self._multimodal_placeholder_by_model_type[
model_type].placeholders_separator
def get_registered_image_model_types(self) -> Tuple[str, ...]:
return (
model_type
for model_type in self._multimodal_placeholder_by_model_type
if "image" in self.
_multimodal_placeholder_by_model_type[model_type].placeholder_map)
def get_registered_video_model_types(self) -> Tuple[str, ...]:
return (
model_type
for model_type in self._multimodal_placeholder_by_model_type
if "video" in self.
_multimodal_placeholder_by_model_type[model_type].placeholder_map)
def get_registered_audio_model_types(self) -> Tuple[str, ...]:
return (
model_type
for model_type in self._multimodal_placeholder_by_model_type
if "audio" in self.
_multimodal_placeholder_by_model_type[model_type].placeholder_map)
def get_registered_model_types(self) -> Tuple[str, ...]:
return tuple(self._multimodal_placeholder_by_model_type.keys())
MULTIMODAL_PLACEHOLDER_REGISTRY = MultimodalPlaceholderRegistry()
class InputProcessorRegistry:
def __init__(self) -> None:
self._input_processors_cls_by_model_type: Dict[
Type[nn.Module], Type[InputProcessor]] = {}
INPUT_PROCESSOR_REGISTRY = InputProcessorRegistry()
def support_multimodal_disaggregated(model_cls: Type[nn.Module]):
"""
Model-class decorator to declare support for multimodal disaggregated inputs.
Apply this to a model class AFTER its input processor has been registered via
@register_input_processor. The decorator will locate the processor class,
validate requirements, and set `supports_multimodal_disagg = True` on both
the processor class and the model class.
"""
processor_cls = INPUT_PROCESSOR_REGISTRY._input_processors_cls_by_model_type.get(
model_cls)
if processor_cls is None:
raise RuntimeError(
f"No input processor registered for {model_cls.__name__}; ensure @register_input_processor is applied closer to the class than @supports_multimodal_disagg."
)
if not issubclass(processor_cls, BaseMultimodalInputProcessor):
raise TypeError(
f"{processor_cls.__name__} must inherit from BaseMultimodalInputProcessor to support multimodal disagg"
)
method = getattr(processor_cls, "get_prompt_token_ids", None)
if method is None or not callable(method):
raise TypeError(
f"{processor_cls.__name__} must implement a callable method `get_prompt_token_ids` to support multimodal disagg"
)
setattr(processor_cls, "support_mm_disagg", True)
setattr(model_cls, "support_mm_disagg", True)
return model_cls
def register_input_processor(
processor_cls: Type[InputProcessor],
model_type: str,
placeholder_metadata: MultimodalPlaceholderMetadata = None):
"""
Register an input processor to a model class.
NOTE:
1. Since this API is only used for multimodal models, we are checking
the model type only for that.
2. If this is used for other models in the future, this logic needs to be
updated e.g. adding another version of this API without the model_type.
"""
def wrapper(model_cls: N) -> N:
INPUT_PROCESSOR_REGISTRY._input_processors_cls_by_model_type[
model_cls] = processor_cls
if placeholder_metadata is None:
raise ValueError(
f"A valid placeholder_metadata must be provided but got {placeholder_metadata}"
)
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
model_type, placeholder_metadata)
# Store model_type on processor class for use in get_dummy_prompt
processor_cls._registered_model_type = model_type
return model_cls
return wrapper
def create_input_processor(
model_path_or_dir: str,
tokenizer,
checkpoint_format: Optional[str] = "HF",
) -> Union[InputProcessor, BaseMultimodalInputProcessor]:
"""Create an input processor for a specific model.
Args:
model_path_or_dir: Path or repo id used to locate pretrained config/tokenizer.
tokenizer: Tokenizer instance.
checkpoint_format: Checkpoint format identifier. "HF" uses Hugging Face-style
config loading; any other value skips HF config loading. Default is "HF".
Returns:
An InputProcessor implementation (model-specific if registered; otherwise DefaultInputProcessor).
"""
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models import get_model_architecture
config = None
if checkpoint_format == "HF":
try:
model_config = ModelConfig.from_pretrained(model_path_or_dir,
trust_remote_code=True)
config = model_config.pretrained_config
except (ValueError, EnvironmentError) as e:
logger.debug(
f"Unable to load HF config from {model_path_or_dir}: {e}. Falling back."
)
elif checkpoint_format in ("mistral", "mistral_large_3"):
logger.debug(f"Detected checkpoint_format={checkpoint_format}.")
from tensorrt_llm._torch.models.checkpoints.mistral.config_loader import \
MistralConfigLoader
model_config = MistralConfigLoader().load(model_path_or_dir)
config = model_config.pretrained_config
else:
logger.debug(
f"checkpoint_format={checkpoint_format}; skipping HF config load.")
if config is not None:
try:
model_cls, _ = get_model_architecture(config)
input_processor_cls = INPUT_PROCESSOR_REGISTRY._input_processors_cls_by_model_type \
.get(model_cls)
except RuntimeError: # unregistered model
logger.info("Unregistered model, using DefaultInputProcessor")
input_processor_cls = None
if input_processor_cls is not None:
return input_processor_cls(model_path_or_dir,
config,
tokenizer,
trust_remote_code=True)
return DefaultInputProcessor(None, None, tokenizer)
def create_input_processor_with_hash(
input_processor: BaseMultimodalInputProcessor,
hash_lib=default_hasher,
) -> Callable[[TextPrompt, SamplingParams], Tuple[
List[int], Optional[ExtraProcessedInputs]]]:
"""Creates a modified processor that applies additional logic like (hashing, find mm chunk positions) to the input processor
Args:
original_processor: The original input processor to wrap.
hash_lib: hasher to use (default: blake3)
Returns:
A wrapped processor that modifies prompts before processing.
"""
def multimodal_hashing_process(
inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
"""
Process the multinmodal hashing for media tokens if possible.
"""
assert 'multi_modal_data' in inputs, "multi_modal_data must be provided for hashing support."
mm_data = inputs['multi_modal_data']
mm_hashes = apply_mm_hashes(mm_data, hash_lib)
prompt_token_ids, extra_processed_inputs = input_processor(
inputs, sampling_params)
num_mm_tokens = find_mm_token_lengths(mm_data, input_processor)
# TODO: here we assume there is only one modality for now
num_mm_tokens = next(iter(num_mm_tokens.values()))
if len(num_mm_tokens) <= 0:
return [], None
vocab_size = input_processor.get_vocab_size()
mm_ids = input_processor.get_mm_token_ids()
mm_special_token_ids = input_processor.get_mm_special_token_ids()
if vocab_size is None and mm_ids is None:
raise ValueError(
"Cannot locate vocab_size or mm_token_ids for multimodal token preprocessing"
)
start_positions, start_special_token_positions = find_mm_token_positions(
input_ids=prompt_token_ids, # token sequence
num_mm_tokens=
num_mm_tokens, # list of lengths of each chunk of visual tokens
vocab_size=vocab_size,
mm_token_ids=mm_ids,
mm_special_token_ids=mm_special_token_ids,
)
# Store special token offsets if available
if len(start_special_token_positions
) > 0 and mm_special_token_ids is not None:
extra_processed_inputs["multimodal_data"][
"special_token_offsets"] = start_special_token_positions
# flatten the hashes from dict to a single list
mm_hashes = [h for hashes in mm_hashes.values() for h in hashes]
validate_mm_inputs(prompt_token_ids, mm_hashes, start_positions,
num_mm_tokens)
mm_hashes_int32 = [hexdigest_to_int32(h) for h in mm_hashes
] # nested list w/ multiple int32 per hash
extra_processed_inputs[
"multimodal_input"] = MultimodalInput.from_components(
mm_hashes_int32, start_positions, num_mm_tokens)
return prompt_token_ids, extra_processed_inputs
def input_processor_wrapper(
inputs: TextPrompt, sampling_params: SamplingParams
) -> Tuple[List[int], Optional[ExtraProcessedInputs]]:
try_multimodal_hashing = False # only used for first time
use_multimodal_hashing = False # used for subsequent calls
modalities = list(set(inputs['multi_modal_data'].keys())
) if 'multi_modal_data' in inputs else []
if len(modalities) > 0:
# TODO: support multimodal hashing for multiple modalities within the same request
# TODO: add audio support
if len(modalities) == 1 and modalities[0] in ['image', 'video']:
# only try multimodal hashing if the inputs only contain image data
if input_processor.multimodal_hashing_supported is not None:
use_multimodal_hashing = input_processor.multimodal_hashing_supported
else:
# we need to try the multimodal hashing for the first time to determine if it is supported
try_multimodal_hashing = True
if try_multimodal_hashing or use_multimodal_hashing:
try:
prompt_token_ids, extra_processed_inputs = multimodal_hashing_process(
inputs, sampling_params)
if try_multimodal_hashing:
# if trying for first time, set the flag to True
input_processor.multimodal_hashing_supported = True
return prompt_token_ids, extra_processed_inputs
except Exception as e:
logger.warning(f"Multimodal hashing failed: {e}.")
if try_multimodal_hashing:
# if trying for first time, fall back to basic input processor
# and set the flag to False so that we don't try again
input_processor.multimodal_hashing_supported = False
logger.warning("Falling back to basic input processor.")
try:
return input_processor(inputs, sampling_params)
except Exception as e2:
logger.warning(f"Basic input processor failed: {e}.")
logger.debug(traceback.format_exc())
raise e2
else:
raise e
else:
try:
return input_processor(inputs, sampling_params)
except Exception as e:
logger.warning(f"Basic input processor failed: {e}.")
logger.debug(traceback.format_exc())
raise e
return input_processor_wrapper