Skip to content

Commit 0c10d44

Browse files
entrpnsusanbao
andauthored
Fixes 720p videos bad quality. (#269)
Moves to use mixed precision by: Excluding norm, conditioning and AdaLN layers from being casted to bfloat16 when weights_dtype and activations_dtype is set to bfloat16 (which is the default). Moves VAE to full fp32. Inputs are casted to fp32. Scheduler samples are casted to fp32. --------- Co-authored-by: susanbao <susanbaonju@gmail.com> Co-authored-by: Sanbao Su <sanbao@google.com>
1 parent 8ac20ca commit 0c10d44

8 files changed

Lines changed: 155 additions & 92 deletions

File tree

README.md

Lines changed: 98 additions & 69 deletions
Large diffs are not rendered by default.

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,4 +323,4 @@ eval_data_dir: ""
323323
enable_generate_video_for_eval: False # This will increase the used TPU memory.
324324
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list).
325325

326-
enable_ssim: True
326+
enable_ssim: False

src/maxdiffusion/generate_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def run(config, pipeline=None, filename_prefix=""):
162162

163163
def main(argv: Sequence[str]) -> None:
164164
pyconfig.initialize(argv)
165-
flax.config.update('flax_always_shard_variable', False)
165+
flax.config.update("flax_always_shard_variable", False)
166166
run(pyconfig.config)
167167

168168

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
in_features=in_channels,
9090
out_features=time_embed_dim,
9191
use_bias=sample_proj_bias,
92-
dtype=dtype,
92+
dtype=jnp.float32,
9393
param_dtype=weights_dtype,
9494
precision=precision,
9595
kernel_init=nnx.with_partitioning(
@@ -121,7 +121,7 @@ def __init__(
121121
in_features=time_embed_dim,
122122
out_features=time_embed_dim_out,
123123
use_bias=sample_proj_bias,
124-
dtype=dtype,
124+
dtype=jnp.float32,
125125
param_dtype=weights_dtype,
126126
precision=precision,
127127
kernel_init=nnx.with_partitioning(
@@ -269,7 +269,7 @@ def __init__(
269269
in_features=in_features,
270270
out_features=hidden_size,
271271
use_bias=True,
272-
dtype=dtype,
272+
dtype=jnp.float32,
273273
param_dtype=weights_dtype,
274274
precision=precision,
275275
kernel_init=nnx.with_partitioning(
@@ -288,7 +288,7 @@ def __init__(
288288
in_features=hidden_size,
289289
out_features=out_features,
290290
use_bias=True,
291-
dtype=dtype,
291+
dtype=jnp.float32,
292292
param_dtype=weights_dtype,
293293
precision=precision,
294294
kernel_init=nnx.with_partitioning(

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116
rngs=rngs,
117117
in_features=dim,
118118
out_features=time_proj_dim,
119-
dtype=dtype,
119+
dtype=jnp.float32,
120120
param_dtype=weights_dtype,
121121
precision=precision,
122122
kernel_init=nnx.with_partitioning(
@@ -332,33 +332,39 @@ def __call__(
332332
rngs: nnx.Rngs = None,
333333
):
334334
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
335-
(self.adaln_scale_shift_table + temb), 6, axis=1
335+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
336336
)
337337
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
338338
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
339339

340340
# 1. Self-attention
341-
norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype)
341+
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
342+
hidden_states.dtype
343+
)
342344
attn_output = self.attn1(
343345
hidden_states=norm_hidden_states,
344346
encoder_hidden_states=norm_hidden_states,
345347
rotary_emb=rotary_emb,
346348
deterministic=deterministic,
347349
rngs=rngs,
348350
)
349-
hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype)
351+
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
350352

351353
# 2. Cross-attention
352-
norm_hidden_states = self.norm2(hidden_states)
354+
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
353355
attn_output = self.attn2(
354356
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
355357
)
356358
hidden_states = hidden_states + attn_output
357359

358360
# 3. Feed-forward
359-
norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype)
361+
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
362+
hidden_states.dtype
363+
)
360364
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
361-
hidden_states = (hidden_states + ff_output * c_gate_msa).astype(hidden_states.dtype)
365+
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
366+
hidden_states.dtype
367+
)
362368
return hidden_states
363369

364370

@@ -563,7 +569,7 @@ def layer_forward(hidden_states):
563569

564570
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
565571

566-
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).astype(hidden_states.dtype)
572+
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
567573
hidden_states = self.proj_out(hidden_states)
568574

569575
hidden_states = hidden_states.reshape(

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,28 @@
4040
import qwix
4141

4242

43+
def cast_with_exclusion(path, x, dtype_to_cast):
44+
"""
45+
Casts arrays to dtype_to_cast, but keeps params from any 'norm' layer in float32.
46+
"""
47+
48+
exclusion_keywords = [
49+
"norm", # For all LayerNorm/GroupNorm layers
50+
"condition_embedder", # The entire time/text conditioning module
51+
"scale_shift_table", # Catches both the final and the AdaLN tables
52+
]
53+
54+
path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path)
55+
56+
if any(keyword in path_str.lower() for keyword in exclusion_keywords):
57+
print("is_norm_path: ", path)
58+
# Keep LayerNorm/GroupNorm weights and biases in full precision
59+
return x.astype(jnp.float32)
60+
else:
61+
# Cast everything else to dtype_to_cast
62+
return x.astype(dtype_to_cast)
63+
64+
4365
def basic_clean(text):
4466
if is_ftfy_available():
4567
import ftfy
@@ -118,7 +140,10 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
118140
num_layers=wan_config["num_layers"],
119141
scan_layers=config.scan_layers,
120142
)
121-
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
143+
144+
params = jax.tree_util.tree_map_with_path(
145+
lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params
146+
)
122147
for path, val in flax.traverse_util.flatten_dict(params).items():
123148
if restored_checkpoint:
124149
path = path[:-1]
@@ -214,8 +239,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
214239
subfolder="vae",
215240
rngs=rngs,
216241
mesh=mesh,
217-
dtype=config.activations_dtype,
218-
weights_dtype=config.weights_dtype,
242+
dtype=jnp.float32,
243+
weights_dtype=jnp.float32,
219244
)
220245
return wan_vae
221246

@@ -474,7 +499,7 @@ def encode_prompt(
474499
num_videos_per_prompt=num_videos_per_prompt,
475500
max_sequence_length=max_sequence_length,
476501
)
477-
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype)
502+
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32)
478503

479504
if negative_prompt_embeds is None:
480505
negative_prompt = negative_prompt or ""
@@ -484,7 +509,7 @@ def encode_prompt(
484509
num_videos_per_prompt=num_videos_per_prompt,
485510
max_sequence_length=max_sequence_length,
486511
)
487-
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype)
512+
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32)
488513

489514
return prompt_embeds, negative_prompt_embeds
490515

@@ -507,7 +532,7 @@ def prepare_latents(
507532
int(height) // vae_scale_factor_spatial,
508533
int(width) // vae_scale_factor_spatial,
509534
)
510-
latents = jax.random.normal(rng, shape=shape, dtype=self.config.weights_dtype)
535+
latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32)
511536

512537
return latents
513538

@@ -597,7 +622,7 @@ def __call__(
597622
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)
598623
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
599624
latents = latents / latents_std + latents_mean
600-
latents = latents.astype(self.config.weights_dtype)
625+
latents = latents.astype(jnp.float32)
601626

602627
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
603628
video = self.vae.decode(latents, self.vae_cache)[0]

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,9 @@ def step(
674674
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
675675
the multistep UniPC.
676676
"""
677+
678+
sample = sample.astype(jnp.float32)
679+
677680
if state.timesteps is None:
678681
raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler")
679682

tests/schedulers/test_scheduler_unipc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,8 @@ def test_fp16_support(self):
518518
step_output = scheduler.step(state, residual, t, sample)
519519
sample = step_output.prev_sample
520520
state = step_output.state
521-
522-
self.assertEqual(sample.dtype, jnp.bfloat16)
521+
# sample is casted to fp32 inside step and output should be fp32.
522+
self.assertEqual(sample.dtype, jnp.float32)
523523

524524
def test_full_loop_with_noise(self):
525525
scheduler_class = self.scheduler_classes[0]

0 commit comments

Comments
 (0)