mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
304 lines
12 KiB
Python
304 lines
12 KiB
Python
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import tensorrt as trt
|
|
import torch
|
|
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm._utils import trt_dtype_to_torch
|
|
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
|
|
|
|
from build import get_engine_name # isort:skip
|
|
|
|
|
|
def print_tensor(tensor_name, tensor, num_elements=10):
|
|
print(
|
|
f'{tensor_name}: mean={tensor.abs().mean().item():.3f}, sum={tensor.abs().sum().item():.3f}, max={tensor.abs().max().item():.3f}'
|
|
)
|
|
# Pass num_elements=-1 will print the whole tensor
|
|
if num_elements < 0:
|
|
num_elements = torch.numel(tensor)
|
|
print(f'{tensor.flatten()[:num_elements]}')
|
|
print("Tensor Shape: ", tensor.size())
|
|
print("")
|
|
|
|
|
|
def read_config(config_path: Path):
|
|
with open(config_path, "r") as f:
|
|
config = json.load(f)
|
|
use_gpt_attention_plugin = config["plugin_config"]["gpt_attention_plugin"]
|
|
remove_input_padding = config["plugin_config"]["remove_input_padding"]
|
|
world_size = config["builder_config"]["tensor_parallel"]
|
|
assert (
|
|
world_size == tensorrt_llm.mpi_world_size()
|
|
), f"Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})"
|
|
num_heads = config["builder_config"]["num_heads"] // world_size
|
|
hidden_size = config["builder_config"]["hidden_size"] // world_size
|
|
vocab_size = config["builder_config"]["vocab_size"]
|
|
num_layers = config["builder_config"]["num_layers"]
|
|
cross_attention = config["builder_config"]["cross_attention"]
|
|
has_position_embedding = config["builder_config"]["has_position_embedding"]
|
|
has_token_type_embedding = config["builder_config"][
|
|
"has_token_type_embedding"]
|
|
num_kv_heads = num_heads
|
|
|
|
model_config = ModelConfig(
|
|
num_heads=num_heads,
|
|
num_kv_heads=num_kv_heads,
|
|
hidden_size=hidden_size,
|
|
vocab_size=vocab_size,
|
|
num_layers=num_layers,
|
|
gpt_attention_plugin=use_gpt_attention_plugin,
|
|
remove_input_padding=remove_input_padding,
|
|
cross_attention=cross_attention,
|
|
has_position_embedding=has_position_embedding,
|
|
has_token_type_embedding=has_token_type_embedding,
|
|
)
|
|
|
|
dtype = config["builder_config"]["precision"]
|
|
max_input_len = config["builder_config"]["max_input_len"]
|
|
|
|
return model_config, world_size, dtype, max_input_len
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--max_new_tokens", type=int, required=True)
|
|
parser.add_argument("--log_level", type=str, default="error")
|
|
parser.add_argument("--engine_dir", type=str, default="trt_engines")
|
|
parser.add_argument("--engine_name", type=str, default="enc_dec")
|
|
parser.add_argument("--tokenizer",
|
|
type=str,
|
|
help="HF tokenizer config path",
|
|
default="t5-small")
|
|
parser.add_argument("--num_beams",
|
|
type=int,
|
|
help="Use beam search if num_beams >1",
|
|
default=1)
|
|
parser.add_argument("--debug_mode",
|
|
type=bool,
|
|
help="Whether or not to turn on the debug mode",
|
|
default=False)
|
|
return parser.parse_args()
|
|
|
|
|
|
class TRTLLMEncDecModel:
|
|
|
|
def __init__(self, engine_name, engine_dir, debug_mode=False):
|
|
engine_dir = Path(engine_dir)
|
|
|
|
# model config
|
|
encoder_config_path = engine_dir / "encoder" / "config.json"
|
|
encoder_model_config, world_size, dtype, max_input_len = read_config(
|
|
encoder_config_path)
|
|
decoder_config_path = engine_dir / "decoder" / "config.json"
|
|
decoder_model_config, _, _, _ = read_config(decoder_config_path)
|
|
self.encoder_model_config = encoder_model_config
|
|
self.decoder_model_config = decoder_model_config
|
|
|
|
# MGMN config
|
|
runtime_rank = tensorrt_llm.mpi_rank()
|
|
runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank)
|
|
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
|
|
|
|
# load engine
|
|
engine_name = get_engine_name(engine_name, dtype, world_size,
|
|
runtime_rank)
|
|
with open(engine_dir / "encoder" / engine_name, "rb") as f:
|
|
encoder_engine_buffer = f.read()
|
|
with open(engine_dir / "decoder" / engine_name, "rb") as f:
|
|
decoder_engine_buffer = f.read()
|
|
|
|
# session setup
|
|
self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
|
|
encoder_engine_buffer)
|
|
self.decoder_session = tensorrt_llm.runtime.GenerationSession(
|
|
decoder_model_config,
|
|
decoder_engine_buffer,
|
|
runtime_mapping,
|
|
debug_mode=debug_mode)
|
|
self.stream = torch.cuda.current_stream().cuda_stream
|
|
|
|
@classmethod
|
|
def from_engine(cls, engine_name, engine_dir, debug_mode=False):
|
|
return cls(engine_name, engine_dir, debug_mode=debug_mode)
|
|
|
|
def encoder_run(self,
|
|
input_ids,
|
|
position_ids=None,
|
|
token_type_ids=None,
|
|
debug_mode=False):
|
|
batch_size = input_ids.shape[0]
|
|
input_lengths = torch.tensor([len(x) for x in input_ids],
|
|
dtype=torch.int32,
|
|
device='cuda')
|
|
max_input_length = torch.max(input_lengths).item()
|
|
|
|
# set input tensors and shapes
|
|
inputs = {
|
|
'input_ids': input_ids,
|
|
'input_lengths': input_lengths,
|
|
}
|
|
if self.encoder_model_config.has_position_embedding:
|
|
inputs['position_ids'] = position_ids
|
|
if self.encoder_model_config.has_token_type_embedding:
|
|
inputs['token_type_ids'] = token_type_ids
|
|
for k, v in inputs.items():
|
|
self.encoder_session.context.set_input_shape(k, v.shape)
|
|
|
|
# set output tensors and shapes
|
|
outputs = {
|
|
'encoder_output':
|
|
torch.empty((batch_size, max_input_length,
|
|
self.encoder_model_config.hidden_size),
|
|
dtype=trt_dtype_to_torch(
|
|
self.encoder_session.engine.get_tensor_dtype(
|
|
'encoder_output')),
|
|
device='cuda')
|
|
}
|
|
|
|
# -------------------------------------------
|
|
if debug_mode:
|
|
engine = self.encoder_session.engine
|
|
context = self.encoder_session.context
|
|
# setup debugging buffer for the encoder
|
|
for i in range(self.encoder_session.engine.num_io_tensors):
|
|
name = engine.get_tensor_name(i)
|
|
if engine.get_tensor_mode(
|
|
name
|
|
) == trt.TensorIOMode.OUTPUT and name not in outputs.keys():
|
|
dtype = engine.get_tensor_dtype(name)
|
|
shape = context.get_tensor_shape(name)
|
|
outputs[name] = torch.zeros(tuple(shape),
|
|
dtype=trt_dtype_to_torch(dtype),
|
|
device='cuda')
|
|
context.set_tensor_address(name, outputs[name].data_ptr())
|
|
# -------------------------------------------
|
|
|
|
# TRT session run
|
|
ok = self.encoder_session.run(inputs, outputs, self.stream)
|
|
assert ok, "Runtime execution failed"
|
|
torch.cuda.synchronize()
|
|
|
|
# -------------------------------------------
|
|
if debug_mode:
|
|
torch.cuda.synchronize()
|
|
# use print_tensor() to print the tensors registered in the encoder network
|
|
print("--------------------------------------")
|
|
print("Debug output for Encoder")
|
|
print("--------------------------------------")
|
|
print("Registered output tensors are: ", outputs.keys())
|
|
print_tensor('encoder_output', outputs['encoder_output'])
|
|
print("--------------------------------------")
|
|
# -------------------------------------------
|
|
|
|
return outputs
|
|
|
|
def generate(
|
|
self,
|
|
encoder_input_ids,
|
|
decoder_input_ids,
|
|
max_new_tokens,
|
|
num_beams=1,
|
|
pad_token_id=None,
|
|
eos_token_id=None,
|
|
bos_token_id=None,
|
|
debug_mode=False,
|
|
):
|
|
# encoder run
|
|
encoder_input_lengths = torch.tensor(
|
|
[len(x) for x in encoder_input_ids],
|
|
dtype=torch.int32,
|
|
device='cuda')
|
|
encoder_outputs = self.encoder_run(encoder_input_ids,
|
|
debug_mode=debug_mode)
|
|
encoder_output = encoder_outputs['encoder_output']
|
|
torch.cuda.synchronize()
|
|
|
|
# decoder_batch_size = decoder_input_ids.shape[0]
|
|
decoder_input_lengths = torch.tensor(
|
|
[len(x) for x in decoder_input_ids],
|
|
dtype=torch.int32,
|
|
device='cuda')
|
|
decoder_max_input_length = torch.max(decoder_input_lengths).item()
|
|
|
|
# generation config
|
|
sampling_config = SamplingConfig(end_id=eos_token_id,
|
|
pad_id=pad_token_id,
|
|
num_beams=num_beams)
|
|
|
|
# decoder autoregressive generation
|
|
self.decoder_session.setup(
|
|
decoder_input_lengths.size(0),
|
|
decoder_max_input_length,
|
|
max_new_tokens,
|
|
num_beams,
|
|
encoder_max_input_length=encoder_output.shape[1])
|
|
|
|
torch.cuda.synchronize()
|
|
output_ids = self.decoder_session.decode(
|
|
decoder_input_ids,
|
|
decoder_input_lengths,
|
|
sampling_config,
|
|
encoder_output=encoder_output,
|
|
encoder_input_lengths=encoder_input_lengths,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
return output_ids
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import os
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
args = parse_arguments()
|
|
tensorrt_llm.logger.set_level(args.log_level)
|
|
|
|
input_text = "translate English to German: The house is wonderful, radiating timeless charm and offering a warm, inviting interior with beautiful details and a serene backyard."
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
|
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.type(
|
|
torch.IntTensor
|
|
).cuda(
|
|
) # by default int64, must cast to int32! otherwise TRT OOTB or lookup plugin will interpret as [a, 0, b, 0, c, 0, ...]
|
|
max_new_tokens = args.max_new_tokens
|
|
|
|
print("--------------------------------------")
|
|
print(
|
|
f"BOS={tokenizer.bos_token_id}, PAD={tokenizer.pad_token_id}, EOS={tokenizer.eos_token_id}"
|
|
)
|
|
print("input ids: ", input_ids)
|
|
print("input length: ", input_ids.shape[1])
|
|
print("input tokens: ", input_text)
|
|
print("--------------------------------------")
|
|
|
|
hf_model = T5ForConditionalGeneration.from_pretrained(
|
|
args.tokenizer).cuda().eval()
|
|
# start_id for decoder (could add more input_ids as forced_decoder_ids)
|
|
decoder_input_ids = torch.IntTensor(
|
|
[[hf_model.config.decoder_start_token_id]]).cuda()
|
|
|
|
# TRT-LLM runtime
|
|
tllm_model = TRTLLMEncDecModel.from_engine(args.engine_name,
|
|
args.engine_dir,
|
|
debug_mode=args.debug_mode)
|
|
tllm_output_ids = tllm_model.generate(
|
|
encoder_input_ids=input_ids,
|
|
decoder_input_ids=decoder_input_ids,
|
|
max_new_tokens=max_new_tokens,
|
|
num_beams=args.num_beams,
|
|
bos_token_id=tokenizer.bos_token_id,
|
|
pad_token_id=tokenizer.pad_token_id,
|
|
eos_token_id=tokenizer.eos_token_id,
|
|
debug_mode=args.debug_mode,
|
|
)
|
|
|
|
print("--------------------------------------")
|
|
print("TRTLLM output_ids: ", tllm_output_ids)
|
|
print("TRTLLM output length: ", tllm_output_ids[0].shape[1])
|
|
print("TRTLLM tokens: ",
|
|
tokenizer.decode(tllm_output_ids[0][0], skip_special_tokens=True))
|
|
print("--------------------------------------")
|