TensorRT-LLMs/examples/openai_triton/plugin_autogen/run_engine.py
2024-12-16 21:50:47 -08:00

171 lines
5.4 KiB
Python

import argparse
import json
import math
# include plugins
# yapf: disable
import sys
from pathlib import Path
import torch
from fmha_triton import fused_attention
from tensorrt_llm import profiler
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
trt_dtype_to_torch)
from tensorrt_llm.logger import logger
from tensorrt_llm.runtime.session import Session, TensorInfo
# from tensorrt_llm.plugin import get_engine_name
sys.path.append('./tmp')
from functional import fused_attention_kernel # isort:skip
# yapf: enable
def get_engine_name(head_size, dtype):
return f'fmha_{head_size}_{dtype}.engine'
def run(engine_dir,
batch_size,
seq_len,
num_heads,
head_size,
do_benchmark=False):
# Load trt engine.
engine_dir = Path(engine_dir)
config_path = engine_dir / 'config.json'
with config_path.open('r') as f:
config = json.load(f)
dtype = config['builder_config']['precision']
serialize_path = engine_dir / get_engine_name(head_size, dtype)
with open(serialize_path, 'rb') as f:
session = Session.from_serialized_engine(f.read())
# Prepare input tensors.
torch_dtype = str_dtype_to_torch(dtype) if isinstance(dtype, str) else dtype
shape = (batch_size, num_heads, seq_len, head_size)
q = torch.normal(mean=0.1,
std=0.2,
size=shape,
dtype=torch_dtype,
device='cuda')
k = torch.normal(mean=0.4,
std=0.2,
size=shape,
dtype=torch_dtype,
device='cuda')
v = torch.normal(mean=0.3,
std=0.2,
size=shape,
dtype=torch_dtype,
device='cuda')
batch_size = q.shape[0]
seq_len = q.shape[2]
inputs = {'Q': q, 'K': k, 'V': v}
# Prepare output tensors.
output_info = session.infer_shapes([
TensorInfo(name, str_dtype_to_trt(dtype), tensor.shape)
for name, tensor in inputs.items()
])
logger.debug(f'output info {output_info}')
outputs = {
t.name:
torch.empty(tuple(t.shape),
dtype=trt_dtype_to_torch(t.dtype),
device='cuda')
for t in output_info
}
# Execute model inference
stream = torch.cuda.current_stream()
ok = session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
assert ok, 'Engine execution failed'
# Sanity check
stream.synchronize()
sm_scale = 1.0 / math.sqrt(head_size)
ref = fused_attention(q, k, v, sm_scale)
out = outputs["out"]
logger.debug(
f'Out: vals: {out.view(1, -1)} abs_sum: {out.float().abs().sum()}')
logger.debug(
f'Ref: vals: {ref.view(1, -1)} abs_sum: {ref.float().abs().sum()}')
torch.testing.assert_close(out, ref)
if do_benchmark:
n_repeats = 10
# For fair comparison, pre-allocate buffers as trt plugin does.
shape = (q.shape[0] * q.shape[1], q.shape[2])
L = torch.empty(shape, device=q.device, dtype=torch.float32)
m = torch.empty(shape, device=q.device, dtype=torch.float32)
o = torch.empty_like(q)
# Triton warm-up
fused_attention(q, k, v, sm_scale, l_buf=L, m_buf=m, o_buf=o)
stream.synchronize()
for _ in range(n_repeats):
profiler.start('Triton')
fused_attention(q, k, v, sm_scale, l_buf=L, m_buf=m, o_buf=o)
stream.synchronize()
profiler.stop('Triton')
# TRT warm-up
stream.synchronize()
ok = session.run(inputs=inputs,
outputs=outputs,
stream=stream.cuda_stream)
stream.synchronize()
for _ in range(n_repeats):
profiler.start('TRT Plugin')
ok = session.run(inputs=inputs,
outputs=outputs,
stream=stream.cuda_stream)
stream.synchronize()
profiler.stop('TRT Plugin')
assert ok
profiler.summary()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--seq_len', type=int, default=128)
parser.add_argument('--num_heads', type=int, default=8)
parser.add_argument('--head_size', type=int, default=64)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument(
'--engine_dir',
type=Path,
default='outputs',
help='The directory where serialized engine files locate.')
parser.add_argument(
'--benchmark',
action='store_true',
help='Do performance benchmark compared to triton baseline.')
args = parser.parse_args()
logger.set_level(args.log_level)
logger.info('Parameters'.center(40, '='))
for k, v in vars(args).items():
logger.info(f' - {k.ljust(15, ".")}: {v}')
logger.info(''.center(40, '='))
assert args.engine_dir.exists(), \
f"Engine file {str(args.engine_dir)} doesn't exists."
logger.info('Inference using the built TensorRT engine.')
run(args.engine_dir,
args.batch_size,
args.seq_len,
args.num_heads,
args.head_size,
do_benchmark=args.benchmark)
logger.info('Done.')