diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 0dc09547b3..b1165ac34e 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -500,7 +500,7 @@ class MultimodalDataTracker: media_type: str, data: Union[Coroutine, Any], *, - is_embedding: bool = False): + is_embedding: bool = False) -> Optional[str]: current_count = len(self._data[media_type]) + len( self._embeddings[media_type]) + 1 placeholder = retrieve_multimodal_placeholder(self._model_type, @@ -509,6 +509,7 @@ class MultimodalDataTracker: if is_embedding else self._data)[media_type].append(data) if placeholder: self._placeholder_counts[placeholder] += 1 + return placeholder def placeholder_counts(self) -> Dict[str, int]: """Get the count of multimodal placeholders.""" diff --git a/tensorrt_llm/serve/chat_utils.py b/tensorrt_llm/serve/chat_utils.py index e08caadaaf..1d2096a387 100644 --- a/tensorrt_llm/serve/chat_utils.py +++ b/tensorrt_llm/serve/chat_utils.py @@ -277,21 +277,28 @@ def parse_chat_messages_coroutines( mm_placeholder_counts = [] mm_data_tracker = MultimodalDataTracker(model_config.model_type, multimodal_server_config) - for msg in messages: parsed_msg = parse_chat_message_content(msg, mm_data_tracker) conversation.append(parsed_msg) + + # Track placeholders added for this message only. + msg_placeholder_counts = {} if parsed_msg["media"]: for mdata in parsed_msg["media"]: - mm_data_tracker.add_data(mdata["modality"], - mdata["data"], - is_embedding=mdata["is_embedding"]) - mm_placeholder_count = mm_data_tracker.placeholder_counts() - if mm_placeholder_count: + placeholder = mm_data_tracker.add_data( + mdata["modality"], + mdata["data"], + is_embedding=mdata["is_embedding"]) + if placeholder: + msg_placeholder_counts[ + placeholder] = msg_placeholder_counts.get( + placeholder, 0) + 1 + + if msg_placeholder_counts: parsed_msg["content"] = add_multimodal_placeholders( model_config.model_type, parsed_msg["content"], - mm_placeholder_count) - mm_placeholder_counts.append(mm_placeholder_count) + msg_placeholder_counts) + mm_placeholder_counts.append(msg_placeholder_counts) return conversation, mm_data_tracker.retrieve_all_async( ), mm_placeholder_counts diff --git a/tests/unittest/llmapi/apps/test_chat_utils.py b/tests/unittest/llmapi/apps/test_chat_utils.py index 4e169ad968..efb1e8a665 100644 --- a/tests/unittest/llmapi/apps/test_chat_utils.py +++ b/tests/unittest/llmapi/apps/test_chat_utils.py @@ -1,8 +1,14 @@ from unittest.mock import MagicMock import pytest +from transformers import AutoConfig -from tensorrt_llm.serve.chat_utils import load_chat_template, parse_chat_message_content +from tensorrt_llm.inputs.registry import MULTIMODAL_PLACEHOLDER_REGISTRY +from tensorrt_llm.serve.chat_utils import ( + load_chat_template, + parse_chat_message_content, + parse_chat_messages_coroutines, +) @pytest.fixture @@ -222,3 +228,123 @@ class TestLoadChatTemplate: template = "{{ messages }}" template_content = load_chat_template(template) assert template_content == template + + +_MM_MODEL_TYPE = "qwen3_vl" +_IMG_PLACEHOLDER = MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder( + model_type=_MM_MODEL_TYPE, modality="image" +) +_VIDEO_PLACEHOLDER = MULTIMODAL_PLACEHOLDER_REGISTRY.get_placeholder( + model_type=_MM_MODEL_TYPE, modality="video" +) + + +class TestMultimodalPlaceholderCounts: + """Verify per-message multimodal placeholder counts. + + Regression test: previously, image/video counts leaked between messages, + causing later text-only messages to report stale placeholder counts. + """ + + @pytest.mark.parametrize( + "messages, expected_mm_placeholder_counts", + [ + # Case #1: 2 messages with 1 image each, 3rd message is text-only. + ( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "foo"}}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "And this one?"}, + {"type": "image_url", "image_url": {"url": "bar"}}, + ], + }, + {"role": "user", "content": "No image here, just text"}, + ], + [{_IMG_PLACEHOLDER: 1}, {_IMG_PLACEHOLDER: 1}, {}], + ), + # Case #2: first and last message have one image each, 2nd is text-only. + ( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + {"type": "image_url", "image_url": {"url": "foo"}}, + ], + }, + {"role": "user", "content": "No image here, just text"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "And this one?"}, + {"type": "image_url", "image_url": {"url": "bar"}}, + ], + }, + ], + [{_IMG_PLACEHOLDER: 1}, {}, {_IMG_PLACEHOLDER: 1}], + ), + # Case #3: 1st message with several images, 2nd without any, 3rd with a video. + ( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What do these images have in common?"}, + {"type": "image_url", "image_url": {"url": "foo1"}}, + {"type": "image_url", "image_url": {"url": "foo2"}}, + {"type": "image_url", "image_url": {"url": "foo3"}}, + ], + }, + {"role": "user", "content": "No image here, just text"}, + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe the image and the video."}, + {"type": "image_url", "image_url": {"url": "bar"}}, + {"type": "video_url", "video_url": {"url": "baz"}}, + ], + }, + ], + [{_IMG_PLACEHOLDER: 3}, {}, {_IMG_PLACEHOLDER: 1, _VIDEO_PLACEHOLDER: 1}], + ), + # Case #4: 1st message with image and video, 2nd with several videos, last is text-only. + ( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe the image and the video."}, + {"type": "image_url", "image_url": {"url": "bar"}}, + {"type": "video_url", "video_url": {"url": "baz"}}, + ], + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "What do these videos have in common?"}, + {"type": "video_url", "video_url": {"url": "foo1"}}, + {"type": "video_url", "video_url": {"url": "foo2"}}, + {"type": "video_url", "video_url": {"url": "foo3"}}, + ], + }, + {"role": "user", "content": "No image here, just text"}, + ], + [{_IMG_PLACEHOLDER: 1, _VIDEO_PLACEHOLDER: 1}, {_VIDEO_PLACEHOLDER: 3}, {}], + ), + ], + ) + def test_per_message_counts(self, messages, expected_mm_placeholder_counts): + mock_config = MagicMock(spec=AutoConfig) + mock_config.model_type = _MM_MODEL_TYPE + + _, _, mm_placeholder_counts = parse_chat_messages_coroutines(messages, mock_config, None) + + assert mm_placeholder_counts == expected_mm_placeholder_counts