TensorRT-LLMs/examples/llama/summarize_long.py
rakib-hasan ff3b741045
feat: adding multimodal (only image for now) support in trtllm-bench (#3490)
* feat: adding multimodal (only image for now) support in trtllm-bench

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* fix: add  in load_dataset() calls to maintain the v2.19.2 behavior

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* re-adding prompt_token_ids and using that for prompt_len

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* updating the datasets version in examples as well

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* api changes are not needed

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* moving datasets requirement and removing a missed api change

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* addressing review comments

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

* refactoring the quickstart example

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>

---------

Signed-off-by: Rakib Hasan <rhasan@nvidia.com>
2025-04-18 07:06:16 +08:00

412 lines
16 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import os
import torch
from datasets import load_dataset, load_metric
from transformers import AutoModelForCausalLM, LlamaTokenizer
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm.bindings import KVCacheType
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization import QuantMode
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--hf_model_location',
type=str,
default='/tmp/models/Mistral-7B-v0.1')
parser.add_argument('--test_hf', action='store_true')
parser.add_argument('--test_trt_llm', action='store_true')
parser.add_argument('--data_type',
type=str,
choices=['fp16'],
default='fp16')
parser.add_argument('--dataset_path', type=str, default='/tmp/data')
parser.add_argument(
'--max_attention_window_size',
type=int,
default=4096,
help=
'The attention window size that controls the sliding window attention / cyclic kv cache behavior'
)
parser.add_argument(
'--max_input_len',
type=int,
default=6400,
help='The max input length TensorRT-LLM engine was built with')
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument('--max_ite', type=int, default=5)
parser.add_argument(
'--engine_dir',
type=str,
default='/code/tensorrt_llm/mistral_trtllm/llama_style_merge_long_v2')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_beams', type=int, default=1)
parser.add_argument('--top_k', type=int, default=1)
parser.add_argument('--output_len', type=int, default=128)
parser.add_argument('--temperature', type=float, default=1)
parser.add_argument('--check_accuracy', action='store_true')
parser.add_argument('--tensorrt_llm_rouge1_threshold',
type=float,
default=15.0)
parser.add_argument(
'--rouge_dir',
default=None,
type=str,
help=
"datasets.load_metrics('rouge') will attempt to pull rouge package from HF. Use cached rouge can avoid network outage of host or HF."
)
parser.add_argument(
'--multi_block_mode',
type=lambda s: s.lower() in
("yes", "true", "t", "1"
), # custom boolean function to convert input string to boolean
default=True,
help=
"Distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel."
)
parser.add_argument('--enable_context_fmha_fp32_acc',
action='store_true',
help="Enable FMHA runner FP32 accumulation.")
args = parser.parse_args()
return args
def TRTLLaMA(args, config):
pretrained_config = config['pretrained_config']
quantization_config = pretrained_config['quantization']
build_config = config['build_config']
kv_cache_type = KVCacheType(build_config['kv_cache_type'])
plugin_config = build_config['plugin_config']
dtype = pretrained_config['dtype']
tp_size = pretrained_config['mapping']['tp_size']
pp_size = pretrained_config['mapping']['pp_size']
world_size = tp_size * pp_size
assert world_size == tensorrt_llm.mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
num_heads = pretrained_config['num_attention_heads'] // tp_size
hidden_size = pretrained_config['hidden_size'] // tp_size
max_batch_size = build_config['max_batch_size']
vocab_size = pretrained_config['vocab_size']
num_layers = pretrained_config['num_hidden_layers']
use_gpt_attention_plugin = bool(plugin_config['gpt_attention_plugin'])
remove_input_padding = plugin_config['remove_input_padding']
num_kv_heads = pretrained_config['num_key_value_heads']
tokens_per_block = plugin_config['tokens_per_block']
quant_mode = QuantMode.from_quant_algo(
quant_algo=quantization_config['quant_algo'],
kv_cache_quant_algo=quantization_config['kv_cache_quant_algo'])
if pretrained_config.get('multi_query_mode', False):
tensorrt_llm.logger.warning(
"`multi_query_mode` config is deprecated. Please rebuild the engine."
)
num_kv_heads = 1
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
model_config = tensorrt_llm.runtime.ModelConfig(
max_batch_size=max_batch_size,
max_beam_width=args.num_beams,
vocab_size=vocab_size,
num_layers=num_layers,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
hidden_size=hidden_size,
kv_cache_type=kv_cache_type,
tokens_per_block=tokens_per_block,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
dtype=dtype,
quant_mode=quant_mode)
runtime_rank = tensorrt_llm.mpi_rank()
runtime_mapping = tensorrt_llm.Mapping(world_size,
runtime_rank,
tp_size=tp_size,
pp_size=pp_size)
torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node)
engine_name = f'rank{runtime_rank}.engine'
serialize_path = os.path.join(args.engine_dir, engine_name)
tensorrt_llm.logger.set_level(args.log_level)
profiler.start('load tensorrt_llm engine')
with open(serialize_path, 'rb') as f:
engine_buffer = f.read()
decoder = tensorrt_llm.runtime.GenerationSession(model_config,
engine_buffer,
runtime_mapping)
profiler.stop('load tensorrt_llm engine')
tensorrt_llm.logger.info(
f'Load engine takes: {profiler.elapsed_time_in_sec("load tensorrt_llm engine")} sec'
)
return decoder
def get_long_texts(dataset_openweb):
for datapoint in dataset_openweb["train"]:
text = datapoint["text"]
approximate_tokens = len(text.split())
if (approximate_tokens
> args.max_attention_window_size) and (approximate_tokens
< args.max_input_len):
yield text
def prepare_prompt(text):
text = text.replace("\n", " ")
text = text + '\n TL;DR: '
text = text.strip()
text = text.replace(" n't", "n't")
return text
def summarize_hf(datapoint, tokenizer, hf_model, args):
line_encoded = tokenizer(datapoint,
return_tensors='pt',
padding=True,
truncation=True)["input_ids"].type(torch.int32)
line_encoded = line_encoded.cuda()
with torch.no_grad():
output = hf_model.generate(line_encoded,
max_new_tokens=args.output_len,
temperature=args.temperature,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
num_beams=args.num_beams,
top_k=args.top_k,
do_sample=True,
early_stopping=True)
tokens_list = output[:, len(line_encoded[0]):].tolist()
output = output.reshape([args.batch_size, args.num_beams, -1])
output_lines_list = [
tokenizer.batch_decode(output[:, i, len(line_encoded[0]):],
skip_special_tokens=True)
for i in range(args.num_beams)
]
return output_lines_list, tokens_list
def summarize_tensorrt_llm(datapoint, tokenizer, tensorrt_llm_llama, args):
line_encoded = []
input_id = tokenizer.encode(datapoint,
return_tensors='pt').type(torch.int32)
line_encoded.append(input_id)
input_lengths = []
input_lengths.append(input_id.shape[-1])
max_length = max(input_lengths)
pad_id = tokenizer.encode(tokenizer.pad_token, add_special_tokens=False)[0]
end_id = tokenizer.encode(tokenizer.eos_token, add_special_tokens=False)[0]
if tensorrt_llm_llama.remove_input_padding:
line_encoded = [
torch.tensor(t, dtype=torch.int32).cuda() for t in line_encoded
]
else:
# do padding, should move outside the profiling to prevent the overhead
for i in range(args.batch_size):
pad_size = max_length - input_lengths[i]
pad = torch.ones([1, pad_size]).type(torch.int32) * pad_id
line_encoded[i] = torch.cat(
[torch.tensor(line_encoded[i], dtype=torch.int32), pad],
axis=-1)
line_encoded = torch.cat(line_encoded, axis=0).cuda()
input_lengths = torch.tensor(input_lengths, dtype=torch.int32).cuda()
sampling_config = tensorrt_llm.runtime.SamplingConfig(
end_id=end_id,
pad_id=pad_id,
top_k=args.top_k,
num_beams=args.num_beams)
with torch.no_grad():
tensorrt_llm_llama.setup(
batch_size=args.batch_size,
max_context_length=max_length,
max_new_tokens=args.output_len,
beam_width=args.num_beams,
max_attention_window_size=args.max_attention_window_size,
multi_block_mode=args.multi_block_mode,
enable_context_fmha_fp32_acc=args.enable_context_fmha_fp32_acc)
logger.info(f"Generation session set up with the parameters: \
batch_size: {tensorrt_llm_llama.batch_size}, \
max_context_length: {tensorrt_llm_llama.max_context_length}, \
max_new_tokens: {tensorrt_llm_llama.max_new_tokens}, \
beam_width: {tensorrt_llm_llama.beam_width}, \
max_attention_window_size: {tensorrt_llm_llama.max_attention_window_size}, \
multi_block_mode: {tensorrt_llm_llama.multi_block_mode}, \
enable_context_fmha_fp32_acc: {tensorrt_llm_llama.enable_context_fmha_fp32_acc}"
)
if tensorrt_llm_llama.remove_input_padding:
output_ids = tensorrt_llm_llama.decode_batch(
line_encoded, sampling_config)
else:
output_ids = tensorrt_llm_llama.decode(
line_encoded,
input_lengths,
sampling_config,
)
torch.cuda.synchronize()
logger.info(f"Decoded output of shape{output_ids.shape}")
# Extract a list of tensors of shape beam_width x output_ids.
if tensorrt_llm_llama.mapping.is_first_pp_rank():
output_beams_list = [
tokenizer.batch_decode(output_ids[batch_idx, :,
input_lengths[batch_idx]:],
skip_special_tokens=True)
for batch_idx in range(args.batch_size)
]
return output_beams_list, output_ids[:, :, max_length:].tolist()
return [], []
def main(args):
runtime_rank = tensorrt_llm.mpi_rank()
logger.set_level(args.log_level)
profiler.start('load tokenizer')
tokenizer = LlamaTokenizer.from_pretrained(args.hf_model_location,
legacy=False,
padding_side='left')
profiler.stop('load tokenizer')
tensorrt_llm.logger.info(
f'Load tokenizer takes: {profiler.elapsed_time_in_sec("load tokenizer")} sec'
)
tokenizer.pad_token = tokenizer.eos_token
dataset_openweb = load_dataset("stas/openwebtext-10k",
cache_dir=args.dataset_path,
trust_remote_code=True)
long_texts = get_long_texts(dataset_openweb) # generator
# get datapoints
try:
datapoints = [
prepare_prompt(next(long_texts)) for i in range(args.max_ite)
]
except StopIteration:
logger.warning(
f"No test data of sufficient length ({args.max_attention_window_size}). Try decreasing the max_attention_window_size parameter"
)
return
if args.test_trt_llm:
config_path = os.path.join(args.engine_dir, 'config.json')
with open(config_path, 'r') as f:
config = json.load(f)
tensorrt_llm_llama = TRTLLaMA(args, config)
trt_llm_summary = []
for ite in range(args.max_ite):
trt_llm_summary.append(
summarize_tensorrt_llm(datapoints[ite], tokenizer,
tensorrt_llm_llama, args)[0])
if runtime_rank == 0:
logger.info(
"---------------------------------------------------------")
logger.info("TRT LLM Generated : ")
logger.info(f" Article : {datapoints[0]}")
logger.info(f"\n Summary : {trt_llm_summary[0]}")
logger.info(
"---------------------------------------------------------")
del tensorrt_llm_llama
test_hf = args.test_hf and runtime_rank == 0 # only run hf on rank 0
if test_hf:
profiler.start('load HF model')
hf_model = AutoModelForCausalLM.from_pretrained(
args.hf_model_location,
torch_dtype=torch.float16,
use_flash_attention_2=True)
profiler.stop('load HF model')
tensorrt_llm.logger.info(
f'Load HF model takes: {profiler.elapsed_time_in_sec("load HF model")} sec'
)
hf_model.cuda()
hf_summary = []
for ite in range(args.max_ite):
hf_summary.append(
summarize_hf(datapoints[ite], tokenizer, hf_model, args)[0])
logger.info("---------------------------------------------------------")
logger.info("HF Generated : ")
logger.info(f" Article : {datapoints[0]}")
logger.info(f"\n Summary : {hf_summary[0]}")
logger.info("---------------------------------------------------------")
# no ground truth, compare with hf
if runtime_rank == 0 and args.test_hf and args.test_trt_llm:
rouge_dir = args.rouge_dir if args.rouge_dir and os.path.exists(
args.rouge_dir) else "rouge"
metric_tensorrt_llm = [
load_metric(rouge_dir) for _ in range(args.num_beams)
]
for i in range(args.num_beams):
metric_tensorrt_llm[i].seed = 0
for ite in range(args.max_ite):
for batch_idx in range(len(trt_llm_summary[0])):
for beam_idx in range(args.num_beams):
metric_tensorrt_llm[beam_idx].add_batch(
predictions=[trt_llm_summary[ite][batch_idx][beam_idx]],
references=[hf_summary[ite][beam_idx][batch_idx]])
for beam_idx in range(args.num_beams):
logger.info(f"TensorRT-LLM beam {beam_idx} result")
computed_metrics_tensorrt_llm = metric_tensorrt_llm[
beam_idx].compute()
for key in computed_metrics_tensorrt_llm.keys():
logger.info(
f' {key} : {computed_metrics_tensorrt_llm[key].mid[2]*100}'
)
if args.check_accuracy and beam_idx == 0:
assert computed_metrics_tensorrt_llm['rouge1'].mid[
2] * 100 > args.tensorrt_llm_rouge1_threshold
if __name__ == '__main__':
args = parse_args()
main(args)