TensorRT-LLMs/tests/model/test_gpt_e2e.py
Kaiyu Xie 8681b3a4c0
open source 4dbf696ae9b74a26829d120b67ab8443d70c8e58 (#2297)
* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvanesh.sridharan@sprinklr.com>
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
2024-10-08 12:19:19 +02:00

253 lines
9.5 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import unittest
from pathlib import Path
from typing import Sequence
import numpy as np
import torch
import torch.multiprocessing as multiprocessing
import tensorrt_llm
from tensorrt_llm.runtime import ModelRunner, SamplingConfig
END_ID = 50256
PAD_ID = 50256
work_dir = Path(__file__).parent.resolve() / 'check_gpt'
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.llm_data import llm_models_root
from utils.util import getSMVersion
gpt_example_root = os.path.join(os.path.dirname(__file__), '../../examples/gpt')
def run_command(command: Sequence[str], *, cwd=None, **kwargs) -> None:
print(f"Running: cd %s && %s" % (str(cwd or Path.cwd()), " ".join(command)),
flush=True)
subprocess.check_call(command, cwd=cwd, **kwargs)
def convert_ckpt(model_dir: str, output_dir: str, *args):
convert_cmd = [
sys.executable, f"{gpt_example_root}/convert_checkpoint.py",
f"--model_dir={model_dir}", f"--output_dir={output_dir}"
] + list(args)
run_command(convert_cmd)
def build_engine(checkpoint_dir: str, engine_dir: str, *args):
build_cmd = [
"trtllm-build",
f"--checkpoint_dir={checkpoint_dir}",
f"--output_dir={engine_dir}",
'--log_level=verbose',
'--max_batch_size=256',
'--max_input_len=40',
'--max_seq_len=60',
'--max_beam_width=2',
]
legacy_args = [
"--gpt_attention_plugin=disable",
"--context_fmha=disable",
"--paged_kv_cache=disable",
"--remove_input_padding=disable",
"--enable_xqa=disable",
]
build_cmd = build_cmd + legacy_args + list(args)
run_command(build_cmd)
def build_engines():
models_root = llm_models_root()
gpt2_dir = str(models_root / "gpt2")
# clone if not exists in the cache or the cache is not set
if models_root is None or not os.path.exists(gpt2_dir):
gpt2_dir = work_dir / 'gpt2'
print("Pulling gpt2 from huggingface")
subprocess.check_call(["rm", "-rf", str(gpt2_dir)])
subprocess.check_call(
["git", "clone", "https://huggingface.co/gpt2",
str(gpt2_dir)])
pytorch_model = str(gpt2_dir / "model.safetensors")
subprocess.check_call(["rm", pytorch_model])
subprocess.check_call([
"wget", "-q",
"https://huggingface.co/gpt2/resolve/main/model.safetensors", "-O",
pytorch_model
])
ckpt_dir = work_dir / 'c-model/gpt2'
engine_dir = work_dir / 'rt_engine/gpt2'
print("\nConverting to fp32")
fp32_ckpt_dir = ckpt_dir / 'fp32/1-gpu'
convert_ckpt(str(gpt2_dir), str(fp32_ckpt_dir), "--dtype=float32")
print("\nBuilding fp32 engines")
build_engine(str(fp32_ckpt_dir), str(engine_dir / 'fp32-default/1-gpu'))
build_engine(str(fp32_ckpt_dir), str(engine_dir / 'fp32-plugin/1-gpu'),
'--gpt_attention_plugin=float32')
print("\nConverting to fp16")
fp16_ckpt_dir = ckpt_dir / 'fp16/1-gpu'
convert_ckpt(str(gpt2_dir), str(fp16_ckpt_dir), "--dtype=float16")
print("\nBuilding fp16 engines")
build_engine(str(fp16_ckpt_dir), str(engine_dir / 'fp16-default/1-gpu'))
build_engine(str(fp16_ckpt_dir), str(engine_dir / 'fp16-plugin/1-gpu'),
'--gpt_attention_plugin=float16')
# Skip tests that are not supported in pre-ampere architecture
if getSMVersion() >= 80:
build_engine(str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin-fmha/1-gpu'),
'--gpt_attention_plugin=float16', '--context_fmha=enable')
build_engine(str(fp16_ckpt_dir),
str(engine_dir / 'fp16-plugin-packed/1-gpu'),
'--gpt_attention_plugin=float16',
'--remove_input_padding=enable')
# Skip tests that are not supported in pre-ampere architecture
if getSMVersion() >= 80:
build_engine(fp16_ckpt_dir,
str(engine_dir / 'fp16-plugin-packed-fmha/1-gpu'),
'--gpt_attention_plugin=float16',
'--remove_input_padding=enable', '--context_fmha=enable')
print("Done.")
def check_accuracy(engine_dir, input_tokens, max_output_len):
runtime_rank = tensorrt_llm.mpi_rank()
runner = ModelRunner.from_dir(engine_dir, rank=runtime_rank)
sampling_config = SamplingConfig(end_id=END_ID, pad_id=END_ID)
all_input_ids = [torch.tensor(x, dtype=torch.int32) for x in input_tokens]
all_input_lengths = [len(x) for x in input_tokens]
num_samples = len(input_tokens)
expect_output = None
for j, batch_size in enumerate([1, 2, 4, 8, 4, 2, 1]):
output = []
output_with_fake_dim = []
print(f"Running batch size: {batch_size}")
for i in range(num_samples // batch_size):
batch_input_ids = all_input_ids[i * batch_size:(i + 1) * batch_size]
batch_input_lengths = all_input_lengths[i * batch_size:(i + 1) *
batch_size]
max_input_length = max(batch_input_lengths)
output_ids = runner.generate(batch_input_ids,
sampling_config=sampling_config,
max_new_tokens=max_output_len)
torch.cuda.synchronize()
if runner.remove_input_padding:
runner.session.setup(batch_size, max_input_length,
max_output_len)
batch_input_ids_with_fake_dim = torch.concat(
batch_input_ids).unsqueeze(0)
output_ids_with_fake_dim = runner.session.decode(
batch_input_ids_with_fake_dim.cuda(),
torch.tensor(batch_input_lengths, dtype=torch.int32).cuda(),
sampling_config)
outputs_with_fake_dim_list = [
output_ids_with_fake_dim[batch_idx, :,
batch_input_lengths[batch_idx]:
batch_input_lengths[batch_idx] +
max_output_len].cpu()
for batch_idx in range(output_ids_with_fake_dim.shape[0])
]
outputs_with_fake_dim = torch.cat(outputs_with_fake_dim_list)
output_with_fake_dim.append(outputs_with_fake_dim)
outputs_list = [
output_ids[batch_idx, :, batch_input_lengths[batch_idx]:
batch_input_lengths[batch_idx] +
max_output_len].cpu()
for batch_idx in range(output_ids.shape[0])
]
outputs = torch.cat(outputs_list)
output.append(outputs)
output = torch.stack(output, dim=0)
if runner.remove_input_padding:
output_with_fake_dim = torch.stack(output_with_fake_dim, dim=0)
error = np.mean(output.cpu().numpy().flatten() !=
output_with_fake_dim.cpu().numpy().flatten())
assert error < 2.0 / 8, f"diff at batch_size={batch_size}, output_with_fake_dim={output_with_fake_dim}, output={output}"
if j == 0:
expect_output = output
if expect_output is not None:
error = np.mean(output.cpu().numpy().flatten() !=
expect_output.cpu().numpy().flatten())
assert error < 0.3, f"diff at batch_size={batch_size}, expect_output={expect_output}, output={output}"
def check_output(engine: str, max_output_len: int = 8):
engine_dir = work_dir / 'rt_engine/gpt2' / engine / '1-gpu/'
print(f"== Checking output for engine: {engine_dir}")
input_tokens = [
[28524, 287, 5093, 12, 23316, 4881, 11, 30022, 263, 8776, 355, 257],
[28524, 287, 5093, 12, 23316, 4881, 11, 30022, 263, 8776, 355],
[28524, 287, 5093, 12, 23316, 4881, 11, 30022, 263, 8776],
[28524, 287, 5093, 12, 23316, 4881, 11, 30022, 263],
[28524, 287, 5093, 12, 23316, 4881, 11, 30022],
[28524, 287, 5093, 12, 23316, 4881, 11],
[28524, 287, 5093, 12, 23316, 4881],
[28524, 287, 5093, 12, 23316],
]
check_accuracy(engine_dir, input_tokens, max_output_len)
def check_outputs():
check_output(engine='fp32-default')
check_output(engine='fp32-plugin')
check_output(engine='fp16-default')
check_output(engine='fp16-plugin')
# Skip tests that are not supported in pre-ampere architecture
if getSMVersion() >= 80:
check_output(engine='fp16-plugin-fmha')
check_output(engine='fp16-plugin-packed')
# Skip tests that are not supported in pre-ampere architecture
if getSMVersion() >= 80:
check_output(engine='fp16-plugin-packed-fmha')
class TestGPTE2E(unittest.TestCase):
def test_check_gpt_e2e(self):
multiprocessing.set_start_method("spawn")
build_engines()
check_outputs()
if __name__ == '__main__':
unittest.main()