Skip to content

Commit 656e150

Browse files
committed
Move block_sizes computation to top of _ulysses_attention after reshape
1 parent 71cba8d commit 656e150

1 file changed

Lines changed: 1 addition & 2 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,12 +487,11 @@ def _ulysses_attention(
487487
"Ulysses attention requires the number of heads to be divisible by the context shard count, "
488488
f"got heads={num_heads} and context_shards={num_shards}."
489489
)
490+
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
490491

491492
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
492493
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
493494

494-
block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash")
495-
496495
@functools.partial(
497496
jax.shard_map,
498497
mesh=mesh,

0 commit comments

Comments
 (0)