chore: [Breaking Change] Rename cuda_graph_config padding_enabled fie… (#6003)

Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
This commit is contained in:
nv-guomingz 2025-07-15 14:50:03 +08:00 committed by GitHub
parent d811843a08
commit 4e4d18826f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 156 additions and 139 deletions

View File

@ -138,7 +138,8 @@ YOUR_DATA_PATH=<your dataset file following the format>
cat >./extra-llm-api-config.yml<<EOF
cuda_graph_config: {}
moe_backend: TRTLLM
moe_config:
backend: TRTLLM
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 3
@ -196,7 +197,7 @@ We are seeing meaningful speedup using FP8 KV cache, thus refreshing the numbers
```bash
cat >./extra-llm-api-config.yml <<EOF
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes:
- 896
- 512
@ -263,7 +264,7 @@ YOUR_DATA_PATH=./dataset.txt
cat >./extra-llm-api-config.yml <<EOF
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes:
- 1
- 2

View File

@ -124,7 +124,8 @@ YOUR_DATA_PATH=<your dataset file following the format>
cat >./extra-llm-api-config.yml<<EOF
cuda_graph_config: {}
moe_backend: TRTLLM
moe_config:
backend: TRTLLM
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 3
@ -179,7 +180,8 @@ YOUR_DATA_PATH=<your dataset file following the format>
cat >./extra-llm-api-config.yml<<EOF
cuda_graph_config: {}
moe_backend: TRTLLM
moe_config:
backend: TRTLLM
speculative_config:
decoding_type: MTP
num_nextn_predict_layers: 3

View File

@ -157,7 +157,7 @@ These optimizations target the overall execution flow, scheduling, and resource
There is a feature called CUDA Graph padding in TensorRT-LLM, which is a good trade-off between the number of CUDA Graphs and the CUDA Graph hit ratio; it tries to pad a batch to the nearest one with a captured CUDA Graph. Normally you should enable the CUDA Graph padding feature to increase the CUDA Graph hit rate, but the padding itself has some overhead due to wasted tokens computation.
Users can opt-out the CUDA Graph padding feature to see the perf benefits, by setting the `cuda_graph_config:\n padding_enabled: False`, see API here [Pytorch backend config](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/config.py#L41)
Users can opt-out the CUDA Graph padding feature to see the perf benefits, by setting the `cuda_graph_config:\n enable_padding: False`, see API here [Pytorch backend config](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/config.py#L41)
* Overlap Scheduler:

View File

@ -623,7 +623,8 @@ Run 36-way expert parallelism inference with the EPLB configuration incorporated
cat > ./extra_llm_api_options_eplb.yaml <<EOF
enable_attention_dp: true
cuda_graph_config: {}
moe_load_balancer: ./moe_load_balancer.yaml
moe_config:
load_balancer: ./moe_load_balancer.yaml
EOF
trtllm-llmapi-launch \

View File

@ -201,7 +201,7 @@ trtllm-bench --model $model_name throughput --dataset $dataset_file --backend py
`llm_options.yml`
```yaml
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes:
- 1
- 2

View File

@ -190,12 +190,14 @@ def gen_config_file(config_path: str,
'max_seq_len': 8576,
'free_gpu_memory_fraction': gen_gpu_memory_fraction,
'cuda_graph_config': {
'padding_enabled': True,
'enable_padding': True,
'batch_sizes': gen_cuda_graph_batch_sizes,
},
'print_iter_log': True,
'kv_cache_dtype': 'fp8',
'moe_backend': 'TRTLLM',
'moe_config': {
'backend': 'TRTLLM',
},
'cache_transceiver_config': {
'max_num_tokens': 8320,
},

View File

@ -21,7 +21,7 @@ def example_cuda_graph_config():
cuda_graph_config = CudaGraphConfig(
batch_sizes=[1, 2, 4],
padding_enabled=True,
enable_padding=True,
)
llm = LLM(

View File

@ -2,7 +2,7 @@ import argparse
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import (CudaGraphConfig, DraftTargetDecodingConfig,
EagleDecodingConfig, KvCacheConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
TorchCompileConfig)
@ -188,7 +188,7 @@ def setup_llm(args):
cuda_graph_config = CudaGraphConfig(
batch_sizes=args.cuda_graph_batch_sizes,
padding_enabled=args.cuda_graph_padding_enabled,
enable_padding=args.cuda_graph_padding_enabled,
) if args.use_cuda_graph else None
llm = LLM(
model=args.model_dir,
@ -207,7 +207,7 @@ def setup_llm(args):
enable_piecewise_cuda_graph= \
args.use_piecewise_cuda_graph)
if args.use_torch_compile else None,
moe_backend=args.moe_backend,
moe_config=MoeConfig(backend=args.moe_backend),
enable_trtllm_sampler=args.enable_trtllm_sampler,
max_seq_len=args.max_seq_len,
max_batch_size=args.max_batch_size,

View File

@ -142,7 +142,7 @@ python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \
cat <<EOF > /tmp/extra-llm-api-config.yml
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes: [1, 4, 8, 12]
EOF
@ -169,9 +169,10 @@ python /app/tensorrt_llm/benchmarks/cpp/prepare_dataset.py \
cat <<EOF > /tmp/extra-llm-api-config.yml
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes: [1, 2]
moe_max_num_tokens: 16384
moe_config:
max_num_tokens: 16384
EOF
trtllm-bench -m deepseek-ai/DeepSeek-R1 --model_path ${DS_R1_NVFP4_MODEL_PATH} throughput \
@ -237,7 +238,7 @@ To serve the model using `trtllm-serve`:
```bash
cat >./extra-llm-api-config.yml <<EOF
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes:
- 1
- 2
@ -316,7 +317,7 @@ export TRTLLM_USE_UCX_KVCACHE=1
cat >./gen-extra-llm-api-config.yml <<EOF
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes:
- 1
- 2
@ -539,7 +540,7 @@ python3 /path/to/TensorRT-LLM/benchmarks/cpp/prepare_dataset.py \
cat >/path/to/TensorRT-LLM/extra-llm-api-config.yml <<EOF
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes:
- 1
- 2

View File

@ -1557,14 +1557,14 @@ cat >./extra-llm-api-config.yml <<EOF
stream_interval: 2
cuda_graph_config:
max_batch_size: 1024
padding_enabled: true
enable_padding: true
EOF
```
Explanation:
- `stream_interval`: The iteration interval to create responses under the streaming mode.
- `cuda_graph_config`: CUDA Graph config.
- `max_batch_size`: Max CUDA graph batch size to capture.
- `padding_enabled`: Whether to enable CUDA graph padding.
- `enable_padding`: Whether to enable CUDA graph padding.
### Launch trtllm-serve OpenAI-compatible API server

View File

@ -29,7 +29,7 @@ enable_attention_dp: true
stream_interval: 2
cuda_graph_config:
max_batch_size: 512
padding_enabled: true
enable_padding: true
EOF
```
Explanation:
@ -37,7 +37,7 @@ Explanation:
- `stream_interval`: The iteration interval to create responses under the streaming mode.
- `cuda_graph_config`: CUDA Graph config.
- `max_batch_size`: Max CUDA graph batch size to capture.
- `padding_enabled`: Whether to enable CUDA graph padding.
- `enable_padding`: Whether to enable CUDA graph padding.
#### 2. Launch trtllm-serve OpenAI-compatible API server
@ -81,7 +81,7 @@ enable_min_latency: true
stream_interval: 2
cuda_graph_config:
max_batch_size: 8
padding_enabled: true
enable_padding: true
EOF
```
Explanation:
@ -90,7 +90,7 @@ Explanation:
- `stream_interval`: The iteration interval to create responses under the streaming mode.
- `cuda_graph_config`: CUDA Graph config.
- `max_batch_size`: Max CUDA graph batch size to capture.
- `padding_enabled`: Whether to enable CUDA graph padding.
- `enable_padding`: Whether to enable CUDA graph padding.
#### 2. Launch trtllm-serve OpenAI-compatible API server

View File

@ -745,7 +745,7 @@ To serve the model using `trtllm-serve`:
```bash
cat >./extra-llm-api-config.yml <<EOF
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes:
- 1
- 2
@ -821,7 +821,7 @@ export TRTLLM_USE_UCX_KVCACHE=1
cat >./gen-extra-llm-api-config.yml <<EOF
cuda_graph_config:
padding_enabled: true
enable_padding: true
batch_sizes:
- 1
- 2

View File

@ -28,7 +28,9 @@ cat > ./extra_llm_api_options.yaml <<EOF
enable_attention_dp: true
cuda_graph_config: {}
moe_backend: WideEP
moe_max_num_tokens: 8192
moe_config:
backend: WideEP
max_num_tokens: 8192
EOF
trtllm-llmapi-launch \
@ -117,9 +119,10 @@ Run 36-way expert parallelism inference with the EPLB configuration incorporated
cat > ./extra_llm_api_options_eplb.yaml <<EOF
enable_attention_dp: true
cuda_graph_config: {}
moe_backend: WideEP
moe_max_num_tokens: 9216
moe_load_balancer: ./moe_load_balancer.yaml
moe_config:
backend: WideEP
max_num_tokens: 9216
load_balancer: ./moe_load_balancer.yaml
EOF
trtllm-llmapi-launch \
@ -183,9 +186,10 @@ Run 36-way expert parallelism inference with the EPLB configuration incorporated
cat > ./extra_llm_api_options_eplb.yaml <<EOF
enable_attention_dp: true
cuda_graph_config: {}
moe_backend: WideEP
moe_max_num_tokens: 9216
moe_load_balancer: ./moe_load_balancer.yaml
moe_config:
backend: WideEP
max_num_tokens: 9216
load_balancer: ./moe_load_balancer.yaml
EOF
trtllm-llmapi-launch \
@ -204,9 +208,9 @@ trtllm-bench --model ${MODEL_NAME} \
> **Note:** Similar to offline EP Load Balancer, you can enable expert ID counting to verify the effectiveness of EPLB, but remember to disable it when running inference for benchmarking or production purposes.
> **Explanation on moe_max_num_tokens:** For Large Scale EP, there can be extreme conditions that all ranks send tokens to a single rank since they all want that expert.
> **Explanation on max_num_tokens of moe_config:** For Large Scale EP, there can be extreme conditions that all ranks send tokens to a single rank since they all want that expert.
In that case, that rank will have too many tokens to compute. In order not to make the hot rank OOM, there is one strategy that chunk the tokens if there are too much.
`moe_max_num_tokens` is the parameter that controls the max chunk size. However, this may have performance penalty if there is enough since batch size is smaller.
`max_num_tokens` of moe_config is the parameter that controls the max chunk size. However, this may have performance penalty if there is enough since batch size is smaller.
So by default, it is set to some value that all tokens can complete in one wave. However, if EP size is large, we may need to trade off that in order not to OOM or got other runtime errors due to lack of memory.
One good point is that if memory is OK, we can set `moe_max_num_tokens` to `max_batch_size * ep_size` to make all generation requests can be processed in one chunk.
For example, if `ep_size` is 36 and `max_batch_size` is 256, we may set `moe_max_num_tokens` to 9216.
One good point is that if memory is OK, we can set `max_num_tokens` to `max_batch_size * ep_size` to make all generation requests can be processed in one chunk.
For example, if `ep_size` is 36 and `max_batch_size` is 256, we may set `max_num_tokens` to 9216.

View File

@ -196,7 +196,7 @@ def gen_config_file(config_path: str,
'max_seq_len': 2176,
'free_gpu_memory_fraction': gen_gpu_memory_fraction,
'cuda_graph_config': {
'padding_enabled': True,
'enable_padding': True,
'batch_sizes': gen_cuda_graph_batch_sizes,
},
'print_iter_log': True,

View File

@ -309,7 +309,7 @@ def get_rank_model_storage(model):
def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
max_batch_size: int, max_num_tokens: int,
max_draft_len: int,
padding_enabled: bool) -> list[int]:
enable_padding: bool) -> list[int]:
# This is the largest possible batch size for a pure decoding batch.
max_cuda_graph_bs = min(max_batch_size,
int(max_num_tokens / (1 + max_draft_len)))
@ -326,8 +326,8 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
# is that if the user is OK padding to a batch size B, they should also
# be OK with padding to some size B' < B since the performance will generally
# just be better in the smaller case.
if padding_enabled and (i == 0
or result[i - 1] != max_cuda_graph_bs):
if enable_padding and (i == 0
or result[i - 1] != max_cuda_graph_bs):
logger.warning(
"CUDA graph padding is enabled, but one of the given CUDA graph "
f"batch sizes ({bs}) is larger than the executor's max batch size "

View File

@ -152,7 +152,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
pass
cuda_graph_config = {
"padding_enabled": True,
"enable_padding": True,
"max_batch_size": max_batch_size
}

View File

@ -9,7 +9,7 @@ from .llm_args import (BatchingType, CacheTransceiverConfig, CalibConfig,
CudaGraphConfig, DraftTargetDecodingConfig,
DynamicBatchConfig, EagleDecodingConfig,
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
LookaheadDecodingConfig, MedusaDecodingConfig,
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
UserProvidedDecodingConfig)
@ -27,6 +27,7 @@ __all__ = [
'KvCacheConfig',
'KvCacheRetentionConfig',
'CudaGraphConfig',
'MoeConfig',
'LookaheadDecodingConfig',
'MedusaDecodingConfig',
'EagleDecodingConfig',

View File

@ -72,7 +72,7 @@ class CudaGraphConfig(BaseModel):
max_batch_size: int = Field(
default=0, description="Maximum batch size for CUDA graphs.")
padding_enabled: bool = Field(
enable_padding: bool = Field(
default=False,
description=
"If true, batches are rounded up to the nearest cuda_graph_batch_size. This is usually a net win for performance."
@ -88,6 +88,30 @@ class CudaGraphConfig(BaseModel):
return v
class MoeConfig(BaseModel):
"""
Configuration for MoE.
"""
backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM",
"VANILLA"] = Field(default='CUTLASS',
description="MoE backend to use.")
max_num_tokens: Optional[int] = Field(
default=None,
description=
"If set, at most max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time. If the number of tokens exceeds max_num_tokens, the input tensors will be split into chunks and a for loop will be used."
)
load_balancer: Optional[Union[object, str]] = Field(
default=None,
description="Configuration for MoE load balancing.",
json_schema_extra={"type": "Union[MoeLoadBalancerConfig, str]"})
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
@dataclass
class _ParallelConfig:
''' The model distribution configs for LLM. '''
@ -1768,26 +1792,12 @@ class TorchLlmArgs(BaseLlmArgs):
disable_overlap_scheduler: bool = Field(
default=False, description="Disable the overlap scheduler.")
moe_max_num_tokens: Optional[int] = Field(
default=None,
description=
"If set, at most moe_max_num_tokens tokens will be sent to torch.ops.trtllm.fused_moe at the same time. If the number of tokens exceeds moe_max_num_tokens, the input tensors will be split into chunks and a for loop will be used."
)
moe_load_balancer: Optional[Union[object, str]] = Field(
default=None,
description="Configuration for MoE load balancing.",
json_schema_extra={
"type":
"Union[tensorrt_llm._torch.model_config.MoeLoadBalancerConfig, str, None]"
})
moe_config: MoeConfig = Field(default_factory=MoeConfig,
description="MoE config.")
attn_backend: str = Field(default='TRTLLM',
description="Attention backend to use.")
moe_backend: str = Field(default='CUTLASS',
description="MoE backend to use.")
enable_mixed_sampler: bool = Field(
default=False,
description=
@ -1889,25 +1899,6 @@ class TorchLlmArgs(BaseLlmArgs):
def extra_resource_managers(self, value: Dict[str, object]) -> None:
self._extra_resource_managers = value
@model_validator(mode="after")
def validate_moe_load_balancer(self):
from .._torch.model_config import MoeLoadBalancerConfig
if isinstance(self.moe_load_balancer, str):
if not os.path.exists(self.moe_load_balancer):
raise FileNotFoundError(
f"MoE load balancer config file not found: {self.moe_load_balancer}"
)
try:
with open(self.moe_load_balancer) as f:
moe_load_balancer_config = yaml.safe_load(f)
self.moe_load_balancer = MoeLoadBalancerConfig(
**moe_load_balancer_config)
except Exception as e:
raise ValueError(
f"Failed to load MoE load balancer config file: {self.moe_load_balancer}"
) from e
return self
@model_validator(mode="after")
def validate_stream_interval(self):
if self.stream_interval <= 0:
@ -1917,17 +1908,17 @@ class TorchLlmArgs(BaseLlmArgs):
@staticmethod
def _generate_cuda_graph_batch_sizes(max_batch_size: int,
padding_enabled: bool) -> List[int]:
enable_padding: bool) -> List[int]:
"""Generate a list of batch sizes for CUDA graphs.
Args:
max_batch_size: Maximum batch size to generate up to
padding_enabled: Whether padding is enabled, which affects the batch size distribution
enable_padding: Whether padding is enabled, which affects the batch size distribution
Returns:
List of batch sizes to create CUDA graphs for
"""
if padding_enabled:
if enable_padding:
batch_sizes = [1, 2, 4] + [i * 8 for i in range(1, 17)]
else:
batch_sizes = list(range(1, 32)) + [32, 64, 128]
@ -1947,6 +1938,25 @@ class TorchLlmArgs(BaseLlmArgs):
return batch_sizes
@model_validator(mode="after")
def validate_load_balancer(self) -> 'TorchLlmArgs':
from .._torch import MoeLoadBalancerConfig
if isinstance(self.moe_config.load_balancer, str):
if not os.path.exists(self.moe_config.load_balancer):
raise FileNotFoundError(
f"MoE load balancer config file not found: {self.moe_config.load_balancer}"
)
try:
with open(self.moe_config.load_balancer) as f:
moe_load_balancer_config = yaml.safe_load(f)
self.moe_config.load_balancer = MoeLoadBalancerConfig(
**moe_load_balancer_config)
except Exception as e:
raise ValueError(
f"Failed to load MoE load balancer config file: {self.load_balancer}"
) from e
return self
@model_validator(mode='after')
def validate_cuda_graph_config(self) -> 'TorchLlmArgs':
"""Validate CUDA graph configuration.
@ -1965,7 +1975,7 @@ class TorchLlmArgs(BaseLlmArgs):
config.batch_sizes = sorted(config.batch_sizes)
if config.max_batch_size != 0:
if config.batch_sizes != self._generate_cuda_graph_batch_sizes(
config.max_batch_size, config.padding_enabled):
config.max_batch_size, config.enable_padding):
raise ValueError(
"Please don't set both cuda_graph_config.batch_sizes "
"and cuda_graph_config.max_batch_size.\n"
@ -1977,7 +1987,7 @@ class TorchLlmArgs(BaseLlmArgs):
else:
max_batch_size = config.max_batch_size or 128
generated_sizes = self._generate_cuda_graph_batch_sizes(
max_batch_size, config.padding_enabled)
max_batch_size, config.enable_padding)
config.batch_sizes = generated_sizes
config.max_batch_size = max_batch_size
@ -1996,14 +2006,14 @@ class TorchLlmArgs(BaseLlmArgs):
cuda_graph_max_batch_size=self.cuda_graph_config.max_batch_size
if self.cuda_graph_config else
CudaGraphConfig.model_fields['max_batch_size'].default,
cuda_graph_padding_enabled=self.cuda_graph_config.padding_enabled
cuda_graph_padding_enabled=self.cuda_graph_config.enable_padding
if self.cuda_graph_config else
CudaGraphConfig.model_fields['padding_enabled'].default,
CudaGraphConfig.model_fields['enable_padding'].default,
disable_overlap_scheduler=self.disable_overlap_scheduler,
moe_max_num_tokens=self.moe_max_num_tokens,
moe_load_balancer=self.moe_load_balancer,
moe_max_num_tokens=self.moe_config.max_num_tokens,
moe_load_balancer=self.moe_config.load_balancer,
attn_backend=self.attn_backend,
moe_backend=self.moe_backend,
moe_backend=self.moe_config.backend,
enable_mixed_sampler=self.enable_mixed_sampler,
enable_trtllm_sampler=self.enable_trtllm_sampler,
kv_cache_dtype=self.kv_cache_dtype,
@ -2046,6 +2056,7 @@ def update_llm_args_with_extra_dict(
"enable_build_cache": BuildCacheConfig,
"speculative_config": DecodingBaseConfig,
"lora_config": LoraConfig,
"moe_config": MoeConfig,
}
for field_name, field_type in field_mapping.items():
if field_name in llm_args_dict:

View File

@ -19,7 +19,7 @@ import pytest
from tensorrt_llm import LLM
from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig, MTPDecodingConfig,
KvCacheConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, SamplingParams,
TorchCompileConfig)
from tensorrt_llm.models.modeling_utils import QuantConfig
@ -97,7 +97,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
torch_compile_config=torch_compile_config,
cuda_graph_config=CudaGraphConfig(padding_enabled=torch_compile,
cuda_graph_config=CudaGraphConfig(enable_padding=torch_compile,
batch_sizes=[4]),
attn_backend=attn_backend,
disable_overlap_scheduler=torch_compile,
@ -123,7 +123,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
torch_compile_config=torch_compile_config,
cuda_graph_config=CudaGraphConfig(padding_enabled=torch_compile,
cuda_graph_config=CudaGraphConfig(enable_padding=torch_compile,
batch_sizes=[4]),
attn_backend=attn_backend,
disable_overlap_scheduler=torch_compile,
@ -147,7 +147,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
torch_compile_config=torch_compile_config,
cuda_graph_config=CudaGraphConfig(padding_enabled=torch_compile,
cuda_graph_config=CudaGraphConfig(enable_padding=torch_compile,
batch_sizes=[4]),
attn_backend=attn_backend,
disable_overlap_scheduler=torch_compile,
@ -185,7 +185,7 @@ class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
enable_fullgraph=True) if torch_compile else None
pytorch_config = dict(
torch_compile_config=torch_compile_config,
cuda_graph_config=CudaGraphConfig(padding_enabled=torch_compile,
cuda_graph_config=CudaGraphConfig(enable_padding=torch_compile,
batch_sizes=[4]),
attn_backend=attn_backend,
disable_overlap_scheduler=torch_compile,
@ -719,7 +719,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
torch_compile_config=torch_compile_config,
moe_backend="CUTEDSL",
moe_config=MoeConfig(backend="CUTEDSL"),
)
quant_config = QuantConfig()
@ -759,7 +759,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(
max_batch_size=512,
padding_enabled=True,
enable_padding=True,
),
)
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
@ -782,7 +782,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
pytorch_config = dict(
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(padding_enabled=True),
cuda_graph_config=CudaGraphConfig(enable_padding=True),
)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.FP8_BLOCK_SCALES
@ -899,7 +899,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=not overlap_scheduler,
use_cuda_graph=cuda_graph,
torch_compile_config=torch_compile_config,
moe_backend="CUTEDSL",
moe_config=MoeConfig(backend="CUTEDSL"),
)
quant_config = QuantConfig()
@ -948,8 +948,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
initial_global_assignments=initial_global_assignments,
layer_updates_per_iter=0)
pytorch_backend_options = dict(cuda_graph_config=CudaGraphConfig(),
moe_backend="WIDEEP",
moe_load_balancer=eplb_config)
moe_config=MoeConfig(
backend="WIDEEP",
load_balancer=eplb_config))
with LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
tensor_parallel_size=4,
moe_expert_parallel_size=4,
@ -968,8 +969,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
eplb_config = MoeLoadBalancerConfig(num_slots=num_slots,
layer_updates_per_iter=2)
pytorch_config = dict(cuda_graph_config=CudaGraphConfig(),
moe_backend="WIDEEP",
moe_load_balancer=eplb_config)
moe_config=MoeConfig(backend="WIDEEP",
load_balancer=eplb_config))
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
@ -992,8 +993,9 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
eplb_config = MoeLoadBalancerConfig(num_slots=num_slots,
layer_updates_per_iter=2)
pytorch_backend_options = dict(cuda_graph_config=CudaGraphConfig(),
moe_backend="WIDEEP",
moe_load_balancer=eplb_config)
moe_config=MoeConfig(
backend="WIDEEP",
load_balancer=eplb_config))
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.NVFP4
if fp8kv:
@ -1035,8 +1037,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_backend=moe_backend,
)
moe_config=MoeConfig(backend=moe_backend))
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
@ -1095,7 +1096,7 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_backend=moe_backend,
moe_config=MoeConfig(backend=moe_backend),
)
mtp_config = None
@ -1331,7 +1332,7 @@ class TestDeepSeekR1(LlmapiAccuracyTestHarness):
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
moe_backend=moe_backend)
moe_config=MoeConfig(backend=moe_backend))
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.NVFP4
@ -1726,7 +1727,7 @@ class TestQwen3_30B_A3B(LlmapiAccuracyTestHarness):
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
moe_backend=moe_backend,
moe_config=MoeConfig(backend=moe_backend),
)
with LLM(
@ -1808,7 +1809,7 @@ class TestQwen3_235B_A22B(LlmapiAccuracyTestHarness):
pytorch_config = dict(
disable_overlap_scheduler=not overlap_scheduler,
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
moe_backend=moe_backend)
moe_config=MoeConfig(backend=moe_backend))
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.6)
with LLM(
@ -1854,7 +1855,7 @@ class TestKanana_Instruct(LlmapiAccuracyTestHarness):
def test_auto_dtype(self):
"RCCA: https://nvbugspro.nvidia.com/bug/5310520"
pytorch_config = dict(cuda_graph_config=CudaGraphConfig(
padding_enabled=True, max_batch_size=384))
enable_padding=True, max_batch_size=384))
with LLM(self.MODEL_PATH, **pytorch_config,
enable_attention_dp=True) as llm:
task = MMLU(self.MODEL_NAME)

View File

@ -17,7 +17,7 @@ generation_servers:
pipeline_parallel_size: 1
enable_attention_dp: true
cuda_graph_config:
padding_enabled: False
enable_padding: False
disable_overlap_scheduler: False
urls:
- "localhost:8002"

View File

@ -15,7 +15,7 @@ generation_servers:
tensor_parallel_size: 2
pipeline_parallel_size: 1
cuda_graph_config:
padding_enabled: False
enable_padding: False
disable_overlap_scheduler: False
urls:
- "localhost:8002"

View File

@ -28,7 +28,7 @@ generation_servers:
free_gpu_memory_fraction: 0.2
enable_partial_reuse: False
cuda_graph_config:
padding_enabled: True
enable_padding: True
batch_sizes: [1,4,8,16,24,32]
disable_overlap_scheduler: True
urls:

View File

@ -30,7 +30,7 @@ def get_model_yaml_config(model_label: str,
base_config = {
'print_iter_log': True,
'cuda_graph_config': {
'padding_enabled': True,
'enable_padding': True,
},
}
if 'kv_cache_dtype' in model_label:
@ -66,7 +66,7 @@ def get_model_yaml_config(model_label: str,
'config': {
'enable_attention_dp': True,
'cuda_graph_config': {
'padding_enabled': True,
'enable_padding': True,
'batch_sizes': [1, 2, 4, 8, 16, 32, 64, 128, 256, 384]
}
}
@ -89,7 +89,7 @@ def get_model_yaml_config(model_label: str,
'config': {
'print_iter_log': True,
'cuda_graph_config': {
'padding_enabled': True,
'enable_padding': True,
'batch_sizes': [1, 512, 1024, 2048]
}
}

View File

@ -519,7 +519,7 @@ def stress_test(config,
if config.backend == "pytorch":
extra_llm_options.update({
"cuda_graph_config": {
"padding_enabled": True,
"enable_padding": True,
"batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128, 256, 384],
},
"print_iter_log": True,

View File

@ -8,7 +8,7 @@ from utils.llm_data import llm_models_root
from utils.util import getSMVersion
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig, MTPDecodingConfig
from tensorrt_llm.llmapi import KvCacheConfig, MoeConfig, MTPDecodingConfig
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
@ -71,7 +71,7 @@ def test_deepseek_trtllmgen(model_name):
kv_cache_dtype="auto",
attn_backend="TRTLLM",
load_format="dummy",
moe_backend="TRTLLM",
moe_config=MoeConfig(backend="TRTLLM"),
)
model_dir = str(llm_models_root() / Path(f"DeepSeek-R1/{model_name}"))

View File

@ -8,7 +8,7 @@ from utils.llm_data import llm_models_root
from utils.util import getSMVersion
from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm.llmapi import KvCacheConfig
from tensorrt_llm.llmapi import KvCacheConfig, MoeConfig
from tensorrt_llm.llmapi.utils import get_total_gpu_memory
@ -65,9 +65,8 @@ def test_deepseek_streaming(model_name, backend, quant, tp_size):
disable_overlap_scheduler=True,
kv_cache_dtype="auto",
attn_backend=backend,
moe_max_num_tokens=moe_max_num_tokens,
)
moe_config = MoeConfig(max_num_tokens=moe_max_num_tokens)
model_dir = str(llm_models_root() / model_name / model_path[quant])
assert Path(model_dir).exists()
@ -76,6 +75,7 @@ def test_deepseek_streaming(model_name, backend, quant, tp_size):
tensor_parallel_size=tp_size,
enable_chunked_prefill=False,
**pytorch_config,
moe_config=moe_config,
moe_expert_parallel_size=-1,
moe_tensor_parallel_size=-1,
enable_attention_dp=enable_attention_dp,

View File

@ -307,8 +307,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
"CUDA graphs should be disabled when cuda_graph_config=None")
# Test 4: Custom CudaGraphConfig with specific settings
custom_config = CudaGraphConfig(max_batch_size=256,
padding_enabled=True)
custom_config = CudaGraphConfig(max_batch_size=256, enable_padding=True)
llm_args_custom = LlmArgs.from_kwargs(model="dummy_model",
cuda_graph_config=custom_config)
pytorch_config_custom = llm_args_custom.get_pytorch_backend_config()
@ -317,7 +316,7 @@ class PyTorchModelEngineTestCase(unittest.TestCase):
self.assertEqual(pytorch_config_custom.cuda_graph_max_batch_size, 256,
"Custom max_batch_size should be respected")
self.assertTrue(pytorch_config_custom.cuda_graph_padding_enabled,
"Custom padding_enabled should be respected")
"Custom enable_padding should be respected")
if __name__ == "__main__":

View File

@ -69,18 +69,12 @@ methods:
disable_overlap_scheduler:
annotation: bool
default: False
moe_max_num_tokens:
annotation: Optional[int]
default: null
moe_load_balancer:
annotation: Union[tensorrt_llm._torch.MoeLoadBalancerConfig, str, None]
moe_config:
annotation: tensorrt_llm.llmapi.llm_args.MoeConfig
default: null
attn_backend:
annotation: str
default: TRTLLM
moe_backend:
annotation: str
default: CUTLASS
enable_mixed_sampler:
annotation: bool
default: False

View File

@ -272,7 +272,7 @@ class TestTorchLlmArgsCudaGraphSettings:
cuda_graph_config=CudaGraphConfig(
batch_sizes=CudaGraphConfig._generate_cuda_graph_batch_sizes(
128, True),
padding_enabled=True,
enable_padding=True,
max_batch_size=128))
assert args.cuda_graph_config.batch_sizes == CudaGraphConfig._generate_cuda_graph_batch_sizes(
128, True)
@ -282,14 +282,14 @@ class TestTorchLlmArgsCudaGraphSettings:
# set cuda_graph_batch_sizes only
args = TorchLlmArgs(model=llama_model_path,
cuda_graph_config=CudaGraphConfig(
batch_sizes=[1, 2, 4], padding_enabled=True))
batch_sizes=[1, 2, 4], enable_padding=True))
assert args.cuda_graph_config.batch_sizes == [1, 2, 4]
def test_cuda_graph_batch_sizes_case_2(self):
# set cuda_graph_config.max_batch_size only
args = TorchLlmArgs(model=llama_model_path,
cuda_graph_config=CudaGraphConfig(
max_batch_size=128, padding_enabled=True))
max_batch_size=128, enable_padding=True))
assert args.cuda_graph_config.batch_sizes == CudaGraphConfig._generate_cuda_graph_batch_sizes(
128, True)
assert args.cuda_graph_config.max_batch_size == 128