mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
* Why? As reported by #11170, when a single request contains multiple messages, and only a subset of those messages include multimodal data, the previous logic incorrectly adds placeholder tokens to subsequent messages that do not contain such data. * What? This commit fixes this issue, and adds unit tests that would have caught this. Signed-off-by: William Zhang <133824995+2ez4bz@users.noreply.github.com>
This commit is contained in:
parent
b4e9669d2c
commit
4debf153d8
@ -500,7 +500,7 @@ class MultimodalDataTracker:
|
||||
media_type: str,
|
||||
data: Union[Coroutine, Any],
|
||||
*,
|
||||
is_embedding: bool = False):
|
||||
is_embedding: bool = False) -> Optional[str]:
|
||||
current_count = len(self._data[media_type]) + len(
|
||||
self._embeddings[media_type]) + 1
|
||||
placeholder = retrieve_multimodal_placeholder(self._model_type,
|
||||
@ -509,6 +509,7 @@ class MultimodalDataTracker:
|
||||
if is_embedding else self._data)[media_type].append(data)
|
||||
if placeholder:
|
||||
self._placeholder_counts[placeholder] += 1
|
||||
return placeholder
|
||||
|
||||
def placeholder_counts(self) -> Dict[str, int]:
|
||||
"""Get the count of multimodal placeholders."""
|
||||
|
||||
@ -277,21 +277,28 @@ def parse_chat_messages_coroutines(
|
||||
mm_placeholder_counts = []
|
||||
mm_data_tracker = MultimodalDataTracker(model_config.model_type,
|
||||
multimodal_server_config)
|
||||
|
||||
for msg in messages:
|
||||
parsed_msg = parse_chat_message_content(msg, mm_data_tracker)
|
||||
conversation.append(parsed_msg)
|
||||
|
||||
# Track placeholders added for this message only.
|
||||
msg_placeholder_counts = {}
|
||||
if parsed_msg["media"]:
|
||||
for mdata in parsed_msg["media"]:
|
||||
mm_data_tracker.add_data(mdata["modality"],
|
||||
mdata["data"],
|
||||
is_embedding=mdata["is_embedding"])
|
||||
mm_placeholder_count = mm_data_tracker.placeholder_counts()
|
||||
if mm_placeholder_count:
|
||||
placeholder = mm_data_tracker.add_data(
|
||||
mdata["modality"],
|
||||
mdata["data"],
|
||||
is_embedding=mdata["is_embedding"])
|
||||
if placeholder:
|
||||
msg_placeholder_counts[
|
||||
placeholder] = msg_placeholder_counts.get(
|
||||
placeholder, 0) + 1
|
||||
|
||||
if msg_placeholder_counts:
|
||||
parsed_msg["content"] = add_multimodal_placeholders(
|
||||
model_config.model_type, parsed_msg["content"],
|
||||
mm_placeholder_count)
|
||||
mm_placeholder_counts.append(mm_placeholder_count)
|
||||
msg_placeholder_counts)
|
||||
mm_placeholder_counts.append(msg_placeholder_counts)
|
||||
|
||||
return conversation, mm_data_tracker.retrieve_all_async(
|
||||
), mm_placeholder_counts
|
||||
|
||||
@ -1,8 +1,14 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from transformers import AutoConfig
|
||||
|
||||
from tensorrt_llm.serve.chat_utils import load_chat_template, parse_chat_message_content
|
||||
from tensorrt_llm.inputs.registry import MULTIMODAL_PLACEHOLDER_REGISTRY
|
||||
from tensorrt_llm.serve.chat_utils import (
|
||||
load_chat_template,
|
||||
parse_chat_message_content,
|
||||
parse_chat_messages_coroutines,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -222,3 +228,123 @@ class TestLoadChatTemplate:
|
||||
template = "{{ messages }}"
|
||||
template_content = load_chat_template(template)
|
||||
assert template_content == template
|
||||
|
||||
|
||||
_MM_MODEL_TYPE = "qwen3_vl"
|
||||
_IMG_PLACEHOLDER = MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder(
|
||||
model_type=_MM_MODEL_TYPE, modality="image"
|
||||
)
|
||||
_VIDEO_PLACEHOLDER = MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder(
|
||||
model_type=_MM_MODEL_TYPE, modality="video"
|
||||
)
|
||||
|
||||
|
||||
class TestMultimodalPlaceholderCounts:
|
||||
"""Verify per-message multimodal placeholder counts.
|
||||
|
||||
Regression test: previously, image/video counts leaked between messages,
|
||||
causing later text-only messages to report stale placeholder counts.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"messages, expected_mm_placeholder_counts",
|
||||
[
|
||||
# Case #1: 2 messages with 1 image each, 3rd message is text-only.
|
||||
(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "foo"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "And this one?"},
|
||||
{"type": "image_url", "image_url": {"url": "bar"}},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "No image here, just text"},
|
||||
],
|
||||
[{_IMG_PLACEHOLDER: 1}, {_IMG_PLACEHOLDER: 1}, {}],
|
||||
),
|
||||
# Case #2: first and last message have one image each, 2nd is text-only.
|
||||
(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "foo"}},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "No image here, just text"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "And this one?"},
|
||||
{"type": "image_url", "image_url": {"url": "bar"}},
|
||||
],
|
||||
},
|
||||
],
|
||||
[{_IMG_PLACEHOLDER: 1}, {}, {_IMG_PLACEHOLDER: 1}],
|
||||
),
|
||||
# Case #3: 1st message with several images, 2nd without any, 3rd with a video.
|
||||
(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do these images have in common?"},
|
||||
{"type": "image_url", "image_url": {"url": "foo1"}},
|
||||
{"type": "image_url", "image_url": {"url": "foo2"}},
|
||||
{"type": "image_url", "image_url": {"url": "foo3"}},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "No image here, just text"},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe the image and the video."},
|
||||
{"type": "image_url", "image_url": {"url": "bar"}},
|
||||
{"type": "video_url", "video_url": {"url": "baz"}},
|
||||
],
|
||||
},
|
||||
],
|
||||
[{_IMG_PLACEHOLDER: 3}, {}, {_IMG_PLACEHOLDER: 1, _VIDEO_PLACEHOLDER: 1}],
|
||||
),
|
||||
# Case #4: 1st message with image and video, 2nd with several videos, last is text-only.
|
||||
(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe the image and the video."},
|
||||
{"type": "image_url", "image_url": {"url": "bar"}},
|
||||
{"type": "video_url", "video_url": {"url": "baz"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What do these videos have in common?"},
|
||||
{"type": "video_url", "video_url": {"url": "foo1"}},
|
||||
{"type": "video_url", "video_url": {"url": "foo2"}},
|
||||
{"type": "video_url", "video_url": {"url": "foo3"}},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "No image here, just text"},
|
||||
],
|
||||
[{_IMG_PLACEHOLDER: 1, _VIDEO_PLACEHOLDER: 1}, {_VIDEO_PLACEHOLDER: 3}, {}],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_per_message_counts(self, messages, expected_mm_placeholder_counts):
|
||||
mock_config = MagicMock(spec=AutoConfig)
|
||||
mock_config.model_type = _MM_MODEL_TYPE
|
||||
|
||||
_, _, mm_placeholder_counts = parse_chat_messages_coroutines(messages, mock_config, None)
|
||||
|
||||
assert mm_placeholder_counts == expected_mm_placeholder_counts
|
||||
|
||||
Loading…
Reference in New Issue
Block a user