TensorRT-LLMs/examples/enc_dec/run.py
2023-10-15 21:26:20 +08:00

304 lines
12 KiB
Python

import argparse
import json
from pathlib import Path
import tensorrt as trt
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
import tensorrt_llm
from tensorrt_llm._utils import trt_dtype_to_torch
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
from build import get_engine_name # isort:skip
def print_tensor(tensor_name, tensor, num_elements=10):
print(
f'{tensor_name}: mean={tensor.abs().mean().item():.3f}, sum={tensor.abs().sum().item():.3f}, max={tensor.abs().max().item():.3f}'
)
# Pass num_elements=-1 will print the whole tensor
if num_elements < 0:
num_elements = torch.numel(tensor)
print(f'{tensor.flatten()[:num_elements]}')
print("Tensor Shape: ", tensor.size())
print("")
def read_config(config_path: Path):
with open(config_path, "r") as f:
config = json.load(f)
use_gpt_attention_plugin = config["plugin_config"]["gpt_attention_plugin"]
remove_input_padding = config["plugin_config"]["remove_input_padding"]
world_size = config["builder_config"]["tensor_parallel"]
assert (
world_size == tensorrt_llm.mpi_world_size()
), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"
num_heads = config["builder_config"]["num_heads"] // world_size
hidden_size = config["builder_config"]["hidden_size"] // world_size
vocab_size = config["builder_config"]["vocab_size"]
num_layers = config["builder_config"]["num_layers"]
cross_attention = config["builder_config"]["cross_attention"]
has_position_embedding = config["builder_config"]["has_position_embedding"]
has_token_type_embedding = config["builder_config"][
"has_token_type_embedding"]
num_kv_heads = num_heads
model_config = ModelConfig(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
cross_attention=cross_attention,
has_position_embedding=has_position_embedding,
has_token_type_embedding=has_token_type_embedding,
)
dtype = config["builder_config"]["precision"]
max_input_len = config["builder_config"]["max_input_len"]
return model_config, world_size, dtype, max_input_len
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--max_new_tokens", type=int, required=True)
parser.add_argument("--log_level", type=str, default="error")
parser.add_argument("--engine_dir", type=str, default="trt_engines")
parser.add_argument("--engine_name", type=str, default="enc_dec")
parser.add_argument("--tokenizer",
type=str,
help="HF tokenizer config path",
default="t5-small")
parser.add_argument("--num_beams",
type=int,
help="Use beam search if num_beams >1",
default=1)
parser.add_argument("--debug_mode",
type=bool,
help="Whether or not to turn on the debug mode",
default=False)
return parser.parse_args()
class TRTLLMEncDecModel:
def __init__(self, engine_name, engine_dir, debug_mode=False):
engine_dir = Path(engine_dir)
# model config
encoder_config_path = engine_dir / "encoder" / "config.json"
encoder_model_config, world_size, dtype, max_input_len = read_config(
encoder_config_path)
decoder_config_path = engine_dir / "decoder" / "config.json"
decoder_model_config, _, _, _ = read_config(decoder_config_path)
self.encoder_model_config = encoder_model_config
self.decoder_model_config = decoder_model_config
# MGMN config
runtime_rank = tensorrt_llm.mpi_rank()
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
# load engine
engine_name = get_engine_name(engine_name, dtype, world_size,
runtime_rank)
with open(engine_dir / "encoder" / engine_name, "rb") as f:
encoder_engine_buffer = f.read()
with open(engine_dir / "decoder" / engine_name, "rb") as f:
decoder_engine_buffer = f.read()
# session setup
self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
encoder_engine_buffer)
self.decoder_session = tensorrt_llm.runtime.GenerationSession(
decoder_model_config,
decoder_engine_buffer,
runtime_mapping,
debug_mode=debug_mode)
self.stream = torch.cuda.current_stream().cuda_stream
@classmethod
def from_engine(cls, engine_name, engine_dir, debug_mode=False):
return cls(engine_name, engine_dir, debug_mode=debug_mode)
def encoder_run(self,
input_ids,
position_ids=None,
token_type_ids=None,
debug_mode=False):
batch_size = input_ids.shape[0]
input_lengths = torch.tensor([len(x) for x in input_ids],
dtype=torch.int32,
device='cuda')
max_input_length = torch.max(input_lengths).item()
# set input tensors and shapes
inputs = {
'input_ids': input_ids,
'input_lengths': input_lengths,
}
if self.encoder_model_config.has_position_embedding:
inputs['position_ids'] = position_ids
if self.encoder_model_config.has_token_type_embedding:
inputs['token_type_ids'] = token_type_ids
for k, v in inputs.items():
self.encoder_session.context.set_input_shape(k, v.shape)
# set output tensors and shapes
outputs = {
'encoder_output':
torch.empty((batch_size, max_input_length,
self.encoder_model_config.hidden_size),
dtype=trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(
'encoder_output')),
device='cuda')
}
# -------------------------------------------
if debug_mode:
engine = self.encoder_session.engine
context = self.encoder_session.context
# setup debugging buffer for the encoder
for i in range(self.encoder_session.engine.num_io_tensors):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(
name
) == trt.TensorIOMode.OUTPUT and name not in outputs.keys():
dtype = engine.get_tensor_dtype(name)
shape = context.get_tensor_shape(name)
outputs[name] = torch.zeros(tuple(shape),
dtype=trt_dtype_to_torch(dtype),
device='cuda')
context.set_tensor_address(name, outputs[name].data_ptr())
# -------------------------------------------
# TRT session run
ok = self.encoder_session.run(inputs, outputs, self.stream)
assert ok, "Runtime execution failed"
torch.cuda.synchronize()
# -------------------------------------------
if debug_mode:
torch.cuda.synchronize()
# use print_tensor() to print the tensors registered in the encoder network
print("--------------------------------------")
print("Debug output for Encoder")
print("--------------------------------------")
print("Registered output tensors are: ", outputs.keys())
print_tensor('encoder_output', outputs['encoder_output'])
print("--------------------------------------")
# -------------------------------------------
return outputs
def generate(
self,
encoder_input_ids,
decoder_input_ids,
max_new_tokens,
num_beams=1,
pad_token_id=None,
eos_token_id=None,
bos_token_id=None,
debug_mode=False,
):
# encoder run
encoder_input_lengths = torch.tensor(
[len(x) for x in encoder_input_ids],
dtype=torch.int32,
device='cuda')
encoder_outputs = self.encoder_run(encoder_input_ids,
debug_mode=debug_mode)
encoder_output = encoder_outputs['encoder_output']
torch.cuda.synchronize()
# decoder_batch_size = decoder_input_ids.shape[0]
decoder_input_lengths = torch.tensor(
[len(x) for x in decoder_input_ids],
dtype=torch.int32,
device='cuda')
decoder_max_input_length = torch.max(decoder_input_lengths).item()
# generation config
sampling_config = SamplingConfig(end_id=eos_token_id,
pad_id=pad_token_id,
num_beams=num_beams)
# decoder autoregressive generation
self.decoder_session.setup(
decoder_input_lengths.size(0),
decoder_max_input_length,
max_new_tokens,
num_beams,
encoder_max_input_length=encoder_output.shape[1])
torch.cuda.synchronize()
output_ids = self.decoder_session.decode(
decoder_input_ids,
decoder_input_lengths,
sampling_config,
encoder_output=encoder_output,
encoder_input_lengths=encoder_input_lengths,
)
torch.cuda.synchronize()
return output_ids
if __name__ == "__main__":
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_arguments()
tensorrt_llm.logger.set_level(args.log_level)
input_text = "translate English to German: The house is wonderful, radiating timeless charm and offering a warm, inviting interior with beautiful details and a serene backyard."
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.type(
torch.IntTensor
).cuda(
) # by default int64, must cast to int32! otherwise TRT OOTB or lookup plugin will interpret as [a, 0, b, 0, c, 0, ...]
max_new_tokens = args.max_new_tokens
print("--------------------------------------")
print(
f"BOS={tokenizer.bos_token_id}, PAD={tokenizer.pad_token_id}, EOS={tokenizer.eos_token_id}"
)
print("input ids: ", input_ids)
print("input length: ", input_ids.shape[1])
print("input tokens: ", input_text)
print("--------------------------------------")
hf_model = T5ForConditionalGeneration.from_pretrained(
args.tokenizer).cuda().eval()
# start_id for decoder (could add more input_ids as forced_decoder_ids)
decoder_input_ids = torch.IntTensor(
[[hf_model.config.decoder_start_token_id]]).cuda()
# TRT-LLM runtime
tllm_model = TRTLLMEncDecModel.from_engine(args.engine_name,
args.engine_dir,
debug_mode=args.debug_mode)
tllm_output_ids = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug_mode=args.debug_mode,
)
print("--------------------------------------")
print("TRTLLM output_ids: ", tllm_output_ids)
print("TRTLLM output length: ", tllm_output_ids[0].shape[1])
print("TRTLLM tokens: ",
tokenizer.decode(tllm_output_ids[0][0], skip_special_tokens=True))
print("--------------------------------------")