|
20 | 20 | import jax |
21 | 21 | from jax import tree_util |
22 | 22 | import jax.numpy as jnp |
| 23 | +from jax.sharding import NamedSharding, PartitionSpec as P |
23 | 24 | from maxdiffusion.models.wan.autoencoder_kl_wan import AutoencoderKLWanCache, WanCausalConv3d # pylint: disable=g-importing-member |
24 | 25 |
|
25 | 26 | from ... import common_types |
@@ -1266,6 +1267,7 @@ def __init__( |
1266 | 1267 | self.temporal_upsample = temperal_downsample[::-1] |
1267 | 1268 | self.latents_mean = latents_mean |
1268 | 1269 | self.latents_std = latents_std |
| 1270 | + self.mesh = mesh |
1269 | 1271 |
|
1270 | 1272 | self.patch_size = 2 |
1271 | 1273 | self.patchify = WanPatchify(patch_size=self.patch_size) |
@@ -1339,16 +1341,23 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): |
1339 | 1341 | iter_ = 1 + (t - 1) // 4 |
1340 | 1342 | enc_feat_map = feat_cache._enc_feat_map |
1341 | 1343 |
|
| 1344 | + spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) |
1342 | 1345 | for i in range(iter_): |
1343 | 1346 | enc_conv_idx = 0 |
1344 | 1347 | if i == 0: |
1345 | | - out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx) |
| 1348 | + chunk = x[:, :1, :, :, :] |
| 1349 | + chunk = jax.lax.with_sharding_constraint(chunk, spatial_sharding) |
| 1350 | + out, enc_feat_map, enc_conv_idx = self.encoder(chunk, feat_cache=enc_feat_map, feat_idx=enc_conv_idx) |
| 1351 | + out = jax.lax.with_sharding_constraint(out, spatial_sharding) |
1346 | 1352 | else: |
| 1353 | + chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :] |
| 1354 | + chunk = jax.lax.with_sharding_constraint(chunk, spatial_sharding) |
1347 | 1355 | out_, enc_feat_map, enc_conv_idx = self.encoder( |
1348 | | - x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], |
| 1356 | + chunk, |
1349 | 1357 | feat_cache=enc_feat_map, |
1350 | 1358 | feat_idx=enc_conv_idx, |
1351 | 1359 | ) |
| 1360 | + out_ = jax.lax.with_sharding_constraint(out_, spatial_sharding) |
1352 | 1361 | out = jnp.concatenate([out, out_], axis=1) |
1353 | 1362 |
|
1354 | 1363 | # Update back to the wrapper object if needed, but for result we use local vars |
@@ -1385,17 +1394,22 @@ def _decode( |
1385 | 1394 |
|
1386 | 1395 | dec_feat_map = feat_cache._feat_map |
1387 | 1396 |
|
| 1397 | + spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) |
1388 | 1398 | for i in range(iter_): |
1389 | 1399 | conv_idx = 0 |
| 1400 | + chunk = x[:, i : i + 1, :, :, :] |
| 1401 | + chunk = jax.lax.with_sharding_constraint(chunk, spatial_sharding) |
1390 | 1402 | if i == 0: |
1391 | 1403 | out, dec_feat_map, conv_idx = self.decoder( |
1392 | | - x[:, i : i + 1, :, :, :], |
| 1404 | + chunk, |
1393 | 1405 | feat_cache=dec_feat_map, |
1394 | 1406 | feat_idx=conv_idx, |
1395 | 1407 | first_chunk=True, |
1396 | 1408 | ) |
| 1409 | + out = jax.lax.with_sharding_constraint(out, spatial_sharding) |
1397 | 1410 | else: |
1398 | | - out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx) |
| 1411 | + out_, dec_feat_map, conv_idx = self.decoder(chunk, feat_cache=dec_feat_map, feat_idx=conv_idx) |
| 1412 | + out_ = jax.lax.with_sharding_constraint(out_, spatial_sharding) |
1399 | 1413 | out = jnp.concatenate([out, out_], axis=1) |
1400 | 1414 |
|
1401 | 1415 | feat_cache._feat_map = dec_feat_map |
|
0 commit comments