# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tempfile import unittest from collections import OrderedDict from itertools import product import numpy as np import parameterized import tensorrt as trt import torch from parameterized import parameterized from transformers import BertConfig, BertForQuestionAnswering, BertModel import tensorrt_llm import tensorrt_llm.runtime from tensorrt_llm import Builder from tensorrt_llm._utils import trt_dtype_to_torch from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.runtime import TensorInfo def extract_layer_idx(name): ss = name.split('.') for s in ss: if s.isdigit(): return s return None def split(v, tp_size, idx, dim=0): if tp_size == 1: return v if len(v.shape) == 1: return np.ascontiguousarray(np.split(v, tp_size)[idx]) elif len(v.shape) == 2: return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) return None def load_from_hf_bert(tensorrt_llm_bert, hf_bert, hf_bert_config, rank=0, tensor_parallel=1, fp16=False): qkv_weight = [[None, None, None] for _ in range(hf_bert_config.num_hidden_layers)] qkv_bias = [[None, None, None] for _ in range(hf_bert_config.num_hidden_layers)] torch_dtype = torch.float16 if fp16 else torch.float32 for k, v in hf_bert.state_dict().items(): v = v.to(torch_dtype).cpu().numpy() if 'embeddings.word_embeddings.weight' in k: tensorrt_llm_bert.embedding.vocab_embedding.weight.value = v elif 'embeddings.position_embeddings.weight' in k: tensorrt_llm_bert.embedding.position_embedding.weight.value = v elif 'embeddings.token_type_embeddings.weight' in k: tensorrt_llm_bert.embedding.token_embedding.weight.value = v elif 'embeddings.LayerNorm.weight' in k: tensorrt_llm_bert.embedding.embedding_ln.weight.value = v elif 'embeddings.LayerNorm.bias' in k: tensorrt_llm_bert.embedding.embedding_ln.bias.value = v else: layer_idx = extract_layer_idx(k) if layer_idx is None: continue idx = int(layer_idx) if 'attention.output.dense.weight' in k: tensorrt_llm_bert.layers[ idx].attention.dense.weight.value = split(v, tensor_parallel, rank, dim=1) elif 'attention.output.dense.bias' in k: tensorrt_llm_bert.layers[idx].attention.dense.bias.value = v elif 'attention.output.LayerNorm.weight' in k: tensorrt_llm_bert.layers[idx].input_layernorm.weight.value = v elif 'attention.output.LayerNorm.bias' in k: tensorrt_llm_bert.layers[idx].input_layernorm.bias.value = v elif 'intermediate.dense.weight' in k: tensorrt_llm_bert.layers[idx].mlp.fc.weight.value = split( v, tensor_parallel, rank) elif 'intermediate.dense.bias' in k: tensorrt_llm_bert.layers[idx].mlp.fc.bias.value = split( v, tensor_parallel, rank) elif 'output.dense.weight' in k: tensorrt_llm_bert.layers[idx].mlp.proj.weight.value = split( v, tensor_parallel, rank, dim=1) elif 'output.dense.bias' in k: tensorrt_llm_bert.layers[idx].mlp.proj.bias.value = v elif 'output.LayerNorm.weight' in k: tensorrt_llm_bert.layers[idx].post_layernorm.weight.value = v elif 'output.LayerNorm.bias' in k: tensorrt_llm_bert.layers[idx].post_layernorm.bias.value = v elif 'attention.self.query.weight' in k: qkv_weight[idx][0] = v elif 'attention.self.query.bias' in k: qkv_bias[idx][0] = v elif 'attention.self.key.weight' in k: qkv_weight[idx][1] = v elif 'attention.self.key.bias' in k: qkv_bias[idx][1] = v elif 'attention.self.value.weight' in k: qkv_weight[idx][2] = v elif 'attention.self.value.bias' in k: qkv_bias[idx][2] = v for i in range(hf_bert_config.num_hidden_layers): tensorrt_llm_bert.layers[i].attention.qkv.weight.value = split( np.concatenate(qkv_weight[i]), tensor_parallel, rank) tensorrt_llm_bert.layers[i].attention.qkv.bias.value = split( np.concatenate(qkv_bias[i]), tensor_parallel, rank) def load_from_hf_qa_bert(tensorrt_llm_qa_bert, hf_qa_bert, hf_bert_config, rank=0, tensor_parallel=1, fp16=False): load_from_hf_bert(tensorrt_llm_qa_bert.bert, hf_qa_bert, hf_bert_config, rank, tensor_parallel, fp16) states = hf_qa_bert.state_dict() torch_dtype = torch.float16 if fp16 else torch.float32 tensorrt_llm_qa_bert.qa_outputs.weight.value = states[ 'qa_outputs.weight'].to(torch_dtype).cpu().numpy() tensorrt_llm_qa_bert.qa_outputs.bias.value = states['qa_outputs.bias'].to( torch_dtype).cpu().numpy() class TestBert(unittest.TestCase): def load_test_cases(): models = [BertModel.__name__, BertForQuestionAnswering.__name__] test_cases = [] test_cases += product(models, [False], [False], [False], [ContextFMHAType.disabled], ['float32']) test_cases += product(models, [False], [True], [True], [ ContextFMHAType.disabled, ContextFMHAType.enabled, ContextFMHAType.enabled_with_fp32_acc ], ['float16']) return test_cases def custom_name_func(testcase_func, param_num, param): return "%s_%s" % ( testcase_func.__name__, parameterized.to_safe_name("_".join(str(x) for x in param.args)), ) @parameterized.expand(load_test_cases, name_func=custom_name_func) def test_bert(self, model, use_refit, use_plugin, fast_building, context_fmha_type, dtype): tensorrt_llm.logger.set_level('error') fp16 = (dtype == 'float16') world_size = 1 rank = 0 batch_size = 8 input_len = 128 vocab_size = 51200 num_layers = 12 num_heads = 12 hidden_act = 'gelu' max_position_embeddings = 512 hidden_size = 768 bs_range = [1, (batch_size + 1) // 2, batch_size] inlen_range = [1, (input_len + 1) // 2, input_len] torch_dtype = torch.float16 if fp16 else torch.float32 trt_dtype = trt.float16 if fp16 else trt.float32 timing_cache = 'model.cache' torch.manual_seed(0) builder = Builder() with tempfile.TemporaryDirectory() as tmpdirname: builder_config = builder.create_builder_config( name=model, precision='float16' if fp16 else 'float32', timing_cache=timing_cache, tensor_parallel=world_size, # TP only use_refit=use_refit) network = builder.create_network() if use_plugin: network.plugin_config.set_bert_attention_plugin(dtype) if fast_building: network.plugin_config.set_gemm_plugin(dtype) network.plugin_config.set_context_fmha(context_fmha_type) with net_guard(network): # Prepare inputs # TODO: could class be better than dict for profiles? input_ids = tensorrt_llm.Tensor(name='input_ids', dtype=trt.int32, shape=[-1, -1], dim_range=OrderedDict([ ('batch_size', [bs_range]), ('input_len', [inlen_range]) ])) input_lengths = tensorrt_llm.Tensor(name='input_lengths', dtype=trt.int32, shape=[-1], dim_range=OrderedDict([ ('batch_size', [bs_range]) ])) # Initialize model bert_config = BertConfig( vocab_size=vocab_size, hidden_size=hidden_size, num_hidden_layers=num_layers, num_attention_heads=num_heads, intermediate_size=4 * hidden_size, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, torch_dtype=torch_dtype, ) output_name = "hidden_states" if model == BertModel.__name__: hf_bert = BertModel( bert_config, add_pooling_layer=False).cuda().to(torch_dtype).eval() tensorrt_llm_bert = tensorrt_llm.models.BertModel( num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, type_vocab_size=bert_config.type_vocab_size, mapping=tensorrt_llm.Mapping( world_size=world_size, rank=rank, tp_size=world_size), # TP only dtype=trt_dtype) load_from_hf_bert(tensorrt_llm_bert, hf_bert, bert_config, rank=rank, tensor_parallel=world_size, fp16=fp16) elif model == BertForQuestionAnswering.__name__: hf_bert = BertForQuestionAnswering(bert_config).cuda().to( torch_dtype).eval() output_name = "logits" tensorrt_llm_bert = tensorrt_llm.models.BertForQuestionAnswering( num_layers=num_layers, num_heads=num_heads, hidden_size=hidden_size, vocab_size=vocab_size, hidden_act=hidden_act, max_position_embeddings=max_position_embeddings, type_vocab_size=bert_config.type_vocab_size, num_labels= 2, # just make it a const here, seems to me not worth as a config mapping=tensorrt_llm.Mapping( world_size=world_size, rank=rank, tp_size=world_size), # TP only dtype=trt_dtype) load_from_hf_qa_bert(tensorrt_llm_bert, hf_bert, bert_config, rank=rank, tensor_parallel=world_size, fp16=fp16) else: assert False, f"Unknown model {model}" # Prepare network.set_named_parameters( tensorrt_llm_bert.named_parameters()) # Forward output = tensorrt_llm_bert(input_ids=input_ids, input_lengths=input_lengths) # Mark outputs output_dtype = trt.float16 if fp16 else trt.float32 output.mark_output(output_name, output_dtype) # Build engine engine_buffer = builder.build_engine(network, builder_config) session = tensorrt_llm.runtime.Session.from_serialized_engine( engine_buffer) stream = torch.cuda.current_stream().cuda_stream # Inference # The dtype of input_ids should be queried from the engine, # for testing purpose, int32 is fine for now. input_ids = torch.randint(100, (batch_size, input_len)).int().cuda() input_lengths = input_len * torch.ones( (batch_size, ), dtype=torch.int32, device='cuda') output_info = session.infer_shapes([ TensorInfo('input_ids', trt.DataType.INT32, (batch_size, input_len)), TensorInfo('input_lengths', trt.DataType.INT32, (batch_size, )) ]) session._print_engine_info() outputs = { t.name: torch.empty(tuple(t.shape), dtype=trt_dtype_to_torch(t.dtype), device='cuda') for t in output_info } assert output_name in outputs, f'{output_name} not found in outputs' session.run(inputs={ 'input_ids': input_ids, 'input_lengths': input_lengths }, outputs=outputs, stream=stream) torch.cuda.synchronize() res = outputs[output_name] with torch.no_grad(): hf_outputs = hf_bert.forward(input_ids) torch.cuda.synchronize() if model == BertModel.__name__: ref = hf_outputs.last_hidden_state np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy(), atol=1e-2, rtol=1e-2) elif model == BertForQuestionAnswering.__name__: res_start_logits, res_end_logits = torch.split(res, 1, -1) res_start_logits = res_start_logits.squeeze() res_end_logits = res_end_logits.squeeze() ref_start_logits = hf_outputs.start_logits ref_end_logits = hf_outputs.end_logits np.testing.assert_allclose(ref_start_logits.cpu().numpy(), res_start_logits.cpu().numpy(), atol=1.5e-2) np.testing.assert_allclose(ref_end_logits.cpu().numpy(), res_end_logits.cpu().numpy(), atol=1.5e-2) if __name__ == '__main__': unittest.main()