@@ -140,6 +140,18 @@ class VideoSparseAttentionMetadata(AttentionMetadata):
140140 reverse_tile_partition_indices : torch .LongTensor
141141 variable_block_sizes : torch .LongTensor
142142 non_pad_index : torch .LongTensor
143+ # Precomputed fancy index that fuses ``x[:, non_pad_index][:, reverse_tile_partition_indices]``
144+ # in postprocess_output(). Avoids materializing the intermediate
145+ # ``[B, len(non_pad_index), H, D]`` tensor on every layer.
146+ untile_combined_index : torch .LongTensor
147+ # Per-step shared padded buffer used by tile(). Lazily populated on
148+ # the first layer's call and reused by every subsequent VSA layer in
149+ # the same denoising step. Scoping to metadata (not class/instance)
150+ # makes the reuse thread-safe across concurrent requests and keeps
151+ # the "pad positions are zero" invariant trivially true (the buffer
152+ # is freshly zeroed alongside ``non_pad_index`` so the index set
153+ # cannot drift between calls).
154+ tile_buf : torch .Tensor | None = None
143155
144156
145157class VideoSparseAttentionMetadataBuilder (AttentionMetadataBuilder ):
@@ -171,6 +183,7 @@ def build( # type: ignore
171183 reverse_tile_partition_indices = get_reverse_tile_partition_indices (dit_seq_shape , VSA_TILE_SIZE , device )
172184 variable_block_sizes = construct_variable_block_sizes (dit_seq_shape , num_tiles , device )
173185 non_pad_index = get_non_pad_index (variable_block_sizes , math .prod (VSA_TILE_SIZE ))
186+ untile_combined_index = non_pad_index [reverse_tile_partition_indices ]
174187
175188 return VideoSparseAttentionMetadata (
176189 current_timestep = current_timestep ,
@@ -181,7 +194,8 @@ def build( # type: ignore
181194 tile_partition_indices = tile_partition_indices , # type: ignore
182195 reverse_tile_partition_indices = reverse_tile_partition_indices ,
183196 variable_block_sizes = variable_block_sizes ,
184- non_pad_index = non_pad_index )
197+ non_pad_index = non_pad_index ,
198+ untile_combined_index = untile_combined_index )
185199
186200
187201class VideoSparseAttentionImpl (AttentionImpl ):
@@ -200,37 +214,59 @@ def __init__(
200214 sp_group = get_sp_group ()
201215 self .sp_size = sp_group .world_size
202216
203- def tile (self , x : torch .Tensor , num_tiles : list [int ], tile_partition_indices : torch .LongTensor ,
204- non_pad_index : torch .LongTensor ) -> torch .Tensor :
217+ def tile (self , x : torch .Tensor , attn_metadata : VideoSparseAttentionMetadata ) -> torch .Tensor :
218+ """Tile ``x`` into ``attn_metadata.tile_buf`` and return it.
219+
220+ The returned tensor aliases the per-metadata buffer and is only
221+ valid until the next ``tile()`` / ``preprocess_qkv`` call on the
222+ same ``attn_metadata``. Callers must consume (or copy) the
223+ result before invoking another VSA layer with the same metadata.
224+ Today both call sites materialize copies via
225+ ``.transpose(...).contiguous()`` inside ``forward()``, so the
226+ contract holds; future callers must preserve it.
227+ """
228+ num_tiles = attn_metadata .num_tiles
205229 t_padded_size = num_tiles [0 ] * VSA_TILE_SIZE [0 ]
206230 h_padded_size = num_tiles [1 ] * VSA_TILE_SIZE [1 ]
207231 w_padded_size = num_tiles [2 ] * VSA_TILE_SIZE [2 ]
208-
209- x_padded = torch .zeros ((x .shape [0 ], t_padded_size * h_padded_size * w_padded_size , x .shape [- 2 ], x .shape [- 1 ]),
210- device = x .device ,
211- dtype = x .dtype )
212- x_padded [:, non_pad_index ] = x [:, tile_partition_indices ]
213- return x_padded
214-
215- def untile (self , x : torch .Tensor , reverse_tile_partition_indices : torch .LongTensor ,
216- non_pad_index : torch .LongTensor ) -> torch .Tensor :
217- x = x [:, non_pad_index ][:, reverse_tile_partition_indices ]
218- return x
232+ target_shape = (x .shape [0 ], t_padded_size * h_padded_size * w_padded_size , x .shape [- 2 ], x .shape [- 1 ])
233+
234+ # Reuse the per-step buffer stashed on metadata (lazily allocated
235+ # on the first VSA layer's call within a denoising step). Pad
236+ # positions are zero from the initial torch.zeros and never
237+ # written to. Scoping to metadata makes reuse safe across
238+ # concurrent requests and keeps the "pad positions are zero"
239+ # invariant trivially true: ``non_pad_index`` is fixed within
240+ # a single metadata instance.
241+ buf = attn_metadata .tile_buf
242+ if (buf is None or buf .shape != target_shape or buf .dtype != x .dtype or buf .device != x .device ):
243+ buf = torch .zeros (target_shape , device = x .device , dtype = x .dtype )
244+ attn_metadata .tile_buf = buf
245+
246+ buf [:, attn_metadata .non_pad_index ] = x [:, attn_metadata .tile_partition_indices ]
247+ return buf
248+
249+ def untile (self , x : torch .Tensor , untile_combined_index : torch .LongTensor ) -> torch .Tensor :
250+ # Single fancy index using precomputed combined indices; avoids
251+ # the intermediate ``[B, len(non_pad_index), H, D]`` tensor that
252+ # the two-step ``x[:, non_pad_index][:, reverse_tile_partition_indices]``
253+ # would allocate on every layer.
254+ return x [:, untile_combined_index ]
219255
220256 def preprocess_qkv (
221257 self ,
222258 qkv : torch .Tensor ,
223259 attn_metadata : VideoSparseAttentionMetadata ,
224260 ) -> torch .Tensor :
225- return self . tile (qkv , attn_metadata . num_tiles , attn_metadata . tile_partition_indices ,
226- attn_metadata . non_pad_index )
261+ """Tile QKV; aliasing contract: see `` tile()``."""
262+ return self . tile ( qkv , attn_metadata )
227263
228264 def postprocess_output (
229265 self ,
230266 output : torch .Tensor ,
231267 attn_metadata : VideoSparseAttentionMetadata ,
232268 ) -> torch .Tensor :
233- return self .untile (output , attn_metadata .reverse_tile_partition_indices , attn_metadata . non_pad_index )
269+ return self .untile (output , attn_metadata .untile_combined_index )
234270
235271 def forward ( # type: ignore[override]
236272 self ,
0 commit comments