TensorRT-LLMs/tests/_torch/modeling/test_modeling_vila.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

487 lines
18 KiB
Python

import os
import sys
import unittest
from copy import deepcopy
from typing import Any
from unittest.mock import patch
import torch
from parameterized import parameterized
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_vila import (VilaLlamaConfig,
VilaLlamaModel)
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.util import getSMVersion
VILA_1_5_3B_CONFIG = config = {
"_name_or_path": "./vlm",
"architectures": ["LlavaLlamaModel"],
"drop_path_rate": 0.0,
"hidden_size": 2560,
"image_aspect_ratio": "resize",
"interpolate_mode": "linear",
"llm_cfg": {
"_name_or_path": "./llm",
"add_cross_attention": False,
"architectures": ["LlamaForCausalLM"],
"attention_bias": False,
"attention_dropout": 0.0,
"bad_words_ids": None,
"begin_suppress_tokens": None,
"bos_token_id": 1,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": None,
"decoder_start_token_id": None,
"diversity_penalty": 0.0,
"do_sample": False,
"early_stopping": False,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": 2,
"exponential_decay_length_penalty": None,
"finetuning_task": None,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"hidden_act": "silu",
"hidden_size": 2560,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"initializer_range": 0.02,
"intermediate_size": 6912,
"is_decoder": False,
"is_encoder_decoder": False,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"length_penalty": 1.0,
"max_length": 20,
"max_position_embeddings": 4096,
"min_length": 0,
"model_max_length": 4096,
"model_type": "llama",
"no_repeat_ngram_size": 0,
"num_attention_heads": 20,
"num_beam_groups": 1,
"num_beams": 1,
"num_hidden_layers": 32,
"num_key_value_heads": 20,
"num_return_sequences": 1,
"output_attentions": False,
"output_hidden_states": False,
"output_scores": False,
"pad_token_id": 0,
"prefix": None,
"pretraining_tp": 1,
"problem_type": None,
"pruned_heads": {},
"remove_invalid_values": False,
"repetition_penalty": 1.0,
"return_dict": True,
"return_dict_in_generate": False,
"rms_norm_eps": 1e-5,
"rope_scaling": None,
"rope_theta": 10000.0,
"sep_token_id": None,
"suppress_tokens": None,
"task_specific_params": None,
"temperature": 1.0,
"tf_legacy_loss": False,
"tie_encoder_decoder": False,
"tie_word_embeddings": False,
"tokenizer_class": None,
"tokenizer_model_max_length": 4096,
"tokenizer_padding_side": "right",
"top_k": 50,
"top_p": 1.0,
"torch_dtype": "bfloat16",
"torchscript": False,
"typical_p": 1.0,
"use_bfloat16": False,
"use_cache": True,
"vocab_size": 32000
},
"mm_hidden_size": 1152,
"mm_projector_cfg": {
"_name_or_path": "./mm_projector",
"add_cross_attention": False,
"architectures": ["MultimodalProjector"],
"bad_words_ids": None,
"begin_suppress_tokens": None,
"bos_token_id": None,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": None,
"decoder_start_token_id": None,
"diversity_penalty": 0.0,
"do_sample": False,
"early_stopping": False,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": None,
"exponential_decay_length_penalty": None,
"finetuning_task": None,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"is_decoder": False,
"is_encoder_decoder": False,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"length_penalty": 1.0,
"max_length": 20,
"min_length": 0,
"mm_projector_type": "mlp_downsample",
"model_type": "v2l_projector",
"no_repeat_ngram_size": 0,
"num_beam_groups": 1,
"num_beams": 1,
"num_return_sequences": 1,
"output_attentions": False,
"output_hidden_states": False,
"output_scores": False,
"pad_token_id": None,
"prefix": None,
"problem_type": None,
"pruned_heads": {},
"remove_invalid_values": False,
"repetition_penalty": 1.0,
"return_dict": True,
"return_dict_in_generate": False,
"sep_token_id": None,
"suppress_tokens": None,
"task_specific_params": None,
"temperature": 1.0,
"tf_legacy_loss": False,
"tie_encoder_decoder": False,
"tie_word_embeddings": True,
"tokenizer_class": None,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": "bfloat16",
"torchscript": False,
"typical_p": 1.0,
"use_bfloat16": False
},
"mm_projector_lr": None,
"mm_use_im_patch_token": False,
"mm_use_im_start_end": False,
"mm_vision_select_feature": "cls_patch",
"mm_vision_select_layer": -2,
"model_dtype": "torch.bfloat16",
"model_type": "llava_llama",
"num_video_frames": 8,
"resume_path": "./vlm",
"s2": False,
"s2_max_split_size": 336,
"s2_scales": "336,672,1008",
"transformers_version": "4.36.2",
"tune_language_model": True,
"tune_mm_projector": True,
"tune_vision_tower": True,
"vision_resolution": -1,
"vision_tower_cfg": {
"_name_or_path": "./vision_tower",
"add_cross_attention": False,
"architectures": ["SiglipVisionModel"],
"attention_dropout": 0.0,
"bad_words_ids": None,
"begin_suppress_tokens": None,
"bos_token_id": None,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": None,
"decoder_start_token_id": None,
"diversity_penalty": 0.0,
"do_sample": False,
"early_stopping": False,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": None,
"exponential_decay_length_penalty": None,
"finetuning_task": None,
"forced_bos_token_id": None,
"forced_eos_token_id": None,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"image_size": 384,
"intermediate_size": 4304,
"is_decoder": False,
"is_encoder_decoder": False,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-06,
"length_penalty": 1.0,
"max_length": 20,
"min_length": 0,
"model_type": "siglip_vision_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 16,
"num_beam_groups": 1,
"num_beams": 1,
"num_channels": 3,
"num_hidden_layers": 27,
"num_return_sequences": 1,
"output_attentions": False,
"output_hidden_states": False,
"output_scores": False,
"pad_token_id": None,
"patch_size": 14,
"prefix": None,
"problem_type": None,
"pruned_heads": {},
"remove_invalid_values": False,
"repetition_penalty": 1.0,
"return_dict": True,
"return_dict_in_generate": False,
"sep_token_id": None,
"suppress_tokens": None,
"task_specific_params": None,
"temperature": 1.0,
"tf_legacy_loss": False,
"tie_encoder_decoder": False,
"tie_word_embeddings": True,
"tokenizer_class": None,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": "bfloat16",
"torchscript": False,
"typical_p": 1.0,
"use_bfloat16": False
}
}
def reduce_vila_config(mem_for_full_model: int, config_dict: dict[str, Any]):
_, total_mem = torch.cuda.mem_get_info()
# scale model down if gpu memory is low
if total_mem < mem_for_full_model:
model_fraction = total_mem / mem_for_full_model
num_layers = int(config_dict['llm_cfg']["num_hidden_layers"] *
model_fraction)
num_layers = min(num_layers, 32)
config_dict['llm_cfg']["num_hidden_layers"] = num_layers
class TestVila(unittest.TestCase):
@parameterized.expand([None])
# rms_norm in flashinfer needs the following fix to work
# https://github.com/flashinfer-ai/flashinfer/pull/646
# temporarily disable it for testing
@patch('tensorrt_llm._torch.modules.rms_norm.IS_FLASHINFER_AVAIABLE',
new=False)
def test_vila_sanity(self, quant_algo):
config_dict = deepcopy(VILA_1_5_3B_CONFIG)
# 8B * sizeof(float16) plus some extra for activations
mem_for_full_model = (2 + 1) * 3 * 2**(30)
reduce_vila_config(mem_for_full_model, config_dict)
if config_dict['llm_cfg']["num_hidden_layers"] <= 0:
self.skipTest("Insufficient memory for a single Llava layer")
vila_config = VilaLlamaConfig.from_dict(config_dict)
if quant_algo:
quant_config = QuantConfig(quant_algo=quant_algo)
else:
quant_config = None
if quant_algo == "FP8" and getSMVersion() < 90:
self.skipTest(
"This test is not supported in pre-Hopper architecture")
device = torch.device('cuda')
model_config = ModelConfig(pretrained_config=vila_config,
quant_config=quant_config)
vila = VilaLlamaModel(model_config).to(device)
dtype = vila.model_dtype
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int,
device=device)
context_sequence_lengths = [3, 2, 1]
sequence_lengths = context_sequence_lengths + [1, 1]
past_seen_tokens = [0, 0, 0, 62, 75]
batch_size = len(sequence_lengths)
request_ids = list(range(batch_size))
token_nums = (torch.tensor(past_seen_tokens) +
torch.tensor(sequence_lengths)).tolist()
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
num_blocks = 100
tokens_per_block = 128
head_dim = vila.config.hidden_size // vila.config.num_attention_heads
num_layers = vila.config.num_hidden_layers
num_heads = vila.config.num_attention_heads
num_kv_heads = vila.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError(f"Invalid dtype: {dtype}")
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers, num_heads, num_kv_heads, head_dim, tokens_per_block,
max_seq_len, batch_size, mapping, kv_cache_dtype)
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
attn_metadata = metadata_cls(
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
num_contexts=len(context_sequence_lengths),
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=past_seen_tokens,
),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
max_num_requests=batch_size,
max_num_tokens=8192,
)
position_ids = []
for i, tokens in enumerate(past_seen_tokens):
seq_len = context_sequence_lengths[i] if i < len(
context_sequence_lengths) else 1
position_id = torch.arange(tokens,
tokens + seq_len,
device=input_ids.device)
position_ids.append(position_id)
position_ids = torch.cat(position_ids).unsqueeze(0)
with torch.inference_mode():
attn_metadata.prepare()
logits = vila.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
self.assertEqual(len(past_seen_tokens), logits.shape[0])
with torch.inference_mode():
attn_metadata.prepare()
logits = vila.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata,
return_context_logits=True)
self.assertEqual(input_ids.shape, logits.shape[:-1])
kv_cache_manager.shutdown()
def test_vila_prepare_multimodal_input(self):
config_dict = deepcopy(VILA_1_5_3B_CONFIG)
# 3B * sizeof(float16) plus some extra for activations
mem_for_full_model = (2 + 1) * 3 * 2**(30)
reduce_vila_config(mem_for_full_model, config_dict)
if config_dict['llm_cfg']["num_hidden_layers"] <= 0:
self.skipTest("Insufficient memory for a single Llava layer")
vila_config = VilaLlamaConfig.from_dict(config_dict)
device = torch.device('cuda')
model_config = ModelConfig(pretrained_config=vila_config,
attn_backend="VANILLA")
vila = VilaLlamaModel(model_config).to(device)
dtype = vila.model_dtype
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError(f"Invalid dtype: {dtype}")
context_sequence_lengths = [232]
num_blocks = 100
tokens_per_block = 128
head_dim = vila.config.hidden_size // vila.config.num_attention_heads
num_layers = vila.config.num_hidden_layers
num_heads = vila.config.num_attention_heads
num_kv_heads = vila.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = len(context_sequence_lengths)
request_ids = list(range(batch_size))
token_nums = context_sequence_lengths
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers, num_heads, num_kv_heads, head_dim, tokens_per_block,
max_seq_len, batch_size, mapping, kv_cache_dtype)
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
attn_metadata = metadata_cls(
seq_lens=torch.tensor(context_sequence_lengths, dtype=torch.int),
num_contexts=len(context_sequence_lengths),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
max_num_requests=batch_size,
max_num_tokens=8192,
)
input_ids = torch.tensor([
1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082,
20255, 29889, 12968, 29901, 29871, 32000, 32001, 32002, 32003,
32004, 32005, 32006, 32007, 32008, 32009, 32010, 32011, 32012,
32013, 32014, 32015, 32016, 32017, 32018, 32019, 32020, 32021,
32022, 32023, 32024, 32025, 32026, 32027, 32028, 32029, 32030,
32031, 32032, 32033, 32034, 32035, 32036, 32037, 32038, 32039,
32040, 32041, 32042, 32043, 32044, 32045, 32046, 32047, 32048,
32049, 32050, 32051, 32052, 32053, 32054, 32055, 32056, 32057,
32058, 32059, 32060, 32061, 32062, 32063, 32064, 32065, 32066,
32067, 32068, 32069, 32070, 32071, 32072, 32073, 32074, 32075,
32076, 32077, 32078, 32079, 32080, 32081, 32082, 32083, 32084,
32085, 32086, 32087, 32088, 32089, 32090, 32091, 32092, 32093,
32094, 32095, 32096, 32097, 32098, 32099, 32100, 32101, 32102,
32103, 32104, 32105, 32106, 32107, 32108, 32109, 32110, 32111,
32112, 32113, 32114, 32115, 32116, 32117, 32118, 32119, 32120,
32121, 32122, 32123, 32124, 32125, 32126, 32127, 32128, 32129,
32130, 32131, 32132, 32133, 32134, 32135, 32136, 32137, 32138,
32139, 32140, 32141, 32142, 32143, 32144, 32145, 32146, 32147,
32148, 32149, 32150, 32151, 32152, 32153, 32154, 32155, 32156,
32157, 32158, 32159, 32160, 32161, 32162, 32163, 32164, 32165,
32166, 32167, 32168, 32169, 32170, 32171, 32172, 32173, 32174,
32175, 32176, 32177, 32178, 32179, 32180, 32181, 32182, 32183,
32184, 32185, 32186, 32187, 32188, 32189, 32190, 32191, 32192,
32193, 32194, 32195, 29871, 320, 29876, 20355, 915, 278, 1203, 322,
278, 14826, 4195, 297, 278, 1967, 29889, 2277, 29937, 7900, 22137,
29901, 450
],
device=device,
dtype=torch.int)
images = [torch.rand(196, 2560, dtype=dtype, device=device)]
input_ids, input_embeds = vila._prepare_inputs_embeds_for_multimodal(
input_ids, images, attn_metadata)
self.assertIsNone(input_ids)
self.assertEqual(list(input_embeds.shape), [233, 2560])
kv_cache_manager.shutdown()