mirror of
https://github.com/vllm-project/vllm.git
synced 2026-06-06 00:16:14 +00:00
[torch.compile]: Disable Sequence Parallelism (SP) for piecewise compilation (#38373)
Signed-off-by: SouthWest7 <am1ao@qq.com> Signed-off-by: Xinan Miao <1403572259@qq.com> Co-authored-by: SouthWest7 <am1ao@qq.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: OpenAI Codex <codex@openai.com> Co-authored-by: Wang Xingran <72983099+wangxingran222@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
@@ -261,6 +261,8 @@ def _compare_sp(
|
||||
},
|
||||
"use_inductor_graph_partition": use_inductor_graph_partition,
|
||||
}
|
||||
if not use_inductor_graph_partition:
|
||||
compilation_config["splitting_ops"] = []
|
||||
|
||||
tp_sp_args = [
|
||||
*common_args,
|
||||
|
||||
@@ -19,6 +19,7 @@ from vllm.config import (
|
||||
VllmConfig,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_reduce_scatter,
|
||||
@@ -288,6 +289,22 @@ def test_async_tp_pass_replace(
|
||||
run_torch_spawn(async_tp_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def test_async_tp_pass_requires_full_graph_compilation():
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config.use_inductor_graph_partition = False
|
||||
vllm_config.compilation_config.splitting_ops = [
|
||||
"vllm::unified_attention_with_output"
|
||||
]
|
||||
|
||||
async_tp_pass = object.__new__(AsyncTPPass)
|
||||
async_tp_pass.compilation_config = vllm_config.compilation_config
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError, match="AsyncTPPass requires full-graph compilation"
|
||||
):
|
||||
async_tp_pass.is_applicable_for_range(Range(start=8, end=8))
|
||||
|
||||
|
||||
def async_tp_pass_on_test_model(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
|
||||
@@ -22,6 +22,7 @@ from vllm.config import (
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.utils import Range
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
init_distributed_environment,
|
||||
@@ -216,6 +217,24 @@ def test_sequence_parallelism_pass(
|
||||
run_torch_spawn(sequence_parallelism_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def test_sequence_parallelism_pass_requires_full_graph_compilation():
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.compilation_config.use_inductor_graph_partition = False
|
||||
vllm_config.compilation_config.splitting_ops = [
|
||||
"vllm::unified_attention_with_output"
|
||||
]
|
||||
|
||||
sequence_parallelism_pass = object.__new__(SequenceParallelismPass)
|
||||
sequence_parallelism_pass.compilation_config = vllm_config.compilation_config
|
||||
sequence_parallelism_pass.min_token_num = 1
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="SequenceParallelismPass requires full-graph compilation",
|
||||
):
|
||||
sequence_parallelism_pass.is_applicable_for_range(Range(start=8, end=8))
|
||||
|
||||
|
||||
def sequence_parallelism_pass_on_test_model(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
|
||||
@@ -407,7 +407,7 @@ def test_should_split():
|
||||
(None, 257, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 256),
|
||||
# max from list
|
||||
([1, 2, 4, 15], None, 1, False, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 15),
|
||||
# filtered out 15 due to SP
|
||||
# SP forces full-graph compilation, sizes are filtered by TP
|
||||
([1, 2, 4, 15], None, 2, True, 2048, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
# limited by the max_tokens
|
||||
([1, 2, 4, 15], None, 1, False, 8, CUDAGraphMode.FULL_AND_PIECEWISE, 4),
|
||||
@@ -465,6 +465,123 @@ def test_cudagraph_sizes_post_init(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not current_platform.support_static_graph_mode(),
|
||||
reason="Skip if not cudagraph mode supported",
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"cudagraph_mode",
|
||||
"use_inductor_graph_partition",
|
||||
"expected_enable_sp",
|
||||
"expected_cudagraph_mode",
|
||||
"expected_piecewise_compile",
|
||||
"expected_capture_sizes",
|
||||
"expected_max_size",
|
||||
),
|
||||
[
|
||||
(CUDAGraphMode.PIECEWISE, False, True, CUDAGraphMode.FULL, False, [2, 4], 4),
|
||||
(
|
||||
CUDAGraphMode.FULL_DECODE_ONLY,
|
||||
False,
|
||||
True,
|
||||
CUDAGraphMode.FULL_DECODE_ONLY,
|
||||
False,
|
||||
[2, 4],
|
||||
4,
|
||||
),
|
||||
(
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
False,
|
||||
True,
|
||||
CUDAGraphMode.FULL,
|
||||
False,
|
||||
[2, 4],
|
||||
4,
|
||||
),
|
||||
(
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
True,
|
||||
True,
|
||||
CUDAGraphMode.FULL_AND_PIECEWISE,
|
||||
True,
|
||||
[2, 4],
|
||||
4,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_sequence_parallelism_requires_full_graph_compilation(
|
||||
cudagraph_mode: CUDAGraphMode,
|
||||
use_inductor_graph_partition: bool,
|
||||
expected_enable_sp: bool,
|
||||
expected_cudagraph_mode: CUDAGraphMode,
|
||||
expected_piecewise_compile: bool,
|
||||
expected_capture_sizes: list[int],
|
||||
expected_max_size: int,
|
||||
):
|
||||
with patch.object(current_platform, "device_count", return_value=2):
|
||||
vllm_config = VllmConfig(
|
||||
parallel_config=ParallelConfig(tensor_parallel_size=2),
|
||||
scheduler_config=SchedulerConfig(
|
||||
max_num_seqs=128,
|
||||
max_num_batched_tokens=2048,
|
||||
max_model_len=2048,
|
||||
is_encoder_decoder=False,
|
||||
),
|
||||
)
|
||||
vllm_config.model_config = MagicMock(
|
||||
dtype=torch.float16,
|
||||
enforce_eager=False,
|
||||
is_moe=False,
|
||||
disable_cascade_attn=False,
|
||||
get_hidden_size=MagicMock(return_value=4096),
|
||||
)
|
||||
vllm_config.compilation_config = CompilationConfig(
|
||||
mode=CompilationMode.VLLM_COMPILE,
|
||||
cudagraph_capture_sizes=[1, 2, 4, 15],
|
||||
max_cudagraph_capture_size=None,
|
||||
compile_sizes=["cudagraph_capture_sizes"],
|
||||
use_inductor_graph_partition=use_inductor_graph_partition,
|
||||
pass_config=PassConfig(
|
||||
enable_sp=True,
|
||||
fuse_gemm_comms=True,
|
||||
fuse_norm_quant=True,
|
||||
fuse_act_quant=True,
|
||||
eliminate_noops=True,
|
||||
sp_min_token_num=512,
|
||||
),
|
||||
cudagraph_mode=cudagraph_mode,
|
||||
)
|
||||
vllm_config.compilation_config.set_splitting_ops_for_v1(
|
||||
all2all_backend=vllm_config.parallel_config.all2all_backend,
|
||||
data_parallel_size=1,
|
||||
)
|
||||
vllm_config._set_compile_ranges()
|
||||
vllm_config._set_cudagraph_sizes()
|
||||
|
||||
assert (
|
||||
vllm_config.compilation_config.use_inductor_graph_partition
|
||||
== use_inductor_graph_partition
|
||||
)
|
||||
assert (
|
||||
bool(vllm_config.compilation_config.splitting_ops) == expected_piecewise_compile
|
||||
)
|
||||
assert vllm_config.compilation_config.pass_config.enable_sp == expected_enable_sp
|
||||
assert (
|
||||
vllm_config.compilation_config.pass_config.fuse_gemm_comms == expected_enable_sp
|
||||
)
|
||||
assert vllm_config.compilation_config.cudagraph_mode == expected_cudagraph_mode
|
||||
assert (
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes == expected_capture_sizes
|
||||
)
|
||||
assert (
|
||||
vllm_config.compilation_config.max_cudagraph_capture_size == expected_max_size
|
||||
)
|
||||
assert (
|
||||
511 in vllm_config.compilation_config.compile_ranges_endpoints
|
||||
) == expected_enable_sp
|
||||
|
||||
|
||||
def test_cached_compilation_config(default_vllm_config):
|
||||
import torch
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
|
||||
@@ -406,16 +406,13 @@ class AsyncTPPass(VllmPatternMatcherPass):
|
||||
self.dump_patterns(config, self.patterns)
|
||||
|
||||
def is_applicable_for_range(self, compile_range: Range) -> bool:
|
||||
# This pass is applied on top of the sequence parallelism pass.
|
||||
# It inherits the same applicability condition as `SequenceParallelismPass`.
|
||||
# See `SequenceParallelismPass.is_applicable` for more details.
|
||||
if (
|
||||
not self.compilation_config.splitting_ops
|
||||
or self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
|
||||
# This pass is applied on top of the sequence parallelism pass,
|
||||
# which is only supported in fullgraph compilation mode.
|
||||
assert (
|
||||
self.compilation_config.use_inductor_graph_partition
|
||||
or not self.compilation_config.splitting_ops
|
||||
), "AsyncTPPass requires full-graph compilation"
|
||||
return True
|
||||
|
||||
@VllmInductorPass.time_and_log
|
||||
def __call__(self, graph: fx.Graph) -> None:
|
||||
|
||||
@@ -341,22 +341,18 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
significantly reduce communication overhead and improve overall model
|
||||
performance.
|
||||
|
||||
This pass is only supported when compiling the whole graph (fullgraph
|
||||
mode, i.e. using Inductor graph partition or empty splitting_ops).
|
||||
Piecewise compilation is not supported because the residual tensor
|
||||
gets split across TP ranks, causing size mismatches at subgraph
|
||||
boundaries.
|
||||
|
||||
This pass splits up the residual tensor across TP ranks and hence divides its size.
|
||||
Because the pattern matcher starts at the end of the graph, the replacement
|
||||
contains a slice that temporarily conforms the input residual to the correct size.
|
||||
After all patterns have been matched, we use a NoOpEliminationPass to clean up
|
||||
what have now become no-op slices.
|
||||
|
||||
Note that an older version of the pass did not need this as it operated only on
|
||||
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
|
||||
mismatched shapes during replacement. So this approach has the same assumption that
|
||||
correctness is only maintained if all rms_norm operations are split across ranks.
|
||||
|
||||
Correctness-wise, this is approach strictly better than before - before,
|
||||
the graph was incorrect semantically and shape-wise during the pass.
|
||||
With this approach there's only semantic incorrectness during the pass.
|
||||
Both approaches restore a correct graph once all patterns are matched.
|
||||
This pass splits up the residual tensor across TP ranks and hence
|
||||
divides its size. Because the pattern matcher starts at the end of
|
||||
the graph, the replacement contains a slice that temporarily conforms
|
||||
the input residual to the correct size. After all patterns have been
|
||||
matched, we use a NoOpEliminationPass to clean up what have now
|
||||
become no-op slices.
|
||||
"""
|
||||
|
||||
@enable_fake_mode
|
||||
@@ -419,19 +415,13 @@ class SequenceParallelismPass(VllmPatternMatcherPass):
|
||||
and gathering tensors across TP ranks outweighs the benefits.
|
||||
|
||||
Returns False (SP disabled) when:
|
||||
- Using piecewise compilation with non-concrete or TP-indivisible sizes
|
||||
- min_token_num is None (SP disabled for this device/config)
|
||||
- The compile range starts below the minimum token threshold
|
||||
"""
|
||||
# For piecewise compilation (not using inductor graph partition),
|
||||
# we need concrete sizes that are divisible by TP for correct splitting
|
||||
if (
|
||||
not self.compilation_config.use_inductor_graph_partition
|
||||
and self.compilation_config.splitting_ops
|
||||
):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
|
||||
return False
|
||||
assert (
|
||||
self.compilation_config.use_inductor_graph_partition
|
||||
or not self.compilation_config.splitting_ops
|
||||
), "SequenceParallelismPass requires full-graph compilation"
|
||||
|
||||
# min_token_num is None when SP is disabled for this device/config
|
||||
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)
|
||||
|
||||
@@ -1148,6 +1148,25 @@ class CompilationConfig:
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
self.splitting_ops = []
|
||||
|
||||
if (
|
||||
not self.use_inductor_graph_partition
|
||||
and (self.pass_config.enable_sp or self.pass_config.fuse_gemm_comms)
|
||||
and self.splitting_ops
|
||||
):
|
||||
logger.warning_once(
|
||||
"Sequence parallelism requires full-graph compilation when "
|
||||
"use_inductor_graph_partition is off. Setting splitting_ops "
|
||||
"to an empty list to preserve SP and async TP."
|
||||
)
|
||||
self.splitting_ops = []
|
||||
if self.cudagraph_mode.has_piecewise_cudagraphs():
|
||||
logger.warning_once(
|
||||
"Sequence parallelism is incompatible with piecewise "
|
||||
"cudagraph when use_inductor_graph_partition is off. "
|
||||
"Setting cudagraph_mode to FULL."
|
||||
)
|
||||
self.cudagraph_mode = CUDAGraphMode.FULL
|
||||
|
||||
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
|
||||
if (
|
||||
all2all_backend == "deepep_high_throughput"
|
||||
|
||||
+17
-28
@@ -983,19 +983,16 @@ class VllmConfig:
|
||||
)
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.fuse_gemm_comms:
|
||||
self.compilation_config.pass_config.enable_sp = True
|
||||
if self.compilation_config.pass_config.enable_sp:
|
||||
# async tp is built on top of sequence parallelism and requires it.
|
||||
pass_config = self.compilation_config.pass_config
|
||||
if pass_config.fuse_gemm_comms:
|
||||
pass_config.enable_sp = True
|
||||
if pass_config.enable_sp:
|
||||
if self.parallel_config.tensor_parallel_size == 1:
|
||||
logger.warning("Sequence Parallelism requires TP>1, disabling")
|
||||
self.compilation_config.pass_config.enable_sp = False
|
||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||
pass_config.enable_sp = False
|
||||
pass_config.fuse_gemm_comms = False
|
||||
else:
|
||||
# Compute SP threshold early; disable if None (model too
|
||||
# small for SP to be beneficial).
|
||||
pass_config = self.compilation_config.pass_config
|
||||
if pass_config.sp_min_token_num is None:
|
||||
from vllm.compilation.passes.fusion.sequence_parallelism import (
|
||||
get_sequence_parallelism_threshold,
|
||||
@@ -1015,8 +1012,8 @@ class VllmConfig:
|
||||
"threshold heuristic, disabling. To force SP, "
|
||||
"set pass_config.sp_min_token_num manually."
|
||||
)
|
||||
self.compilation_config.pass_config.enable_sp = False
|
||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||
pass_config.enable_sp = False
|
||||
pass_config.fuse_gemm_comms = False
|
||||
|
||||
from vllm.utils.torch_utils import HAS_OPAQUE_TYPE
|
||||
|
||||
@@ -1098,6 +1095,7 @@ class VllmConfig:
|
||||
self.compilation_config.cudagraph_num_of_warmups = 1
|
||||
|
||||
self._set_cudagraph_sizes()
|
||||
|
||||
else:
|
||||
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
|
||||
|
||||
@@ -1171,8 +1169,8 @@ class VllmConfig:
|
||||
)
|
||||
|
||||
if self.compilation_config.pass_config.enable_sp:
|
||||
# With pipeline parallelism or dynamo partitioning,
|
||||
# native rms norm tracing errors due to incorrect residual shape.
|
||||
# With pipeline parallelism, native rms norm tracing errors due to
|
||||
# incorrect residual shape.
|
||||
# Use custom rms norm to unblock. In the future,
|
||||
# the pass will operate on higher-level IR to avoid the issue.
|
||||
# TODO: https://github.com/vllm-project/vllm/issues/27894
|
||||
@@ -1183,24 +1181,15 @@ class VllmConfig:
|
||||
self.compilation_config.mode,
|
||||
)
|
||||
|
||||
is_fullgraph = (
|
||||
self.compilation_config.use_inductor_graph_partition
|
||||
or len(self.compilation_config.splitting_ops or []) == 0
|
||||
)
|
||||
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
|
||||
if self.parallel_config.pipeline_parallel_size > 1:
|
||||
if "-rms_norm" not in self.compilation_config.custom_ops:
|
||||
self.compilation_config.custom_ops.append("+rms_norm")
|
||||
else:
|
||||
regime = (
|
||||
"Dynamo partition"
|
||||
if not is_fullgraph
|
||||
else "pipeline parallelism"
|
||||
)
|
||||
logger.warning_once(
|
||||
"Sequence parallelism not supported with "
|
||||
"native rms_norm when using %s, "
|
||||
"this will likely lead to an error.",
|
||||
regime,
|
||||
"pipeline parallelism",
|
||||
)
|
||||
|
||||
# final check of cudagraph mode after all possible updates
|
||||
@@ -1212,9 +1201,9 @@ class VllmConfig:
|
||||
and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() # noqa: E501
|
||||
):
|
||||
logger.warning_once(
|
||||
"No piecewise cudagraph for executing cascade attention."
|
||||
" Will fall back to eager execution if a batch runs "
|
||||
"into cascade attentions."
|
||||
"No piecewise cudagraph for executing cascade attention. "
|
||||
"Will fall back to eager execution if a batch runs into "
|
||||
"cascade attentions."
|
||||
)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():
|
||||
|
||||
+8
-15
@@ -519,12 +519,8 @@ def is_residual_scattered_for_sp(
|
||||
"""Check if the residual tensor is scattered for sequence parallelism.
|
||||
|
||||
The residual tensor is scattered across tensor parallel ranks when sequence
|
||||
parallelism and tensor parallelism is enabled.
|
||||
|
||||
This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
|
||||
- In full-graph compilation mode (no splitting ops or using inductor graph
|
||||
partition), SP is always applied
|
||||
- Otherwise, SP is only applied for specific shapes in compile_sizes
|
||||
parallelism and tensor parallelism is enabled. SP is only supported in
|
||||
full-graph compilation mode.
|
||||
"""
|
||||
if not vllm_config.compilation_config.pass_config.enable_sp:
|
||||
return False
|
||||
@@ -534,16 +530,13 @@ def is_residual_scattered_for_sp(
|
||||
if tp == 1:
|
||||
return False
|
||||
|
||||
assert (
|
||||
vllm_config.compilation_config.use_inductor_graph_partition
|
||||
or not vllm_config.compilation_config.splitting_ops
|
||||
), "Sequence parallelism requires full-graph compilation"
|
||||
|
||||
# When sequence parallelism is enabled, we always pad num_input_tokens
|
||||
# to be a multiple of tensor_parallel_size (tp) earlier.
|
||||
assert num_input_tokens % tp == 0
|
||||
|
||||
if (
|
||||
not vllm_config.compilation_config.splitting_ops
|
||||
or vllm_config.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return True
|
||||
compile_sizes = vllm_config.compilation_config.compile_sizes
|
||||
if compile_sizes is None:
|
||||
return False
|
||||
return num_input_tokens in compile_sizes
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user