TensorRT-LLMs/examples/qwen/benchmark.py
Kaiyu Xie 587d063e6d
Update TensorRT-LLM (#506)
* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
2023-11-30 16:46:22 +08:00

379 lines
14 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Benchmark offline inference throughput."""
import argparse
import json
import os
import random
import time
from typing import List, Tuple
import torch
from run import QWenForCausalLMGenerationSession, get_model
from tqdm import tqdm, trange
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from utils.utils import get_stop_words_ids, make_context
now_dir = os.path.dirname(os.path.abspath(__file__))
MAX_INPUT_LEN = 2048
MAX_SEQ_LEN = 4096
TRT_MAX_BATCH_SIZE = 2
TEMPERATURE = 1.0
TOP_P = 0.5
TOP_K = 1
def sample_requests(
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
num_requests: int,
chat_format: str = "chatml",
) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
# Tokenize the prompts and completions.
tokenized_dataset = []
for i in trange(len(dataset), desc="Tokenizing for sample"):
prompt = dataset[i][0]
output_text = dataset[i][1]
raw_text, prompt_tokens = make_context(tokenizer=tokenizer,
query=prompt,
max_input_length=MAX_INPUT_LEN,
chat_format=chat_format)
new_token_len = len(tokenizer(output_text).input_ids)
tokenized_dataset.append((raw_text, prompt_tokens, new_token_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = []
for prompt, prompt_token_ids, new_token_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or new_token_len < 4:
# Prune too short sequences.
continue
if prompt_len > MAX_INPUT_LEN or (prompt_len +
new_token_len) > MAX_SEQ_LEN:
# Prune too long sequences.
continue
# limit by MAX_SEQ_LEN
filtered_dataset.append((prompt, prompt_len, new_token_len))
# Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests)
return sampled_requests
def run_trt_llm(
requests: List[Tuple[str, int, int]],
engine_dir: str,
tokenizer_dir: str,
n: int,
max_batch_size: int,
) -> float:
global_max_input_len = MAX_INPUT_LEN
global_max_output_len = MAX_SEQ_LEN
if max_batch_size > TRT_MAX_BATCH_SIZE:
raise Exception(
"max batch size {} must be lower than trt_max_batch_size {}".format(
max_batch_size, TRT_MAX_BATCH_SIZE))
(model_config, sampling_config, runtime_mapping, runtime_rank,
serialize_path, remove_input_padding, tokenizer, eos_token_id,
pad_token_id) = get_model(
tokenizer_dir=tokenizer_dir,
engine_dir=engine_dir,
)
with open(serialize_path, 'rb') as f:
engine_buffer = f.read()
decoder = QWenForCausalLMGenerationSession(model_config, engine_buffer,
runtime_mapping)
# Add the requests to the engine.
sampling_config.num_beams = n
sampling_config.temperature = 0.0 if n > 1 else TEMPERATURE
sampling_config.top_p = TOP_P
sampling_config.top_k = TOP_K
start = time.time()
pad_id = tokenizer.im_end_id
batch: List[str] = []
max_new_tokens = 0
total_num_tokens = []
for i, (prompt, prompt_len, new_token_len) in tqdm(enumerate(requests),
total=len(requests)):
# Add the prompt to the batch.
batch.append(prompt)
max_new_tokens = max(max_new_tokens, new_token_len)
if len(batch) < max_batch_size and i < len(requests) - 1:
continue
input_ids = []
input_lengths = []
for input_text in batch:
input_id = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=global_max_input_len,
).input_ids.type(torch.int32)
input_ids.append(input_id)
input_lengths.append(input_id.shape[-1])
# padding
max_length = max(input_lengths)
# do padding, should move outside the profiling to prevent the overhead
for i in range(len(input_ids)):
pad_size = max_length - input_lengths[i]
pad = torch.ones([1, pad_size]).type(torch.int32) * pad_id
input_ids[i] = torch.cat([torch.IntTensor(input_ids[i]), pad],
axis=-1)
# do inference
input_ids = torch.cat(input_ids, axis=0).cuda()
input_lengths = torch.IntTensor(input_lengths).type(torch.int32).cuda()
output_ids = decoder.generate(
input_ids=input_ids,
input_lengths=input_lengths,
sampling_config=sampling_config,
max_new_tokens=min(max_new_tokens,
global_max_output_len - input_ids.shape[1]),
)
pure_output_ids = []
for i in range(len(batch)):
temp_ids = output_ids[i, input_lengths[i]:]
pure_ids = []
for i in range(len(temp_ids)):
if temp_ids[i] in [tokenizer.im_start_id, tokenizer.im_end_id]:
pure_ids = temp_ids[:i + 1]
break
if len(pure_ids) == 0:
pure_ids = temp_ids
pure_output_ids.append(pure_ids)
# get the output text
output_texts = [
tokenizer.decode(out_ids, skip_special_tokens=True)
for out_ids in pure_output_ids
]
# get the total num of tokens
output_lengths = [len(out_ids) for out_ids in pure_output_ids]
assert len(output_lengths) == len(batch)
for input_len, new_token_len in zip(input_lengths, output_lengths):
total_num_tokens.append(input_len + new_token_len)
batch = []
max_new_tokens = 0
end = time.time()
during = end - start
sum_total_num_tokens = sum(total_num_tokens)
return during, sum_total_num_tokens
def run_hf(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: PreTrainedTokenizerBase,
n: int,
max_batch_size: int,
chat_format: str = "chatml",
) -> float:
global_max_input_len = MAX_INPUT_LEN
global_max_output_len = MAX_SEQ_LEN
llm = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.bfloat16,
trust_remote_code=True)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
elif llm.config.model_type == "qwen":
tokenizer.pad_token = tokenizer.decode(tokenizer.im_end_id)
llm = llm.cuda()
stop_words_ids = []
stop_words_ids.extend(get_stop_words_ids(chat_format, tokenizer))
stop_words_ids2 = [idx for ids in stop_words_ids for idx in ids]
pbar = tqdm(total=len(requests))
start = time.time()
total_num_tokens = []
batch: List[str] = []
input_lengths: List[int] = []
max_prompt_len = 0
max_new_tokens = 0
for i in range(len(requests)):
prompt, prompt_len, new_token_len = requests[i]
# Add the prompt to the batch.
batch.append(prompt)
input_lengths.append(prompt_len)
max_prompt_len = max(max_prompt_len, prompt_len)
max_new_tokens = max(max_new_tokens, new_token_len)
if len(batch) < max_batch_size and i != len(requests) - 1:
# Check if we can add more requests to the batch.
_, next_prompt_len, next_output_len = requests[i + 1]
temp_input_max = max(max_prompt_len, next_prompt_len)
temp_new_token_max = max(max_new_tokens, next_output_len)
if temp_input_max <= global_max_input_len and \
(temp_input_max + temp_new_token_max) <= global_max_output_len:
continue
# Generate the sequences.
input_ids = tokenizer(
batch,
return_tensors="pt",
padding=True,
truncation=True,
max_length=global_max_input_len,
).input_ids
# limit the max_new_tokens
max_new_tokens = min(max_new_tokens,
global_max_output_len - input_ids.shape[1])
llm_outputs = llm.generate(
input_ids=input_ids.cuda(),
do_sample=True,
stop_words_ids=stop_words_ids,
num_return_sequences=n,
top_k=TOP_K,
top_p=TOP_P,
temperature=TEMPERATURE,
use_cache=True,
max_new_tokens=max_new_tokens,
)
pure_output_ids = llm_outputs[:, input_ids.shape[-1]:]
# get the output text
output_texts = tokenizer.batch_decode(pure_output_ids,
skip_special_tokens=True)
output_lengths = []
for out_ids in pure_output_ids:
early_stop = False
for i in range(len(out_ids)):
if out_ids[i] in stop_words_ids2:
output_lengths.append(i + 1)
early_stop = True
break
if not early_stop:
output_lengths.append(len(out_ids))
assert len(output_lengths) == len(batch)
for input_len, new_token_len in zip(input_lengths, output_lengths):
total_num_tokens.append(input_len + new_token_len)
pbar.update(len(batch))
# Clear the batch.
batch = []
input_lengths = []
max_prompt_len = 0
max_new_tokens = 0
end = time.time()
during = end - start
sum_total_num_tokens = sum(total_num_tokens)
return during, sum_total_num_tokens
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer_dir,
padding_side='left',
trust_remote_code=True,
)
requests = sample_requests(tokenizer=tokenizer,
dataset_path=args.dataset,
num_requests=args.num_prompts,
chat_format=args.chat_format)
if args.backend == "trt_llm":
elapsed_time, total_num_tokens = run_trt_llm(
requests=requests,
engine_dir=args.engine_dir,
tokenizer_dir=args.tokenizer_dir,
n=args.n,
max_batch_size=args.trt_max_batch_size,
)
elif args.backend == "hf":
elapsed_time, total_num_tokens = run_hf(
requests=requests,
model=args.hf_model_dir,
tokenizer=tokenizer,
n=args.n,
max_batch_size=args.hf_max_batch_size,
)
else:
raise ValueError(f"Unknown backend: {args.backend}")
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument(
"--backend",
type=str,
choices=["trt_llm", "hf"],
default="trt_llm",
)
parser.add_argument("--dataset",
type=str,
default=os.path.join(
now_dir,
"ShareGPT_V3_unfiltered_cleaned_split.json"),
help="Path to the dataset.")
parser.add_argument("--hf_model_dir", type=str, default=None)
parser.add_argument("--tokenizer_dir",
type=str,
default=".",
help="Directory containing the tokenizer.model.")
parser.add_argument('--engine_dir', type=str, default='qwen_outputs')
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--num-prompts",
type=int,
default=100,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf_max_batch_size",
type=int,
default=1,
help="Maximum batch size for HF backend.")
parser.add_argument("--trt_max_batch_size",
type=int,
default=1,
help="Maximum batch size for TRT-LLM backend.")
parser.add_argument("--chat-format",
type=str,
default="chatml",
choices=["chatml", "raw"],
help="choice the model format, base or chat")
args = parser.parse_args()
if args.backend == "trt-llm":
if args.trt_max_batch_size is None:
raise ValueError(
"trt max batch size is required for TRT-LLM backend.")
elif args.backend == "hf":
if args.hf_max_batch_size is None:
raise ValueError("hf max batch size is required for HF backend.")
if args.tokenizer_dir is None:
args.tokenizer_dir = args.hf_model
main(args)