-
Notifications
You must be signed in to change notification settings - Fork 7k
feat: support ring attention with arbitrary KV sequence lengths #13545
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
3b920de
894e289
09b3995
04df747
5aeb97f
d181043
0ca0e42
1215d5a
8e1529c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -371,6 +371,40 @@ We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention and Ulys | |||||||||||
|
|
||||||||||||
| From the above table, it is clear that Ulysses Anything Attention offers better compatibility with arbitrary sequence lengths while maintaining the same performance as the standard Ulysses Attention. | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| ### Ring Anything Attention | ||||||||||||
|
|
||||||||||||
| The default Ring Attention requires the sequence length of hidden states to be evenly divisible across the ring degree. Ring Anything Attention is a variant of Ring Attention that supports arbitrary (non-evenly divisible) sequence lengths. It pads each rank's local KV to the global maximum sequence length, all-gathers the padded KV buffer, and slices back to each rank's true length before running attention. | ||||||||||||
|
|
||||||||||||
| [`ContextParallelConfig`] supports Ring Anything Attention by specifying both `ring_degree` and `ring_anything`. Please note that Ring Anything Attention is not currently supported by Unified Attention. Pass the [`ContextParallelConfig`] with `ring_degree` set to bigger than 1 and `ring_anything=True` to [`~ModelMixin.enable_parallelism`]. | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
|
||||||||||||
| ```py | ||||||||||||
| pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ring_degree=2, ring_anything=True)) | ||||||||||||
| ``` | ||||||||||||
|
|
||||||||||||
| > [!TIP] To avoid multiple forced CUDA sync caused by H2D and D2H transfers, please add the **gloo** backend in `init_process_group`. This will significantly reduce communication latency. | ||||||||||||
|
sayakpaul marked this conversation as resolved.
Outdated
|
||||||||||||
|
|
||||||||||||
| > [!NOTE] | ||||||||||||
| > Backward is not implemented yet; this mode is currently inference-only. | ||||||||||||
| > `attn_mask` must be `None`; non-None attention masks are not supported. | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be better after the first paragraph
Suggested change
|
||||||||||||
|
|
||||||||||||
| We ran a benchmark for FLUX.1-dev with Ulysses, Ring, Unified Attention, Ulysses Anything Attention, and Ring Anything Attention on a node of 4 RTX 4090 (48GB) GPUs. The results are summarized as follows: | ||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestions! I’ve incorporated all the documentation refinements (including the note formatting and benchmark descriptions) in the latest commit. |
||||||||||||
|
|
||||||||||||
| | CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) | Shape (HxW)| | ||||||||||||
| |--------------------|------------------|-------------|------------------|------------| | ||||||||||||
| | ulysses | 259.07 | 3.86 | 33.83 | 1024x1024 | | ||||||||||||
| | ring | 338.98 | 2.95 | 33.83 | 1024x1024 | | ||||||||||||
| | unified_balanced | 321.54 | 3.11 | 33.83 | 1024x1024 | | ||||||||||||
| | ulysses_anything | 259.07 | 3.86 | 33.83 | 1024x1024 | | ||||||||||||
| | ring_anything | 340.14 | 2.94 | 33.83 | 1024x1024 | | ||||||||||||
| | ulysses | failed | failed | failed | 1008x1008 | | ||||||||||||
| | ring | failed | failed | failed | 1008x1008 | | ||||||||||||
| | unified_balanced | failed | failed | failed | 1008x1008 | | ||||||||||||
| | ulysses_anything | 253.16 | 3.95 | 33.75 | 1008x1008 | | ||||||||||||
| | ring_anything | 335.57 | 2.98 | 33.75 | 1008x1008 | | ||||||||||||
|
|
||||||||||||
| From the above table, Ring Anything Attention offers compatibility with arbitrary sequence lengths while maintaining performance comparable to the standard Ring Attention. | ||||||||||||
|
sayakpaul marked this conversation as resolved.
|
||||||||||||
|
|
||||||||||||
| ### parallel_config | ||||||||||||
|
|
||||||||||||
| Pass `parallel_config` during model initialization to enable context parallelism. | ||||||||||||
|
|
||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2076,6 +2076,119 @@ def backward( | |
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None | ||
|
|
||
|
|
||
| class TemplatedRingAnythingAttention(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
| ctx: torch.autograd.function.FunctionCtx, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attn_mask: torch.Tensor | None, | ||
| dropout_p: float, | ||
| is_causal: bool, | ||
| scale: float | None, | ||
| enable_gqa: bool, | ||
| return_lse: bool, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config: "ParallelConfig" | None = None, | ||
| ): | ||
| # Ring attention for arbitrary sequence lengths. | ||
| if attn_mask is not None: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like a pretty big limitation no? This would make it incompatible with models like QwenImage, right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's fair. For this PR I'm keeping |
||
| raise ValueError( | ||
| "TemplatedRingAnythingAttention does not support non-None attn_mask: " | ||
| "non-uniform sequence lengths across ranks make cross-rank mask slicing ambiguous." | ||
| ) | ||
| ring_mesh = _parallel_config.context_parallel_config._ring_mesh | ||
| group = ring_mesh.get_group() | ||
| rank = _parallel_config.context_parallel_config._ring_local_rank | ||
| world_size = _parallel_config.context_parallel_config.ring_degree | ||
| next_rank = (rank + 1) % world_size | ||
| prev_out = prev_lse = None | ||
|
|
||
| ctx.forward_op = forward_op | ||
| ctx.backward_op = backward_op | ||
| ctx.q_shape = query.shape | ||
| ctx.kv_shape = key.shape | ||
| ctx._parallel_config = _parallel_config | ||
|
|
||
| kv_seq_len = key.shape[1] # local S_KV (may differ across ranks) | ||
| all_kv_seq_lens = gather_size_by_comm(kv_seq_len, group) | ||
| s_max = max(all_kv_seq_lens) | ||
|
|
||
| # Padding is applied on the sequence dimension (dim=1) at the end. | ||
| def pad_to_s_max(t: torch.Tensor) -> torch.Tensor: | ||
| pad_len = s_max - t.shape[1] | ||
| if pad_len == 0: | ||
| return t | ||
| pad_shape = list(t.shape) | ||
| pad_shape[1] = pad_len | ||
| return torch.cat([t, torch.zeros(pad_shape, dtype=t.dtype, device=t.device)], dim=1) | ||
|
sayakpaul marked this conversation as resolved.
Outdated
|
||
|
|
||
| key_padded = pad_to_s_max(key) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would add a small explainer comment.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a comment here. |
||
| value_padded = pad_to_s_max(value) | ||
|
|
||
| kv_buffer = torch.cat([key_padded.flatten(), value_padded.flatten()]).contiguous() | ||
| kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=group) | ||
| kv_buffer = kv_buffer.chunk(world_size) | ||
|
|
||
| # numel per-rank in the padded layout | ||
| kv_padded_numel = key_padded.numel() | ||
|
|
||
| for i in range(world_size): | ||
| if i > 0: | ||
| true_seq_len = all_kv_seq_lens[next_rank] | ||
| kv = kv_buffer[next_rank] | ||
| # Reshape to padded shape, then slice to true sequence length | ||
| key = kv[:kv_padded_numel].reshape_as(key_padded)[:, :true_seq_len] | ||
| value = kv[kv_padded_numel:].reshape_as(value_padded)[:, :true_seq_len] | ||
| next_rank = (next_rank + 1) % world_size | ||
| else: | ||
| # i == 0: use local (unpadded) key/value | ||
| key = key_padded[:, :kv_seq_len] | ||
| value = value_padded[:, :kv_seq_len] | ||
|
|
||
| out, lse = forward_op( | ||
| ctx, | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| True, | ||
|
Comment on lines
+2141
to
+2164
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @claude can we use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for pointing this out. I tried removing the branch, but it regressed on PCIe since the local chunk has to be read back from the all-gather buffer. Keeping the local fast path for now, but happy to revisit if there’s a cleaner way to express it. |
||
| _save_ctx=i == 0, | ||
| _parallel_config=_parallel_config, | ||
| ) | ||
|
|
||
| if _parallel_config.context_parallel_config.convert_to_fp32: | ||
| out = out.to(torch.float32) | ||
| lse = lse.to(torch.float32) | ||
|
|
||
| if is_torch_version("<", "2.9.0"): | ||
| lse = lse.unsqueeze(-1) | ||
| if prev_out is not None: | ||
| out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out) | ||
| lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse) | ||
| prev_out = out | ||
| prev_lse = lse | ||
|
|
||
| out = out.to(query.dtype) | ||
| lse = lse.squeeze(-1) | ||
|
|
||
| return (out, lse) if return_lse else out | ||
|
|
||
| @staticmethod | ||
| def backward( | ||
| ctx: torch.autograd.function.FunctionCtx, | ||
| grad_out: torch.Tensor, | ||
| *args, | ||
| ): | ||
| raise NotImplementedError("Backward pass for Ring Anything Attention in diffusers is not implemented yet.") | ||
|
|
||
|
|
||
| class TemplatedUlyssesAnythingAttention(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( | ||
|
|
@@ -2254,20 +2367,36 @@ def _templated_context_parallel_attention( | |
| _parallel_config, | ||
| ) | ||
| elif _parallel_config.context_parallel_config.ring_degree > 1: | ||
| return TemplatedRingAttention.apply( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| return_lse, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config, | ||
| ) | ||
| if _parallel_config.context_parallel_config.ring_anything: | ||
| return TemplatedRingAnythingAttention.apply( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| return_lse, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config, | ||
| ) | ||
| else: | ||
| return TemplatedRingAttention.apply( | ||
| query, | ||
| key, | ||
| value, | ||
| attn_mask, | ||
| dropout_p, | ||
| is_causal, | ||
| scale, | ||
| enable_gqa, | ||
| return_lse, | ||
| forward_op, | ||
| backward_op, | ||
| _parallel_config, | ||
| ) | ||
| elif _parallel_config.context_parallel_config.ulysses_degree > 1: | ||
| if _parallel_config.context_parallel_config.ulysses_anything: | ||
| # For Any sequence lengths and Any head num support | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.