Skip to content

Commit 19b1e78

Browse files
authored
fix: disable sequence parallelism for piecewise compilation (#1650)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent f2764f3 commit 19b1e78

5 files changed

Lines changed: 64 additions & 68 deletions

File tree

aphrodite/compilation/passes/fusion/collective_fusion.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -368,13 +368,12 @@ def __init__(self, config: AphroditeConfig) -> None:
368368
self.dump_patterns(config, self.patterns)
369369

370370
def is_applicable_for_range(self, compile_range: Range) -> bool:
371-
# This pass is applied on top of the sequence parallelism pass.
372-
# It inherits the same applicability condition as `SequenceParallelismPass`.
373-
# See `SequenceParallelismPass.is_applicable` for more details.
374-
if not self.compilation_config.splitting_ops or self.compilation_config.use_inductor_graph_partition:
375-
return True
376-
tp_size = get_tensor_model_parallel_world_size()
377-
return bool(compile_range.is_single_size() and compile_range.end % tp_size == 0)
371+
# This pass is applied on top of the sequence parallelism pass,
372+
# which is only supported in fullgraph compilation mode.
373+
assert self.compilation_config.use_inductor_graph_partition or not self.compilation_config.splitting_ops, (
374+
"AsyncTPPass requires full-graph compilation"
375+
)
376+
return True
378377

379378
@AphroditeInductorPass.time_and_log
380379
def __call__(self, graph: fx.Graph) -> None:

aphrodite/compilation/passes/fusion/sequence_parallelism.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -330,21 +330,18 @@ class SequenceParallelismPass(AphroditePatternMatcherPass):
330330
performance.
331331
332332
333-
This pass splits up the residual tensor across TP ranks and hence divides its size.
334-
Because the pattern matcher starts at the end of the graph, the replacement
335-
contains a slice that temporarily conforms the input residual to the correct size.
336-
After all patterns have been matched, we use a NoOpEliminationPass to clean up
337-
what have now become no-op slices.
338-
339-
Note that an older version of the pass did not need this as it operated only on
340-
custom rms_norm and fused_rms_norm_add custom ops which did not complain about
341-
mismatched shapes during replacement. So this approach has the same assumption that
342-
correctness is only maintained if all rms_norm operations are split across ranks.
343-
344-
Correctness-wise, this is approach strictly better than before - before,
345-
the graph was incorrect semantically and shape-wise during the pass.
346-
With this approach there's only semantic incorrectness during the pass.
347-
Both approaches restore a correct graph once all patterns are matched.
333+
This pass is only supported when compiling the whole graph (fullgraph
334+
mode, i.e. using Inductor graph partition or empty splitting_ops).
335+
Piecewise compilation is not supported because the residual tensor
336+
gets split across TP ranks, causing size mismatches at subgraph
337+
boundaries.
338+
339+
This pass splits up the residual tensor across TP ranks and hence
340+
divides its size. Because the pattern matcher starts at the end of
341+
the graph, the replacement contains a slice that temporarily conforms
342+
the input residual to the correct size. After all patterns have been
343+
matched, we use a NoOpEliminationPass to clean up what have now
344+
become no-op slices.
348345
"""
349346

350347
@enable_fake_mode
@@ -397,16 +394,12 @@ def is_applicable_for_range(self, compile_range: Range) -> bool:
397394
and gathering tensors across TP ranks outweighs the benefits.
398395
399396
Returns False (SP disabled) when:
400-
- Using piecewise compilation with non-concrete or TP-indivisible sizes
401397
- min_token_num is None (SP disabled for this device/config)
402398
- The compile range starts below the minimum token threshold
403399
"""
404-
# For piecewise compilation (not using inductor graph partition),
405-
# we need concrete sizes that are divisible by TP for correct splitting
406-
if not self.compilation_config.use_inductor_graph_partition and self.compilation_config.splitting_ops:
407-
tp_size = get_tensor_model_parallel_world_size()
408-
if not compile_range.is_single_size() or compile_range.end % tp_size != 0:
409-
return False
400+
assert self.compilation_config.use_inductor_graph_partition or not self.compilation_config.splitting_ops, (
401+
"SequenceParallelismPass requires full-graph compilation"
402+
)
410403

411404
# min_token_num is None when SP is disabled for this device/config
412405
# (e.g., non-CUDA platform, unsupported GPU, or small hidden_size)

aphrodite/config/aphrodite.py

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -931,19 +931,16 @@ def has_blocked_weights():
931931
)
932932
self.compilation_config.cudagraph_mode = CUDAGraphMode.NONE
933933

934-
# async tp is built on top of sequence parallelism
935-
# and requires it to be enabled.
936-
if self.compilation_config.pass_config.fuse_gemm_comms:
937-
self.compilation_config.pass_config.enable_sp = True
938-
if self.compilation_config.pass_config.enable_sp:
934+
# async tp is built on top of sequence parallelism and requires it.
935+
pass_config = self.compilation_config.pass_config
936+
if pass_config.fuse_gemm_comms:
937+
pass_config.enable_sp = True
938+
if pass_config.enable_sp:
939939
if self.parallel_config.tensor_parallel_size == 1:
940940
logger.warning("Sequence Parallelism requires TP>1, disabling")
941-
self.compilation_config.pass_config.enable_sp = False
942-
self.compilation_config.pass_config.fuse_gemm_comms = False
941+
pass_config.enable_sp = False
942+
pass_config.fuse_gemm_comms = False
943943
else:
944-
# Compute SP threshold early; disable if None (model too
945-
# small for SP to be beneficial).
946-
pass_config = self.compilation_config.pass_config
947944
if pass_config.sp_min_token_num is None:
948945
from aphrodite.compilation.passes.fusion.sequence_parallelism import (
949946
get_sequence_parallelism_threshold,
@@ -963,8 +960,8 @@ def has_blocked_weights():
963960
"threshold heuristic, disabling. To force SP, "
964961
"set pass_config.sp_min_token_num manually."
965962
)
966-
self.compilation_config.pass_config.enable_sp = False
967-
self.compilation_config.pass_config.fuse_gemm_comms = False
963+
pass_config.enable_sp = False
964+
pass_config.fuse_gemm_comms = False
968965

969966
from aphrodite.utils.torch_utils import HAS_OPAQUE_TYPE
970967

@@ -1102,8 +1099,8 @@ def has_blocked_weights():
11021099
)
11031100

11041101
if self.compilation_config.pass_config.enable_sp:
1105-
# With pipeline parallelism or dynamo partitioning,
1106-
# native rms norm tracing errors due to incorrect residual shape.
1102+
# With pipeline parallelism, native rms norm tracing errors due to
1103+
# incorrect residual shape.
11071104
# Use custom rms norm to unblock. In the future,
11081105
# the pass will operate on higher-level IR to avoid the issue.
11091106
# TODO: https://github.com/aphrodite-project/aphrodite/issues/27894
@@ -1113,20 +1110,15 @@ def has_blocked_weights():
11131110
self.compilation_config.mode,
11141111
)
11151112

1116-
is_fullgraph = (
1117-
self.compilation_config.use_inductor_graph_partition
1118-
or len(self.compilation_config.splitting_ops or []) == 0
1119-
)
1120-
if self.parallel_config.pipeline_parallel_size > 1 or not is_fullgraph:
1113+
if self.parallel_config.pipeline_parallel_size > 1:
11211114
if "-rms_norm" not in self.compilation_config.custom_ops:
11221115
self.compilation_config.custom_ops.append("+rms_norm")
11231116
else:
1124-
regime = "Dynamo partition" if not is_fullgraph else "pipeline parallelism"
11251117
logger.warning_once(
11261118
"Sequence parallelism not supported with "
11271119
"native rms_norm when using %s, "
11281120
"this will likely lead to an error.",
1129-
regime,
1121+
"pipeline parallelism",
11301122
)
11311123

11321124
# final check of cudagraph mode after all possible updates
@@ -1138,9 +1130,9 @@ def has_blocked_weights():
11381130
and not self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs() # noqa: E501
11391131
):
11401132
logger.warning_once(
1141-
"No piecewise cudagraph for executing cascade attention."
1142-
" Will fall back to eager execution if a batch runs "
1143-
"into cascade attentions."
1133+
"No piecewise cudagraph for executing cascade attention. "
1134+
"Will fall back to eager execution if a batch runs into "
1135+
"cascade attentions."
11441136
)
11451137

11461138
if self.compilation_config.cudagraph_mode.requires_piecewise_compilation():

aphrodite/config/compilation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,25 @@ def set_splitting_ops_for_v1(self, all2all_backend: str, data_parallel_size: int
11051105
self.cudagraph_mode = CUDAGraphMode.FULL
11061106
self.splitting_ops = []
11071107

1108+
if (
1109+
not self.use_inductor_graph_partition
1110+
and (self.pass_config.enable_sp or self.pass_config.fuse_gemm_comms)
1111+
and self.splitting_ops
1112+
):
1113+
logger.warning_once(
1114+
"Sequence parallelism requires full-graph compilation when "
1115+
"use_inductor_graph_partition is off. Setting splitting_ops "
1116+
"to an empty list to preserve SP and async TP."
1117+
)
1118+
self.splitting_ops = []
1119+
if self.cudagraph_mode.has_piecewise_cudagraphs():
1120+
logger.warning_once(
1121+
"Sequence parallelism is incompatible with piecewise "
1122+
"cudagraph when use_inductor_graph_partition is off. "
1123+
"Setting cudagraph_mode to FULL."
1124+
)
1125+
self.cudagraph_mode = CUDAGraphMode.FULL
1126+
11081127
# Disable CUDA graphs for DeepEP high-throughput since its not CG compatible
11091128
if (
11101129
all2all_backend == "deepep_high_throughput"

aphrodite/v1/worker/utils.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -491,12 +491,8 @@ def is_residual_scattered_for_sp(aphrodite_config: AphroditeConfig, num_input_to
491491
"""Check if the residual tensor is scattered for sequence parallelism.
492492
493493
The residual tensor is scattered across tensor parallel ranks when sequence
494-
parallelism and tensor parallelism is enabled.
495-
496-
This follows the same logic as SequenceParallelismPass.is_applicable_for_range():
497-
- In full-graph compilation mode (no splitting ops or using inductor graph
498-
partition), SP is always applied
499-
- Otherwise, SP is only applied for specific shapes in compile_sizes
494+
parallelism and tensor parallelism is enabled. SP is only supported in
495+
full-graph compilation mode.
500496
"""
501497
if not aphrodite_config.compilation_config.pass_config.enable_sp:
502498
return False
@@ -506,16 +502,13 @@ def is_residual_scattered_for_sp(aphrodite_config: AphroditeConfig, num_input_to
506502
if tp == 1:
507503
return False
508504

505+
assert (
506+
aphrodite_config.compilation_config.use_inductor_graph_partition
507+
or not aphrodite_config.compilation_config.splitting_ops
508+
), "Sequence parallelism requires full-graph compilation"
509+
509510
# When sequence parallelism is enabled, we always pad num_input_tokens
510511
# to be a multiple of tensor_parallel_size (tp) earlier.
511512
assert num_input_tokens % tp == 0
512513

513-
if (
514-
not aphrodite_config.compilation_config.splitting_ops
515-
or aphrodite_config.compilation_config.use_inductor_graph_partition
516-
):
517-
return True
518-
compile_sizes = aphrodite_config.compilation_config.compile_sizes
519-
if compile_sizes is None:
520-
return False
521-
return num_input_tokens in compile_sizes
514+
return True

0 commit comments

Comments
 (0)