Skip to content

Commit 9997c59

Browse files
Source sync: Add Flax configuration updates and improve WAN trainer profiling (#326)
* source sync PiperOrigin-RevId: 866156279 * source sync PiperOrigin-RevId: 866197055 * source sync PiperOrigin-RevId: 866530641 --------- Co-authored-by: maxdiffusion authors <google-ml-automation@google.com>
1 parent 5d16f16 commit 9997c59

3 files changed

Lines changed: 22 additions & 15 deletions

File tree

src/maxdiffusion/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@
6464
],
6565
}
6666

67+
if is_flax_available():
68+
from flax import config as flax_config
69+
70+
flax_config.update("flax_always_shard_variable", False)
71+
6772
try:
6873
if not is_onnx_available():
6974
raise OptionalDependencyNotAvailable()

src/maxdiffusion/configuration_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,11 @@ def load_config(
376376
if os.path.isfile(pretrained_model_name_or_path):
377377
config_file = pretrained_model_name_or_path
378378
elif os.path.isdir(pretrained_model_name_or_path):
379-
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
379+
if subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
380+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
381+
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
380382
# Load from a PyTorch checkpoint
381383
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
382-
elif subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
383-
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
384384
else:
385385
raise EnvironmentError(f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}.")
386386
else:

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -444,19 +444,21 @@ def loss_fn(params):
444444
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
445445
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
446446

447-
model_pred = model(
448-
hidden_states=noisy_latents,
449-
timestep=timesteps,
450-
encoder_hidden_states=encoder_hidden_states,
451-
deterministic=False,
452-
rngs=nnx.Rngs(dropout_rng),
453-
)
447+
with jax.named_scope("forward_pass"):
448+
model_pred = model(
449+
hidden_states=noisy_latents,
450+
timestep=timesteps,
451+
encoder_hidden_states=encoder_hidden_states,
452+
deterministic=False,
453+
rngs=nnx.Rngs(dropout_rng),
454+
)
454455

455-
training_target = scheduler.training_target(latents, noise, timesteps)
456-
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
457-
loss = (training_target - model_pred) ** 2
458-
loss = loss * training_weight
459-
loss = jnp.mean(loss)
456+
with jax.named_scope("loss"):
457+
training_target = scheduler.training_target(latents, noise, timesteps)
458+
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
459+
loss = (training_target - model_pred) ** 2
460+
loss = loss * training_weight
461+
loss = jnp.mean(loss)
460462

461463
return loss
462464

0 commit comments

Comments
 (0)