Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -628,16 +628,16 @@ To generate images, run the following command:
We added ring attention support for Wan models. Below are the stats for one `720p` (81 frames) video generation (with CFG DP):
| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
| -- | -- | -- | -- | -- | -- |
| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | 264.2 |
| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | **252.4** |
| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | 212.7 |
| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | **201.7** |
| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | **249.3** |
| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | 252.4 |
| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | **194.4** |
| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | 201.7 |

| Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time |
| -- | -- | -- | -- | -- | -- |
| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | 146.6 |
| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | **137.2** |
| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **117.8** |
| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | **127.1** |
| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | 137.2 |
| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **106.0** |
| v7x-16 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context8-tp1 | 137.5 |

(* There are some known stability issues for ring attention on 16 TPUs, please use `tokamax_flash` attention instead.)
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses
use_base2_exp: True
use_experimental_scheduler: True
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_wan_1_3b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
use_base2_exp: True
use_experimental_scheduler: True
flash_min_seq_length: 0

# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_wan_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
use_base2_exp: True
use_experimental_scheduler: True
flash_min_seq_length: 4096
# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
# Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster.
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_wan_i2v_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
use_base2_exp: True
use_experimental_scheduler: True
flash_min_seq_length: 4096
dropout: 0.0

Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_wan_i2v_27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ jit_initializers: True
from_pt: True
split_head_dim: True
attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
use_base2_exp: True
use_experimental_scheduler: True
flash_min_seq_length: 4096
dropout: 0.0

Expand Down
9 changes: 1 addition & 8 deletions src/maxdiffusion/generate_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,14 +302,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None):
f"{'=' * 50}"
)

s0 = time.perf_counter()
if max_utils.profiler_enabled(config):
with max_utils.Profiler(config):
videos = call_pipeline(config, pipeline, prompt, negative_prompt)
generation_time_with_profiler = time.perf_counter() - s0
max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}")
if writer and jax.process_index() == 0:
writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0)
videos = call_pipeline(config, pipeline, prompt, negative_prompt)

return saved_video_path

Expand Down
52 changes: 42 additions & 10 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def convert_to_tokamax_splash_config(
attn_logits_soft_cap: float | None = None,
fuse_reciprocal: bool = True,
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
max_logit_const: float | None = None,
interpret: bool = False,
dq_reduction_steps: int | None = None,
Expand All @@ -294,6 +295,7 @@ def convert_to_tokamax_splash_config(
attn_logits_soft_cap=attn_logits_soft_cap,
fuse_reciprocal=fuse_reciprocal,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
max_logit_const=max_logit_const,
interpret=interpret,
dq_reduction_steps=dq_reduction_steps,
Expand All @@ -314,6 +316,8 @@ def _tpu_flash_attention(
mask_padding_tokens: bool = True,
residual_checkpoint_name: str | None = None,
attention_mask: jax.Array = None,
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
) -> jax.Array:
"""TPU Flash Attention"""

Expand Down Expand Up @@ -399,7 +403,12 @@ def wrap_flash_attention(query, key, value):
splash_kernel = tokamax_splash_attention_kernel.make_splash_mha(
mask=mask,
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
config=convert_to_tokamax_splash_config(
block_sizes,
residual_checkpoint_name=residual_checkpoint_name,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
),
save_residuals=False,
)
elif attention_kernel == "tokamax_ring":
Expand All @@ -409,7 +418,12 @@ def wrap_flash_attention(query, key, value):
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
mask=mask,
is_mqa=False,
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
config=convert_to_tokamax_splash_config(
block_sizes,
residual_checkpoint_name=residual_checkpoint_name,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
),
save_residuals=False,
ring_axis="context",
rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids
Expand Down Expand Up @@ -473,13 +487,13 @@ def ring_scan_body(carry, _):
raise ValueError("ring attention requires context > 1")
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)

devices_in_data_context = mesh.shape["data"] * mesh.shape["context"]
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
# This warning might show up when doing model eval for example, when calculating model flops
# and that is expected.
if not (query.shape[0] / devices_in_data_context).is_integer():
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
max_logging.log(
"Warning, batch dimension should be shardable among the devices in data and context"
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
"Warning, batch dimension should be shardable among the devices in data and fsdp"
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
)
x = wrap_flash_attention(query, key, value)
# Trim back to original sequence length after context-axis padding.
Expand Down Expand Up @@ -614,11 +628,11 @@ def wrap_ulysses_attention(query, key, value):
attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True)
return attention_output

devices_in_data_context = mesh.shape["data"] * num_shards
if not (query.shape[0] / devices_in_data_context).is_integer():
devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1)
if not (query.shape[0] / devices_in_batch_sharding).is_integer():
max_logging.log(
"Warning, batch dimension should be shardable among the devices in data and context"
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
"Warning, batch dimension should be shardable among the devices in data and fsdp"
f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}"
)
x = wrap_ulysses_attention(query, key, value)
x = x[:, :, :orig_q_seq_len, :]
Expand Down Expand Up @@ -741,6 +755,8 @@ def _apply_attention(
mask_padding_tokens: bool = True,
residual_checkpoint_name: str | None = None,
attention_mask: Array = None,
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
):
"""Routes to different attention kernels."""
_check_attention_inputs(query, key, value)
Expand Down Expand Up @@ -789,6 +805,8 @@ def _apply_attention(
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name=residual_checkpoint_name,
attention_mask=attention_mask,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
)
elif "ring" in attention_kernel:
return _tpu_flash_attention(
Expand Down Expand Up @@ -983,8 +1001,12 @@ def __init__(
quant: Quant = None,
mask_padding_tokens: bool = True,
residual_checkpoint_name: str | None = None,
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
):
self.dpa_layer = None
self.use_base2_exp = use_base2_exp
self.use_experimental_scheduler = use_experimental_scheduler
if attention_kernel == "cudnn_flash_te":
from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error

Expand Down Expand Up @@ -1045,6 +1067,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
mask_padding_tokens=self.mask_padding_tokens,
residual_checkpoint_name=self.residual_checkpoint_name,
attention_mask=attention_mask,
use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False,
use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False,
)


Expand All @@ -1063,6 +1087,8 @@ class AttentionOp(nn.Module):
flash_block_sizes: BlockSizes = None
dtype: DType = jnp.float32
quant: Quant = None
use_base2_exp: bool = False
use_experimental_scheduler: bool = False

def setup(self):
self.dpa_layer = None
Expand Down Expand Up @@ -1108,6 +1134,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask
flash_block_sizes=self.flash_block_sizes,
dpa_layer=self.dpa_layer,
attention_mask=attention_mask,
use_base2_exp=self.use_base2_exp,
use_experimental_scheduler=self.use_experimental_scheduler,
)


Expand Down Expand Up @@ -1144,6 +1172,8 @@ def __init__(
enable_jax_named_scopes: bool = False,
added_kv_proj_dim: Optional[int] = None, # New for I2V
image_seq_len: Optional[int] = None, # New for I2V
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
):
if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
Expand Down Expand Up @@ -1186,6 +1216,8 @@ def __init__(
quant=quant,
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name=residual_checkpoint_name,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
)
# None axes corresponds to the stacked weights across all blocks
# because of the use of nnx.vmap and nnx.scan.
Expand Down
4 changes: 2 additions & 2 deletions src/maxdiffusion/models/wan/autoencoder_kl_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ def _decode(
fm1, fm2, fm3, fm4 = out_chunk_1[:, 0, ...], out_chunk_1[:, 1, ...], out_chunk_1[:, 2, ...], out_chunk_1[:, 3, ...]
axis = 1 if fm1.shape[0] > 1 else 0
fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]]
out_1 = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
out_1 = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1)

out_list = [out_0, out_1]

Expand All @@ -1226,7 +1226,7 @@ def scan_fn(carry, chunk_in):
fm1, fm2, fm3, fm4 = out_chunk[:, 0, ...], out_chunk[:, 1, ...], out_chunk[:, 2, ...], out_chunk[:, 3, ...]
axis = 1 if fm1.shape[0] > 1 else 0
fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]]
new_chunk = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1)
new_chunk = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1)

return next_feat_map, new_chunk

Expand Down
10 changes: 10 additions & 0 deletions src/maxdiffusion/models/wan/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def __init__(
dropout: float = 0.0,
mask_padding_tokens: bool = True,
enable_jax_named_scopes: bool = False,
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
):
self.enable_jax_named_scopes = enable_jax_named_scopes

Expand All @@ -315,6 +317,8 @@ def __init__(
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name="self_attn",
enable_jax_named_scopes=enable_jax_named_scopes,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
)

# 1. Cross-attention
Expand All @@ -339,6 +343,8 @@ def __init__(
mask_padding_tokens=mask_padding_tokens,
residual_checkpoint_name="cross_attn",
enable_jax_named_scopes=enable_jax_named_scopes,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
)
assert cross_attn_norm is True
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)
Expand Down Expand Up @@ -486,6 +492,8 @@ def __init__(
mask_padding_tokens: bool = True,
scan_layers: bool = True,
enable_jax_named_scopes: bool = False,
use_base2_exp: bool = False,
use_experimental_scheduler: bool = False,
):
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
Expand Down Expand Up @@ -547,6 +555,8 @@ def init_block(rngs):
enable_jax_named_scopes=enable_jax_named_scopes,
added_kv_proj_dim=added_kv_proj_dim,
image_seq_len=image_seq_len,
use_base2_exp=use_base2_exp,
use_experimental_scheduler=use_experimental_scheduler,
)

self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy)
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def cast_with_exclusion(path, x, dtype_to_cast):
path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path)

if any(keyword in path_str.lower() for keyword in exclusion_keywords):
print("is_norm_path: ", path)
# Keep LayerNorm/GroupNorm weights and biases in full precision
return x.astype(jnp.float32)
else:
Expand Down Expand Up @@ -139,6 +138,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
wan_config["mask_padding_tokens"] = config.mask_padding_tokens
wan_config["scan_layers"] = config.scan_layers
wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes
wan_config["use_base2_exp"] = config.use_base2_exp
wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler

# 2. eval_shape - will not use flops or create weights on device
# thus not using HBM memory.
Expand Down
17 changes: 17 additions & 0 deletions src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import jax
import jax.numpy as jnp
from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
import numpy as np
from ... import max_utils


class WanPipeline2_1(WanPipeline):
Expand Down Expand Up @@ -142,6 +144,7 @@ def __call__(
retention_ratio=retention_ratio,
height=height,
mag_ratios_base=getattr(config, "mag_ratios_base", None),
config=self.config,
)

with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
Expand Down Expand Up @@ -175,6 +178,7 @@ def run_inference_2_1(
retention_ratio: float = 0.2,
height: int = 480,
mag_ratios_base: Optional[List[float]] = None,
config=None,
):
"""Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache.

Expand Down Expand Up @@ -253,7 +257,16 @@ def run_inference_2_1(
skip_warmup = magcache_init[7]
mag_ratios = magcache_init[8]

first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0
profiler_steps = config.profiler_steps if config else 0
last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1)

profiler = None
for step in range(num_inference_steps):
if config and max_utils.profiler_enabled(config) and step == first_profiling_step:
profiler = max_utils.Profiler(config)
profiler.start()

t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]

if use_magcache and do_cfg:
Expand Down Expand Up @@ -328,4 +341,8 @@ def run_inference_2_1(

latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()

if config and max_utils.profiler_enabled(config) and step == last_profiling_step:
if profiler:
profiler.stop()

return latents
Loading
Loading