Skip to content

Commit 82033df

Browse files
committed
Merge PR #390 fixing mesh axis bugs and aligning Wan-VACE
2 parents dc594e4 + fa877ab commit 82033df

6 files changed

Lines changed: 175 additions & 126 deletions

File tree

src/maxdiffusion/kernels/custom_splash_attention.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,8 @@ def v_index_map(h, i, j, *_):
427427
compiler_params=pltpu.CompilerParams(
428428
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
429429
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler},
430+
disable_bounds_checks=True,
431+
skip_device_barrier=True,
430432
),
431433
out_shape=out_shapes,
432434
)(q, k, v)
@@ -514,6 +516,8 @@ def out_index_map(h, i, j, *_):
514516
compiler_params=pltpu.CompilerParams(
515517
dimension_semantics=("parallel", "arbitrary", "arbitrary"),
516518
flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler},
519+
disable_bounds_checks=True,
520+
skip_device_barrier=True,
517521
),
518522
out_shape=out_shapes,
519523
)(q, k, v)

src/maxdiffusion/models/wan/autoencoder_kl_wan_2p2.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import jax
2121
from jax import tree_util
2222
import jax.numpy as jnp
23+
from jax.sharding import NamedSharding, PartitionSpec as P
2324
from maxdiffusion.models.wan.autoencoder_kl_wan import AutoencoderKLWanCache, WanCausalConv3d # pylint: disable=g-importing-member
2425

2526
from ... import common_types
@@ -1266,6 +1267,7 @@ def __init__(
12661267
self.temporal_upsample = temperal_downsample[::-1]
12671268
self.latents_mean = latents_mean
12681269
self.latents_std = latents_std
1270+
self.mesh = mesh
12691271

12701272
self.patch_size = 2
12711273
self.patchify = WanPatchify(patch_size=self.patch_size)
@@ -1339,16 +1341,23 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
13391341
iter_ = 1 + (t - 1) // 4
13401342
enc_feat_map = feat_cache._enc_feat_map
13411343

1344+
spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None))
13421345
for i in range(iter_):
13431346
enc_conv_idx = 0
13441347
if i == 0:
1345-
out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx)
1348+
chunk = x[:, :1, :, :, :]
1349+
chunk = jax.lax.with_sharding_constraint(chunk, spatial_sharding)
1350+
out, enc_feat_map, enc_conv_idx = self.encoder(chunk, feat_cache=enc_feat_map, feat_idx=enc_conv_idx)
1351+
out = jax.lax.with_sharding_constraint(out, spatial_sharding)
13461352
else:
1353+
chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :]
1354+
chunk = jax.lax.with_sharding_constraint(chunk, spatial_sharding)
13471355
out_, enc_feat_map, enc_conv_idx = self.encoder(
1348-
x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :],
1356+
chunk,
13491357
feat_cache=enc_feat_map,
13501358
feat_idx=enc_conv_idx,
13511359
)
1360+
out_ = jax.lax.with_sharding_constraint(out_, spatial_sharding)
13521361
out = jnp.concatenate([out, out_], axis=1)
13531362

13541363
# Update back to the wrapper object if needed, but for result we use local vars
@@ -1385,17 +1394,22 @@ def _decode(
13851394

13861395
dec_feat_map = feat_cache._feat_map
13871396

1397+
spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None))
13881398
for i in range(iter_):
13891399
conv_idx = 0
1400+
chunk = x[:, i : i + 1, :, :, :]
1401+
chunk = jax.lax.with_sharding_constraint(chunk, spatial_sharding)
13901402
if i == 0:
13911403
out, dec_feat_map, conv_idx = self.decoder(
1392-
x[:, i : i + 1, :, :, :],
1404+
chunk,
13931405
feat_cache=dec_feat_map,
13941406
feat_idx=conv_idx,
13951407
first_chunk=True,
13961408
)
1409+
out = jax.lax.with_sharding_constraint(out, spatial_sharding)
13971410
else:
1398-
out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
1411+
out_, dec_feat_map, conv_idx = self.decoder(chunk, feat_cache=dec_feat_map, feat_idx=conv_idx)
1412+
out_ = jax.lax.with_sharding_constraint(out_, spatial_sharding)
13991413
out = jnp.concatenate([out, out_], axis=1)
14001414

14011415
feat_cache._feat_map = dec_feat_map

0 commit comments

Comments
 (0)