Skip to content

Commit 636d3b7

Browse files
[misc] attention hot-path cleanup + denoising loop hoists (hao-ai-lab#1272)
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent e3a5c69 commit 636d3b7

6 files changed

Lines changed: 117 additions & 88 deletions

File tree

fastvideo/attention/backends/bsa_attn.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -358,12 +358,14 @@ def _flash_attn_single_mask(
358358
flat_k = torch.cat(k_list, dim=0)
359359
flat_v = torch.cat(v_list, dim=0)
360360

361+
# Compute max_seqlen_k from the Python list before moving to GPU to
362+
# avoid a `.item()` round-trip that would force a host/device sync.
363+
max_seqlen_q = Sq
364+
max_seqlen_k = max(b - a for a, b in zip(cu_seqlens_k[:-1], cu_seqlens_k[1:], strict=False))
365+
361366
cu_seqlens_q_t = torch.tensor(cu_seqlens_q, dtype=torch.int32, device=device)
362367
cu_seqlens_k_t = torch.tensor(cu_seqlens_k, dtype=torch.int32, device=device)
363368

364-
max_seqlen_q = Sq
365-
max_seqlen_k = int((cu_seqlens_k_t[1:] - cu_seqlens_k_t[:-1]).max().item())
366-
367369
orig_dtype = flat_q.dtype
368370
compute_dtype = orig_dtype
369371
if compute_dtype not in (torch.float16, torch.bfloat16):
@@ -445,12 +447,14 @@ def _flash_attn_single_head(
445447
flat_k = torch.cat(k_list, dim=0)
446448
flat_v = torch.cat(v_list, dim=0)
447449

450+
# Compute max_seqlen_k from the Python list before moving to GPU to
451+
# avoid a `.item()` round-trip that would force a host/device sync.
452+
max_seqlen_q = Sq
453+
max_seqlen_k = max(b - a for a, b in zip(cu_seqlens_k[:-1], cu_seqlens_k[1:], strict=False))
454+
448455
cu_seqlens_q_t = torch.tensor(cu_seqlens_q, dtype=torch.int32, device=device)
449456
cu_seqlens_k_t = torch.tensor(cu_seqlens_k, dtype=torch.int32, device=device)
450457

451-
max_seqlen_q = Sq
452-
max_seqlen_k = int((cu_seqlens_k_t[1:] - cu_seqlens_k_t[:-1]).max().item())
453-
454458
orig_dtype = flat_q.dtype
455459
compute_dtype = orig_dtype
456460
if compute_dtype not in (torch.float16, torch.bfloat16):

fastvideo/attention/backends/flash_attn.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,24 @@ def get_builder_cls() -> type["AttentionMetadataBuilder"]:
5757
raise NotImplementedError
5858

5959

60+
def _key_padding_mask_from_attn_mask(attn_mask: torch.Tensor, key_len: int) -> torch.Tensor:
61+
# Normalize attn_mask to [B, key_len] where True means valid token.
62+
if attn_mask.dim() == 4:
63+
attn_mask = attn_mask[:, 0, 0, :]
64+
elif attn_mask.dim() == 3:
65+
attn_mask = attn_mask[:, 0, :]
66+
elif attn_mask.dim() != 2:
67+
raise ValueError(f"Unsupported attn_mask shape for FLASH_ATTN: {attn_mask.shape}")
68+
69+
# SDPA additive mask convention: valid=0, masked=-inf/large negative.
70+
key_padding_mask = attn_mask if attn_mask.dtype == torch.bool else attn_mask >= 0
71+
72+
if key_padding_mask.shape[-1] != key_len:
73+
raise ValueError("Invalid key padding mask length for FLASH_ATTN: "
74+
f"expected {key_len}, got {key_padding_mask.shape[-1]}")
75+
return key_padding_mask
76+
77+
6078
@dataclass
6179
class FlashAttnMetadata(AttentionMetadata):
6280
current_timestep: int
@@ -101,24 +119,6 @@ def forward(
101119
value: torch.Tensor,
102120
attn_metadata: FlashAttnMetadata,
103121
):
104-
105-
def _key_padding_mask_from_attn_mask(attn_mask: torch.Tensor, key_len: int) -> torch.Tensor:
106-
# Normalize attn_mask to [B, key_len] where True means valid token.
107-
if attn_mask.dim() == 4:
108-
attn_mask = attn_mask[:, 0, 0, :]
109-
elif attn_mask.dim() == 3:
110-
attn_mask = attn_mask[:, 0, :]
111-
elif attn_mask.dim() != 2:
112-
raise ValueError(f"Unsupported attn_mask shape for FLASH_ATTN: {attn_mask.shape}")
113-
114-
# SDPA additive mask convention: valid=0, masked=-inf/large negative.
115-
key_padding_mask = attn_mask if attn_mask.dtype == torch.bool else attn_mask >= 0
116-
117-
if key_padding_mask.shape[-1] != key_len:
118-
raise ValueError("Invalid key padding mask length for FLASH_ATTN: "
119-
f"expected {key_len}, got {key_padding_mask.shape[-1]}")
120-
return key_padding_mask
121-
122122
if (attn_metadata is not None and hasattr(attn_metadata, "attn_mask") and attn_metadata.attn_mask is not None):
123123
from fastvideo.attention.utils.flash_attn_no_pad import (
124124
flash_attn_no_pad,
@@ -136,6 +136,7 @@ def _key_padding_mask_from_attn_mask(attn_mask: torch.Tensor, key_len: int) -> t
136136
device=query.device,
137137
)
138138
key_padding_mask = _key_padding_mask_from_attn_mask(attn_mask, key.shape[1]).to(device=key.device)
139+
139140
return flash_attn_varlen_qk_no_pad(
140141
query,
141142
key,
@@ -148,9 +149,8 @@ def _key_padding_mask_from_attn_mask(attn_mask: torch.Tensor, key_len: int) -> t
148149
)
149150

150151
qkv = torch.stack([query, key, value], dim=2)
151-
152-
attn_mask = F.pad(attn_mask, (qkv.shape[1] - attn_mask.shape[1], 0), value=True)
153-
output = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0, softmax_scale=None)
152+
attn_mask_padded = F.pad(attn_mask, (qkv.shape[1] - attn_mask.shape[1], 0), value=True)
153+
output = flash_attn_no_pad(qkv, attn_mask_padded, causal=False, dropout_p=0, softmax_scale=None)
154154
else:
155155
output = flash_attn_func(
156156
query, # type: ignore[no-untyped-call]

fastvideo/attention/backends/sla.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,9 +307,12 @@ def forward(
307307
# Sparse attention
308308
o_s = _attention.apply(q, k, v, sparse_map, lut, real_topk, self.BLKQ, self.BLKK)
309309

310-
# Linear attention with feature maps
311-
q_linear = self.feature_map_q(q).contiguous().to(self.dtype)
312-
k_linear = self.feature_map_k(k).contiguous().to(self.dtype)
310+
# Linear attention with feature maps. Note: softmax / elu / relu
311+
# are elementwise and preserve layout, so the inputs are already
312+
# contiguous from the transpose-contiguous above — no need to
313+
# call .contiguous() again here.
314+
q_linear = self.feature_map_q(q).to(self.dtype)
315+
k_linear = self.feature_map_k(k).to(self.dtype)
313316
o_l = self._calc_linear_attention(q_linear, k_linear, v)
314317

315318
# Project linear attention output and combine
@@ -539,9 +542,10 @@ def forward(
539542
False, 1, scale, 0)
540543
# ========== END SPARGE ==========
541544

542-
# Linear attention with feature maps
543-
q_linear = self.feature_map_q(q).contiguous().to(self.dtype)
544-
k_linear = self.feature_map_k(k).contiguous().to(self.dtype)
545+
# Linear attention with feature maps (see SLAAttentionImpl.forward
546+
# for why .contiguous() is unnecessary here).
547+
q_linear = self.feature_map_q(q).to(self.dtype)
548+
k_linear = self.feature_map_k(k).to(self.dtype)
545549
o_l = self._calc_linear_attention(q_linear, k_linear, v)
546550

547551
# Project linear attention output and combine

fastvideo/attention/backends/video_sparse_attn.py

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,18 @@ class VideoSparseAttentionMetadata(AttentionMetadata):
140140
reverse_tile_partition_indices: torch.LongTensor
141141
variable_block_sizes: torch.LongTensor
142142
non_pad_index: torch.LongTensor
143+
# Precomputed fancy index that fuses ``x[:, non_pad_index][:, reverse_tile_partition_indices]``
144+
# in postprocess_output(). Avoids materializing the intermediate
145+
# ``[B, len(non_pad_index), H, D]`` tensor on every layer.
146+
untile_combined_index: torch.LongTensor
147+
# Per-step shared padded buffer used by tile(). Lazily populated on
148+
# the first layer's call and reused by every subsequent VSA layer in
149+
# the same denoising step. Scoping to metadata (not class/instance)
150+
# makes the reuse thread-safe across concurrent requests and keeps
151+
# the "pad positions are zero" invariant trivially true (the buffer
152+
# is freshly zeroed alongside ``non_pad_index`` so the index set
153+
# cannot drift between calls).
154+
tile_buf: torch.Tensor | None = None
143155

144156

145157
class VideoSparseAttentionMetadataBuilder(AttentionMetadataBuilder):
@@ -171,6 +183,7 @@ def build( # type: ignore
171183
reverse_tile_partition_indices = get_reverse_tile_partition_indices(dit_seq_shape, VSA_TILE_SIZE, device)
172184
variable_block_sizes = construct_variable_block_sizes(dit_seq_shape, num_tiles, device)
173185
non_pad_index = get_non_pad_index(variable_block_sizes, math.prod(VSA_TILE_SIZE))
186+
untile_combined_index = non_pad_index[reverse_tile_partition_indices]
174187

175188
return VideoSparseAttentionMetadata(
176189
current_timestep=current_timestep,
@@ -181,7 +194,8 @@ def build( # type: ignore
181194
tile_partition_indices=tile_partition_indices, # type: ignore
182195
reverse_tile_partition_indices=reverse_tile_partition_indices,
183196
variable_block_sizes=variable_block_sizes,
184-
non_pad_index=non_pad_index)
197+
non_pad_index=non_pad_index,
198+
untile_combined_index=untile_combined_index)
185199

186200

187201
class VideoSparseAttentionImpl(AttentionImpl):
@@ -200,37 +214,59 @@ def __init__(
200214
sp_group = get_sp_group()
201215
self.sp_size = sp_group.world_size
202216

203-
def tile(self, x: torch.Tensor, num_tiles: list[int], tile_partition_indices: torch.LongTensor,
204-
non_pad_index: torch.LongTensor) -> torch.Tensor:
217+
def tile(self, x: torch.Tensor, attn_metadata: VideoSparseAttentionMetadata) -> torch.Tensor:
218+
"""Tile ``x`` into ``attn_metadata.tile_buf`` and return it.
219+
220+
The returned tensor aliases the per-metadata buffer and is only
221+
valid until the next ``tile()`` / ``preprocess_qkv`` call on the
222+
same ``attn_metadata``. Callers must consume (or copy) the
223+
result before invoking another VSA layer with the same metadata.
224+
Today both call sites materialize copies via
225+
``.transpose(...).contiguous()`` inside ``forward()``, so the
226+
contract holds; future callers must preserve it.
227+
"""
228+
num_tiles = attn_metadata.num_tiles
205229
t_padded_size = num_tiles[0] * VSA_TILE_SIZE[0]
206230
h_padded_size = num_tiles[1] * VSA_TILE_SIZE[1]
207231
w_padded_size = num_tiles[2] * VSA_TILE_SIZE[2]
208-
209-
x_padded = torch.zeros((x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1]),
210-
device=x.device,
211-
dtype=x.dtype)
212-
x_padded[:, non_pad_index] = x[:, tile_partition_indices]
213-
return x_padded
214-
215-
def untile(self, x: torch.Tensor, reverse_tile_partition_indices: torch.LongTensor,
216-
non_pad_index: torch.LongTensor) -> torch.Tensor:
217-
x = x[:, non_pad_index][:, reverse_tile_partition_indices]
218-
return x
232+
target_shape = (x.shape[0], t_padded_size * h_padded_size * w_padded_size, x.shape[-2], x.shape[-1])
233+
234+
# Reuse the per-step buffer stashed on metadata (lazily allocated
235+
# on the first VSA layer's call within a denoising step). Pad
236+
# positions are zero from the initial torch.zeros and never
237+
# written to. Scoping to metadata makes reuse safe across
238+
# concurrent requests and keeps the "pad positions are zero"
239+
# invariant trivially true: ``non_pad_index`` is fixed within
240+
# a single metadata instance.
241+
buf = attn_metadata.tile_buf
242+
if (buf is None or buf.shape != target_shape or buf.dtype != x.dtype or buf.device != x.device):
243+
buf = torch.zeros(target_shape, device=x.device, dtype=x.dtype)
244+
attn_metadata.tile_buf = buf
245+
246+
buf[:, attn_metadata.non_pad_index] = x[:, attn_metadata.tile_partition_indices]
247+
return buf
248+
249+
def untile(self, x: torch.Tensor, untile_combined_index: torch.LongTensor) -> torch.Tensor:
250+
# Single fancy index using precomputed combined indices; avoids
251+
# the intermediate ``[B, len(non_pad_index), H, D]`` tensor that
252+
# the two-step ``x[:, non_pad_index][:, reverse_tile_partition_indices]``
253+
# would allocate on every layer.
254+
return x[:, untile_combined_index]
219255

220256
def preprocess_qkv(
221257
self,
222258
qkv: torch.Tensor,
223259
attn_metadata: VideoSparseAttentionMetadata,
224260
) -> torch.Tensor:
225-
return self.tile(qkv, attn_metadata.num_tiles, attn_metadata.tile_partition_indices,
226-
attn_metadata.non_pad_index)
261+
"""Tile QKV; aliasing contract: see ``tile()``."""
262+
return self.tile(qkv, attn_metadata)
227263

228264
def postprocess_output(
229265
self,
230266
output: torch.Tensor,
231267
attn_metadata: VideoSparseAttentionMetadata,
232268
) -> torch.Tensor:
233-
return self.untile(output, attn_metadata.reverse_tile_partition_indices, attn_metadata.non_pad_index)
269+
return self.untile(output, attn_metadata.untile_combined_index)
234270

235271
def forward( # type: ignore[override]
236272
self,

fastvideo/pipelines/stages/conditioning.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,11 @@
66
import torch
77

88
from fastvideo.fastvideo_args import FastVideoArgs
9-
from fastvideo.logger import init_logger
109
from fastvideo.pipelines.pipeline_batch_info import ForwardBatch
1110
from fastvideo.pipelines.stages.base import PipelineStage
1211
from fastvideo.pipelines.stages.validators import StageValidators as V
1312
from fastvideo.pipelines.stages.validators import VerificationResult
1413

15-
logger = init_logger(__name__)
16-
1714

1815
class ConditioningStage(PipelineStage):
1916
"""
@@ -39,31 +36,11 @@ def forward(
3936
Returns:
4037
The batch with applied conditioning.
4138
"""
42-
# TODO!!
43-
if not batch.do_classifier_free_guidance:
44-
return batch
45-
else:
46-
return batch
47-
48-
logger.info("batch.negative_prompt_embeds: %s", batch.negative_prompt_embeds)
49-
logger.info("do_classifier_free_guidance: %s", batch.do_classifier_free_guidance)
50-
logger.info("cfg_scale: %s", batch.guidance_scale)
51-
52-
# Ensure negative prompt embeddings are available
53-
assert batch.negative_prompt_embeds is not None, (
54-
"Negative prompt embeddings are required for classifier-free guidance")
55-
56-
# Concatenate primary embeddings and masks
57-
batch.prompt_embeds = torch.cat([batch.negative_prompt_embeds, batch.prompt_embeds])
58-
if batch.attention_mask is not None:
59-
batch.attention_mask = torch.cat([batch.negative_attention_mask, batch.attention_mask])
60-
61-
# Concatenate secondary embeddings and masks if present
62-
if batch.prompt_embeds_2 is not None:
63-
batch.prompt_embeds_2 = torch.cat([batch.negative_prompt_embeds_2, batch.prompt_embeds_2])
64-
if batch.attention_mask_2 is not None:
65-
batch.attention_mask_2 = torch.cat([batch.negative_attention_mask_2, batch.attention_mask_2])
66-
39+
# Forward is a no-op: CFG is applied via two separate
40+
# transformer forward passes inside DenoisingStage (e.g.
41+
# denoising.py:364-394, :706, :930). The class is kept because
42+
# verify_input / verify_output still validate CFG fields and
43+
# disable CFG when prompt_embeds is empty.
6744
return batch
6845

6946
def verify_input(self, batch: ForwardBatch, fastvideo_args: FastVideoArgs) -> VerificationResult:

fastvideo/pipelines/stages/denoising.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,23 @@ def forward(
212212
trajectory_timesteps: list[torch.Tensor] = []
213213
trajectory_latents: list[torch.Tensor] = []
214214

215+
# Hoisted out of the per-step loop: depends only on inputs that
216+
# are constant across denoising steps.
217+
use_meanflow = getattr(self.transformer.config, "use_meanflow", False)
218+
embedded_cfg_scale = fastvideo_args.pipeline_config.embedded_cfg_scale
219+
if embedded_cfg_scale is not None:
220+
guidance_expand = (torch.tensor(
221+
[embedded_cfg_scale] * latents.shape[0],
222+
dtype=torch.float32,
223+
device=get_local_torch_device(),
224+
).to(target_dtype) * 1000.0)
225+
else:
226+
guidance_expand = None
227+
# V2V padding: zero-filled tensor concatenated with each step's
228+
# latent_model_input. Shape is fixed by latents and is never
229+
# written to, so we allocate once.
230+
v2v_zero_pad = torch.zeros_like(latents) if batch.video_latent is not None else None
231+
215232
# Run denoising loop
216233
with self.progress_bar(total=num_inference_steps) as progress_bar:
217234
for i, t in enumerate(timesteps):
@@ -248,8 +265,7 @@ def forward(
248265
# Expand latents for V2V/I2V
249266
latent_model_input = latents.to(target_dtype)
250267
if batch.video_latent is not None:
251-
latent_model_input = torch.cat([latent_model_input, batch.video_latent,
252-
torch.zeros_like(latents)],
268+
latent_model_input = torch.cat([latent_model_input, batch.video_latent, v2v_zero_pad],
253269
dim=1).to(target_dtype)
254270
elif batch.image_latent is not None:
255271
assert not fastvideo_args.pipeline_config.ti2v_task, "image latents should not be provided for TI2V task"
@@ -266,7 +282,6 @@ def forward(
266282
t_expand = t.repeat(latent_model_input.shape[0])
267283
t_expand = t_expand.to(get_local_torch_device())
268284

269-
use_meanflow = getattr(self.transformer.config, "use_meanflow", False)
270285
if use_meanflow:
271286
if i == len(timesteps) - 1:
272287
timesteps_r = torch.tensor([0.0], device=get_local_torch_device())
@@ -285,13 +300,6 @@ def forward(
285300

286301
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
287302

288-
# Prepare inputs for transformer
289-
guidance_expand = (torch.tensor(
290-
[fastvideo_args.pipeline_config.embedded_cfg_scale] * latent_model_input.shape[0],
291-
dtype=torch.float32,
292-
device=get_local_torch_device(),
293-
).to(target_dtype) * 1000.0 if fastvideo_args.pipeline_config.embedded_cfg_scale is not None else None)
294-
295303
# Predict noise residual
296304
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
297305
if (vsa_available and self.attn_backend == VideoSparseAttentionBackend):

0 commit comments

Comments
 (0)