[torch.compile]: Disable Sequence Parallelism (SP) for piecewise compilation (#38373)

Signed-off-by: SouthWest7 <am1ao@qq.com>
Signed-off-by: Xinan Miao <1403572259@qq.com>
Co-authored-by: SouthWest7 <am1ao@qq.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Co-authored-by: Wang Xingran <72983099+wangxingran222@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
Xinan Miao
2026-04-27 01:44:42 +08:00
committed by GitHub
parent b39c266dae
commit 32e45636e3
9 changed files with 222 additions and 79 deletions
@@ -261,6 +261,8 @@ def _compare_sp(
},
"use_inductor_graph_partition": use_inductor_graph_partition,
}
if not use_inductor_graph_partition:
compilation_config["splitting_ops"] = []
tp_sp_args = [
*common_args,
@@ -19,6 +19,7 @@ from vllm.config import (
VllmConfig,
set_current_vllm_config,
)
from vllm.config.utils import Range
from vllm.distributed import (
tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter,
@@ -288,6 +289,22 @@ def test_async_tp_pass_replace(
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
def test_async_tp_pass_requires_full_graph_compilation():
vllm_config = VllmConfig()
vllm_config.compilation_config.use_inductor_graph_partition = False
vllm_config.compilation_config.splitting_ops = [
"vllm::unified_attention_with_output"
]
async_tp_pass = object.__new__(AsyncTPPass)
async_tp_pass.compilation_config = vllm_config.compilation_config
with pytest.raises(
AssertionError, match="AsyncTPPass requires full-graph compilation"
):
async_tp_pass.is_applicable_for_range(Range(start=8, end=8))
def async_tp_pass_on_test_model(
local_rank: int,
world_size: int,
@@ -22,6 +22,7 @@ from vllm.config import (
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.config.utils import Range
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
init_distributed_environment,
@@ -216,6 +217,24 @@ def test_sequence_parallelism_pass(
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
def test_sequence_parallelism_pass_requires_full_graph_compilation():
vllm_config = VllmConfig()
vllm_config.compilation_config.use_inductor_graph_partition = False
vllm_config.compilation_config.splitting_ops = [
"vllm::unified_attention_with_output"
]
sequence_parallelism_pass = object.__new__(SequenceParallelismPass)
sequence_parallelism_pass.compilation_config = vllm_config.compilation_config
sequence_parallelism_pass.min_token_num = 1
with pytest.raises(
AssertionError,
match="SequenceParallelismPass requires full-graph compilation",
):
sequence_parallelism_pass.is_applicable_for_range(Range(start=8, end=8))
def sequence_parallelism_pass_on_test_model(
local_rank: int,
world_size: int,
+118 -1
View File
@@ -407,7 +407,7 @@ def test_should_split():
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
# max from list
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
# filtered out 15 due to SP
# SP forces full-graph compilation, sizes are filtered by TP
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
# limited by the max_tokens
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
@@ -465,6 +465,123 @@ def test_cudagraph_sizes_post_init(
)
@pytest.mark.skipif(
not current_platform.support_static_graph_mode(),
reason="Skip if not cudagraph mode supported",
)
@pytest.mark.parametrize(
(
"cudagraph_mode",
"use_inductor_graph_partition",
"expected_enable_sp",
"expected_cudagraph_mode",
"expected_piecewise_compile",
"expected_capture_sizes",
"expected_max_size",
),
[
(CUDAGraphMode.PIECEWISE, False, True, CUDAGraphMode.FULL, False, [2, 4], 4),
(
CUDAGraphMode.FULL_DECODE_ONLY,
False,
True,
CUDAGraphMode.FULL_DECODE_ONLY,
False,
[2, 4],
4,
),
(
CUDAGraphMode.FULL_AND_PIECEWISE,
False,
True,
CUDAGraphMode.FULL,
False,
[2, 4],
4,
),
(
CUDAGraphMode.FULL_AND_PIECEWISE,
True,
True,
CUDAGraphMode.FULL_AND_PIECEWISE,
True,
[2, 4],
4,
),
],
)
def test_sequence_parallelism_requires_full_graph_compilation(
cudagraph_mode: CUDAGraphMode,
use_inductor_graph_partition: bool,
expected_enable_sp: bool,
expected_cudagraph_mode: CUDAGraphMode,
expected_piecewise_compile: bool,
expected_capture_sizes: list[int],
expected_max_size: int,
):
with patch.object(current_platform, "device_count", return_value=2):
vllm_config = VllmConfig(
parallel_config=ParallelConfig(tensor_parallel_size=2),
scheduler_config=SchedulerConfig(
max_num_seqs=128,
max_num_batched_tokens=2048,
max_model_len=2048,
is_encoder_decoder=False,
),
)
vllm_config.model_config = MagicMock(
dtype=torch.float16,
enforce_eager=False,
is_moe=False,
disable_cascade_attn=False,
get_hidden_size=MagicMock(return_value=4096),
)
vllm_config.compilation_config = CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
cudagraph_capture_sizes=[1, 2, 4, 15],
max_cudagraph_capture_size=None,
compile_sizes=["cudagraph_capture_sizes"],
use_inductor_graph_partition=use_inductor_graph_partition,
pass_config=PassConfig(
enable_sp=True,
fuse_gemm_comms=True,
fuse_norm_quant=True,
fuse_act_quant=True,
eliminate_noops=True,
sp_min_token_num=512,
),
cudagraph_mode=cudagraph_mode,
)
vllm_config.compilation_config.set_splitting_ops_for_v1(
all2all_backend=vllm_config.parallel_config.all2all_backend,
data_parallel_size=1,
)
vllm_config._set_compile_ranges()
vllm_config._set_cudagraph_sizes()
assert (
vllm_config.compilation_config.use_inductor_graph_partition
== use_inductor_graph_partition
)
assert (
bool(vllm_config.compilation_config.splitting_ops) == expected_piecewise_compile
)
assert vllm_config.compilation_config.pass_config.enable_sp == expected_enable_sp
assert (
vllm_config.compilation_config.pass_config.fuse_gemm_comms == expected_enable_sp
)
assert vllm_config.compilation_config.cudagraph_mode == expected_cudagraph_mode
assert (
vllm_config.compilation_config.cudagraph_capture_sizes == expected_capture_sizes
)
assert (
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
)
assert (
511 in vllm_config.compilation_config.compile_ranges_endpoints
) == expected_enable_sp
def test_cached_compilation_config(default_vllm_config):
import torch
from torch._inductor.utils import run_and_get_code
@@ -406,16 +406,13 @@ class AsyncTPPass(VllmPatternMatcherPass):
self.dump_patterns(config, self.patterns)
def is_applicable_for_range(self, compile_range: Range) -> bool:
# This pass is applied on top of the sequence parallelism pass.
# It inherits the same applicability condition as `SequenceParallelismPass`.
# See `SequenceParallelismPass.is_applicable` for more details.
if (
not self.compilation_config.splitting_ops
or self.compilation_config.use_inductor_graph_partition
):
return True
tp_size = get_tensor_model_parallel_world_size()
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
# This pass is applied on top of the sequence parallelism pass,
# which is only supported in fullgraph compilation mode.
assert (
self.compilation_config.use_inductor_graph_partition
or not self.compilation_config.splitting_ops
), "AsyncTPPass requires full-graph compilation"
return True
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
@@ -341,22 +341,18 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
significantly reduce communication overhead and improve overall model
performance.
This pass is only supported when compiling the whole graph (fullgraph
mode, i.e. using Inductor graph partition or empty splitting_ops).
Piecewise compilation is not supported because the residual tensor
gets split across TP ranks, causing size mismatches at subgraph
boundaries.
This pass splits up the residual tensor across TP ranks and hence divides its size.
Because the pattern matcher starts at the end of the graph, the replacement
contains a slice that temporarily conforms the input residual to the correct size.
After all patterns have been matched, we use a NoOpEliminationPass to clean up
what have now become no-op slices.
Note that an older version of the pass did not need this as it operated only on
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
mismatched shapes during replacement. So this approach has the same assumption that
correctness is only maintained if all rms_norm operations are split across ranks.
Correctness-wise, this is approach strictly better than before - before,
the graph was incorrect semantically and shape-wise during the pass.
With this approach there's only semantic incorrectness during the pass.
Both approaches restore a correct graph once all patterns are matched.
This pass splits up the residual tensor across TP ranks and hence
divides its size. Because the pattern matcher starts at the end of
the graph, the replacement contains a slice that temporarily conforms
the input residual to the correct size. After all patterns have been
matched, we use a NoOpEliminationPass to clean up what have now
become no-op slices.
"""
@enable_fake_mode
@@ -419,19 +415,13 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
and gathering tensors across TP ranks outweighs the benefits.
Returns False (SP disabled) when:
- Using piecewise compilation with non-concrete or TP-indivisible sizes
- min_token_num is None (SP disabled for this device/config)
- The compile range starts below the minimum token threshold
"""
# For piecewise compilation (not using inductor graph partition),
# we need concrete sizes that are divisible by TP for correct splitting
if (
not self.compilation_config.use_inductor_graph_partition
and self.compilation_config.splitting_ops
):
tp_size = get_tensor_model_parallel_world_size()
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
return False
assert (
self.compilation_config.use_inductor_graph_partition
or not self.compilation_config.splitting_ops
), "SequenceParallelismPass requires full-graph compilation"
# min_token_num is None when SP is disabled for this device/config
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
+19
View File
@@ -1148,6 +1148,25 @@ class CompilationConfig:
self.cudagraph_mode = CUDAGraphMode.FULL
self.splitting_ops = []
if (
not self.use_inductor_graph_partition
and (self.pass_config.enable_sp or self.pass_config.fuse_gemm_comms)
and self.splitting_ops
):
logger.warning_once(
"Sequence parallelism requires full-graph compilation when "
"use_inductor_graph_partition is off. Setting splitting_ops "
"to an empty list to preserve SP and async TP."
)
self.splitting_ops = []
if self.cudagraph_mode.has_piecewise_cudagraphs():
logger.warning_once(
"Sequence parallelism is incompatible with piecewise "
"cudagraph when use_inductor_graph_partition is off. "
"Setting cudagraph_mode to FULL."
)
self.cudagraph_mode = CUDAGraphMode.FULL
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
if (
all2all_backend == "deepep_high_throughput"
+17 -28
View File
@@ -983,19 +983,16 @@ class VllmConfig:
)
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
# async tp is built on top of sequence parallelism
# and requires it to be enabled.
if self.compilation_config.pass_config.fuse_gemm_comms:
self.compilation_config.pass_config.enable_sp = True
if self.compilation_config.pass_config.enable_sp:
# async tp is built on top of sequence parallelism and requires it.
pass_config = self.compilation_config.pass_config
if pass_config.fuse_gemm_comms:
pass_config.enable_sp = True
if pass_config.enable_sp:
if self.parallel_config.tensor_parallel_size == 1:
logger.warning("Sequence Parallelism requires TP>1, disabling")
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
pass_config.enable_sp = False
pass_config.fuse_gemm_comms = False
else:
# Compute SP threshold early; disable if None (model too
# small for SP to be beneficial).
pass_config = self.compilation_config.pass_config
if pass_config.sp_min_token_num is None:
from vllm.compilation.passes.fusion.sequence_parallelism import (
get_sequence_parallelism_threshold,
@@ -1015,8 +1012,8 @@ class VllmConfig:
"threshold heuristic, disabling. To force SP, "
"set pass_config.sp_min_token_num manually."
)
self.compilation_config.pass_config.enable_sp = False
self.compilation_config.pass_config.fuse_gemm_comms = False
pass_config.enable_sp = False
pass_config.fuse_gemm_comms = False
from vllm.utils.torch_utils import HAS_OPAQUE_TYPE
@@ -1098,6 +1095,7 @@ class VllmConfig:
self.compilation_config.cudagraph_num_of_warmups = 1
self._set_cudagraph_sizes()
else:
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
@@ -1171,8 +1169,8 @@ class VllmConfig:
)
if self.compilation_config.pass_config.enable_sp:
# With pipeline parallelism or dynamo partitioning,
# native rms norm tracing errors due to incorrect residual shape.
# With pipeline parallelism, native rms norm tracing errors due to
# incorrect residual shape.
# Use custom rms norm to unblock. In the future,
# the pass will operate on higher-level IR to avoid the issue.
# TODO: https://github.com/vllm-project/vllm/issues/27894
@@ -1183,24 +1181,15 @@ class VllmConfig:
self.compilation_config.mode,
)
is_fullgraph = (
self.compilation_config.use_inductor_graph_partition
or len(self.compilation_config.splitting_ops or []) == 0
)
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
if self.parallel_config.pipeline_parallel_size > 1:
if "-rms_norm" not in self.compilation_config.custom_ops:
self.compilation_config.custom_ops.append("+rms_norm")
else:
regime = (
"Dynamo partition"
if not is_fullgraph
else "pipeline parallelism"
)
logger.warning_once(
"Sequence parallelism not supported with "
"native rms_norm when using %s, "
"this will likely lead to an error.",
regime,
"pipeline parallelism",
)
# final check of cudagraph mode after all possible updates
@@ -1212,9 +1201,9 @@ class VllmConfig:
and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() # noqa: E501
):
logger.warning_once(
"No piecewise cudagraph for executing cascade attention."
" Will fall back to eager execution if a batch runs "
"into cascade attentions."
"No piecewise cudagraph for executing cascade attention. "
"Will fall back to eager execution if a batch runs into "
"cascade attentions."
)
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
+8 -15
View File
@@ -519,12 +519,8 @@ def is_residual_scattered_for_sp(
"""Check if the residual tensor is scattered for sequence parallelism.
The residual tensor is scattered across tensor parallel ranks when sequence
parallelism and tensor parallelism is enabled.
This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
- In full-graph compilation mode (no splitting ops or using inductor graph
partition), SP is always applied
- Otherwise, SP is only applied for specific shapes in compile_sizes
parallelism and tensor parallelism is enabled. SP is only supported in
full-graph compilation mode.
"""
if not vllm_config.compilation_config.pass_config.enable_sp:
return False
@@ -534,16 +530,13 @@ def is_residual_scattered_for_sp(
if tp == 1:
return False
assert (
vllm_config.compilation_config.use_inductor_graph_partition
or not vllm_config.compilation_config.splitting_ops
), "Sequence parallelism requires full-graph compilation"
# When sequence parallelism is enabled, we always pad num_input_tokens
# to be a multiple of tensor_parallel_size (tp) earlier.
assert num_input_tokens % tp == 0
if (
not vllm_config.compilation_config.splitting_ops
or vllm_config.compilation_config.use_inductor_graph_partition
):
return True
compile_sizes = vllm_config.compilation_config.compile_sizes
if compile_sizes is None:
return False
return num_input_tokens in compile_sizes
return True