Skip to content

Commit f51786d

Browse files
Google-ML-Automationmichelle-yooh
authored andcommitted
source sync
PiperOrigin-RevId: 866197055
1 parent df9c7b0 commit f51786d

2 files changed

Lines changed: 3 additions & 2 deletions

File tree

src/maxdiffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666

6767
if is_flax_available():
6868
from flax import config as flax_config
69+
6970
flax_config.update('flax_always_shard_variable', False)
7071

7172
try:

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def loss_fn(params):
444444
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
445445
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
446446

447-
with jax.named_scope('forward_pass'):
447+
with jax.named_scope("forward_pass"):
448448
model_pred = model(
449449
hidden_states=noisy_latents,
450450
timestep=timesteps,
@@ -453,7 +453,7 @@ def loss_fn(params):
453453
rngs=nnx.Rngs(dropout_rng),
454454
)
455455

456-
with jax.named_scope('loss'):
456+
with jax.named_scope("loss"):
457457
training_target = scheduler.training_target(latents, noise, timesteps)
458458
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
459459
loss = (training_target - model_pred) ** 2

0 commit comments

Comments
 (0)