# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random import sys import tempfile import unittest import numpy as np import torch from parameterized import parameterized from transformers import MistralConfig, MistralForCausalLM import tensorrt_llm from tensorrt_llm import Builder from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.models import PretrainedConfig from tensorrt_llm.models.llama.convert import load_weights_from_hf_model from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import (skip_bf16_pre_ampere, skip_fp32_accum_pre_ampere, unittest_name_func) class TestArctic(unittest.TestCase): EOS_TOKEN = 2 PAD_TOKEN = 2 def _gen_tensorrt_llm_network(self, network, hf_mistral, mistral_config: MistralConfig, batch_size, beam_width, input_len, output_len, dtype, rank, tensor_parallel): list(range(tensor_parallel)) with net_guard(network): str_dtype_to_trt(dtype) config = { 'architecture': "LlamaForCausalLM", 'dtype': dtype, 'logits_dtype': 'float32', 'num_hidden_layers': mistral_config.num_hidden_layers, 'num_attention_heads': mistral_config.num_attention_heads, 'hidden_size': mistral_config.hidden_size, 'intermediate_size': mistral_config.intermediate_size, 'num_key_value_heads': mistral_config.num_key_value_heads, 'vocab_size': mistral_config.vocab_size, 'position_embedding_type': 'rope_gpt_neox', 'max_position_embeddings': mistral_config.max_position_embeddings, 'hidden_act': mistral_config.hidden_act, 'rotary_base': getattr(mistral_config, 'rotary_base', 10000.0), 'rotary_scaling': getattr(mistral_config, 'rotary_scaling', None), 'norm_epsilon': mistral_config.rms_norm_eps, 'residual_mlp': mistral_config.residual_mlp, 'mapping': { 'world_size': tensor_parallel, 'tp_size': tensor_parallel, 'rank': rank, }, 'use_parallel_embedding': False, 'embedding_sharding_dim': 0, 'moe': { 'num_experts': 0, 'top_k': 0, 'tp_mode': 1, 'normalization_mode': 1, }, 'use_fused_mlp': False, } # Initialize model config = PretrainedConfig.from_dict(config) tensorrt_llm_mistral = tensorrt_llm.models.LLaMAForCausalLM(config) if not mistral_config.residual_mlp: weights = load_weights_from_hf_model(hf_mistral, config) tensorrt_llm_mistral.load(weights) # Prepare network.set_named_parameters( tensorrt_llm_mistral.named_parameters()) inputs = tensorrt_llm_mistral.prepare_inputs( max_batch_size=batch_size, max_input_len=input_len, max_seq_len=input_len + output_len, max_num_tokens=batch_size * input_len, use_cache=True, max_beam_width=beam_width) # Forward tensorrt_llm_mistral(**inputs) return network def _gen_tensorrt_llm_engine(self, dtype, rank, world_size, llama_config, hf_llama, model_name, use_plugin, batch_size, beam_width, input_len, output_len, use_refit, fast_building=False, context_fmha_flag=ContextFMHAType.disabled, enable_remove_input_padding=False): builder = Builder() with tempfile.TemporaryDirectory() as tmpdirname: builder_config = builder.create_builder_config( name=model_name, precision=dtype, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, strongly_typed=True, ) network = builder.create_network() network.plugin_config.to_legacy_setting() if use_plugin: network.plugin_config.gpt_attention_plugin = dtype if fast_building: network.plugin_config.gemm_plugin = dtype if enable_remove_input_padding: network.plugin_config.remove_input_padding = True network.plugin_config.set_context_fmha(context_fmha_flag) self._gen_tensorrt_llm_network(network, hf_llama, llama_config, batch_size, beam_width, input_len, output_len, dtype, rank, world_size) engine_buffer = builder.build_engine(network, builder_config) return engine_buffer def _gen_tensorrt_llm_runtime(self, log_level, dtype, world_size, rank, llama_config, hf_llama, model_name, use_plugin, batch_size, beam_width, input_len, output_len, use_refit, fast_building=False, context_fmha_flag=ContextFMHAType.disabled, enable_remove_input_padding=False): tensorrt_llm.logger.set_level(log_level) mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size) engine_buffer = self._gen_tensorrt_llm_engine( dtype, rank, world_size, llama_config, hf_llama, model_name, use_plugin, batch_size, beam_width, input_len, output_len, use_refit, fast_building, context_fmha_flag, enable_remove_input_padding) runtime = tensorrt_llm.runtime.generation._Runtime( engine_buffer, mapping) return runtime, engine_buffer def load_test_cases(): test_cases = [] test_cases.append((False, True, ContextFMHAType.disabled, False, 'bfloat16', 56, True)) # arctic MHA return test_cases @parameterized.expand(load_test_cases, name_func=unittest_name_func) def test_arctic(self, use_refit, fast_building, context_fmha_flag, enable_remove_input_padding, dtype, num_kv_heads, residual_mlp): # Simplified from Mistral test # - Arctic is not officially supported in HuggingFace yet, so skipping results comparison # - Skip model loader tests skip_hf = True # Skip tests that are not supported in pre-ampere architecture skip_bf16_pre_ampere(dtype) skip_fp32_accum_pre_ampere(context_fmha_flag) PRECHECKED_GOOD_RANDOM_SEEDS = [1, 4, 5, 8] model = 'llama' log_level = 'error' use_plugin = True # gpt plugin batch_size = 4 beam_width = 1 input_len = 4 output_len = 2 max_seq_len = input_len + output_len world_size = 1 head_size = 32 rank = 0 mistral_config = MistralConfig() mistral_config.hidden_act = 'silu' mistral_config.num_hidden_layers = 2 mistral_config.max_position_embeddings = 64 mistral_config.vocab_size = 128 mistral_config.num_attention_heads = num_kv_heads mistral_config.hidden_size = mistral_config.num_attention_heads * head_size mistral_config.intermediate_size = (( (mistral_config.hidden_size * 4 * 2 // 3) + head_size - 1) // head_size) * head_size mistral_config.num_key_value_heads = num_kv_heads assert (mistral_config.num_attention_heads % mistral_config.num_key_value_heads) == 0 mistral_config.pad_token_id = self.PAD_TOKEN mistral_config.eos_token_id = self.EOS_TOKEN mistral_config.residual_mlp = residual_mlp seed_idx = random.randint(0, len(PRECHECKED_GOOD_RANDOM_SEEDS) - 1) torch.manual_seed(PRECHECKED_GOOD_RANDOM_SEEDS[seed_idx]) if not skip_hf: hf_mistral = MistralForCausalLM(mistral_config).cuda() runtime, _ = self._gen_tensorrt_llm_runtime( log_level, dtype, world_size, rank, mistral_config, None, model, use_plugin, batch_size, beam_width, input_len, output_len, use_refit, fast_building, context_fmha_flag, enable_remove_input_padding) key_value_cache_buffers = [] head_size = mistral_config.hidden_size // mistral_config.num_attention_heads for i in range(mistral_config.num_hidden_layers): key_value_cache_buffers.append( torch.zeros(( batch_size, 2, mistral_config.num_key_value_heads, max_seq_len, head_size, ), dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), device='cuda')) # compare context step = 0 ctx_ids = torch.randint(100, (batch_size, input_len)).int().cuda() ctx_context_lengths = input_len * torch.ones( (batch_size), dtype=torch.int32, device='cuda') ctx_position_ids = torch.tensor(range(input_len), dtype=torch.int32).reshape([ 1, input_len ]).expand([batch_size, input_len]).cuda() ctx_last_token_ids = ctx_context_lengths.clone() ctx_host_request_types = torch.tensor([0] * batch_size, dtype=torch.int32) # We need sequence_lengths start as context_lengths for step 0, # and it will be added one after each step. sequence_length_buffer = ctx_context_lengths.detach().clone() if not skip_hf: with torch.no_grad(): hf_outputs = hf_mistral.forward(ctx_ids) torch.cuda.synchronize() ref = hf_outputs.logits[:, -1, :] if enable_remove_input_padding: ctx_ids = ctx_ids.view([batch_size * input_len]) ctx_position_ids = ctx_position_ids.view([batch_size * input_len]) ctx_last_token_ids = torch.cumsum(ctx_last_token_ids, dim=0).int() cache_indirections = [ torch.full(( batch_size, beam_width, max_seq_len, ), 0, dtype=torch.int32, device='cuda'), torch.full(( batch_size, beam_width, max_seq_len, ), 0, dtype=torch.int32, device='cuda') ] # ping-pong buffers ctx_buffer = { 'input_ids': ctx_ids, 'context_lengths': ctx_context_lengths, 'position_ids': ctx_position_ids, 'last_token_ids': ctx_last_token_ids, 'cache_indirection': cache_indirections[0], 'host_request_types': ctx_host_request_types, } if enable_remove_input_padding: ctx_buffer['host_context_lengths'] = ctx_context_lengths.cpu() ctx_shape = {k: v.shape for k, v in ctx_buffer.items()} kv_shape = (batch_size, 2, mistral_config.num_key_value_heads, max_seq_len, head_size) ctx_buffer[f'host_max_attention_window_sizes'] = torch.tensor( [max_seq_len] * mistral_config.num_hidden_layers, dtype=torch.int32) ctx_shape[f'host_max_attention_window_sizes'] = ( mistral_config.num_hidden_layers, ) for i in range(mistral_config.num_hidden_layers): ctx_shape[f'past_key_value_{i}'] = kv_shape ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] ctx_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] ctx_buffer['sequence_length'] = sequence_length_buffer ctx_shape['sequence_length'] = ctx_buffer['sequence_length'].shape ctx_shape['host_past_key_value_lengths'] = (batch_size, ) ctx_buffer['host_past_key_value_lengths'] = torch.tensor( [0] * batch_size, dtype=torch.int32) ctx_shape['host_sink_token_length'] = (1, ) ctx_buffer['host_sink_token_length'] = torch.tensor([0], dtype=torch.int32) context = runtime.ctx_context runtime._set_shape(context, ctx_shape) runtime._set_buffer(context, ctx_buffer) runtime._run(context) torch.cuda.synchronize() res = ctx_buffer['logits'] if not skip_hf: np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(), res.to(torch.float32).cpu().numpy(), atol=0.12) # compare generation step = 1 step1_id = torch.randint(100, (batch_size, 1)).int().cuda() gen_context_lengths = ctx_context_lengths.clone() gen_position_ids = torch.ones_like(step1_id).int().cuda() * input_len gen_last_token_ids = torch.zeros_like(gen_context_lengths).int().cuda() gen_host_request_types = torch.tensor([1] * batch_size, dtype=torch.int32) if not skip_hf: with torch.no_grad(): hf_outputs = hf_mistral.forward( step1_id, past_key_values=hf_outputs.past_key_values, use_cache=True) torch.cuda.synchronize() ref = hf_outputs.logits[:, -1, :] if enable_remove_input_padding: step1_id = step1_id.view([batch_size]) gen_position_ids = gen_position_ids.view([batch_size]) gen_last_token_ids = torch.ones_like( gen_context_lengths).int().cuda() gen_last_token_ids = torch.cumsum(gen_last_token_ids, dim=0).int() step1_buffer = { 'input_ids': step1_id, 'context_lengths': gen_context_lengths, 'position_ids': gen_position_ids, 'last_token_ids': gen_last_token_ids, 'host_request_types': gen_host_request_types, 'cache_indirection': cache_indirections[1], } if enable_remove_input_padding: step1_buffer['host_context_lengths'] = gen_context_lengths.cpu() step1_shape = {k: v.shape for k, v in step1_buffer.items()} step1_shape[f'host_max_attention_window_sizes'] = ( mistral_config.num_hidden_layers, ) step1_buffer[f'host_max_attention_window_sizes'] = torch.tensor( [max_seq_len] * mistral_config.num_hidden_layers, dtype=torch.int32) for i in range(mistral_config.num_hidden_layers): step1_shape[f'past_key_value_{i}'] = kv_shape step1_shape['sequence_length'] = (batch_size, ) step1_shape['host_past_key_value_lengths'] = (batch_size, ) step1_shape['host_sink_token_length'] = (1, ) for i in range(mistral_config.num_hidden_layers): step1_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] step1_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] step1_buffer[ 'host_past_key_value_lengths'] = sequence_length_buffer.cpu() sequence_length_buffer = torch.add(sequence_length_buffer, step) step1_buffer['sequence_length'] = sequence_length_buffer step1_buffer['host_sink_token_length'] = torch.tensor([0], dtype=torch.int32) context = runtime.context_1 runtime._set_shape(context, step1_shape) runtime._set_buffer(context, step1_buffer) runtime._run(context) torch.cuda.synchronize() res = step1_buffer['logits'] if not skip_hf: np.testing.assert_allclose(ref.to(torch.float32).cpu().numpy(), res.to(torch.float32).cpu().numpy(), atol=0.12) if __name__ == '__main__': unittest.main()