Revert "chore: [Breaking Change] Rename cuda_graph_config padding_enabled fie…" (#5818)

This commit is contained in:
nv-guomingz 2025-07-08 12:15:30 +08:00 committed by GitHub
parent 5bc3a15f10
commit 0be41b6524
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 70 additions and 71 deletions

View File

@ -195,20 +195,20 @@ We are seeing meaningful speedup using FP8 KV cache, thus refreshing the numbers
#### Benchmark
```bash
cat >./extra-llm-api-config.yml <<EOF
cuda_graph_config:
enable_padding: true
batch_sizes:
- 896
- 512
- 256
- 128
- 64
- 32
- 16
- 8
- 4
- 2
- 1
use_cuda_graph: true
cuda_graph_padding_enabled: true
cuda_graph_batch_sizes:
- 896
- 512
- 256
- 128
- 64
- 32
- 16
- 8
- 4
- 2
- 1
print_iter_log: true
kv_cache_dtype: fp8
enable_attention_dp: true
@ -262,19 +262,19 @@ python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py \
YOUR_DATA_PATH=./dataset.txt
cat >./extra-llm-api-config.yml <<EOF
cuda_graph_config:
enable_padding: true
batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
- 384
use_cuda_graph: true
cuda_graph_padding_enabled: true
cuda_graph_batch_sizes:
- 1
- 2
- 4
- 8
- 16
- 32
- 64
- 128
- 256
- 384
print_iter_log: ${PRINT_ITER_LOG}
enable_attention_dp: true
EOF

View File

@ -151,7 +151,7 @@ These optimizations target the overall execution flow, scheduling, and resource
* CUDA Graph
This had a significant **22% E2E performance impact** for throughput scenarios. CUDA Graphs allow capturing a sequence of CUDA operations and launching them as a single unit, drastically reducing kernel launch overheads. This is particularly beneficial for models with many small kernels, and particularly on the PyTorch flow, because the python host code normally executes slower than C++. Since the CUDA Graph freezes the kernel launch parameters, which is normally associated with the tensor shapes, it can only be safely used with static shape, meaning that different CUDA graphs need to be captured for different batch sizes. Each graph will have some cost of memory usage, and capturing time, thus we cannot capture every possible CUDA graph for all possible batches. For the non-captured batch sizes, PyTorch eager mode code will be executed. There is a feature called CUDA Graph padding in TensorRT-LLM, which is a good trade-off between the number of CUDA Graphs and the CUDA Graph hit ratio; it tries to pad a batch to the nearest one with a captured CUDA Graph. Normally you should enable the CUDA Graph padding feature to increase the CUDA Graph hit rate, but the padding itself has some overhead due to wasted tokens computation. Users can opt-out the CUDA Graph padding feature to see the perf benefits, by setting the `cuda_graph_config: enable_padding` to false, see API here [Pytorch backend config](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/config.py#L41)
This had a significant **22% E2E performance impact** for throughput scenarios. CUDA Graphs allow capturing a sequence of CUDA operations and launching them as a single unit, drastically reducing kernel launch overheads. This is particularly beneficial for models with many small kernels, and particularly on the PyTorch flow, because the python host code normally executes slower than C++. Since the CUDA Graph freezes the kernel launch parameters, which is normally associated with the tensor shapes, it can only be safely used with static shape, meaning that different CUDA graphs need to be captured for different batch sizes. Each graph will have some cost of memory usage, and capturing time, thus we cannot capture every possible CUDA graph for all possible batches. For the non-captured batch sizes, PyTorch eager mode code will be executed. There is a feature called CUDA Graph padding in TensorRT-LLM, which is a good trade-off between the number of CUDA Graphs and the CUDA Graph hit ratio; it tries to pad a batch to the nearest one with a captured CUDA Graph. Normally you should enable the CUDA Graph padding feature to increase the CUDA Graph hit rate, but the padding itself has some overhead due to wasted tokens computation. Users can opt-out the CUDA Graph padding feature to see the perf benefits, by setting the `cuda_graph_padding_enabled` to false, see API here [Pytorch backend config](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/config.py#L41)
* Overlap Scheduler:

View File

@ -201,7 +201,7 @@ trtllm-bench --model $model_name throughput --dataset $dataset_file --backend py
`llm_options.yml`
```yaml
cuda_graph_config:
enable_padding: true
padding_enabled: true
batch_sizes:
- 1
- 2

View File

@ -190,7 +190,7 @@ def gen_config_file(config_path: str,
'max_seq_len': 8576,
'free_gpu_memory_fraction': gen_gpu_memory_fraction,
'cuda_graph_config': {
'enable_padding': True,
'padding_enabled': True,
'batch_sizes': gen_cuda_graph_batch_sizes,
},
'print_iter_log': True,

View File

@ -142,7 +142,7 @@ python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \
cat <<EOF > /tmp/extra-llm-api-config.yml
cuda_graph_config:
enable_padding: true
padding_enabled: true
batch_sizes: [1, 4, 8, 12]
EOF
@ -169,7 +169,7 @@ python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \
cat <<EOF > /tmp/extra-llm-api-config.yml
cuda_graph_config:
enable_padding: true
padding_enabled: true
batch_sizes: [1, 2]
moe_max_num_tokens: 16384
EOF
@ -237,7 +237,7 @@ To serve the model using `trtllm-serve`:
```bash
cat >./extra-llm-api-config.yml <<EOF
cuda_graph_config:
enable_padding: true
padding_enabled: true
batch_sizes:
- 1
- 2
@ -316,7 +316,7 @@ export TRTLLM_USE_UCX_KVCACHE=1
cat >./gen-extra-llm-api-config.yml <<EOF
cuda_graph_config:
enable_padding: true
padding_enabled: true
batch_sizes:
- 1
- 2
@ -538,7 +538,7 @@ python3 /path/to/TensorRT-LLM/benchmarks/cpp/prepare_dataset.py \
cat >/path/to/TensorRT-LLM/extra-llm-api-config.yml <<EOF
cuda_graph_config:
enable_padding: true
padding_enabled: true
batch_sizes:
- 1
- 2

View File

@ -745,7 +745,7 @@ To serve the model using `trtllm-serve`:
```bash
cat >./extra-llm-api-config.yml <<EOF
cuda_graph_config:
enable_padding: true
padding_enabled: true
batch_sizes:
- 1
- 2
@ -821,7 +821,7 @@ export TRTLLM_USE_UCX_KVCACHE=1
cat >./gen-extra-llm-api-config.yml <<EOF
cuda_graph_config:
enable_padding: true
padding_enabled: true
batch_sizes:
- 1
- 2

View File

@ -189,7 +189,7 @@ def setup_llm(args):
cuda_graph_config = CudaGraphConfig(
batch_sizes=args.cuda_graph_batch_sizes,
enable_padding=args.cuda_graph_padding_enabled,
padding_enabled=args.cuda_graph_padding_enabled,
) if args.use_cuda_graph else None
llm = LLM(
model=args.model_dir,

View File

@ -196,7 +196,7 @@ def gen_config_file(config_path: str,
'max_seq_len': 2176,
'free_gpu_memory_fraction': gen_gpu_memory_fraction,
'cuda_graph_config': {
'enable_padding': True,
'padding_enabled': True,
'batch_sizes': gen_cuda_graph_batch_sizes,
},
'print_iter_log': True,

View File

@ -309,7 +309,7 @@ def get_rank_model_storage(model):
def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
max_batch_size: int, max_num_tokens: int,
max_draft_len: int,
enable_padding: bool) -> list[int]:
padding_enabled: bool) -> list[int]:
# This is the largest possible batch size for a pure decoding batch.
max_cuda_graph_bs = min(max_batch_size,
int(max_num_tokens / (1 + max_draft_len)))
@ -326,8 +326,8 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
# is that if the user is OK padding to a batch size B, they should also
# be OK with padding to some size B' < B since the performance will generally
# just be better in the smaller case.
if enable_padding and (i == 0
or result[i - 1] != max_cuda_graph_bs):
if padding_enabled and (i == 0
or result[i - 1] != max_cuda_graph_bs):
logger.warning(
"CUDA graph padding is enabled, but one of the given CUDA graph "
f"batch sizes ({bs}) is larger than the executor's max batch size "

View File

@ -151,7 +151,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
pass
cuda_graph_config = {
"enable_padding": True,
"padding_enabled": True,
"max_batch_size": max_batch_size
}

View File

@ -71,7 +71,7 @@ class CudaGraphConfig(BaseModel):
max_batch_size: int = Field(
default=0, description="Maximum batch size for CUDA graphs.")
enable_padding: bool = Field(
padding_enabled: bool = Field(
default=False,
description=
"If true, batches are rounded up to the nearest cuda_graph_batch_size. This is usually a net win for performance."
@ -1853,17 +1853,17 @@ class TorchLlmArgs(BaseLlmArgs):
@staticmethod
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
enable_padding: bool) -> List[int]:
padding_enabled: bool) -> List[int]:
"""Generate a list of batch sizes for CUDA graphs.
Args:
max_batch_size: Maximum batch size to generate up to
enable_padding: Whether padding is enabled, which affects the batch size distribution
padding_enabled: Whether padding is enabled, which affects the batch size distribution
Returns:
List of batch sizes to create CUDA graphs for
"""
if enable_padding:
if padding_enabled:
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
else:
batch_sizes = list(range(1, 32)) + [32, 64, 128]
@ -1901,7 +1901,7 @@ class TorchLlmArgs(BaseLlmArgs):
config.batch_sizes = sorted(config.batch_sizes)
if config.max_batch_size != 0:
if config.batch_sizes != self._generate_cuda_graph_batch_sizes(
config.max_batch_size, config.enable_padding):
config.max_batch_size, config.padding_enabled):
raise ValueError(
"Please don't set both cuda_graph_config.batch_sizes "
"and cuda_graph_config.max_batch_size.\n"
@ -1913,7 +1913,7 @@ class TorchLlmArgs(BaseLlmArgs):
else:
max_batch_size = config.max_batch_size or 128
generated_sizes = self._generate_cuda_graph_batch_sizes(
max_batch_size, config.enable_padding)
max_batch_size, config.padding_enabled)
config.batch_sizes = generated_sizes
config.max_batch_size = max_batch_size
@ -1932,9 +1932,9 @@ class TorchLlmArgs(BaseLlmArgs):
cuda_graph_max_batch_size=self.cuda_graph_config.max_batch_size
if self.cuda_graph_config else
CudaGraphConfig.model_fields['max_batch_size'].default,
cuda_graph_padding_enabled=self.cuda_graph_config.enable_padding
cuda_graph_padding_enabled=self.cuda_graph_config.padding_enabled
if self.cuda_graph_config else
CudaGraphConfig.model_fields['enable_padding'].default,
CudaGraphConfig.model_fields['padding_enabled'].default,
disable_overlap_scheduler=self.disable_overlap_scheduler,
moe_max_num_tokens=self.moe_max_num_tokens,
moe_load_balancer=self.moe_load_balancer,

View File

@ -102,7 +102,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
torch_compile_config=torch_compile_config,
cuda_graph_config=CudaGraphConfig(enable_padding=torch_compile,
cuda_graph_config=CudaGraphConfig(padding_enabled=torch_compile,
batch_sizes=[4]),
attn_backend=attn_backend,
disable_overlap_scheduler=torch_compile,
@ -129,7 +129,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
torch_compile_config=torch_compile_config,
cuda_graph_config=CudaGraphConfig(enable_padding=torch_compile,
cuda_graph_config=CudaGraphConfig(padding_enabled=torch_compile,
batch_sizes=[4]),
attn_backend=attn_backend,
disable_overlap_scheduler=torch_compile,
@ -154,7 +154,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
torch_compile_config=torch_compile_config,
cuda_graph_config=CudaGraphConfig(enable_padding=torch_compile,
cuda_graph_config=CudaGraphConfig(padding_enabled=torch_compile,
batch_sizes=[4]),
attn_backend=attn_backend,
disable_overlap_scheduler=torch_compile,
@ -193,7 +193,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
torch_compile_config=torch_compile_config,
cuda_graph_config=CudaGraphConfig(enable_padding=torch_compile,
cuda_graph_config=CudaGraphConfig(padding_enabled=torch_compile,
batch_sizes=[4]),
attn_backend=attn_backend,
disable_overlap_scheduler=torch_compile,
@ -741,7 +741,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(
max_batch_size=512,
enable_padding=True,
padding_enabled=True,
),
)
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
@ -765,7 +765,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
pytorch_config = dict(
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(enable_padding=True),
cuda_graph_config=CudaGraphConfig(padding_enabled=True),
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
@ -1857,7 +1857,7 @@ class TestKanana_Instruct(LlmapiAccuracyTestHarness):
def test_auto_dtype(self):
"RCCA: https://nvbugspro.nvidia.com/bug/5310520"
pytorch_config = dict(cuda_graph_config=CudaGraphConfig(
enable_padding=True, max_batch_size=384))
padding_enabled=True, max_batch_size=384))
with LLM(self.MODEL_PATH, **pytorch_config,
enable_attention_dp=True) as llm:
task = MMLU(self.MODEL_NAME)

View File

@ -17,7 +17,7 @@ generation_servers:
pipeline_parallel_size: 1
enable_attention_dp: true
cuda_graph_config:
enable_padding: False
padding_enabled: False
disable_overlap_scheduler: False
urls:
- "localhost:8002"

View File

@ -15,7 +15,7 @@ generation_servers:
tensor_parallel_size: 2
pipeline_parallel_size: 1
cuda_graph_config:
enable_padding: False
padding_enabled: False
disable_overlap_scheduler: False
urls:
- "localhost:8002"

View File

@ -28,7 +28,7 @@ generation_servers:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
cuda_graph_config:
enable_padding: True
padding_enabled: True
batch_sizes: [1,4,8,16,24,32]
disable_overlap_scheduler: True
urls:

View File

@ -30,7 +30,7 @@ def get_model_yaml_config(model_label: str,
base_config = {
'print_iter_log': True,
'cuda_graph_config': {
'enable_padding': True,
'padding_enabled': True,
},
}
if 'kv_cache_dtype' in model_label:
@ -65,10 +65,9 @@ def get_model_yaml_config(model_label: str,
],
'config': {
'enable_attention_dp': True,
'cuda_graph_config': {
'enable_padding': True,
'batch_sizes': [1, 2, 4, 8, 16, 32, 64, 128, 256, 384]
}
'cuda_graph_padding_enabled': True,
'cuda_graph_batch_sizes':
[1, 2, 4, 8, 16, 32, 64, 128, 256, 384]
}
},
# DeepSeek R1 model with specific batch size 128
@ -87,7 +86,7 @@ def get_model_yaml_config(model_label: str,
'config': {
'print_iter_log': True,
'cuda_graph_config': {
'enable_padding': True,
'padding_enabled': True,
'batch_sizes': [1, 512, 1024, 2048]
}
}

View File

@ -519,7 +519,7 @@ def stress_test(config,
if config.backend == "pytorch":
extra_llm_options.update({
"cuda_graph_config": {
"enable_padding": True,
"padding_enabled": True,
"batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128, 256, 384],
},
"print_iter_log": True,

View File

@ -272,7 +272,7 @@ class TestTorchLlmArgsCudaGraphSettings:
cuda_graph_config=CudaGraphConfig(
batch_sizes=CudaGraphConfig._generate_cuda_graph_batch_sizes(
128, True),
enable_padding=True,
padding_enabled=True,
max_batch_size=128))
assert args.cuda_graph_config.batch_sizes == CudaGraphConfig._generate_cuda_graph_batch_sizes(
128, True)
@ -282,14 +282,14 @@ class TestTorchLlmArgsCudaGraphSettings:
# set cuda_graph_batch_sizes only
args = TorchLlmArgs(model=llama_model_path,
cuda_graph_config=CudaGraphConfig(
batch_sizes=[1, 2, 4], enable_padding=True))
batch_sizes=[1, 2, 4], padding_enabled=True))
assert args.cuda_graph_config.batch_sizes == [1, 2, 4]
def test_cuda_graph_batch_sizes_case_2(self):
# set cuda_graph_config.max_batch_size only
args = TorchLlmArgs(model=llama_model_path,
cuda_graph_config=CudaGraphConfig(
max_batch_size=128, enable_padding=True))
max_batch_size=128, padding_enabled=True))
assert args.cuda_graph_config.batch_sizes == CudaGraphConfig._generate_cuda_graph_batch_sizes(
128, True)
assert args.cuda_graph_config.max_batch_size == 128