mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[https://nvbugs/5752687][fix] Choose register model config over root config for VLM (#10553)
Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
parent
d80f01d205
commit
fdbdbba540
@ -356,6 +356,20 @@ class BaseMultimodalDummyInputsBuilder(ABC):
|
||||
def get_dummy_prompt(self, input_seq_len: int):
|
||||
# TODO(yechank): We use the max resolution as starting point and keep reducing the resolution until the prompt length is less than the input sequence length.
|
||||
# Need to find better way to calculate the dummy prompt length as this iteration may not be efficient.
|
||||
|
||||
# Use the registered model_type from the decorator if available,
|
||||
# otherwise fall back to HuggingFace config's model_type.
|
||||
# This ensures consistency between placeholder registration and lookup.
|
||||
registered_model_type = getattr(self.__class__,
|
||||
'_registered_model_type', None)
|
||||
config_model_type = self.config.model_type
|
||||
model_type = registered_model_type or config_model_type
|
||||
|
||||
logger.debug(
|
||||
f"[get_dummy_prompt] registered_model_type={registered_model_type}, "
|
||||
f"config.model_type={config_model_type}, using model_type={model_type}"
|
||||
)
|
||||
|
||||
while self.image_max_dim >= self.image_min_dim:
|
||||
image = self.get_dummy_image(max_width=self.image_max_dim,
|
||||
max_height=self.image_max_dim)
|
||||
@ -363,7 +377,7 @@ class BaseMultimodalDummyInputsBuilder(ABC):
|
||||
test_mm_prompt = tensorrt_llm.inputs.utils.default_multimodal_input_loader(
|
||||
tokenizer=self.tokenizer,
|
||||
model_dir=self.model_path,
|
||||
model_type=self.config.model_type,
|
||||
model_type=model_type,
|
||||
modality="image",
|
||||
prompts=[""],
|
||||
media=[[image]],
|
||||
@ -565,6 +579,9 @@ def register_input_processor(
|
||||
MULTIMODAL_PLACEHOLDER_REGISTRY.set_placeholder_metadata(
|
||||
model_type, placeholder_metadata)
|
||||
|
||||
# Store model_type on processor class for use in get_dummy_prompt
|
||||
processor_cls._registered_model_type = model_type
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
Loading…
Reference in New Issue
Block a user