mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
169 lines
7.3 KiB
Python
169 lines
7.3 KiB
Python
import argparse
|
|
import os
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm.profiler as profiler
|
|
from tensorrt_llm import logger
|
|
from tensorrt_llm.runtime import MultimodalModelRunner
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--max_new_tokens', type=int, default=30)
|
|
parser.add_argument('--batch_size', type=int, default=1)
|
|
parser.add_argument('--log_level', type=str, default='info')
|
|
parser.add_argument('--visual_engine_dir',
|
|
type=str,
|
|
default=None,
|
|
help='Directory containing visual TRT engines')
|
|
parser.add_argument('--visual_engine_name',
|
|
type=str,
|
|
default='model.engine',
|
|
help='Name of visual TRT engine')
|
|
parser.add_argument('--llm_engine_dir',
|
|
type=str,
|
|
default=None,
|
|
help='Directory containing TRT-LLM engines')
|
|
parser.add_argument('--hf_model_dir',
|
|
type=str,
|
|
default=None,
|
|
help="Directory containing tokenizer")
|
|
parser.add_argument('--input_text',
|
|
type=str,
|
|
default=None,
|
|
help='Text prompt to LLM')
|
|
parser.add_argument('--num_beams',
|
|
type=int,
|
|
help="Use beam search if num_beams >1",
|
|
default=1)
|
|
parser.add_argument('--top_k', type=int, default=1)
|
|
parser.add_argument('--top_p', type=float, default=0.0)
|
|
parser.add_argument('--temperature', type=float, default=1.0)
|
|
parser.add_argument('--repetition_penalty', type=float, default=1.0)
|
|
parser.add_argument('--run_profiling',
|
|
action='store_true',
|
|
help='Profile runtime over several iterations')
|
|
parser.add_argument('--profiling_iterations',
|
|
type=int,
|
|
help="Number of iterations to run profiling",
|
|
default=20)
|
|
parser.add_argument('--check_accuracy',
|
|
action='store_true',
|
|
help='Check correctness of text output')
|
|
parser.add_argument('--video_path',
|
|
type=str,
|
|
default=None,
|
|
help='Path to your local video file')
|
|
parser.add_argument("--image_path",
|
|
type=str,
|
|
default=None,
|
|
help='List of input image paths, separated by symbol')
|
|
parser.add_argument("--path_sep",
|
|
type=str,
|
|
default=",",
|
|
help='Path separator symbol')
|
|
parser.add_argument('--enable_context_fmha_fp32_acc',
|
|
action='store_true',
|
|
default=None,
|
|
help="Enable FMHA runner FP32 accumulation.")
|
|
parser.add_argument(
|
|
'--enable_chunked_context',
|
|
action='store_true',
|
|
help='Enables chunked context (only available with cpp session).',
|
|
)
|
|
parser.add_argument(
|
|
'--use_py_session',
|
|
default=False,
|
|
action='store_true',
|
|
help=
|
|
"Whether or not to use Python runtime session. By default C++ runtime session is used for the LLM."
|
|
)
|
|
parser.add_argument(
|
|
'--kv_cache_free_gpu_memory_fraction',
|
|
default=0.9,
|
|
type=float,
|
|
help='Specify the free gpu memory fraction.',
|
|
)
|
|
parser.add_argument(
|
|
'--cross_kv_cache_fraction',
|
|
default=0.5,
|
|
type=float,
|
|
help=
|
|
'Specify the kv cache fraction reserved for cross attention. Only applicable for encoder-decoder models. By default 0.5 for self and 0.5 for cross.',
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def print_result(model, input_text, output_text, args):
|
|
logger.info("---------------------------------------------------------")
|
|
if model.model_type != 'nougat':
|
|
logger.info(f"\n[Q] {input_text}")
|
|
for i in range(len(output_text)):
|
|
logger.info(f"\n[A]: {output_text[i]}")
|
|
|
|
if args.num_beams == 1:
|
|
output_ids = model.tokenizer(output_text[0][0],
|
|
add_special_tokens=False)['input_ids']
|
|
logger.info(f"Generated {len(output_ids)} tokens")
|
|
|
|
if args.check_accuracy:
|
|
if model.model_type != 'nougat':
|
|
if model.model_type == "vila":
|
|
for i in range(len(args.image_path.split(args.path_sep))):
|
|
if i % 2 == 0:
|
|
assert output_text[i][0].lower(
|
|
) == "the image captures a bustling city intersection teeming with life. from the perspective of a car's dashboard camera, we see"
|
|
else:
|
|
assert output_text[i][0].lower(
|
|
) == "the image captures the iconic merlion statue in singapore, a renowned worldwide landmark. the merlion, a mythical"
|
|
elif model.model_type == 'fuyu':
|
|
assert output_text[0][0].lower() == '4'
|
|
elif model.model_type == "pix2struct":
|
|
assert "characteristic | cat food, day | cat food, wet | cat treats" in output_text[
|
|
0][0].lower()
|
|
elif model.model_type in [
|
|
'blip2', 'neva', 'phi-3-vision', 'llava_next'
|
|
]:
|
|
assert 'singapore' in output_text[0][0].lower()
|
|
elif model.model_type == 'video-neva':
|
|
assert 'robot' in output_text[0][0].lower()
|
|
elif model.model_type == 'kosmos-2':
|
|
assert 'snowman' in output_text[0][0].lower()
|
|
elif model.model_type == "mllama":
|
|
if "<|image|><|begin_of_text|>If I had to write a haiku for this one" in input_text:
|
|
assert "it would be:.\\nPeter Rabbit is a rabbit.\\nHe lives in a" in output_text[
|
|
0][0]
|
|
elif "The key to life is" in input_text:
|
|
assert "to find your passion and pursue it with all your heart." in output_text[
|
|
0][0]
|
|
else:
|
|
assert output_text[0][0].lower() == 'singapore'
|
|
|
|
if args.run_profiling:
|
|
msec_per_batch = lambda name: 1000 * profiler.elapsed_time_in_sec(
|
|
name) / args.profiling_iterations
|
|
logger.info('Latencies per batch (msec)')
|
|
logger.info('TRT vision encoder: %.1f' % (msec_per_batch('Vision')))
|
|
logger.info('TRTLLM LLM generate: %.1f' % (msec_per_batch('LLM')))
|
|
logger.info('Multimodal generate: %.1f' % (msec_per_batch('Generate')))
|
|
|
|
logger.info("---------------------------------------------------------")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
args = parse_arguments()
|
|
logger.set_level(args.log_level)
|
|
|
|
model = MultimodalModelRunner(args)
|
|
raw_image = model.load_test_image()
|
|
|
|
num_iters = args.profiling_iterations if args.run_profiling else 1
|
|
for _ in range(num_iters):
|
|
input_text, output_text = model.run(args.input_text, raw_image,
|
|
args.max_new_tokens)
|
|
|
|
runtime_rank = tensorrt_llm.mpi_rank()
|
|
if runtime_rank == 0:
|
|
print_result(model, input_text, output_text, args)
|