# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import sys import tempfile import unittest from itertools import product import numpy as np import pytest import torch from parameterized import parameterized from transformers import AutoConfig, AutoModelForCausalLM import tensorrt_llm from tensorrt_llm import Builder from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) from examples.phi.convert_checkpoint import convert_hf_phi # isort:skip sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from utils.util import getSMVersion # Fixed code revision or updated config can break the tests. HF_CODE_REVISION = "cb2f4533604d8b67de604e7df03bfe6f3ca22869" def compare_max_abs_error(ref, res, str): # calculate max abs error compare_HF = ref.cpu().numpy().flatten() compare_TRT_LLM = res.cpu().numpy().flatten() max_abs_error = np.max(abs(compare_TRT_LLM - compare_HF)) print(str, "max abs error = ", max_abs_error) class TestPhi(unittest.TestCase): def setUp(self): super().setUp() # Fix random seed for the reproducibility. torch.random.manual_seed(1773) def generate_hf_model(self, dtype: str): # Need to use the latest remote code for config and model class. gpt_config = AutoConfig.from_pretrained("microsoft/phi-2", code_revision=HF_CODE_REVISION, trust_remote_code=True) gpt_config.num_hidden_layers = 2 model = AutoModelForCausalLM.from_config( gpt_config, code_revision=HF_CODE_REVISION, trust_remote_code=True).cuda().to( tensorrt_llm._utils.str_dtype_to_torch(dtype)).eval() return gpt_config, model def initialize_network(self, network: tensorrt_llm.Network, hf_model, hf_config, dtype: str, batch_size: int, beam_width: int, input_len: int, output_len: int, tensor_parallel: int, rank: int): config = { 'architecture': hf_config.architectures[0], 'dtype': dtype, 'num_hidden_layers': hf_config.num_hidden_layers, 'num_attention_heads': hf_config.num_key_value_heads, 'partial_rotary_factor': hf_config.partial_rotary_factor, 'rope_theta': hf_config.rope_theta, 'hidden_size': hf_config.hidden_size, 'intermediate_size': hf_config.intermediate_size, 'vocab_size': hf_config.vocab_size, 'max_position_embeddings': hf_config.max_position_embeddings, 'hidden_act': hf_config.hidden_act, 'quantization': { 'use_weight_only': False, 'weight_only_precision': False, }, 'mapping': { 'world_size': tensor_parallel, 'tp_size': tensor_parallel, }, 'use_parallel_embedding': False, 'embedding_sharding_dim': 0, 'share_embedding_table': False, } config = tensorrt_llm.models.PretrainedConfig.from_dict(config) config.set_rank(rank) weights = convert_hf_phi(hf_model, rank=rank, tensor_parallel=tensor_parallel, dtype=dtype, use_parallel_embedding=False, sharding_dim=0) trtllm_model = tensorrt_llm.models.PhiForCausalLM(config) trtllm_model.load(weights) with net_guard(network): # Initialize model network.set_named_parameters(trtllm_model.named_parameters()) inputs = trtllm_model.prepare_inputs(batch_size, input_len, input_len + output_len, use_cache=True, max_beam_width=beam_width) # Prepare trtllm_model(**inputs) def generate_trtllm_runtime(self, log_level, dtype, world_size, rank, hf_config, hf_model, model, use_attention_plugin, batch_size, beam_width, input_len, output_len, use_refit, use_ln_gemm_plugin, apply_query_key_layer_scaling, context_fmha_flag=ContextFMHAType.disabled, enable_remove_input_padding=False): tensorrt_llm.logger.set_level('error') mapping = tensorrt_llm.Mapping(world_size, rank, tp_size=world_size) runtime = None builder = Builder() fp16 = (dtype == 'float16') with tempfile.TemporaryDirectory() as tmpdirname: builder_config = builder.create_builder_config( name='phi', precision=dtype, timing_cache='model.cache', tensor_parallel=world_size, # TP only use_refit=use_refit, strongly_typed=fp16, ) network = builder.create_network() network.plugin_config.to_legacy_setting() if use_attention_plugin: network.plugin_config.set_gpt_attention_plugin(dtype) if use_ln_gemm_plugin: network.plugin_config.set_gemm_plugin(dtype) if enable_remove_input_padding: network.plugin_config.enable_remove_input_padding() network.plugin_config.set_context_fmha(context_fmha_flag) self.initialize_network(network, hf_model, hf_config, dtype, batch_size, beam_width, input_len, output_len, world_size, rank) engine_buffer = builder.build_engine(network, builder_config) runtime = tensorrt_llm.runtime.generation._Runtime( engine_buffer, mapping) ok = builder.save_timing_cache(builder_config, 'model.cache') assert ok, "Failed to save timing cache." return runtime, engine_buffer def load_test_cases(): test_cases = product([ ContextFMHAType.disabled, ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc ], [False, True]) return test_cases @parameterized.expand(load_test_cases) def test_phi(self, context_fmha_flag, enable_remove_input_padding): # Skip tests that are not supported in pre-ampere architecture if getSMVersion() < 80: if context_fmha_flag == ContextFMHAType.enabled_with_fp32_acc: pytest.skip( "ContextFMHAType with fp32 acc is not supported in pre-ampere architecture" ) torch.random.manual_seed(0) use_refit = False apply_query_key_layer_scaling = False model = 'phi' log_level = 'error' dtype = 'float16' world_size = 1 rank = 0 max_length = 128 batch_size = 1 beam_width = 1 seq_len = 128 total_seq_len = max_length + seq_len use_attention_plugin = True use_ln_gemm_plugin = True gpt_config, hf_gpt = self.generate_hf_model(dtype) runtime, _ = self.generate_trtllm_runtime( log_level, dtype, world_size, rank, gpt_config, hf_gpt, model, use_attention_plugin, batch_size, beam_width, seq_len, max_length, use_refit, use_ln_gemm_plugin, apply_query_key_layer_scaling, context_fmha_flag, enable_remove_input_padding) key_value_cache_buffers = [] head_size = gpt_config.hidden_size // gpt_config.num_attention_heads for i in range(gpt_config.num_hidden_layers): key_value_cache_buffers.append( torch.zeros(( batch_size, 2, gpt_config.num_attention_heads, total_seq_len, head_size, ), dtype=tensorrt_llm._utils.str_dtype_to_torch(dtype), device='cuda')) # compare context step = 0 ctx_ids = torch.randint(100, (batch_size, seq_len)).int().cuda() with torch.no_grad(): hf_outputs = hf_gpt.forward(ctx_ids, use_cache=True) torch.cuda.synchronize() ref = hf_outputs.logits[:, -1, :] ctx_context_lengths = seq_len * torch.ones( (batch_size), dtype=torch.int32, device='cuda') ctx_host_request_types = torch.tensor([0] * batch_size, dtype=torch.int32) ctx_position_ids = torch.tensor(range(seq_len), dtype=torch.int32).reshape([ 1, seq_len ]).expand([batch_size, seq_len]).cuda() ctx_last_token_ids = ctx_context_lengths.clone() # We need sequence_lengths start as context_lengths for step 0, # and it will be added one after each step. sequence_length_buffer = ctx_context_lengths.detach().clone() if enable_remove_input_padding: ctx_ids = ctx_ids.view([batch_size * seq_len]) ctx_position_ids = ctx_position_ids.view([batch_size * seq_len]) ctx_last_token_ids = torch.cumsum(ctx_last_token_ids, dim=0).int() cache_indirections = [ torch.full(( batch_size, beam_width, total_seq_len, ), 0, dtype=torch.int32, device='cuda'), torch.full(( batch_size, beam_width, total_seq_len, ), 0, dtype=torch.int32, device='cuda') ] # ping-pong buffers ctx_buffer = { 'input_ids': ctx_ids, 'context_lengths': ctx_context_lengths, 'host_request_types': ctx_host_request_types, 'position_ids': ctx_position_ids, 'last_token_ids': ctx_last_token_ids, 'cache_indirection': cache_indirections[0], } if enable_remove_input_padding: ctx_buffer['host_context_lengths'] = ctx_context_lengths.cpu() ctx_shape = {k: v.shape for k, v in ctx_buffer.items()} shape = (batch_size, 2, gpt_config.num_attention_heads, total_seq_len, gpt_config.hidden_size // gpt_config.num_attention_heads) ctx_buffer[f'host_max_attention_window_sizes'] = torch.tensor( [total_seq_len] * gpt_config.num_hidden_layers, dtype=torch.int32) ctx_shape[f'host_max_attention_window_sizes'] = ( gpt_config.num_hidden_layers, ) for i in range(gpt_config.num_hidden_layers): ctx_shape[f'past_key_value_{i}'] = shape ctx_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] ctx_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] ctx_buffer['sequence_length'] = sequence_length_buffer sequence_length_buffer = torch.add(sequence_length_buffer, step) ctx_shape['sequence_length'] = ctx_buffer['sequence_length'].shape ctx_buffer['host_past_key_value_lengths'] = ctx_context_lengths.cpu() ctx_shape['host_past_key_value_lengths'] = ctx_buffer[ 'host_past_key_value_lengths'].shape ctx_buffer['host_sink_token_length'] = torch.tensor([0], dtype=torch.int32) ctx_shape['host_sink_token_length'] = (1, ) context = runtime.ctx_context runtime._set_shape(context, ctx_shape) runtime._set_buffer(context, ctx_buffer) runtime._run(context) torch.cuda.synchronize() res = ctx_buffer['logits'] np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy(), atol=1e-1) compare_max_abs_error(ref, res, "context logits") # compare generation step = 1 step1_id = torch.randint(100, (batch_size, 1)).int().cuda() gen_position_ids = torch.ones_like(step1_id).int().cuda() * seq_len gen_context_lengths = ctx_context_lengths.clone() gen_host_request_types = torch.tensor([1] * batch_size, dtype=torch.int32) gen_last_token_ids = torch.zeros_like(gen_context_lengths).int().cuda() with torch.no_grad(): hf_input_ids = torch.cat((ctx_ids.reshape(1, seq_len), step1_id), 1) hf_outputs = hf_gpt.forward(hf_input_ids, use_cache=True) torch.cuda.synchronize() ref = hf_outputs.logits[:, -1, :] if enable_remove_input_padding: step1_id = step1_id.view([batch_size]) gen_position_ids = gen_position_ids.view([batch_size]) gen_last_token_ids = torch.ones_like( gen_context_lengths).int().cuda() gen_last_token_ids = torch.cumsum(gen_last_token_ids, dim=0).int() step1_buffer = { 'input_ids': step1_id, 'context_lengths': gen_context_lengths, 'host_request_types': gen_host_request_types, 'position_ids': gen_position_ids, 'last_token_ids': gen_last_token_ids, 'cache_indirection': cache_indirections[1], } if enable_remove_input_padding: step1_buffer['host_context_lengths'] = gen_context_lengths.cpu() step1_shape = {k: v.shape for k, v in step1_buffer.items()} step1_shape[f'host_max_attention_window_sizes'] = ( gpt_config.num_hidden_layers, ) for i in range(gpt_config.num_hidden_layers): step1_shape[f'past_key_value_{i}'] = shape step1_shape['sequence_length'] = (batch_size, ) step1_shape['host_past_key_value_lengths'] = (batch_size, ) step1_shape['host_sink_token_length'] = (1, ) step1_buffer[f'host_max_attention_window_sizes'] = torch.tensor( [total_seq_len] * gpt_config.num_hidden_layers, dtype=torch.int32) for i in range(gpt_config.num_hidden_layers): step1_buffer[f'past_key_value_{i}'] = key_value_cache_buffers[i] step1_buffer[f'present_key_value_{i}'] = key_value_cache_buffers[i] # For step 1, the sequence_lengths = context_lengths + 1. sequence_length_buffer = torch.add(sequence_length_buffer, step) step1_buffer['sequence_length'] = sequence_length_buffer step1_buffer['host_past_key_value_lengths'] = torch.tensor( [seq_len + step - 1] * batch_size, dtype=torch.int32) step1_buffer['host_sink_token_length'] = torch.tensor([0], dtype=torch.int32) context = runtime.context_1 runtime._set_shape(context, step1_shape) runtime._set_buffer(context, step1_buffer) runtime._run(context) torch.cuda.synchronize() res = step1_buffer['logits'] np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy(), atol=1e-1) compare_max_abs_error(ref, res, "generation logits") if __name__ == '__main__': unittest.main()