mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
171 lines
5.4 KiB
Python
171 lines
5.4 KiB
Python
import argparse
|
|
import json
|
|
import math
|
|
# include plugins
|
|
# yapf: disable
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from fmha_triton import fused_attention
|
|
|
|
from tensorrt_llm import profiler
|
|
from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt,
|
|
trt_dtype_to_torch)
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.runtime.session import Session, TensorInfo
|
|
|
|
# from tensorrt_llm.plugin import get_engine_name
|
|
|
|
|
|
sys.path.append('./tmp')
|
|
from functional import fused_attention_kernel # isort:skip
|
|
# yapf: enable
|
|
|
|
|
|
def get_engine_name(head_size, dtype):
|
|
return f'fmha_{head_size}_{dtype}.engine'
|
|
|
|
|
|
def run(engine_dir,
|
|
batch_size,
|
|
seq_len,
|
|
num_heads,
|
|
head_size,
|
|
do_benchmark=False):
|
|
# Load trt engine.
|
|
engine_dir = Path(engine_dir)
|
|
config_path = engine_dir / 'config.json'
|
|
with config_path.open('r') as f:
|
|
config = json.load(f)
|
|
dtype = config['builder_config']['precision']
|
|
serialize_path = engine_dir / get_engine_name(head_size, dtype)
|
|
|
|
with open(serialize_path, 'rb') as f:
|
|
session = Session.from_serialized_engine(f.read())
|
|
|
|
# Prepare input tensors.
|
|
torch_dtype = str_dtype_to_torch(dtype) if isinstance(dtype, str) else dtype
|
|
shape = (batch_size, num_heads, seq_len, head_size)
|
|
q = torch.normal(mean=0.1,
|
|
std=0.2,
|
|
size=shape,
|
|
dtype=torch_dtype,
|
|
device='cuda')
|
|
k = torch.normal(mean=0.4,
|
|
std=0.2,
|
|
size=shape,
|
|
dtype=torch_dtype,
|
|
device='cuda')
|
|
v = torch.normal(mean=0.3,
|
|
std=0.2,
|
|
size=shape,
|
|
dtype=torch_dtype,
|
|
device='cuda')
|
|
batch_size = q.shape[0]
|
|
seq_len = q.shape[2]
|
|
|
|
inputs = {'Q': q, 'K': k, 'V': v}
|
|
|
|
# Prepare output tensors.
|
|
output_info = session.infer_shapes([
|
|
TensorInfo(name, str_dtype_to_trt(dtype), tensor.shape)
|
|
for name, tensor in inputs.items()
|
|
])
|
|
logger.debug(f'output info {output_info}')
|
|
outputs = {
|
|
t.name:
|
|
torch.empty(tuple(t.shape),
|
|
dtype=trt_dtype_to_torch(t.dtype),
|
|
device='cuda')
|
|
for t in output_info
|
|
}
|
|
|
|
# Execute model inference
|
|
stream = torch.cuda.current_stream()
|
|
ok = session.run(inputs=inputs, outputs=outputs, stream=stream.cuda_stream)
|
|
assert ok, 'Engine execution failed'
|
|
|
|
# Sanity check
|
|
stream.synchronize()
|
|
sm_scale = 1.0 / math.sqrt(head_size)
|
|
ref = fused_attention(q, k, v, sm_scale)
|
|
out = outputs["out"]
|
|
logger.debug(
|
|
f'Out: vals: {out.view(1, -1)} abs_sum: {out.float().abs().sum()}')
|
|
logger.debug(
|
|
f'Ref: vals: {ref.view(1, -1)} abs_sum: {ref.float().abs().sum()}')
|
|
torch.testing.assert_close(out, ref)
|
|
|
|
if do_benchmark:
|
|
n_repeats = 10
|
|
|
|
# For fair comparison, pre-allocate buffers as trt plugin does.
|
|
shape = (q.shape[0] * q.shape[1], q.shape[2])
|
|
L = torch.empty(shape, device=q.device, dtype=torch.float32)
|
|
m = torch.empty(shape, device=q.device, dtype=torch.float32)
|
|
o = torch.empty_like(q)
|
|
|
|
# Triton warm-up
|
|
fused_attention(q, k, v, sm_scale, l_buf=L, m_buf=m, o_buf=o)
|
|
stream.synchronize()
|
|
for _ in range(n_repeats):
|
|
profiler.start('Triton')
|
|
fused_attention(q, k, v, sm_scale, l_buf=L, m_buf=m, o_buf=o)
|
|
stream.synchronize()
|
|
profiler.stop('Triton')
|
|
|
|
# TRT warm-up
|
|
stream.synchronize()
|
|
ok = session.run(inputs=inputs,
|
|
outputs=outputs,
|
|
stream=stream.cuda_stream)
|
|
stream.synchronize()
|
|
for _ in range(n_repeats):
|
|
profiler.start('TRT Plugin')
|
|
ok = session.run(inputs=inputs,
|
|
outputs=outputs,
|
|
stream=stream.cuda_stream)
|
|
stream.synchronize()
|
|
profiler.stop('TRT Plugin')
|
|
assert ok
|
|
profiler.summary()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument('--batch_size', type=int, default=4)
|
|
parser.add_argument('--seq_len', type=int, default=128)
|
|
parser.add_argument('--num_heads', type=int, default=8)
|
|
parser.add_argument('--head_size', type=int, default=64)
|
|
parser.add_argument('--log_level', type=str, default='info')
|
|
parser.add_argument(
|
|
'--engine_dir',
|
|
type=Path,
|
|
default='outputs',
|
|
help='The directory where serialized engine files locate.')
|
|
parser.add_argument(
|
|
'--benchmark',
|
|
action='store_true',
|
|
help='Do performance benchmark compared to triton baseline.')
|
|
args = parser.parse_args()
|
|
|
|
logger.set_level(args.log_level)
|
|
logger.info('Parameters'.center(40, '='))
|
|
for k, v in vars(args).items():
|
|
logger.info(f' - {k.ljust(15, ".")}: {v}')
|
|
logger.info(''.center(40, '='))
|
|
|
|
assert args.engine_dir.exists(), \
|
|
f"Engine file {str(args.engine_dir)} doesn't exists."
|
|
|
|
logger.info('Inference using the built TensorRT engine.')
|
|
run(args.engine_dir,
|
|
args.batch_size,
|
|
args.seq_len,
|
|
args.num_heads,
|
|
args.head_size,
|
|
do_benchmark=args.benchmark)
|
|
logger.info('Done.')
|