mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
841 lines
38 KiB
Python
841 lines
38 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import ast
|
|
import os
|
|
from pathlib import Path
|
|
|
|
import evaluate
|
|
import numpy as np
|
|
import torch
|
|
from datasets import load_dataset
|
|
from transformers import (AutoModel, AutoModelForCausalLM,
|
|
AutoModelForSeq2SeqLM, GenerationConfig)
|
|
from utils import (DEFAULT_HF_MODEL_DIRS, add_common_args, load_tokenizer,
|
|
read_model_name, supports_inflight_batching)
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm.profiler as profiler
|
|
from tensorrt_llm._utils import mpi_broadcast, str_dtype_to_torch
|
|
from tensorrt_llm.logger import logger
|
|
from tensorrt_llm.models.qwen.utils import make_context
|
|
from tensorrt_llm.runtime import PYTHON_BINDINGS, ModelRunner
|
|
from tensorrt_llm.tools.ppl import ppl
|
|
|
|
if PYTHON_BINDINGS:
|
|
from tensorrt_llm.runtime import ModelRunnerCpp
|
|
|
|
from prompt_lookup.run_dtm_pld import run_dtm_pld
|
|
|
|
|
|
def main(args):
|
|
runtime_rank = tensorrt_llm.mpi_rank()
|
|
logger.set_level(args.log_level)
|
|
|
|
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
|
|
test_trt_llm = args.test_trt_llm
|
|
model_name, model_version = read_model_name(args.engine_dir)
|
|
if args.hf_model_dir is None:
|
|
logger.warning(
|
|
"hf_model_dir is not specified. Try to infer from model_name, but this may be incorrect."
|
|
)
|
|
if model_name in DEFAULT_HF_MODEL_DIRS:
|
|
args.hf_model_dir = DEFAULT_HF_MODEL_DIRS[model_name]
|
|
else:
|
|
args.hf_model_dir = None
|
|
if args.tokenizer_dir is None:
|
|
args.tokenizer_dir = args.hf_model_dir
|
|
|
|
profiler.start('load tokenizer')
|
|
tokenizer, pad_id, end_id = load_tokenizer(
|
|
tokenizer_dir=args.tokenizer_dir,
|
|
vocab_file=args.vocab_file,
|
|
model_name=model_name,
|
|
model_version=model_version,
|
|
tokenizer_type=args.tokenizer_type,
|
|
)
|
|
profiler.stop('load tokenizer')
|
|
logger.info(
|
|
f'Load tokenizer takes: {profiler.elapsed_time_in_sec("load tokenizer")} sec'
|
|
)
|
|
|
|
if args.eval_task == 'code_completion':
|
|
dataset_name = "openai_humaneval"
|
|
dataset_revision = None
|
|
dataset_input_key = 'prompt'
|
|
dataset_output_key = 'canonical_solution'
|
|
dataset_split = 'test'
|
|
elif args.eval_task == 'summarize':
|
|
dataset_name = "ccdv/cnn_dailymail"
|
|
dataset_revision = "3.0.0"
|
|
dataset_input_key = 'article'
|
|
dataset_output_key = 'highlights'
|
|
dataset_split = 'test'
|
|
elif args.eval_task == 'summarize_long':
|
|
dataset_name = "tau/zero_scrolls"
|
|
dataset_revision = 'squality'
|
|
dataset_input_key = 'input'
|
|
dataset_output_key = 'output'
|
|
dataset_split = 'validation' # only this split contains reference strings
|
|
elif args.eval_task == "eval_context_ppl":
|
|
dataset_name = "SlimPajama-6B"
|
|
dataset_revision = None
|
|
dataset_input_key = 'text'
|
|
dataset_output_key = 'text'
|
|
dataset_split = 'test'
|
|
args.output_len = 1 # Only want to compute the ppl of context
|
|
args.eval_ppl = True
|
|
logger.warning(
|
|
f"Run task '{args.eval_task}', setting 'output_len' to 1, and enable 'eval_ppl'."
|
|
)
|
|
if args.dataset_dir is not None and isinstance(args.dataset_dir, str):
|
|
args.dataset_dir = args.dataset_dir.rstrip('/')
|
|
if args.dataset_dir.endswith(dataset_name):
|
|
dataset_name = args.dataset_dir
|
|
else:
|
|
dataset_name = f"{args.dataset_dir}/{dataset_name}"
|
|
dataset = load_dataset(dataset_name,
|
|
dataset_revision,
|
|
cache_dir=args.dataset_cache_dir,
|
|
split=dataset_split)
|
|
|
|
max_batch_size = args.batch_size
|
|
|
|
# runtime parameters
|
|
top_k = args.top_k
|
|
top_p = args.top_p
|
|
output_len = args.output_len
|
|
test_token_num = args.max_input_length
|
|
max_attention_window_size = args.max_attention_window_size
|
|
sink_token_length = args.sink_token_length
|
|
|
|
if args.end_id:
|
|
end_id = args.end_id
|
|
|
|
stop_words_list = None
|
|
if args.stop_words:
|
|
stop_words_list = tensorrt_llm.runtime.decode_words_list(
|
|
args.stop_words, tokenizer)
|
|
if model_version == 'glm4': # add default stop token ids for GLM-4
|
|
glm4_stop_ids = [[151329], [151336], [151338]]
|
|
if stop_words_list is None:
|
|
stop_words_list = [glm4_stop_ids] * args.batch_size
|
|
else:
|
|
for req_stop_words_list in stop_words_list:
|
|
req_stop_words_list.extend(glm4_stop_ids)
|
|
|
|
bad_words_list = None
|
|
if args.bad_words:
|
|
bad_words_list = tensorrt_llm.runtime.decode_words_list(
|
|
args.bad_words, tokenizer)
|
|
|
|
num_beams = args.num_beams
|
|
num_return_sequences = args.num_return_sequences
|
|
num_sequences = args.num_return_sequences or num_beams
|
|
assert num_beams == 1 or num_sequences <= num_beams
|
|
|
|
temperature = args.temperature
|
|
length_penalty = args.length_penalty
|
|
early_stopping = args.early_stopping
|
|
repetition_penalty = args.repetition_penalty
|
|
presence_penalty = args.presence_penalty
|
|
frequency_penalty = args.frequency_penalty
|
|
random_seed = args.random_seed
|
|
torch.manual_seed(random_seed)
|
|
|
|
output_dir = Path(args.output_dir) if args.output_dir else None
|
|
if output_dir is not None:
|
|
output_dir.mkdir(exist_ok=True, parents=True)
|
|
if test_trt_llm:
|
|
with (output_dir / 'trtllm.out').open('w') as f:
|
|
f.write(f'Engine path: {args.engine_dir}\n')
|
|
f.write(f'Tokenizer path: {args.tokenizer_dir}\n')
|
|
if test_hf:
|
|
with (output_dir / 'hf.out').open('w') as f:
|
|
f.write(f'Model path: {args.hf_model_dir}\n')
|
|
f.write(f'Tokenizer path: {args.tokenizer_dir}\n')
|
|
|
|
# TODO: Add random_seed flag in gptj
|
|
rouge_dir = args.rouge_dir if args.rouge_dir and os.path.exists(
|
|
args.rouge_dir) else "rouge"
|
|
metric_tensorrt_llm = [
|
|
evaluate.load(rouge_dir) for _ in range(num_sequences)
|
|
]
|
|
metric_hf = [evaluate.load(rouge_dir) for _ in range(num_sequences)]
|
|
for i in range(num_sequences):
|
|
metric_tensorrt_llm[i].seed = 0
|
|
metric_hf[i].seed = 0
|
|
ppls_trt_llm = [[] for _ in range(num_sequences)]
|
|
ppls_hf = [[] for _ in range(num_sequences)]
|
|
|
|
def _prepare_inputs(batch_input_texts,
|
|
eval_task='summarize',
|
|
add_special_tokens=True,
|
|
min_input_length=0):
|
|
batch_size = len(batch_input_texts)
|
|
append_str = ' TL;DR: ' if eval_task == 'summarize' else ''
|
|
batch_input_ids = []
|
|
for i in range(batch_size):
|
|
curr_text = batch_input_texts[i] + append_str
|
|
curr_text = curr_text.strip().replace(" n't", "n't")
|
|
|
|
# TODO: The below lines are used to be compatible with the original code; may need fix
|
|
if 'GLM' in model_name and model_version in ('chatglm2',
|
|
'chatglm3'):
|
|
input_ids = tokenizer.encode(curr_text,
|
|
return_tensors='pt').squeeze(0)
|
|
input_ids = input_ids[:test_token_num]
|
|
elif 'qwen' in model_name.lower() and model_version == 'qwen':
|
|
# use make_content to generate prompt
|
|
system_prompt = "You are a useful assistant, please directly output the corresponding summary according to the article entered by the user."
|
|
_, input_id_list = make_context(
|
|
tokenizer=tokenizer,
|
|
query=curr_text,
|
|
history=[],
|
|
system=system_prompt,
|
|
max_input_length=test_token_num,
|
|
)
|
|
input_ids = torch.tensor(input_id_list)
|
|
else:
|
|
if 'qwen' in model_name.lower() and 'qwen2' in model_version:
|
|
messages = [{
|
|
"role":
|
|
"system",
|
|
"content":
|
|
"You are a helpful assistant, please summarize the article entered by the user with one or two sentences."
|
|
}, {
|
|
"role": "user",
|
|
"content": curr_text
|
|
}]
|
|
curr_text = tokenizer.apply_chat_template(
|
|
messages, tokenize=False, add_generation_prompt=True)
|
|
input_ids = tokenizer.encode(
|
|
curr_text,
|
|
return_tensors='pt',
|
|
add_special_tokens=add_special_tokens,
|
|
truncation=True,
|
|
max_length=test_token_num).squeeze(0)
|
|
|
|
if input_ids.numel() > min_input_length:
|
|
batch_input_ids.append(input_ids)
|
|
return batch_input_ids
|
|
|
|
def eval_trt_llm(datapoint,
|
|
eval_task='summarize',
|
|
eval_ppl=False,
|
|
add_special_tokens=True,
|
|
min_input_length=0,
|
|
runner=None):
|
|
batch_size = len(datapoint[dataset_input_key])
|
|
batch_input_ids = _prepare_inputs(datapoint[dataset_input_key],
|
|
eval_task=eval_task,
|
|
add_special_tokens=add_special_tokens,
|
|
min_input_length=min_input_length)
|
|
batch_size = len(batch_input_ids)
|
|
if batch_size == 0:
|
|
return [], [], [], {}
|
|
input_lengths = [x.size(0) for x in batch_input_ids]
|
|
|
|
if args.prompt_lookup_config is not None:
|
|
# Speculative decoding of Prompt-Lookup-Decoding (PLD)
|
|
outputs = run_dtm_pld(batch_input_ids,
|
|
args,
|
|
runtime_rank,
|
|
end_id,
|
|
pad_id,
|
|
stop_words_list,
|
|
bad_words_list,
|
|
tokenizer.vocab_size,
|
|
target_runner=runner)
|
|
if not args.streaming: # Unpack runner from the return value in No-Streaming mode
|
|
outputs, runner = list(outputs)[0]
|
|
else: # Normal run
|
|
with torch.no_grad():
|
|
outputs = runner.generate(
|
|
batch_input_ids,
|
|
max_new_tokens=output_len,
|
|
max_attention_window_size=max_attention_window_size,
|
|
sink_token_length=sink_token_length,
|
|
end_id=end_id,
|
|
pad_id=pad_id,
|
|
temperature=temperature,
|
|
top_k=top_k,
|
|
top_p=top_p,
|
|
stop_words_list=stop_words_list,
|
|
bad_words_list=bad_words_list,
|
|
num_beams=num_beams,
|
|
num_return_sequences=num_return_sequences,
|
|
length_penalty=length_penalty,
|
|
early_stopping=early_stopping,
|
|
repetition_penalty=repetition_penalty,
|
|
presence_penalty=presence_penalty,
|
|
frequency_penalty=frequency_penalty,
|
|
lora_uids=args.lora_task_uids,
|
|
lookahead_config=args.lookahead_config,
|
|
output_sequence_lengths=True,
|
|
return_dict=True,
|
|
random_seed=random_seed,
|
|
medusa_choices=args.medusa_choices,
|
|
eagle_choices=args.eagle_choices)
|
|
torch.cuda.synchronize()
|
|
|
|
# Extract a list of tensors of shape beam_width x output_ids.
|
|
if runtime_rank == 0:
|
|
output_ids = outputs['output_ids']
|
|
output_beams_list = [
|
|
tokenizer.batch_decode(beam_tokens[:, input_lengths[i]:],
|
|
skip_special_tokens=True)
|
|
for i, beam_tokens in enumerate(output_ids)
|
|
]
|
|
output_ids_list = [
|
|
beam_tokens[:, input_lengths[i]:]
|
|
for i, beam_tokens in enumerate(output_ids)
|
|
]
|
|
|
|
ppls = [[] for _ in range(batch_size)]
|
|
lengths_info = {
|
|
'input_lengths': input_lengths,
|
|
'seq_lengths': outputs["sequence_lengths"].cpu().tolist(),
|
|
}
|
|
if eval_ppl:
|
|
seq_lengths = outputs['sequence_lengths']
|
|
context_logits = outputs['context_logits']
|
|
# Remove the first generation logits which are same to last
|
|
# context logits.
|
|
generation_logits = outputs['generation_logits'][:, :, 1:]
|
|
for batch_idx in range(batch_size):
|
|
# [batch, beam, step]
|
|
for beam_idx in range(num_sequences):
|
|
curr_len = seq_lengths[batch_idx, beam_idx]
|
|
curr_ctx_len = input_lengths[batch_idx]
|
|
curr_gen_len = curr_len - curr_ctx_len
|
|
|
|
curr_ids = output_ids[batch_idx, beam_idx, 1:curr_len]
|
|
curr_logits = torch.cat([
|
|
context_logits[batch_idx],
|
|
generation_logits[batch_idx,
|
|
beam_idx, :curr_gen_len - 1]
|
|
],
|
|
dim=0)
|
|
curr_ppl = ppl(curr_logits, curr_ids)
|
|
logger.debug(f"TensorRT-LLM PPL: {curr_ppl:.3f} | "
|
|
f"Generation length: {curr_gen_len}")
|
|
ppls[batch_idx].append(curr_ppl)
|
|
return output_beams_list, output_ids_list, ppls, lengths_info
|
|
return [], [], [], {}
|
|
|
|
def eval_hf(datapoint,
|
|
eval_task='summarize',
|
|
eval_ppl=False,
|
|
add_special_tokens=True,
|
|
min_input_length=0):
|
|
batch_size = len(datapoint[dataset_input_key])
|
|
if batch_size > 1:
|
|
logger.warning(
|
|
f"HF does not support batch_size > 1 to verify correctness due to padding. Current batch size is {batch_size}"
|
|
)
|
|
batch_input_ids = _prepare_inputs(datapoint[dataset_input_key],
|
|
eval_task=eval_task,
|
|
add_special_tokens=add_special_tokens,
|
|
min_input_length=min_input_length)
|
|
batch_size = len(batch_input_ids)
|
|
if batch_size == 0:
|
|
return [], [], [], [[] for _ in range(batch_size)]
|
|
input_lengths = [x.size(0) for x in batch_input_ids]
|
|
# Left padding for HF
|
|
max_length = max(input_lengths)
|
|
paddings = [
|
|
torch.ones(max_length - l, dtype=torch.int32) * pad_id
|
|
for l in input_lengths
|
|
]
|
|
batch_input_ids = [
|
|
torch.cat([pad, x]) for x, pad in zip(batch_input_ids, paddings)
|
|
]
|
|
batch_input_ids = torch.stack(batch_input_ids)
|
|
batch_input_ids = batch_input_ids.cuda()
|
|
|
|
# specialization for HF
|
|
if early_stopping in [0, 1]:
|
|
local_early_stopping = bool(early_stopping)
|
|
else:
|
|
local_early_stopping = "never"
|
|
|
|
with torch.no_grad():
|
|
hf_config = {}
|
|
if num_beams == 1:
|
|
hf_config.update({
|
|
"top_k": top_k,
|
|
"top_p": top_p,
|
|
"do_sample": True,
|
|
})
|
|
else:
|
|
hf_config.update({
|
|
"num_beams": num_beams,
|
|
"early_stopping": local_early_stopping,
|
|
})
|
|
|
|
outputs = model.generate(batch_input_ids,
|
|
max_new_tokens=output_len,
|
|
num_return_sequences=num_sequences,
|
|
temperature=temperature,
|
|
eos_token_id=end_id,
|
|
pad_token_id=pad_id,
|
|
length_penalty=length_penalty,
|
|
output_scores=True,
|
|
return_dict_in_generate=True,
|
|
**hf_config)
|
|
if eval_ppl and batch_size == 1:
|
|
# model.generate cannot return context logits?
|
|
# Will cause additional latency
|
|
context_outputs = model(batch_input_ids)
|
|
|
|
output_ids = outputs['sequences']
|
|
tokens_list = output_ids[:, max_length:].tolist()
|
|
output_ids = output_ids.reshape([batch_size, num_sequences, -1])
|
|
output_lines_list = [
|
|
tokenizer.batch_decode(output_ids[:, i, max_length:],
|
|
skip_special_tokens=True)
|
|
for i in range(num_sequences)
|
|
]
|
|
|
|
ppls = [[] for _ in range(batch_size)]
|
|
if eval_ppl and batch_size == 1:
|
|
# Only for batch size of 1
|
|
seq_lens = (output_ids
|
|
!= end_id).logical_and(output_ids != pad_id).sum(dim=-1)
|
|
context_logits = context_outputs['logits']
|
|
# Remove the first generation logits which are same to last context logits
|
|
generation_logits = outputs['scores'][1:]
|
|
# When output_len is 1, generation_logits would be () and lead to error if we do torch.stack
|
|
if len(generation_logits) == 0:
|
|
generation_logits = torch.empty(
|
|
[context_logits.shape[0], 0, context_logits.shape[-1]],
|
|
device=context_logits.device)
|
|
else:
|
|
generation_logits = torch.stack(generation_logits, dim=1)
|
|
_, max_gen_len, voc_size = generation_logits.size()
|
|
generation_logits = generation_logits.view(batch_size, num_beams,
|
|
max_gen_len, voc_size)
|
|
for batch_idx in range(batch_size):
|
|
for beam_idx in range(num_sequences):
|
|
curr_len = seq_lens[batch_idx, beam_idx]
|
|
curr_ctx_len = input_lengths[batch_idx]
|
|
curr_gen_len = curr_len - curr_ctx_len
|
|
|
|
curr_ids = output_ids[batch_idx, beam_idx, 1:curr_len]
|
|
curr_logits = torch.cat([
|
|
context_logits[batch_idx],
|
|
generation_logits[batch_idx,
|
|
beam_idx, :curr_gen_len - 1]
|
|
],
|
|
dim=0)
|
|
curr_ppl = ppl(curr_logits, curr_ids)
|
|
logger.debug(
|
|
f"HF PPL: {curr_ppl:.3f} | Generation length: {curr_gen_len}"
|
|
)
|
|
ppls[batch_idx].append(curr_ppl)
|
|
|
|
return output_lines_list, tokens_list, ppls
|
|
|
|
if test_trt_llm:
|
|
if not supports_inflight_batching(args.engine_dir):
|
|
logger.warning(
|
|
"The given engine does not support in-flight batching, fallback to python session"
|
|
)
|
|
args.use_py_session = True
|
|
|
|
if not PYTHON_BINDINGS and not args.use_py_session:
|
|
logger.warning(
|
|
"Python bindings of C++ session is unavailable, fallback to Python session."
|
|
)
|
|
args.use_py_session = True
|
|
if args.return_all_generated_tokens:
|
|
raise ValueError(
|
|
"Returning all the generated tokens at each step is not supported in summarize.py"
|
|
)
|
|
|
|
logger.info(
|
|
f"Using {'Python' if args.use_py_session else 'C++'} session")
|
|
|
|
runner_cls = ModelRunner if args.use_py_session else ModelRunnerCpp
|
|
runner_kwargs = dict(engine_dir=args.engine_dir,
|
|
rank=runtime_rank,
|
|
debug_mode=args.debug_mode,
|
|
gpu_weights_percent=args.gpu_weights_percent)
|
|
if args.medusa_choices is not None:
|
|
args.medusa_choices = ast.literal_eval(args.medusa_choices)
|
|
assert args.temperature == 1.0, "Medusa should use temperature == 1.0"
|
|
assert args.num_beams == 1, "Medusa should use num_beams == 1"
|
|
runner_kwargs.update(medusa_choices=args.medusa_choices)
|
|
if args.eagle_choices is not None or args.eagle_posterior_threshold is not None:
|
|
args.eagle_choices = ast.literal_eval(args.eagle_choices)
|
|
assert args.num_beams == 1, "Eagle should use num_beams == 1"
|
|
runner_kwargs.update(eagle_choices=args.eagle_choices)
|
|
runner_kwargs.update(
|
|
eagle_posterior_threshold=args.eagle_posterior_threshold)
|
|
if args.lookahead_config is not None:
|
|
args.lookahead_config = ast.literal_eval(args.lookahead_config)
|
|
assert len(
|
|
args.lookahead_config
|
|
) == 3, "Lookahead needs [max_window_size, max_ngram_size, max_verification_set_size]"
|
|
runner_kwargs.update(lookahead_config=args.lookahead_config)
|
|
if args.prompt_lookup_config is not None:
|
|
assert args.kv_cache_enable_block_reuse, "`--kv_cache_enable_block_reuse` must be specified in speculative decoding."
|
|
assert not args.use_py_session, "`--use_py_session` is not supported in Speculative decoding."
|
|
assert args.num_beams == 1, "`--num_beams>1` is not supported in Speculative decoding."
|
|
prompt_lookup_num_tokens, _, target_device_list = ast.literal_eval(
|
|
args.prompt_lookup_config)
|
|
args.max_output_len = output_len # Specialization for PLD
|
|
runner_kwargs.update(is_orchestrator_mode=True,
|
|
device_ids=target_device_list)
|
|
|
|
if not args.use_py_session:
|
|
runner_kwargs.update(
|
|
lora_dir=args.lora_dir,
|
|
lora_ckpt_source=args.lora_ckpt_source,
|
|
max_batch_size=max_batch_size,
|
|
max_input_len=test_token_num,
|
|
max_output_len=output_len,
|
|
max_beam_width=num_beams,
|
|
max_attention_window_size=max_attention_window_size,
|
|
sink_token_length=sink_token_length,
|
|
max_tokens_in_paged_kv_cache=args.max_tokens_in_paged_kv_cache,
|
|
kv_cache_enable_block_reuse=args.kv_cache_enable_block_reuse,
|
|
kv_cache_free_gpu_memory_fraction=args.
|
|
kv_cache_free_gpu_memory_fraction,
|
|
enable_chunked_context=args.enable_chunked_context,
|
|
multi_block_mode=args.multi_block_mode,
|
|
cuda_graph_mode=args.cuda_graph_mode)
|
|
if args.backend is not None:
|
|
runner_kwargs.update(
|
|
backend=args.backend,
|
|
py_executor_config={'hf_model_dir': args.hf_model_dir})
|
|
runner_kwargs.update(
|
|
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc)
|
|
if args.prompt_lookup_config is not None:
|
|
# Specialization for PLD since many call of `generate()` is needed
|
|
runner_kwargs.update(max_input_len=test_token_num +
|
|
prompt_lookup_num_tokens + output_len)
|
|
runner = runner_cls.from_dir(**runner_kwargs)
|
|
assert not (args.eval_ppl and not (runner.gather_context_logits and runner.gather_generation_logits)), \
|
|
"PPL evaluation requires engine built with gather_all_token_logits enabled"
|
|
|
|
datapoint = dataset[0:1]
|
|
output, *_ = eval_trt_llm(datapoint,
|
|
eval_task=args.eval_task,
|
|
eval_ppl=args.eval_ppl,
|
|
add_special_tokens=args.add_special_tokens,
|
|
min_input_length=args.min_input_length,
|
|
runner=runner)
|
|
if runtime_rank == 0 and args.eval_task != "eval_context_ppl":
|
|
logger.info(
|
|
"---------------------------------------------------------")
|
|
logger.info("TensorRT-LLM Generated : ")
|
|
logger.info(f" Input : {datapoint[dataset_input_key]}")
|
|
logger.info(f"\n Reference : {datapoint[dataset_output_key]}")
|
|
logger.info(f"\n Output : {output}")
|
|
logger.info(
|
|
"---------------------------------------------------------")
|
|
|
|
ite_count = 0
|
|
data_point_idx = 0
|
|
total_output_token_count_trt_llm = 0 # only valid for runtime_rank == 0
|
|
while (data_point_idx < len(dataset)) and (ite_count < args.max_ite):
|
|
if runtime_rank == 0:
|
|
logger.debug(
|
|
f"run data_point {data_point_idx} ~ {data_point_idx + max_batch_size}"
|
|
)
|
|
datapoint = dataset[data_point_idx:(data_point_idx +
|
|
max_batch_size)]
|
|
|
|
profiler.start('tensorrt_llm')
|
|
output_tensorrt_llm, output_ids_trt_llm, curr_ppls_trt_llm, lengths_info = eval_trt_llm(
|
|
datapoint,
|
|
eval_task=args.eval_task,
|
|
eval_ppl=args.eval_ppl,
|
|
add_special_tokens=args.add_special_tokens,
|
|
min_input_length=args.min_input_length,
|
|
runner=runner)
|
|
profiler.stop('tensorrt_llm')
|
|
|
|
empty_batch = runtime_rank == 0 and len(output_tensorrt_llm) == 0
|
|
empty_batch = mpi_broadcast(empty_batch, 0)
|
|
if empty_batch:
|
|
# No valid samples in the current batch, skip this iteration
|
|
data_point_idx += max_batch_size
|
|
continue
|
|
|
|
if runtime_rank == 0:
|
|
input_lengths = lengths_info['input_lengths']
|
|
seq_lengths = lengths_info['seq_lengths']
|
|
output_token_count_trt_llm = sum(
|
|
beam_len - input_lengths[batch_idx]
|
|
for batch_idx, beam_lens in enumerate(seq_lengths)
|
|
for beam_len in beam_lens)
|
|
total_output_token_count_trt_llm += output_token_count_trt_llm
|
|
for batch_idx, output_beams in enumerate(output_tensorrt_llm):
|
|
reference = datapoint[dataset_output_key][batch_idx]
|
|
for beam_idx, output_beam in enumerate(output_beams):
|
|
metric_tensorrt_llm[beam_idx].add_batch(
|
|
predictions=[output_beam], references=[reference])
|
|
if args.eval_ppl:
|
|
ppls_trt_llm[beam_idx].append(
|
|
curr_ppls_trt_llm[batch_idx][beam_idx])
|
|
if output_dir is not None:
|
|
for i in range(len(output_tensorrt_llm[0])):
|
|
for beam_idx in range(num_sequences):
|
|
with (output_dir / 'trtllm.out').open('a') as f:
|
|
f.write(
|
|
f'[{data_point_idx + i}] [Beam {beam_idx}] {output_tensorrt_llm[beam_idx][i]}\n'
|
|
)
|
|
|
|
logger.debug('-' * 100)
|
|
logger.debug(f"Input : {datapoint[dataset_input_key]}")
|
|
logger.debug(f'TensorRT-LLM Output: {output_tensorrt_llm}')
|
|
logger.debug(f"Reference : {datapoint[dataset_output_key]}")
|
|
|
|
data_point_idx += max_batch_size
|
|
ite_count += 1
|
|
del runner
|
|
|
|
if test_hf and runtime_rank == 0:
|
|
profiler.start('load HF model')
|
|
dtype_alias_mapping = {
|
|
'fp32': 'float32',
|
|
'fp16': 'float16',
|
|
'bf16': 'bfloat16'
|
|
}
|
|
args.hf_data_type = dtype_alias_mapping.get(args.hf_data_type,
|
|
args.hf_data_type)
|
|
if 'GLM' in model_name and model_version == 'glm':
|
|
auto_model_cls = AutoModelForSeq2SeqLM
|
|
elif 'GLM' in model_name and model_version == 'chatglm':
|
|
auto_model_cls = AutoModel
|
|
else:
|
|
auto_model_cls = AutoModelForCausalLM
|
|
# TODO: args.hf_device_map_auto is not being correctly set
|
|
# remove in future version
|
|
if model_name == 'DeepseekV2ForCausalLM':
|
|
args.hf_device_map_auto = True
|
|
model = auto_model_cls.from_pretrained(
|
|
args.hf_model_dir,
|
|
trust_remote_code=True,
|
|
torch_dtype=str_dtype_to_torch(args.hf_data_type),
|
|
device_map='auto' if args.hf_device_map_auto else None)
|
|
try:
|
|
model.to_bettertransformer()
|
|
except Exception as e:
|
|
logger.warning(
|
|
f'Fail to call model.to_bettertransformer(), exception:\n{str(e)}'
|
|
)
|
|
if not args.hf_device_map_auto:
|
|
model.cuda()
|
|
if model_name == 'qwen':
|
|
model.generation_config = GenerationConfig.from_pretrained(
|
|
args.hf_model_dir, trust_remote_code=True)
|
|
profiler.stop('load HF model')
|
|
logger.info(
|
|
f'Load HF model takes: {profiler.elapsed_time_in_sec("load HF model")} sec'
|
|
)
|
|
|
|
datapoint = dataset[0:1]
|
|
output, *_ = eval_hf(datapoint,
|
|
eval_task=args.eval_task,
|
|
eval_ppl=args.eval_ppl,
|
|
add_special_tokens=args.add_special_tokens,
|
|
min_input_length=args.min_input_length)
|
|
if runtime_rank == 0 and args.eval_task != "eval_context_ppl":
|
|
logger.info(
|
|
"---------------------------------------------------------")
|
|
logger.info("HF Generated : ")
|
|
logger.info(f" Input : {datapoint[dataset_input_key]}")
|
|
logger.info(f"\n Reference : {datapoint[dataset_output_key]}")
|
|
logger.info(f"\n Output : {output}")
|
|
logger.info(
|
|
"---------------------------------------------------------")
|
|
|
|
ite_count = 0
|
|
data_point_idx = 0
|
|
total_output_token_count_hf = 0 # only valid for runtime_rank == 0
|
|
while (data_point_idx < len(dataset)) and (ite_count < args.max_ite):
|
|
if runtime_rank == 0:
|
|
logger.debug(
|
|
f"run data_point {data_point_idx} ~ {data_point_idx + max_batch_size}"
|
|
)
|
|
datapoint = dataset[data_point_idx:(data_point_idx +
|
|
max_batch_size)]
|
|
|
|
profiler.start('hf')
|
|
output_hf, token_list, curr_ppls_hf = eval_hf(
|
|
datapoint,
|
|
eval_task=args.eval_task,
|
|
eval_ppl=args.eval_ppl,
|
|
add_special_tokens=args.add_special_tokens,
|
|
min_input_length=args.min_input_length)
|
|
profiler.stop('hf')
|
|
|
|
# HF model runs on rank 0 only
|
|
empty_batch = len(output_hf) == 0
|
|
if empty_batch:
|
|
# No valid samples in the current batch, skip this iteration
|
|
data_point_idx += max_batch_size
|
|
continue
|
|
|
|
if runtime_rank == 0:
|
|
seq_lengths = [len(tokens) for tokens in token_list]
|
|
total_output_token_count_hf += sum(seq_lengths)
|
|
for beam_idx in range(num_sequences):
|
|
for batch_idx in range(len(output_hf[beam_idx])):
|
|
metric_hf[beam_idx].add_batch(
|
|
predictions=[output_hf[beam_idx][batch_idx]],
|
|
references=[
|
|
datapoint[dataset_output_key][batch_idx]
|
|
])
|
|
if args.eval_ppl and args.batch_size == 1:
|
|
ppls_hf[beam_idx].append(
|
|
curr_ppls_hf[batch_idx][beam_idx])
|
|
if output_dir is not None:
|
|
for i in range(len(output_hf[0])):
|
|
for beam_idx in range(num_sequences):
|
|
with (output_dir / 'hf.out').open('a') as f:
|
|
f.write(
|
|
f'[{data_point_idx + i}] [Beam {beam_idx}] {output_hf[beam_idx][i]}\n'
|
|
)
|
|
|
|
logger.debug('-' * 100)
|
|
logger.debug(f"Input : {datapoint[dataset_input_key]}")
|
|
logger.debug(f'HF Output: {output_hf}')
|
|
logger.debug(f"Reference : {datapoint[dataset_output_key]}")
|
|
|
|
data_point_idx += max_batch_size
|
|
ite_count += 1
|
|
del model
|
|
|
|
if runtime_rank == 0 and args.max_ite > 0:
|
|
if test_trt_llm:
|
|
np.random.seed(0) # rouge score use sampling to compute the score
|
|
logger.info(
|
|
f'TensorRT-LLM (total latency: {profiler.elapsed_time_in_sec("tensorrt_llm")} sec)'
|
|
)
|
|
|
|
logger.info(
|
|
f'TensorRT-LLM (total output tokens: {total_output_token_count_trt_llm})'
|
|
)
|
|
logger.info(
|
|
f'TensorRT-LLM (tokens per second: {total_output_token_count_trt_llm / profiler.elapsed_time_in_sec("tensorrt_llm")})'
|
|
)
|
|
for beam_idx in range(num_sequences):
|
|
logger.info(f"TensorRT-LLM beam {beam_idx} result")
|
|
if args.eval_task != "eval_context_ppl":
|
|
computed_metrics_tensorrt_llm = metric_tensorrt_llm[
|
|
beam_idx].compute()
|
|
for key in computed_metrics_tensorrt_llm.keys():
|
|
logger.info(
|
|
f' {key} : {computed_metrics_tensorrt_llm[key]*100}'
|
|
)
|
|
if args.check_accuracy and beam_idx == 0:
|
|
rouge1 = computed_metrics_tensorrt_llm['rouge1'] * 100
|
|
assert rouge1 > args.tensorrt_llm_rouge1_threshold, f"[FAILED] rouge1 ({rouge1}) is smaller than threshold ({args.tensorrt_llm_rouge1_threshold})."
|
|
if args.eval_ppl:
|
|
logger.info(
|
|
f" Per-token perplexity: {np.mean(ppls_trt_llm[beam_idx])}"
|
|
)
|
|
if args.check_accuracy and beam_idx == 0:
|
|
avg_ppl = np.mean(ppls_trt_llm[beam_idx])
|
|
assert avg_ppl < args.tensorrt_llm_ppl_threshold, f"[FAILED] average PPL ({avg_ppl}) is larger than threshold ({args.tensorrt_llm_ppl_threshold})."
|
|
if test_hf:
|
|
np.random.seed(0) # rouge score use sampling to compute the score
|
|
logger.info(
|
|
f'Hugging Face (total latency: {profiler.elapsed_time_in_sec("hf")} sec)'
|
|
)
|
|
logger.info(
|
|
f'Hugging Face (total output tokens: {total_output_token_count_hf})'
|
|
)
|
|
logger.info(
|
|
f'Hugging Face (tokens per second: {total_output_token_count_hf / profiler.elapsed_time_in_sec("hf")})'
|
|
)
|
|
|
|
for beam_idx in range(num_sequences):
|
|
logger.info(f"HF beam {beam_idx} result")
|
|
computed_metrics_hf = metric_hf[beam_idx].compute()
|
|
if args.eval_task != "eval_context_ppl":
|
|
for key in computed_metrics_hf.keys():
|
|
logger.info(f' {key} : {computed_metrics_hf[key]*100}')
|
|
if args.eval_ppl and args.batch_size == 1:
|
|
logger.info(
|
|
f" Per-token perplexity: {np.mean(ppls_hf[beam_idx])}")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# see `add_common_args` for extended list of arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--test_hf', action='store_true')
|
|
parser.add_argument('--test_trt_llm', action='store_true')
|
|
parser.add_argument('--eval_task',
|
|
type=str,
|
|
default='summarize',
|
|
choices=[
|
|
'summarize', 'summarize_long', 'code_completion',
|
|
'eval_context_ppl'
|
|
])
|
|
parser.add_argument('--check_accuracy', action='store_true')
|
|
parser.add_argument('--tensorrt_llm_rouge1_threshold',
|
|
type=float,
|
|
default=15.0)
|
|
parser.add_argument('--eval_ppl', action='store_true')
|
|
parser.add_argument('--tensorrt_llm_ppl_threshold',
|
|
type=float,
|
|
default=15.0)
|
|
parser.add_argument(
|
|
'--dataset_dir',
|
|
type=str,
|
|
default=None,
|
|
help="The local directory of the dataset for evaluation; "
|
|
"will download the dataset from huggingface hub if not specified.")
|
|
parser.add_argument(
|
|
'--dataset_cache_dir',
|
|
type=str,
|
|
default=None,
|
|
help="The local cache directory for dataset; "
|
|
"will use `~/.cache/huggingface/datasets` if not specified.")
|
|
parser.add_argument('--batch_size', type=int, default=1)
|
|
parser.add_argument('--max_ite', type=int, default=20)
|
|
parser.add_argument('--output_len', type=int, default=100)
|
|
parser.add_argument('--max_input_length', type=int, default=923)
|
|
parser.add_argument(
|
|
'--min_input_length',
|
|
type=int,
|
|
default=0,
|
|
help='skip the sentences which are shorter than min_input_length.')
|
|
parser.add_argument(
|
|
'--output_dir',
|
|
type=str,
|
|
default=None,
|
|
help="Directory where to save output sentences. 'trtllm.out' for "
|
|
"TensorRT-LLM outputs, and 'hf.out' for HF outputs. If None, do not "
|
|
"save outputs.")
|
|
parser.add_argument(
|
|
'--rouge_dir',
|
|
default=None,
|
|
type=str,
|
|
help=
|
|
"evaluate.load('rouge') will attempt to pull rouge package from HF. Use cached rouge can avoid network outage of host or HF."
|
|
)
|
|
parser = add_common_args(parser)
|
|
args = parser.parse_args()
|
|
|
|
main(args)
|