mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-23 12:12:39 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
487 lines
18 KiB
Python
487 lines
18 KiB
Python
import os
|
|
import sys
|
|
import unittest
|
|
from copy import deepcopy
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
from parameterized import parameterized
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
|
from tensorrt_llm._torch.metadata import KVCacheParams
|
|
from tensorrt_llm._torch.model_config import ModelConfig
|
|
from tensorrt_llm._torch.models.modeling_vila import (VilaLlamaConfig,
|
|
VilaLlamaModel)
|
|
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
|
from tensorrt_llm.bindings.executor import KvCacheConfig
|
|
from tensorrt_llm.mapping import Mapping
|
|
from tensorrt_llm.models.modeling_utils import QuantConfig
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from utils.util import getSMVersion
|
|
|
|
VILA_1_5_3B_CONFIG = config = {
|
|
"_name_or_path": "./vlm",
|
|
"architectures": ["LlavaLlamaModel"],
|
|
"drop_path_rate": 0.0,
|
|
"hidden_size": 2560,
|
|
"image_aspect_ratio": "resize",
|
|
"interpolate_mode": "linear",
|
|
"llm_cfg": {
|
|
"_name_or_path": "./llm",
|
|
"add_cross_attention": False,
|
|
"architectures": ["LlamaForCausalLM"],
|
|
"attention_bias": False,
|
|
"attention_dropout": 0.0,
|
|
"bad_words_ids": None,
|
|
"begin_suppress_tokens": None,
|
|
"bos_token_id": 1,
|
|
"chunk_size_feed_forward": 0,
|
|
"cross_attention_hidden_size": None,
|
|
"decoder_start_token_id": None,
|
|
"diversity_penalty": 0.0,
|
|
"do_sample": False,
|
|
"early_stopping": False,
|
|
"encoder_no_repeat_ngram_size": 0,
|
|
"eos_token_id": 2,
|
|
"exponential_decay_length_penalty": None,
|
|
"finetuning_task": None,
|
|
"forced_bos_token_id": None,
|
|
"forced_eos_token_id": None,
|
|
"hidden_act": "silu",
|
|
"hidden_size": 2560,
|
|
"id2label": {
|
|
"0": "LABEL_0",
|
|
"1": "LABEL_1"
|
|
},
|
|
"initializer_range": 0.02,
|
|
"intermediate_size": 6912,
|
|
"is_decoder": False,
|
|
"is_encoder_decoder": False,
|
|
"label2id": {
|
|
"LABEL_0": 0,
|
|
"LABEL_1": 1
|
|
},
|
|
"length_penalty": 1.0,
|
|
"max_length": 20,
|
|
"max_position_embeddings": 4096,
|
|
"min_length": 0,
|
|
"model_max_length": 4096,
|
|
"model_type": "llama",
|
|
"no_repeat_ngram_size": 0,
|
|
"num_attention_heads": 20,
|
|
"num_beam_groups": 1,
|
|
"num_beams": 1,
|
|
"num_hidden_layers": 32,
|
|
"num_key_value_heads": 20,
|
|
"num_return_sequences": 1,
|
|
"output_attentions": False,
|
|
"output_hidden_states": False,
|
|
"output_scores": False,
|
|
"pad_token_id": 0,
|
|
"prefix": None,
|
|
"pretraining_tp": 1,
|
|
"problem_type": None,
|
|
"pruned_heads": {},
|
|
"remove_invalid_values": False,
|
|
"repetition_penalty": 1.0,
|
|
"return_dict": True,
|
|
"return_dict_in_generate": False,
|
|
"rms_norm_eps": 1e-5,
|
|
"rope_scaling": None,
|
|
"rope_theta": 10000.0,
|
|
"sep_token_id": None,
|
|
"suppress_tokens": None,
|
|
"task_specific_params": None,
|
|
"temperature": 1.0,
|
|
"tf_legacy_loss": False,
|
|
"tie_encoder_decoder": False,
|
|
"tie_word_embeddings": False,
|
|
"tokenizer_class": None,
|
|
"tokenizer_model_max_length": 4096,
|
|
"tokenizer_padding_side": "right",
|
|
"top_k": 50,
|
|
"top_p": 1.0,
|
|
"torch_dtype": "bfloat16",
|
|
"torchscript": False,
|
|
"typical_p": 1.0,
|
|
"use_bfloat16": False,
|
|
"use_cache": True,
|
|
"vocab_size": 32000
|
|
},
|
|
"mm_hidden_size": 1152,
|
|
"mm_projector_cfg": {
|
|
"_name_or_path": "./mm_projector",
|
|
"add_cross_attention": False,
|
|
"architectures": ["MultimodalProjector"],
|
|
"bad_words_ids": None,
|
|
"begin_suppress_tokens": None,
|
|
"bos_token_id": None,
|
|
"chunk_size_feed_forward": 0,
|
|
"cross_attention_hidden_size": None,
|
|
"decoder_start_token_id": None,
|
|
"diversity_penalty": 0.0,
|
|
"do_sample": False,
|
|
"early_stopping": False,
|
|
"encoder_no_repeat_ngram_size": 0,
|
|
"eos_token_id": None,
|
|
"exponential_decay_length_penalty": None,
|
|
"finetuning_task": None,
|
|
"forced_bos_token_id": None,
|
|
"forced_eos_token_id": None,
|
|
"id2label": {
|
|
"0": "LABEL_0",
|
|
"1": "LABEL_1"
|
|
},
|
|
"is_decoder": False,
|
|
"is_encoder_decoder": False,
|
|
"label2id": {
|
|
"LABEL_0": 0,
|
|
"LABEL_1": 1
|
|
},
|
|
"length_penalty": 1.0,
|
|
"max_length": 20,
|
|
"min_length": 0,
|
|
"mm_projector_type": "mlp_downsample",
|
|
"model_type": "v2l_projector",
|
|
"no_repeat_ngram_size": 0,
|
|
"num_beam_groups": 1,
|
|
"num_beams": 1,
|
|
"num_return_sequences": 1,
|
|
"output_attentions": False,
|
|
"output_hidden_states": False,
|
|
"output_scores": False,
|
|
"pad_token_id": None,
|
|
"prefix": None,
|
|
"problem_type": None,
|
|
"pruned_heads": {},
|
|
"remove_invalid_values": False,
|
|
"repetition_penalty": 1.0,
|
|
"return_dict": True,
|
|
"return_dict_in_generate": False,
|
|
"sep_token_id": None,
|
|
"suppress_tokens": None,
|
|
"task_specific_params": None,
|
|
"temperature": 1.0,
|
|
"tf_legacy_loss": False,
|
|
"tie_encoder_decoder": False,
|
|
"tie_word_embeddings": True,
|
|
"tokenizer_class": None,
|
|
"top_k": 50,
|
|
"top_p": 1.0,
|
|
"torch_dtype": "bfloat16",
|
|
"torchscript": False,
|
|
"typical_p": 1.0,
|
|
"use_bfloat16": False
|
|
},
|
|
"mm_projector_lr": None,
|
|
"mm_use_im_patch_token": False,
|
|
"mm_use_im_start_end": False,
|
|
"mm_vision_select_feature": "cls_patch",
|
|
"mm_vision_select_layer": -2,
|
|
"model_dtype": "torch.bfloat16",
|
|
"model_type": "llava_llama",
|
|
"num_video_frames": 8,
|
|
"resume_path": "./vlm",
|
|
"s2": False,
|
|
"s2_max_split_size": 336,
|
|
"s2_scales": "336,672,1008",
|
|
"transformers_version": "4.36.2",
|
|
"tune_language_model": True,
|
|
"tune_mm_projector": True,
|
|
"tune_vision_tower": True,
|
|
"vision_resolution": -1,
|
|
"vision_tower_cfg": {
|
|
"_name_or_path": "./vision_tower",
|
|
"add_cross_attention": False,
|
|
"architectures": ["SiglipVisionModel"],
|
|
"attention_dropout": 0.0,
|
|
"bad_words_ids": None,
|
|
"begin_suppress_tokens": None,
|
|
"bos_token_id": None,
|
|
"chunk_size_feed_forward": 0,
|
|
"cross_attention_hidden_size": None,
|
|
"decoder_start_token_id": None,
|
|
"diversity_penalty": 0.0,
|
|
"do_sample": False,
|
|
"early_stopping": False,
|
|
"encoder_no_repeat_ngram_size": 0,
|
|
"eos_token_id": None,
|
|
"exponential_decay_length_penalty": None,
|
|
"finetuning_task": None,
|
|
"forced_bos_token_id": None,
|
|
"forced_eos_token_id": None,
|
|
"hidden_act": "gelu_pytorch_tanh",
|
|
"hidden_size": 1152,
|
|
"id2label": {
|
|
"0": "LABEL_0",
|
|
"1": "LABEL_1"
|
|
},
|
|
"image_size": 384,
|
|
"intermediate_size": 4304,
|
|
"is_decoder": False,
|
|
"is_encoder_decoder": False,
|
|
"label2id": {
|
|
"LABEL_0": 0,
|
|
"LABEL_1": 1
|
|
},
|
|
"layer_norm_eps": 1e-06,
|
|
"length_penalty": 1.0,
|
|
"max_length": 20,
|
|
"min_length": 0,
|
|
"model_type": "siglip_vision_model",
|
|
"no_repeat_ngram_size": 0,
|
|
"num_attention_heads": 16,
|
|
"num_beam_groups": 1,
|
|
"num_beams": 1,
|
|
"num_channels": 3,
|
|
"num_hidden_layers": 27,
|
|
"num_return_sequences": 1,
|
|
"output_attentions": False,
|
|
"output_hidden_states": False,
|
|
"output_scores": False,
|
|
"pad_token_id": None,
|
|
"patch_size": 14,
|
|
"prefix": None,
|
|
"problem_type": None,
|
|
"pruned_heads": {},
|
|
"remove_invalid_values": False,
|
|
"repetition_penalty": 1.0,
|
|
"return_dict": True,
|
|
"return_dict_in_generate": False,
|
|
"sep_token_id": None,
|
|
"suppress_tokens": None,
|
|
"task_specific_params": None,
|
|
"temperature": 1.0,
|
|
"tf_legacy_loss": False,
|
|
"tie_encoder_decoder": False,
|
|
"tie_word_embeddings": True,
|
|
"tokenizer_class": None,
|
|
"top_k": 50,
|
|
"top_p": 1.0,
|
|
"torch_dtype": "bfloat16",
|
|
"torchscript": False,
|
|
"typical_p": 1.0,
|
|
"use_bfloat16": False
|
|
}
|
|
}
|
|
|
|
|
|
def reduce_vila_config(mem_for_full_model: int, config_dict: dict[str, Any]):
|
|
_, total_mem = torch.cuda.mem_get_info()
|
|
# scale model down if gpu memory is low
|
|
if total_mem < mem_for_full_model:
|
|
model_fraction = total_mem / mem_for_full_model
|
|
num_layers = int(config_dict['llm_cfg']["num_hidden_layers"] *
|
|
model_fraction)
|
|
num_layers = min(num_layers, 32)
|
|
config_dict['llm_cfg']["num_hidden_layers"] = num_layers
|
|
|
|
|
|
class TestVila(unittest.TestCase):
|
|
|
|
@parameterized.expand([None])
|
|
# rms_norm in flashinfer needs the following fix to work
|
|
# https://github.com/flashinfer-ai/flashinfer/pull/646
|
|
# temporarily disable it for testing
|
|
@patch('tensorrt_llm._torch.modules.rms_norm.IS_FLASHINFER_AVAIABLE',
|
|
new=False)
|
|
def test_vila_sanity(self, quant_algo):
|
|
config_dict = deepcopy(VILA_1_5_3B_CONFIG)
|
|
# 8B * sizeof(float16) plus some extra for activations
|
|
mem_for_full_model = (2 + 1) * 3 * 2**(30)
|
|
reduce_vila_config(mem_for_full_model, config_dict)
|
|
if config_dict['llm_cfg']["num_hidden_layers"] <= 0:
|
|
self.skipTest("Insufficient memory for a single Llava layer")
|
|
vila_config = VilaLlamaConfig.from_dict(config_dict)
|
|
if quant_algo:
|
|
quant_config = QuantConfig(quant_algo=quant_algo)
|
|
else:
|
|
quant_config = None
|
|
if quant_algo == "FP8" and getSMVersion() < 90:
|
|
self.skipTest(
|
|
"This test is not supported in pre-Hopper architecture")
|
|
device = torch.device('cuda')
|
|
|
|
model_config = ModelConfig(pretrained_config=vila_config,
|
|
quant_config=quant_config)
|
|
vila = VilaLlamaModel(model_config).to(device)
|
|
|
|
dtype = vila.model_dtype
|
|
|
|
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
|
|
dtype=torch.int,
|
|
device=device)
|
|
|
|
context_sequence_lengths = [3, 2, 1]
|
|
sequence_lengths = context_sequence_lengths + [1, 1]
|
|
past_seen_tokens = [0, 0, 0, 62, 75]
|
|
batch_size = len(sequence_lengths)
|
|
request_ids = list(range(batch_size))
|
|
token_nums = (torch.tensor(past_seen_tokens) +
|
|
torch.tensor(sequence_lengths)).tolist()
|
|
prompt_lens = token_nums[:3] + past_seen_tokens[3:]
|
|
|
|
num_blocks = 100
|
|
tokens_per_block = 128
|
|
head_dim = vila.config.hidden_size // vila.config.num_attention_heads
|
|
num_layers = vila.config.num_hidden_layers
|
|
num_heads = vila.config.num_attention_heads
|
|
num_kv_heads = vila.config.num_key_value_heads
|
|
max_seq_len = num_blocks * tokens_per_block
|
|
|
|
if dtype == torch.half:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
|
elif dtype == torch.bfloat16:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
|
else:
|
|
raise ValueError(f"Invalid dtype: {dtype}")
|
|
|
|
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
|
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
|
tokens_per_block)
|
|
kv_cache_manager = KVCacheManager(
|
|
kv_cache_config,
|
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
|
num_layers, num_heads, num_kv_heads, head_dim, tokens_per_block,
|
|
max_seq_len, batch_size, mapping, kv_cache_dtype)
|
|
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
|
|
|
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
|
|
attn_metadata = metadata_cls(
|
|
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int),
|
|
num_contexts=len(context_sequence_lengths),
|
|
kv_cache_params=KVCacheParams(
|
|
use_cache=True,
|
|
num_cached_tokens_per_seq=past_seen_tokens,
|
|
),
|
|
kv_cache_manager=kv_cache_manager,
|
|
request_ids=request_ids,
|
|
prompt_lens=prompt_lens,
|
|
max_num_requests=batch_size,
|
|
max_num_tokens=8192,
|
|
)
|
|
|
|
position_ids = []
|
|
for i, tokens in enumerate(past_seen_tokens):
|
|
seq_len = context_sequence_lengths[i] if i < len(
|
|
context_sequence_lengths) else 1
|
|
position_id = torch.arange(tokens,
|
|
tokens + seq_len,
|
|
device=input_ids.device)
|
|
position_ids.append(position_id)
|
|
|
|
position_ids = torch.cat(position_ids).unsqueeze(0)
|
|
|
|
with torch.inference_mode():
|
|
attn_metadata.prepare()
|
|
logits = vila.forward(input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
attn_metadata=attn_metadata)
|
|
|
|
self.assertEqual(len(past_seen_tokens), logits.shape[0])
|
|
|
|
with torch.inference_mode():
|
|
attn_metadata.prepare()
|
|
logits = vila.forward(input_ids=input_ids,
|
|
position_ids=position_ids,
|
|
attn_metadata=attn_metadata,
|
|
return_context_logits=True)
|
|
self.assertEqual(input_ids.shape, logits.shape[:-1])
|
|
|
|
kv_cache_manager.shutdown()
|
|
|
|
def test_vila_prepare_multimodal_input(self):
|
|
config_dict = deepcopy(VILA_1_5_3B_CONFIG)
|
|
# 3B * sizeof(float16) plus some extra for activations
|
|
mem_for_full_model = (2 + 1) * 3 * 2**(30)
|
|
reduce_vila_config(mem_for_full_model, config_dict)
|
|
if config_dict['llm_cfg']["num_hidden_layers"] <= 0:
|
|
self.skipTest("Insufficient memory for a single Llava layer")
|
|
vila_config = VilaLlamaConfig.from_dict(config_dict)
|
|
|
|
device = torch.device('cuda')
|
|
|
|
model_config = ModelConfig(pretrained_config=vila_config,
|
|
attn_backend="VANILLA")
|
|
vila = VilaLlamaModel(model_config).to(device)
|
|
|
|
dtype = vila.model_dtype
|
|
if dtype == torch.half:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
|
|
elif dtype == torch.bfloat16:
|
|
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
|
|
else:
|
|
raise ValueError(f"Invalid dtype: {dtype}")
|
|
|
|
context_sequence_lengths = [232]
|
|
num_blocks = 100
|
|
tokens_per_block = 128
|
|
head_dim = vila.config.hidden_size // vila.config.num_attention_heads
|
|
num_layers = vila.config.num_hidden_layers
|
|
num_heads = vila.config.num_attention_heads
|
|
num_kv_heads = vila.config.num_key_value_heads
|
|
max_seq_len = num_blocks * tokens_per_block
|
|
batch_size = len(context_sequence_lengths)
|
|
request_ids = list(range(batch_size))
|
|
token_nums = context_sequence_lengths
|
|
|
|
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
|
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
|
|
tokens_per_block)
|
|
kv_cache_manager = KVCacheManager(
|
|
kv_cache_config,
|
|
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
|
num_layers, num_heads, num_kv_heads, head_dim, tokens_per_block,
|
|
max_seq_len, batch_size, mapping, kv_cache_dtype)
|
|
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
|
|
|
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
|
|
attn_metadata = metadata_cls(
|
|
seq_lens=torch.tensor(context_sequence_lengths, dtype=torch.int),
|
|
num_contexts=len(context_sequence_lengths),
|
|
kv_cache_manager=kv_cache_manager,
|
|
request_ids=request_ids,
|
|
max_num_requests=batch_size,
|
|
max_num_tokens=8192,
|
|
)
|
|
input_ids = torch.tensor([
|
|
1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082,
|
|
20255, 29889, 12968, 29901, 29871, 32000, 32001, 32002, 32003,
|
|
32004, 32005, 32006, 32007, 32008, 32009, 32010, 32011, 32012,
|
|
32013, 32014, 32015, 32016, 32017, 32018, 32019, 32020, 32021,
|
|
32022, 32023, 32024, 32025, 32026, 32027, 32028, 32029, 32030,
|
|
32031, 32032, 32033, 32034, 32035, 32036, 32037, 32038, 32039,
|
|
32040, 32041, 32042, 32043, 32044, 32045, 32046, 32047, 32048,
|
|
32049, 32050, 32051, 32052, 32053, 32054, 32055, 32056, 32057,
|
|
32058, 32059, 32060, 32061, 32062, 32063, 32064, 32065, 32066,
|
|
32067, 32068, 32069, 32070, 32071, 32072, 32073, 32074, 32075,
|
|
32076, 32077, 32078, 32079, 32080, 32081, 32082, 32083, 32084,
|
|
32085, 32086, 32087, 32088, 32089, 32090, 32091, 32092, 32093,
|
|
32094, 32095, 32096, 32097, 32098, 32099, 32100, 32101, 32102,
|
|
32103, 32104, 32105, 32106, 32107, 32108, 32109, 32110, 32111,
|
|
32112, 32113, 32114, 32115, 32116, 32117, 32118, 32119, 32120,
|
|
32121, 32122, 32123, 32124, 32125, 32126, 32127, 32128, 32129,
|
|
32130, 32131, 32132, 32133, 32134, 32135, 32136, 32137, 32138,
|
|
32139, 32140, 32141, 32142, 32143, 32144, 32145, 32146, 32147,
|
|
32148, 32149, 32150, 32151, 32152, 32153, 32154, 32155, 32156,
|
|
32157, 32158, 32159, 32160, 32161, 32162, 32163, 32164, 32165,
|
|
32166, 32167, 32168, 32169, 32170, 32171, 32172, 32173, 32174,
|
|
32175, 32176, 32177, 32178, 32179, 32180, 32181, 32182, 32183,
|
|
32184, 32185, 32186, 32187, 32188, 32189, 32190, 32191, 32192,
|
|
32193, 32194, 32195, 29871, 320, 29876, 20355, 915, 278, 1203, 322,
|
|
278, 14826, 4195, 297, 278, 1967, 29889, 2277, 29937, 7900, 22137,
|
|
29901, 450
|
|
],
|
|
device=device,
|
|
dtype=torch.int)
|
|
images = [torch.rand(196, 2560, dtype=dtype, device=device)]
|
|
input_ids, input_embeds = vila._prepare_inputs_embeds_for_multimodal(
|
|
input_ids, images, attn_metadata)
|
|
self.assertIsNone(input_ids)
|
|
self.assertEqual(list(input_embeds.shape), [233, 2560])
|
|
|
|
kv_cache_manager.shutdown()
|