[Bugfix] Reject non-positive values for ParallelConfig int knobs (#44057)

Signed-off-by: jwzheng96 <jianweizheng@pku.edu.cn>
Signed-off-by: JianweiZheng <32029023+jwzheng96@users.noreply.github.com>
Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
JianweiZheng
2026-06-04 23:46:50 +08:00
committed by GitHub
parent 4cc78c9d5d
commit 99ef652907
+20 -18
View File
@@ -109,22 +109,24 @@ class EPLBConfig:
class ParallelConfig:
"""Configuration for the distributed execution."""
pipeline_parallel_size: int = 1
pipeline_parallel_size: int = Field(default=1, ge=1)
"""Number of pipeline parallel groups."""
tensor_parallel_size: int = 1
tensor_parallel_size: int = Field(default=1, ge=1)
"""Number of tensor parallel groups."""
prefill_context_parallel_size: int = 1
prefill_context_parallel_size: int = Field(default=1, ge=1)
"""Number of prefill context parallel groups."""
data_parallel_size: int = 1
data_parallel_size: int = Field(default=1, ge=1)
"""Number of data parallel groups. MoE layers will be sharded according to
the product of the tensor parallel size and data parallel size."""
data_parallel_size_local: int = 1
"""Number of local data parallel groups."""
data_parallel_rank: int = 0
"""Rank of the data parallel group."""
data_parallel_size_local: int = Field(default=1, ge=0)
"""Number of local data parallel groups. A value of 0 is a sentinel used by
the engine-args layer to signal that data parallelism was specified
externally (see `ParallelConfig.__post_init__`)."""
data_parallel_rank: int = Field(default=0, ge=0)
"""Rank of the data parallel group. The runtime check at
``__post_init__`` further bounds this by ``data_parallel_size``."""
data_parallel_rank_local: int | None = None
"""Local rank of the data parallel group,
set only in SPMD mode."""
"""Local rank of the data parallel group, set only in SPMD mode."""
data_parallel_master_ip: str = "127.0.0.1"
"""IP of the data parallel master."""
data_parallel_rpc_port: int = 29550
@@ -184,7 +186,7 @@ class ParallelConfig:
- "flashinfer_nvlink_two_sided": Use flashinfer two-sided kernels for mnnvl
- "flashinfer_nvlink_one_sided": Use flashinfer high-throughput a2a kernels"""
max_parallel_loading_workers: int | None = None
max_parallel_loading_workers: int | None = Field(default=None, ge=1)
"""Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
parallel and large models."""
@@ -197,15 +199,15 @@ class ParallelConfig:
enable_dbo: bool = False
"""Enable dual batch overlap for the model executor."""
ubatch_size: int = 0
ubatch_size: int = Field(default=0, ge=0)
"""Number of ubatch size."""
dbo_decode_token_threshold: int = 32
dbo_decode_token_threshold: int = Field(default=32, ge=0)
"""The threshold for dual batch overlap for batches only containing decodes.
If the number of tokens in the request is greater than this threshold,
microbatching will be used. Otherwise, the request will be processed in a
single batch."""
dbo_prefill_token_threshold: int = 512 # TODO(lucas): tune
dbo_prefill_token_threshold: int = Field(default=512, ge=0) # TODO(lucas): tune
"""The threshold for dual batch overlap for batches that contain one or more
prefills. If the number of tokens in the request is greater than this
threshold, microbatching will be used. Otherwise, the request will be
@@ -260,10 +262,10 @@ class ParallelConfig:
master_port: int = 29501
"""distributed master port for multi-node distributed
inference when distributed_executor_backend is mp."""
node_rank: int = 0
"""distributed node rank for multi-node distributed
node_rank: int = Field(default=0, ge=0)
"""distributed node rank for multi-node distributed
inference when distributed_executor_backend is mp."""
nnodes: int = 1
nnodes: int = Field(default=1, ge=1)
"""num of nodes for multi-node distributed
inference when distributed_executor_backend is mp."""
numa_bind: bool = False
@@ -318,7 +320,7 @@ class ParallelConfig:
"""Port of the coordination TCPStore. Can be set by the API server; workers
connect as clients to exchange self-picked group ports at runtime."""
decode_context_parallel_size: int = 1
decode_context_parallel_size: int = Field(default=1, ge=1)
"""Number of decode context parallel groups, because the world size does
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
needs to be divisible by dcp_size."""