Skip to content

Commit f9c1253

Browse files
committed
Fix VAE spatial sharding dynamic calculation bug in Wan pipeline.
Previously, setting `vae_spatial: -1` in the config (intended to trigger dynamic calculation of the VAE spatial sharding axis size) was ineffective because `pyconfig.py` prematurely overrode any `-1` or missing `vae_spatial` value to `1`. Furthermore, the dynamic calculation formula in `wan_pipeline.py` (`vae_spatial = (2 * total_devices) // dp_size`) was not robust. On single-device runs (where `total_devices=1` and `dp_size=1`) or configurations with odd data parallel (DP) sizes, it would calculate a `vae_spatial` value (e.g., 2) that does not divide `total_devices`, failing the mesh validation assertion.
1 parent 19d4e4d commit f9c1253

2 files changed

Lines changed: 4 additions & 7 deletions

File tree

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -672,11 +672,8 @@ def _create_common_components(cls, config, vae_only=False, i2v=False):
672672
vae_spatial = getattr(config, "vae_spatial", -1)
673673
total_devices = math.prod(devices_array.shape)
674674

675-
if vae_spatial <= 0:
676-
dp_size = mesh.shape.get("data", 1)
677-
if dp_size == -1 or dp_size == 0:
678-
dp_size = 1
679-
vae_spatial = (2 * total_devices) // dp_size
675+
if vae_spatial == -1:
676+
vae_spatial = total_devices
680677

681678
assert (
682679
total_devices % vae_spatial == 0

src/maxdiffusion/pyconfig.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def user_init(raw_keys):
281281
raw_keys["global_batch_size_to_train_on"],
282282
) = _HyperParameters.calculate_global_batch_sizes(raw_keys["per_device_batch_size"])
283283

284-
if raw_keys.get("vae_spatial", -1) == -1:
285-
raw_keys["vae_spatial"] = 1
284+
if "vae_spatial" not in raw_keys:
285+
raw_keys["vae_spatial"] = -1
286286

287287

288288
def get_num_slices(raw_keys):

0 commit comments

Comments
 (0)