Skip to content

Commit 51c34e6

Browse files
Merge pull request #407 from AI-Hypercomputer:ninatu/wan_vae_sharding_bug
PiperOrigin-RevId: 918635461
2 parents 52854a3 + f9c1253 commit 51c34e6

2 files changed

Lines changed: 4 additions & 7 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -674,11 +674,8 @@ def _create_common_components(cls, config, vae_only=False, i2v=False):
674674
vae_spatial = getattr(config, "vae_spatial", -1)
675675
total_devices = math.prod(devices_array.shape)
676676

677-
if vae_spatial <= 0:
678-
dp_size = mesh.shape.get("data", 1)
679-
if dp_size == -1 or dp_size == 0:
680-
dp_size = 1
681-
vae_spatial = (2 * total_devices) // dp_size
677+
if vae_spatial == -1:
678+
vae_spatial = total_devices
682679

683680
assert (
684681
total_devices % vae_spatial == 0

src/maxdiffusion/pyconfig.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def user_init(raw_keys):
281281
raw_keys["global_batch_size_to_train_on"],
282282
) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"])
283283

284-
if raw_keys.get("vae_spatial", -1) == -1:
285-
raw_keys["vae_spatial"] = 1
284+
if "vae_spatial" not in raw_keys:
285+
raw_keys["vae_spatial"] = -1
286286

287287

288288
def get_num_slices(raw_keys):

0 commit comments

Comments
 (0)