mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5520490][fix] Fix intermittent test failures by avoiding external web data pulls (#7879)
Signed-off-by: Chang Liu (Enterprise Products) <9713593+chang-l@users.noreply.github.com>
This commit is contained in:
parent
8adaf0bb78
commit
2e317a7db6
@ -1,9 +1,9 @@
|
||||
import io
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from PIL import Image
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import MultimodalEncoder
|
||||
from tensorrt_llm._torch.models.modeling_llava_next import \
|
||||
@ -12,38 +12,34 @@ from tensorrt_llm._torch.models.modeling_qwen2vl import \
|
||||
Qwen2VLInputProcessorBase
|
||||
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
|
||||
from tensorrt_llm.inputs import default_multimodal_input_loader
|
||||
from tensorrt_llm.inputs.utils import load_video
|
||||
from tensorrt_llm.inputs.utils import load_image, load_video
|
||||
|
||||
test_data_root = Path(
|
||||
os.path.join(llm_models_root(), "multimodals", "test_data"))
|
||||
example_images = [
|
||||
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
||||
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
|
||||
str(test_data_root / "seashore.png"),
|
||||
str(test_data_root / "inpaint.png"),
|
||||
str(test_data_root / "61.jpg"),
|
||||
]
|
||||
|
||||
example_videos = [
|
||||
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
|
||||
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
|
||||
str(test_data_root / "OAI-sora-tokyo-walk.mp4"),
|
||||
str(test_data_root / "world.mp4"),
|
||||
]
|
||||
|
||||
|
||||
def download_image(url: str) -> Image.Image:
|
||||
"""Download image from URL and return as PIL Image."""
|
||||
response = requests.get(url, timeout=30)
|
||||
response.raise_for_status()
|
||||
img = Image.open(io.BytesIO(response.content))
|
||||
return img.convert("RGB")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def multimodal_model_configs():
|
||||
"""Get multimodal model configurations for testing."""
|
||||
model_configs = {
|
||||
'llava-v1.6-mistral-7b-hf': {
|
||||
'hf_model_dir': 'llava-hf/llava-v1.6-mistral-7b-hf',
|
||||
'model_dir':
|
||||
llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf",
|
||||
'model_type': 'llava_next',
|
||||
},
|
||||
'qwen2.5-vl': {
|
||||
'hf_model_dir': 'Qwen/Qwen2.5-VL-3B-Instruct',
|
||||
'model_dir': llm_models_root() / "Qwen2.5-VL-3B-Instruct",
|
||||
'model_type': 'qwen2_5_vl',
|
||||
},
|
||||
}
|
||||
@ -66,7 +62,7 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs):
|
||||
pytest.skip(f"Skipping test for {model_key} - model not available")
|
||||
|
||||
model_config = multimodal_model_configs[model_key]
|
||||
encoder_model_dir = model_config['hf_model_dir']
|
||||
encoder_model_dir = model_config['model_dir']
|
||||
model_type = model_config['model_type']
|
||||
|
||||
# Test configuration
|
||||
@ -119,10 +115,10 @@ def test_get_num_tokens_per_image(model_key, multimodal_model_configs):
|
||||
example_images
|
||||
), f"Expected {len(example_images)} encoder outputs, got {len(encoder_outputs)}"
|
||||
|
||||
for image_idx, test_image_url in enumerate(example_images):
|
||||
for image_idx, test_image in enumerate(example_images):
|
||||
|
||||
# Get test image dimensions
|
||||
test_image = download_image(test_image_url)
|
||||
test_image = load_image(test_image, format="pil")
|
||||
image_width, image_height = test_image.size
|
||||
|
||||
# Get actual embedding tensor for this image
|
||||
@ -173,7 +169,7 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs):
|
||||
pytest.skip(f"Skipping test for {model_key} - model not available")
|
||||
|
||||
model_config = multimodal_model_configs[model_key]
|
||||
encoder_model_dir = model_config['hf_model_dir']
|
||||
encoder_model_dir = model_config['model_dir']
|
||||
model_type = model_config['model_type']
|
||||
|
||||
# Test configuration
|
||||
@ -226,10 +222,10 @@ def test_get_num_tokens_per_video(model_key, multimodal_model_configs):
|
||||
example_videos
|
||||
), f"Expected {len(example_videos)} encoder outputs, got {len(encoder_outputs)}"
|
||||
|
||||
for video_idx, test_video_url in enumerate(example_videos):
|
||||
for video_idx, test_video in enumerate(example_videos):
|
||||
|
||||
# Get test video dimensions
|
||||
test_video = load_video(test_video_url, num_frames=8, format="pil")
|
||||
test_video = load_video(test_video, num_frames=8, format="pil")
|
||||
# load_video returns a list of frames, we only have one video
|
||||
video_width, video_height = test_video[0].size
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
from tensorrt_llm import MultimodalEncoder
|
||||
from tensorrt_llm._torch.shared_tensor import SharedTensorContainer
|
||||
@ -22,8 +23,12 @@ def multimodal_model_config():
|
||||
# You can extend this to support multiple models or get from environment
|
||||
model_configs = {
|
||||
'llava-v1.6-mistral-7b-hf': {
|
||||
'model_name': 'llava-v1.6-mistral-7b-hf',
|
||||
'hf_model_dir': 'llava-hf/llava-v1.6-mistral-7b-hf',
|
||||
'model_name':
|
||||
'llava-v1.6-mistral-7b-hf',
|
||||
'hf_model_dir':
|
||||
'llava-hf/llava-v1.6-mistral-7b-hf',
|
||||
'model_dir':
|
||||
llm_models_root() / "multimodals" / "llava-v1.6-mistral-7b-hf",
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,7 +52,7 @@ def test_single_image_chat(model_key, multimodal_model_config):
|
||||
)
|
||||
|
||||
# Extract model information from config
|
||||
encoder_model_dir = multimodal_model_config['hf_model_dir']
|
||||
encoder_model_dir = multimodal_model_config['model_dir']
|
||||
|
||||
# Test configuration
|
||||
max_tokens = 64
|
||||
|
||||
Loading…
Reference in New Issue
Block a user