Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 124 additions & 42 deletions src/optimization/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,10 @@ def call_sage_attn_2_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, m
return out.to(out_dtype) if out.dtype != out_dtype else out


# Track if we've logged SA3 grouping info (to avoid spamming logs)
_sa3_grouping_logged = False


@torch._dynamo.disable
def call_sage_attn_3_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, **kwargs):
"""
Expand All @@ -453,8 +457,12 @@ def call_sage_attn_3_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, m
SageAttention 3 / Blackwell provides maximum performance on RTX 50xx series GPUs.
However, it only supports batched attention (uniform sequence lengths), not varlen.

This wrapper detects uniform-length batches and reshapes accordingly.
For variable-length sequences, it automatically falls back to SageAttention 2.
This wrapper handles variable-length sequences by grouping them by length and
processing each group separately with SA3, maintaining batch efficiency for
sequences of the same length. This is particularly important for window attention
where boundary windows may have different sizes than interior windows.

If the grouping approach fails, it falls back to SageAttention 2 if available.

This function is excluded from torch.compile because:
1. SageAttention is a C++ extension that can't be compiled anyway
Expand All @@ -475,6 +483,8 @@ def call_sage_attn_3_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, m
Returns:
Attention output tensor (total_seq, heads, head_dim)
"""
global _sa3_grouping_logged

if not SAGE_ATTN_3_AVAILABLE:
raise ImportError("SageAttention 3 (Blackwell) is not available")

Expand All @@ -484,35 +494,13 @@ def call_sage_attn_3_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, m
if torch.is_tensor(max_seqlen_k):
max_seqlen_k = int(max_seqlen_k.item())

# Check if all sequences have uniform length (required for SA3 batched API)
# SA3/Blackwell uses batched attention, not varlen, so we need uniform lengths
# Compute sequence lengths for each batch item
seq_lens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seq_lens_k = cu_seqlens_k[1:] - cu_seqlens_k[:-1]

uniform_q = (seq_lens_q == seq_lens_q[0]).all()
uniform_k = (seq_lens_k == seq_lens_k[0]).all()

if not (uniform_q and uniform_k):
# Fall back to SA2 for variable-length sequences
# This is expected behavior - SA3 Blackwell doesn't support varlen natively
if SAGE_ATTN_2_AVAILABLE:
return call_sage_attn_2_varlen(
q, k, v, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, **kwargs
)
raise RuntimeError(
"SageAttention 3 (Blackwell) requires uniform sequence lengths, "
"and SageAttention 2 is not available as fallback. "
"Please install sageattention package or use flash_attn/sdpa instead."
)

# Extract batch dimensions
batch_size = len(cu_seqlens_q) - 1
seq_len_q = int(seq_lens_q[0].item())
seq_len_k = int(seq_lens_k[0].item())
heads = q.shape[1]
dim = q.shape[2]

# SageAttention requires half precision (fp16/bf16)
out_dtype = q.dtype
half_dtypes = (torch.float16, torch.bfloat16)
Expand All @@ -526,23 +514,117 @@ def call_sage_attn_3_varlen(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, m
k = k.to(torch.bfloat16)
v = v.to(torch.bfloat16)

# Reshape varlen (total_seq, heads, dim) -> batched (batch, seq, heads, dim)
q_batched = q.view(batch_size, seq_len_q, heads, dim)
k_batched = k.view(batch_size, seq_len_k, heads, dim)
v_batched = v.view(batch_size, seq_len_k, heads, dim)

# SA3/Blackwell expects (batch, heads, seq, dim) layout
q_batched = q_batched.transpose(1, 2) # (batch, heads, seq, dim)
k_batched = k_batched.transpose(1, 2)
v_batched = v_batched.transpose(1, 2)

# Call SA3 Blackwell
out = sageattn_blackwell(q_batched, k_batched, v_batched, per_block_mean=False)

# Reshape back to varlen format (total_seq, heads, dim)
out = out.transpose(1, 2).reshape(-1, heads, dim).contiguous()
heads = q.shape[1]
dim = q.shape[2]
batch_size = len(cu_seqlens_q) - 1

return out.to(out_dtype) if out.dtype != out_dtype else out
if uniform_q and uniform_k:
# Fast path: all sequences have uniform length, can batch directly
seq_len_q = int(seq_lens_q[0].item())
seq_len_k = int(seq_lens_k[0].item())

# Reshape varlen (total_seq, heads, dim) -> batched (batch, seq, heads, dim)
q_batched = q.view(batch_size, seq_len_q, heads, dim)
k_batched = k.view(batch_size, seq_len_k, heads, dim)
v_batched = v.view(batch_size, seq_len_k, heads, dim)

# SA3/Blackwell expects (batch, heads, seq, dim) layout
q_batched = q_batched.transpose(1, 2)
k_batched = k_batched.transpose(1, 2)
v_batched = v_batched.transpose(1, 2)

# Call SA3 Blackwell
out = sageattn_blackwell(q_batched, k_batched, v_batched, per_block_mean=False)

# Reshape back to varlen format (total_seq, heads, dim)
out = out.transpose(1, 2).reshape(-1, heads, dim).contiguous()

return out.to(out_dtype) if out.dtype != out_dtype else out

# Variable-length path: group sequences by their (q_len, k_len) and process each group
# This handles window attention where boundary windows have different sizes
try:
# Create length pairs for grouping
len_pairs = torch.stack([seq_lens_q, seq_lens_k], dim=1) # (batch, 2)

# Find unique length combinations and group indices
unique_pairs, inverse_indices = torch.unique(len_pairs, dim=0, return_inverse=True)
num_groups = len(unique_pairs)

# Log grouping info once per session
if not _sa3_grouping_logged:
print(f"🔄 SA3 varlen: Using grouped attention for {batch_size} sequences in {num_groups} length-groups")
_sa3_grouping_logged = True

# Pre-allocate output tensor
output = torch.empty_like(q)

# Process each group of sequences with the same lengths
for group_idx in range(num_groups):
# Find batch indices belonging to this group
mask = (inverse_indices == group_idx)
batch_indices = torch.where(mask)[0]

if len(batch_indices) == 0:
continue

# Gather Q, K, V for this group
q_group_list = []
k_group_list = []
v_group_list = []
output_slices = []

for idx in batch_indices:
idx = int(idx.item())
q_start = int(cu_seqlens_q[idx].item())
q_end = int(cu_seqlens_q[idx + 1].item())
k_start = int(cu_seqlens_k[idx].item())
k_end = int(cu_seqlens_k[idx + 1].item())

q_group_list.append(q[q_start:q_end])
k_group_list.append(k[k_start:k_end])
v_group_list.append(v[k_start:k_end])
output_slices.append((q_start, q_end))

# Stack into batched tensors: (group_batch, seq, heads, dim)
q_group = torch.stack(q_group_list)
k_group = torch.stack(k_group_list)
v_group = torch.stack(v_group_list)

# SA3/Blackwell expects (batch, heads, seq, dim) layout
q_group = q_group.transpose(1, 2)
k_group = k_group.transpose(1, 2)
v_group = v_group.transpose(1, 2)

# Call SA3 Blackwell
out_group = sageattn_blackwell(q_group, k_group, v_group, per_block_mean=False)

# Reshape back: (batch, heads, seq, dim) -> (batch, seq, heads, dim)
out_group = out_group.transpose(1, 2)

# Scatter results back to output tensor
for i, (q_start, q_end) in enumerate(output_slices):
output[q_start:q_end] = out_group[i]

result = output.to(out_dtype) if output.dtype != out_dtype else output
return result

except Exception as e:
# SA3 grouping failed, try falling back to SA2
if SAGE_ATTN_2_AVAILABLE:
print(f"⚠️ SA3 varlen grouping failed: {e}")
print("🔄 Falling back to SageAttention 2 for this attention call")
return call_sage_attn_2_varlen(
q, k, v, cu_seqlens_q, cu_seqlens_k,
max_seqlen_q, max_seqlen_k, **kwargs
)
else:
# No SA2 available, re-raise the original error with context
raise RuntimeError(
f"SageAttention 3 grouping failed and SageAttention 2 is not available as fallback. "
f"Original error: {e}\n"
f"Please install sageattention package or use flash_attn/sdpa instead."
) from e


# 2. Triton - Required for torch.compile with inductor backend
Expand Down Expand Up @@ -952,4 +1034,4 @@ def __setattr__(self, name, value):
setattr(self.dit_model, name, value)
else:
super().__setattr__(name, value)