Skip to content

Commit 2401c2d

Browse files
sagarchaparaclaude
andcommitted
chore: update attention tests and README
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent b91e784 commit 2401c2d

2 files changed

Lines changed: 5 additions & 3 deletions

File tree

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ After installation completes, run the training script.
210210
--xla_enable_async_all_gather=true \
211211
--xla_tpu_scoped_vmem_limit_kib=65536 \
212212
--xla_tpu_enable_async_all_to_all=true \
213+
--xla_tpu_enable_latency_hiding_scheduler=true \
213214
--xla_tpu_enable_all_experimental_scheduler_features=true \
214215
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
215216
--xla_tpu_host_transfer_overlap_limit=24 \
@@ -329,6 +330,7 @@ After installation completes, run the training script.
329330
--xla_enable_async_all_gather=true \
330331
--xla_tpu_scoped_vmem_limit_kib=65536 \
331332
--xla_tpu_enable_async_all_to_all=true \
333+
--xla_tpu_enable_latency_hiding_scheduler=true \
332334
--xla_tpu_enable_all_experimental_scheduler_features=true \
333335
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
334336
--xla_tpu_host_transfer_overlap_limit=24 \

src/maxdiffusion/tests/attention_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def _ulysses_mesh(self):
4545

4646
def _ulysses_ring_mesh(self):
4747
devices = np.array(jax.devices()[:4]).reshape(1, 1, 2, 2)
48-
return Mesh(devices, ("data", "fsdp", "context", "tensor"))
48+
return Mesh(devices, ("data", "fsdp", "ring", "ulysses"))
4949

5050
def _ulysses_axis_rules(self):
5151
return (
@@ -60,8 +60,8 @@ def _ulysses_ring_axis_rules(self):
6060
return (
6161
(attention_flax.BATCH, "data"),
6262
(attention_flax.SELF_ATTN_HEAD, None),
63-
(attention_flax.SELF_ATTN_Q_LENGTH, ("context", "tensor")),
64-
(attention_flax.SELF_ATTN_KV_LENGTH, ("context", "tensor")),
63+
(attention_flax.SELF_ATTN_Q_LENGTH, ("ring", "ulysses")),
64+
(attention_flax.SELF_ATTN_KV_LENGTH, ("ring", "ulysses")),
6565
(attention_flax.D_KV, None),
6666
)
6767

0 commit comments

Comments
 (0)