TensorRT-LLMs/tests/_torch/modeling/test_modeling_mllama.py
Sharan Chetlur 258c7540c0 open source 09df54c0cc99354a60bbc0303e3e8ea33a96bef0 (#2725)
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736)

Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>

Add note for blackwell (#2742)

Update the docs to workaround the extra-index-url issue (#2744)

update README.md (#2751)

Fix github io pages (#2761)

Update
2025-02-11 02:21:51 +00:00

438 lines
14 KiB
Python

import os
import sys
import unittest
from copy import deepcopy
import torch
from parameterized import parameterized
from transformers import MllamaConfig
from transformers import \
MllamaForConditionalGeneration as HFMllamaForConditionalGeneration
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_mllama import \
MllamaForConditionalGeneration
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.mapping import Mapping
sys.path.append(os.path.dirname(__file__))
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from test_modeling_llama import Scenario, reduce_llama_config
LLAMA_3_2_11B_VISION_CONFIG = {
'architectures': ['MllamaForConditionalGeneration'],
'image_token_index': 128256,
'model_type': 'mllama',
'text_config': {
'_name_or_path': '',
'add_cross_attention': False,
'architectures': None,
'bad_words_ids': None,
'begin_suppress_tokens': None,
'bos_token_id': 128000,
'chunk_size_feed_forward': 0,
'cross_attention_hidden_size': None,
'cross_attention_layers': [3, 8, 13, 18, 23, 28, 33, 38],
'decoder_start_token_id': None,
'diversity_penalty': 0.0,
'do_sample': False,
'dropout': 0,
'early_stopping': False,
'encoder_no_repeat_ngram_size': 0,
'eos_token_id': 128001,
'exponential_decay_length_penalty': None,
'finetuning_task': None,
'forced_bos_token_id': None,
'forced_eos_token_id': None,
'hidden_act': 'silu',
'hidden_size': 4096,
'id2label': {
'0': 'LABEL_0',
'1': 'LABEL_1'
},
'initializer_range': 0.02,
'intermediate_size': 14336,
'is_decoder': False,
'is_encoder_decoder': False,
'label2id': {
'LABEL_0': 0,
'LABEL_1': 1
},
'length_penalty': 1.0,
'max_length': 20,
'max_position_embeddings': 131072,
'min_length': 0,
'model_type': 'mllama_text_model',
'no_repeat_ngram_size': 0,
'num_attention_heads': 32,
'num_beam_groups': 1,
'num_beams': 1,
'num_hidden_layers': 40,
'num_key_value_heads': 8,
'num_return_sequences': 1,
'output_attentions': False,
'output_hidden_states': False,
'output_scores': False,
'pad_token_id': 128004,
'prefix': None,
'problem_type': None,
'pruned_heads': {},
'remove_invalid_values': False,
'repetition_penalty': 1.0,
'return_dict': True,
'return_dict_in_generate': False,
'rms_norm_eps': 1e-05,
'rope_scaling': {
'factor': 8.0,
'high_freq_factor': 4.0,
'low_freq_factor': 1.0,
'original_max_position_embeddings': 8192,
'rope_type': 'llama3'
},
'rope_theta': 500000.0,
'sep_token_id': None,
'suppress_tokens': None,
'task_specific_params': None,
'temperature': 1.0,
'tf_legacy_loss': False,
'tie_encoder_decoder': False,
'tie_word_embeddings': False,
'tokenizer_class': None,
'top_k': 50,
'top_p': 1.0,
'torch_dtype': 'bfloat16',
'torchscript': False,
'typical_p': 1.0,
'use_bfloat16': False,
'use_cache': True,
'vocab_size': 128256
},
'torch_dtype': 'bfloat16',
'transformers_version': '4.45.0.dev0',
'vision_config': {
'_name_or_path':
'',
'add_cross_attention':
False,
'architectures':
None,
'attention_heads':
16,
'bad_words_ids':
None,
'begin_suppress_tokens':
None,
'bos_token_id':
None,
'chunk_size_feed_forward':
0,
'cross_attention_hidden_size':
None,
'decoder_start_token_id':
None,
'diversity_penalty':
0.0,
'do_sample':
False,
'early_stopping':
False,
'encoder_no_repeat_ngram_size':
0,
'eos_token_id':
None,
'exponential_decay_length_penalty':
None,
'finetuning_task':
None,
'forced_bos_token_id':
None,
'forced_eos_token_id':
None,
'hidden_act':
'gelu',
'hidden_size':
1280,
'id2label': {
'0': 'LABEL_0',
'1': 'LABEL_1'
},
'image_size':
448,
'intermediate_layers_indices': [3, 7, 15, 23, 30],
'intermediate_size':
5120,
'is_decoder':
False,
'is_encoder_decoder':
False,
'label2id': {
'LABEL_0': 0,
'LABEL_1': 1
},
'length_penalty':
1.0,
'max_length':
20,
'max_num_tiles':
4,
'min_length':
0,
'model_type':
'mllama_vision_model',
'no_repeat_ngram_size':
0,
'norm_eps':
1e-05,
'num_beam_groups':
1,
'num_beams':
1,
'num_channels':
3,
'num_global_layers':
8,
'num_hidden_layers':
32,
'num_return_sequences':
1,
'output_attentions':
False,
'output_hidden_states':
False,
'output_scores':
False,
'pad_token_id':
None,
'patch_size':
14,
'prefix':
None,
'problem_type':
None,
'pruned_heads': {},
'remove_invalid_values':
False,
'repetition_penalty':
1.0,
'return_dict':
True,
'return_dict_in_generate':
False,
'sep_token_id':
None,
'supported_aspect_ratios': [[1, 1], [1, 2], [1, 3], [1, 4], [2, 1],
[2, 2], [3, 1], [4, 1]],
'suppress_tokens':
None,
'task_specific_params':
None,
'temperature':
1.0,
'tf_legacy_loss':
False,
'tie_encoder_decoder':
False,
'tie_word_embeddings':
True,
'tokenizer_class':
None,
'top_k':
50,
'top_p':
1.0,
'torch_dtype':
'bfloat16',
'torchscript':
False,
'typical_p':
1.0,
'use_bfloat16':
False,
'vision_output_dim':
7680
}
}
class TestMLlama(unittest.TestCase):
@parameterized.expand([
Scenario(backend="VANILLA"),
Scenario(backend="FLASHINFER"),
Scenario(backend="FLASHINFER", use_cuda_graph=True),
Scenario(backend="TRTLLM"),
Scenario(backend="TRTLLM", use_cuda_graph=True),
], lambda testcase_func, param_num, param:
f"{testcase_func.__name__}[{param.args[0]}]")
@torch.no_grad()
def test_mllama_allclose_to_hf_text_only(self, scenario: Scenario) -> None:
"""
Compare output to HF
"""
backend = scenario.backend
metadata_cls = get_attention_backend(backend).Metadata
torch.random.manual_seed(0)
config_dict = deepcopy(LLAMA_3_2_11B_VISION_CONFIG)
dtype = MllamaConfig.from_dict(config_dict['text_config']).torch_dtype
dtype_bytes = dtype.itemsize
# 11B * sizeof(float16) plus some extra for activations (1.3x approx).
# MLllama also have vision encoder part. Just use 11B as upper bound.
activation_factor = 1.3
model_params = 11 * (10**9)
mem_for_full_model = 2 * model_params * dtype_bytes * activation_factor
reduce_llama_config(mem_for_full_model, config_dict['text_config'], 8)
if config_dict['text_config']['num_hidden_layers'] <= 0:
self.skipTest('Insufficient memory for a single Llama layer')
mllama_config = MllamaConfig.from_dict(config_dict)
# For text path only, downscale vision encoder to only 1 layer.
config_dict['vision_config']['num_hidden_layers'] = 1
device = torch.device('cuda')
hf_mllama = HFMllamaForConditionalGeneration(mllama_config).to(
dtype).to(device).eval()
mllama = MllamaForConditionalGeneration(
ModelConfig(pretrained_config=mllama_config,
attn_backend=backend)).to(dtype).to(device)
mllama.load_weights(hf_mllama.state_dict())
# KV cache setup
num_blocks = 1
tokens_per_block = 128
head_dim = mllama.config.hidden_size // mllama.config.num_attention_heads
num_layers = mllama.config.num_hidden_layers
num_heads = mllama.config.num_attention_heads
num_kv_heads = mllama.config.num_key_value_heads
max_seq_len = num_blocks * tokens_per_block
batch_size = 1
if dtype == torch.half:
kv_cache_dtype = tensorrt_llm.bindings.DataType.HALF
elif dtype == torch.bfloat16:
kv_cache_dtype = tensorrt_llm.bindings.DataType.BF16
else:
raise ValueError("Invalid dtype")
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=num_blocks *
tokens_per_block)
kv_cache_manager = KVCacheManager(
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers, num_heads, num_kv_heads, head_dim, tokens_per_block,
max_seq_len, batch_size, mapping, kv_cache_dtype)
# context
input_ids = torch.tensor([100, 200, 300, 100, 200, 100, 400, 500],
dtype=torch.int,
device=device)
num_cached_tokens_per_seq = [0]
request_ids = [1]
token_nums = [input_ids.size(-1)]
prompt_lens = [input_ids.size(-1)]
kv_cache_manager.add_dummy_requests(request_ids, token_nums)
attn_metadata = metadata_cls(
seq_lens=torch.tensor([input_ids.size(-1)], dtype=torch.int),
num_contexts=1,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens)
# Note: no CUDA graphs for prefill, the graph runner is built for
# decoding only.
position_ids = [torch.arange(0, input_ids.size(-1))]
position_ids = torch.cat(position_ids).unsqueeze(0).cuda()
with torch.inference_mode():
attn_metadata.prepare()
logits = mllama.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
ref = hf_mllama.forward(input_ids=input_ids.unsqueeze(0),
position_ids=position_ids,
use_cache=True)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.3,
rtol=0.3)
# gen
gen_input_ids = torch.tensor([600], dtype=torch.int, device=device)
num_cached_tokens_per_seq = [input_ids.size(-1)]
attn_metadata = metadata_cls(
seq_lens=torch.tensor([gen_input_ids.size(-1)], dtype=torch.int),
num_contexts=0,
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=num_cached_tokens_per_seq),
max_num_requests=1,
max_num_tokens=8192,
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens)
gen_position_ids = [
torch.arange(input_ids.size(-1),
input_ids.size(-1) + gen_input_ids.size(-1))
]
gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda()
def run_forward(input_ids, position_ids, attn_metadata):
attn_metadata.prepare()
if not scenario.use_cuda_graph:
return mllama.forward(input_ids=input_ids,
position_ids=position_ids,
attn_metadata=attn_metadata)
else:
graph_runner = DecodingCUDAGraphRunner(
attn_metadata.max_num_requests, "cuda", attn_metadata)
graph_runner.capture(lambda inputs: mllama.forward(**inputs))
for _ in range(2):
# Run it twice. This helps us catch problems if buffers are accidentally reallocated
# in prepare().
attn_metadata.prepare()
logits = graph_runner.run({
"input_ids": input_ids,
"position_ids": position_ids,
"attn_metadata": attn_metadata,
})
return logits
if scenario.use_cuda_graph:
attn_metadata = attn_metadata.create_cuda_graph_metadata(1)
with torch.inference_mode():
logits = run_forward(input_ids=gen_input_ids,
position_ids=gen_position_ids,
attn_metadata=attn_metadata)
ref = hf_mllama.forward(input_ids=gen_input_ids.unsqueeze(0),
position_ids=gen_position_ids,
past_key_values=ref.past_key_values,
use_cache=True)
torch.testing.assert_close(logits,
ref.logits[:, -1].float(),
atol=0.3,
rtol=0.3)