TensorRT-LLMs/docs/source/performance.md
2023-12-01 22:27:51 +08:00

21 KiB

Performance of TensorRT-LLM

This document summarizes performance measurements of TensorRT-LLM on H100 (Hopper), L40S (Ada) and A100 (Ampere) GPUs for a few key models.

The data in the following tables is provided as a reference point to help users validate observed performance. It should not be considered as the peak performance that can be delivered by TensorRT-LLM.

Methodology

The different performance numbers below were collected using the methodology described in the benchmarks folder.

High Throughput

The below tables provide reference data at large batch sizes, representing high throughput tasks.

H100 GPUs (FP8)

Model Batch Size TP (1) Input Length Output Length Throughput (out tok/s)
GPT-J 6B 64 1 128 128 10,907
GPT-J 6B 64 1 128 2048 6,179
GPT-J 6B 64 1 2048 128 2,229
GPT-J 6B 64 1 2048 2048 2,980
LLaMA 7B 64 1 128 128 9,193
LLaMA 7B 64 1 128 2048 5,367
LLaMA 7B 64 1 2048 128 2,058
LLaMA 7B 32 1 2048 2048 2,230
LLaMA 70B 64 4 128 128 3,317
LLaMA 70B 64 4 128 2048 2,616
LLaMA 70B 64 4 2048 128 843
LLaMA 70B 64 4 2048 2048 1,583
Falcon 180B 96 8 128 128 2,686
Falcon 180B 96 8 128 2048 2,073
Falcon 180B 64 8 2048 128 465

L40S GPUs (FP8)

Model Batch Size TP (1) Input Length Output Length Throughput (out tok/s)
GPT-J 6B 64 1 128 128 3,630
GPT-J 6B 64 1 128 2048 1,859
GPT-J 6B 32 1 2048 128 616
GPT-J 6B 32 1 2048 2048 757
LLaMA 7B 64 1 128 128 3,240
LLaMA 7B 64 1 128 2048 1,622
LLaMA 7B 32 1 2048 128 581
LLaMA 7B 16 1 2048 2048 531

A100 GPUs (FP16)

Model Batch Size TP (1) Input Length Output Length Throughput (out tok/s)
GPT-J 6B 64 1 128 128 3,679
GPT-J 6B 32 1 128 2048 1,558
GPT-J 6B 32 1 2048 128 526
GPT-J 6B 16 1 2048 2048 650
LLaMA 7B 64 1 128 128 3,486
LLaMA 7B 32 1 128 2048 1,459
LLaMA 7B 32 1 2048 128 529
LLaMA 7B 16 1 2048 2048 592
LLaMA 70B 64 4 128 128 1,237
LLaMA 70B 64 4 128 2048 1,181
LLaMA 70B 64 4 2048 128 272
LLaMA 70B 64 4 2048 2048 738
Falcon 180B 64 8 128 128 929
Falcon 180B 64 8 128 2048 923
Falcon 180B 64 8 2048 128 202

(1) TP stands for Tensor Parallelism.

Low Latency

The below tables provide reference data at batch size 1 for first token latency, representing end-user's perceived latency for online streaming tasks.

H100 GPUs (FP8)

Model Batch Size TP (1) Input Length 1st Token Latency (ms)
GPT-J 6B 1 1 128 7
GPT-J 6B 1 1 2048 29
LLaMA 7B 1 1 128 7
LLaMA 7B 1 1 2048 36
LLaMA 70B 1 4 128 26
LLaMA 70B 1 4 2048 109
Falcon 180B 1 8 128 27
Falcon 180B 1 8 2048 205

L40S GPUs (FP8)

Model Batch Size TP (1) Input Length 1st Token Latency (ms)
GPT-J 6B 1 1 128 12
GPT-J 6B 1 1 2048 71
LLaMA 7B 1 1 128 14
LLaMA 7B 1 1 2048 73

A100 GPUs (FP16)

Model Batch Size TP (1) Input Length 1st Token Latency (ms)
GPT-J 6B 1 1 128 12
GPT-J 6B 1 1 2048 129
LLaMA 7B 1 1 128 16
LLaMA 7B 1 1 2048 133
LLaMA 70B 1 4 128 47
LLaMA 70B 1 4 2048 377
Falcon 180B 1 8 128 61
Falcon 180B 1 8 2048 509

(1) TP stands for Tensor Parallelism.

Known Issues

The following issues are being addressed to improve the efficiency of TensorRT-LLM.

Fused Matmul + Gated-SiLU (LLaMA)

The current implementation combines two Matmul operations into one Matmul followed by a separate SwiGLU kernel (when --use_fused_mlp is enabled). The future release will include a more efficient implementation that runs single Matmul + SwiGLU fused kernel.

Reproducing Benchmarked Results

Building the TensorRT-LLM Container


In order to benchmark TensorRT-LLM, you will need to follow the Quick Start build process to create a baseline container for building a wheel. Additionally, the development container needs a copy of the source code to build the wheel and the benchmarking script. Create the right build environment, use the following :

git clone https://github.com/NVIDIA/TensorRT-LLM.git
cd TensorRT-LLM
git submodule update --init --recursive
git lfs install
git lfs pull
make -C docker build
make -C docker run LOCAL_USER=1

Warning

If you have elevated privileges on your system, then skip the make -C docker run LOCAL_USER=1 command above as it may make it so that you cannot access some required system libraries within the container because the build forces your UID and GID to match those that are set for your non-elevated user. There are cases where the container will be booted as root (i.e. on some SLURM systems with the pyxis plugin) which will cause libraries to be missing.

If you are benchmarking in a shared environment, you need to specify the GPU indices that you would like the container to use, otherwise the Makefile defaults to loading the container with all GPUs on the system. For example, if you only have the 4 higher indices of GPUs on your system you can configure it using the following example:

NV_GPU=0,1,2,3
make -C docker run LOCAL_USER=1 GPU_OPTS='--gpus \"device=${NV_GPU}\"'

Additionally, if you'd like to mount external storage to access persistent storage, or previously built engines, you can mount directories as follows (simply replace source and destination with the appropriate paths):

make -C docker run LOCAL_USER=1 DOCKER_RUN_ARGS="-v /source:/destination"

Once the container starts, you'll need to build the wheel and the benchmarking scripts. From the code root (the default directory when the container is loaded), the following commands will build the TensorRT-LLM wheel, install dependencies, and build the benchmark scripts:

python3 ./scripts/build_wheel.py --benchmarks --trt_root /usr/local/tensorrt
pip install ./build/tensorrt_llm*.whl

Methodology

Engine Building Setups

Each engine needs to be built before they can be benchmarked, and requires the source code for each of their respective build scripts. For smaller models, it is fine to build the engine on the fly in container; however, for larger engines it is recommended to pre-build and mount a directory with the engine because engine files are quite large and take time to repeatedly build. Additionally, built engines can be used for input lengths, output lengths, and batch sizes up to their build options meaning you can use an engine to benchmark multiple input configurations.

In order to benchmark the various networks, our engine building scheme is as follows:

  • For the GPT-J, Llama2-7b, and Llama2-70b benchmarks were ran using a single-setting engine build for each network configured for our maximum expected throughput.
  • For Falcon-180B, where memory limits and model size have a higher impact for running the model, our benchmarks transition to a per-configuration engine build.

Below we document how to benchmark each model on an H100-HBM3-80GB system and reproduce the throughput numbers we document on our [Performance section](#performance of-tensorrt-llm).

Running on A100

To run the benchmarks below on A100, you will need to remove the --enable_fp8 --fp8_kv_cache options from each engine build command because FP8 computation is a feature in H100 and newer GPUs.

Reproducing First Token Latency

In order to test the latency to the first token, you can build the engines as specified below (or with the tweaks specified above on A100) -- once built as described in the build steps above, you can then benchmark with a single output token in order to find the time to first token latency. We provide the appropriate command lines below for each of the benchmarked models, but you can use this same method to benchmark other models available in TensorRT-LLM.

Benchmarking per Model

GPT-J 6B


python examples/gptj/build.py \
	--enable_context_fmha \
	--parallel_build \
	--output_dir /tmp/engines/gptj \
	--dtype float16 \
	--use_gpt_attention_plugin float16 \
	--world_size 1 \
	--max_batch_size 64 \
	--max_input_len 2048 \
	--max_output_len 2048 \
	--hidden_act gelu \
	--enable_fp8 \
	--fp8_kv_cache \
	--strongly_typed \
	--n_layer 28 \
	--n_head 16 \
	--n_embd 4096 \
	--n_positions 2048 \
	--enable_two_optimization_profiles
Throughput Benchmark
in_out_sizes=("64:128,128" "64:128,2048" "64:2048,128" "64:2048,2048")
for in_out in ${in_out_sizes[@]}
do
	batch_size=$(echo $in_out | awk -F':' '{ print $1 }')
	in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }')
	echo "BS: $batch_size, ISL/OSL: $in_out_dims"

	./cpp/build/benchmarks/gptSessionBenchmark --model gptj --engine_dir /tmp/engines/gptj/ --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims
done
First Token Latency Benchmark
in_out_sizes=("64:128,1" "64:2048,1")
for in_out in ${in_out_sizes[@]}
do
	batch_size=$(echo $in_out | awk -F':' '{ print $1 }')
	in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }')
	echo "BS: $batch_size, ISL/OSL: $in_out_dims"

	./cpp/build/benchmarks/gptSessionBenchmark --model gptj --engine_dir /tmp/engines/gptj/ --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims
done

Llama2-7b


pip install -r examples/llama/requirements.txt
python examples/llama/build.py \
	--remove_input_padding \
	--enable_context_fmha \
	--parallel_build \
	--output_dir /tmp/engines/llama/7b \
	--dtype float16 \
	--use_gpt_attention_plugin float16 \
	--world_size 1 \
	--tp_size 1 \
	--pp_size 1 \
	--max_batch_size 64 \
	--max_input_len 2048 \
	--max_output_len 2048 \
	--enable_fp8 \
	--fp8_kv_cache \
	--strongly_typed \
	--n_layer 32 \
	--n_head 32 \
	--n_embd 4096 \
	--inter_size 11008 \
	--vocab_size 32000 \
	--n_positions 4096 \
	--hidden_act silu
Throughput Benchmark
in_out_sizes=("64:128,128" "64:128,2048" "64:2048,128" "32:2048,2048")
for in_out in ${in_out_sizes[@]}
do
	batch_size=$(echo $in_out | awk -F':' '{ print $1 }')
	in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }')
	echo "BS: $batch_size, ISL/OSL: $in_out_dims"

	./cpp/build/benchmarks/gptSessionBenchmark --model llama --engine_dir /tmp/engines/llama/7b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims
done
First Token Latency Benchmark
in_out_sizes=("64:128,1" "32:2048,1")
for in_out in ${in_out_sizes[@]}
do
	batch_size=$(echo $in_out | awk -F':' '{ print $1 }')
	in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }')
	echo "BS: $batch_size, ISL/OSL: $in_out_dims"

	./cpp/build/benchmarks/gptSessionBenchmark --model llama --engine_dir /tmp/engines/llama/7b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims
done

Llama2-70b

pip install -r examples/llama/requirements.txt
python examples/llama/build.py \
	--remove_input_padding \
	--enable_context_fmha \
	--parallel_build \
	--output_dir /tmp/engines/llama/70b \
	--dtype float16 \
	--use_gpt_attention_plugin float16 \
	--world_size 4 \
	--tp_size 4 \
	--pp_size 1 \
	--max_batch_size 64 \
	--max_input_len 2048 \
	--max_output_len 2048 \
	--enable_fp8 \
	--fp8_kv_cache \
	--strongly_typed \
	--n_layer 80 \
	--n_head 64 \
	--n_kv_head 8 \
	--n_embd 8192 \
	--inter_size 28672 \
	--vocab_size 32000 \
	--n_positions 4096 \
	--hidden_act silu \
	--ffn_dim_multiplier 1.3 \
	--multiple_of 4096
Throughput Benchmark
in_out_sizes=("64:128,128" "64:128,2048" "64:2048,128" "64:2048,2048")
for in_out in ${in_out_sizes[@]}
do
	batch_size=$(echo $in_out | awk -F':' '{ print $1 }')
	in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }')
	echo "BS: $batch_size, ISL/OSL: $in_out_dims"

	mpirun -n 4 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --model llama --engine_dir /tmp/engines/llama/70b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims
done
First Token Latency Benchmark
in_out_sizes=("64:128,1" "64:128,1")
for in_out in ${in_out_sizes[@]}
do
	batch_size=$(echo $in_out | awk -F':' '{ print $1 }')
	in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }')
	echo "BS: $batch_size, ISL/OSL: $in_out_dims"

	mpirun -n 4 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --model llama --engine_dir /tmp/engines/llama/70b --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len $in_out_dims
done

Falcon-180B


Benchmarking Falcon-180B requires a custom engine per batch size, input/output sequence length due to the large footprint of the model and the large input size of 2048. You can build and benchmark each engine one at a time with the following loop.

# Benchmark specific batch size:isl:osl combinations.
in_out_sizes=("96:128,128" "96:128,2048" "64:2048,128")
for in_out in ${in_out_sizes[@]}
do
	batch_size=$(echo $in_out | awk -F':' '{ print $1 }')
	in_out_dims=$(echo $in_out | awk -F':' '{ print $2 }')
	isl=$(echo $in_out_dims | awk -F',' '{ print $1 }')
	osl=$(echo $in_out_dims | awk -F',' '{ print $2 }')
	engine_path="/tmp/engines/falcon/180b/${batch_size}_${isl}_${osl}"
	echo "BS: $batch_size, ISL/OSL: ${isl},${osl}"

	# Build the specific engine for the BS,ISL,OSL combination
	python examples/falcon/build.py \
		--use_inflight_batching \
		--paged_kv_cache \
		--remove_input_padding \
		--enable_context_fmha \
		--parallel_build \
		--output_dir $engine_path \
		--dtype float16 \
		--use_gemm_plugin float16 \
		--use_gpt_attention_plugin float16 \
		--world_size 8 \
		--tp 8 \
		--max_batch_size $batch_size \
		--max_input_len $isl \
		--max_output_len $osl \
		--enable_fp8 \
		--fp8_kv_cache \
		--n_layer 80 \
		--n_head 232 \
		--n_kv_head 8 \
		--n_embd 14848 \
		--vocab_size 65024 \
		--new_decoder_architecture
	# Throughput benchmark
	mpirun -n 8 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --model falcon --engine_dir $engine_path --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len "${isl},${osl}"
	# Time to first token benchmark
	mpirun -n 8 --allow-run-as-root --oversubscribe ./cpp/build/benchmarks/gptSessionBenchmark --model falcon --engine_dir $engine_path --warm_up 1 --batch_size $batch_size --duration 0 --num_runs 5 --input_output_len "${isl},1"

	# The Falcon-180b engine is quite large, remove after the benchmark to free up space
	# Remove this line if you'd like to save the engines.
	rm -r $engine_path
done