diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 51ce45744..2a3fae518 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -255,6 +255,8 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" ) videos = call_pipeline(config, pipeline, prompt, negative_prompt) + if isinstance(videos, tuple): + videos = videos[0] max_logging.log("===================== Model details =======================") max_logging.log(f"model name: {config.model_name}") @@ -278,7 +280,12 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): upload_video_to_gcs(os.path.join(config.output_dir, config.run_name), video_path) s0 = time.perf_counter() - videos = call_pipeline(config, pipeline, prompt, negative_prompt) + outputs = call_pipeline(config, pipeline, prompt, negative_prompt) + if isinstance(outputs, tuple): + videos, trace = outputs + else: + videos = outputs + trace = {} generation_time = time.perf_counter() - s0 max_logging.log(f"generation_time: {generation_time}") if writer and jax.process_index() == 0: @@ -291,18 +298,39 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): max_logging.log(f"generation time per video: {generation_time_per_video}") else: max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") - max_logging.log( - f"\n{'=' * 50}\n" - f" TIMING SUMMARY\n" - f"{'=' * 50}\n" - f" Load (checkpoint): {load_time:>7.1f}s\n" - f" Compile: {compile_time:>7.1f}s\n" - f" {'─' * 40}\n" - f" Inference: {generation_time:>7.1f}s\n" - f"{'=' * 50}" - ) + summary = [ + f"\n{'=' * 50}", + " TIMING SUMMARY", + f"{'=' * 50}", + f" Load (checkpoint): {load_time:>7.1f}s", + f" Compile: {compile_time:>7.1f}s", + f" {'─' * 40}", + f" Inference: {generation_time:>7.1f}s", + ] + if trace: + summary.extend([ + f" Conditioning: {trace.get('conditioning', 0.0):>7.1f}s", + f" Denoise Total: {trace.get('denoise_total', 0.0):>7.1f}s", + f" VAE Decode: {trace.get('vae_decode', 0.0):>7.1f}s", + ]) + summary.append(f"{'=' * 50}") + max_logging.log("\n".join(summary)) - videos = call_pipeline(config, pipeline, prompt, negative_prompt) + s0 = time.perf_counter() + if max_utils.profiler_enabled(config): + # Injecting user requested XLA tracing flags + xla_flags = os.environ.get("XLA_FLAGS", "") + new_flags = "--xla_enable_mxu_trace=true --xla_jf_dump_llo_html=true --xla_tpu_enable_llo_profiling=true" + os.environ["XLA_FLAGS"] = f"{xla_flags} {new_flags}" + max_logging.log(f"Injected XLA_FLAGS for profiling: {new_flags}") + + videos = call_pipeline(config, pipeline, prompt, negative_prompt) + if isinstance(videos, tuple): + videos = videos[0] + generation_time_with_profiler = time.perf_counter() - s0 + max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") + if writer and jax.process_index() == 0: + writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) return saved_video_path diff --git a/src/maxdiffusion/kernels/custom_splash_attention.py b/src/maxdiffusion/kernels/custom_splash_attention.py new file mode 100644 index 000000000..078f9036c --- /dev/null +++ b/src/maxdiffusion/kernels/custom_splash_attention.py @@ -0,0 +1,754 @@ +""" +Copyright 2026 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Custom Pallas flash attention kernel for TPU.""" + +import functools +import math + +import jax +import jax.numpy as jnp +import numpy as np +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.shard_map import shard_map +from jax.sharding import PartitionSpec as P + +DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) +NUM_LANES = 128 +NUM_SUBLANES = 8 +NT_DIM_NUMBERS = (((1,), (1,)), ((), ())) + +# Default block sizes (tuned for 720p Wan2.1 on v6e/v7x) +DEFAULT_BQSIZE = 3328 +DEFAULT_BKVSIZE = 2816 +# Cranked up to 1024 for massive MXU throughput +DEFAULT_BKVCOMPUTESIZE = 1024 +# Kept at 256 to protect VPU registers (V1 Optimization) +DEFAULT_BKVCOMPUTEINSIZE = 256 + + +class _BlockSizes: + __slots__ = ("block_q", "block_kv", "block_kv_compute") + + def __init__(self, block_q: int, block_kv: int, block_kv_compute: int | None = None): + self.block_q = block_q + self.block_kv = block_kv + self.block_kv_compute = block_kv_compute if block_kv_compute is not None else block_kv + + +def _flash_attention_kernel( + q_ref, + k_ref, + v_ref, + m_scratch_ref, + l_scratch_ref, + o_scratch_ref, + o_ref, + *, + mask_value: float, + grid_width: int, + bq: int, + bkv: int, + bkv_compute: int, + bkv_compute_in: int, + head_dim_v: int, + q_seq_len: int, + kv_seq_len: int, + use_base2_exp: bool = True, +): + float32 = jnp.float32 + head_dim_v_repeats, rem = divmod(head_dim_v, NUM_SUBLANES) + if rem != 0: + raise NotImplementedError(f"{head_dim_v=} should be a multiple of {NUM_SUBLANES}") + + _, _, j = pl.program_id(0), pl.program_id(1), pl.program_id(2) + exp = jnp.exp2 if use_base2_exp else jnp.exp + + @pl.when(j == 0) + def init(): + o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref) + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) + l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) + + def compute_body(kv_compute_index, _): + m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + q = q_ref[...] + o_prev = o_scratch_ref[:] + + base_offset = kv_compute_index * bkv_compute + slice_k = pl.ds(base_offset, bkv_compute) + k_chunk = k_ref[slice_k, :] + + qk = lax.dot_general(k_chunk, q, NT_DIM_NUMBERS, preferred_element_type=float32) + v_chunk = v_ref[slice_k, :] + + # --- V1 VPU REGISTER TILING --- + step = bkv_compute_in + for i in range(0, qk.shape[0], step): + qk_slice = qk[i : i + step] + + m_curr = qk_slice.max(axis=0)[None, :] + m_next = jnp.maximum(m_prev, m_curr) + s_curr = exp(qk_slice - m_next[0:1]) + l_curr = s_curr.sum(axis=0, keepdims=True) + + alpha = exp(m_prev - m_next) + l_next = l_curr + alpha * l_prev + + sv_dims = (((0,), (0,)), ((), ())) + o_curr = lax.dot_general( + v_chunk[i : i + step], + s_curr.astype(q_ref.dtype), + sv_dims, + preferred_element_type=float32, + ) + + alpha_o = alpha[0:1, ...] + o_prev = alpha_o * o_prev + o_curr + + m_prev, l_prev = m_next, l_next + # --- END V1 TILING --- + + m_scratch_ref[...], l_scratch_ref[...] = m_prev, l_prev + o_scratch_ref[:] = o_prev + + def last_compute_body(kv_compute_index): + m_prev, l_prev = m_scratch_ref[...], l_scratch_ref[...] + q = q_ref[...] + o_prev = o_scratch_ref[:] + + slice_k_len = kv_seq_len % bkv_compute + slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len) + k_chunk = k_ref[slice_k, :] + + qk = lax.dot_general(k_chunk, q, NT_DIM_NUMBERS, preferred_element_type=float32) + v_chunk = v_ref[slice_k, :] + + # --- V1 VPU REGISTER TILING --- + step = bkv_compute_in + for i in range(0, qk.shape[0], step): + qk_slice = qk[i : i + step] + + m_curr = qk_slice.max(axis=0)[None, :] + m_next = jnp.maximum(m_prev, m_curr) + s_curr = exp(qk_slice - m_next[0:1]) + l_curr = s_curr.sum(axis=0, keepdims=True) + + alpha = exp(m_prev - m_next) + l_next = l_curr + alpha * l_prev + + sv_dims = (((0,), (0,)), ((), ())) + o_curr = lax.dot_general( + v_chunk[i : i + step], + s_curr.astype(q_ref.dtype), + sv_dims, + preferred_element_type=float32, + ) + + alpha_o = alpha[0:1, ...] + o_prev = alpha_o * o_prev + o_curr + + m_prev, l_prev = m_next, l_next + # --- END V1 TILING --- + + m_scratch_ref[...], l_scratch_ref[...] = m_prev, l_prev + o_scratch_ref[:] = o_prev + + assert bkv % bkv_compute == 0 + + @pl.when(j != grid_width - 1) + def body(): + lax.fori_loop(0, (bkv // bkv_compute), compute_body, None, unroll=True) + + @pl.when(j == grid_width - 1) + def last_body(): + if kv_seq_len % bkv == 0: + iter_num = bkv // bkv_compute + lax.fori_loop(0, iter_num, compute_body, None, unroll=True) + else: + remain_kv_seq_len = kv_seq_len % bkv + iter_num = (remain_kv_seq_len + bkv_compute - 1) // bkv_compute + if remain_kv_seq_len % bkv_compute == 0: + lax.fori_loop(0, iter_num, compute_body, None, unroll=True) + else: + lax.fori_loop(0, iter_num - 1, compute_body, None, unroll=True) + last_compute_body(iter_num - 1) + + @pl.when(j == grid_width - 1) + def end(): + l = l_scratch_ref[...] + l_inv = pltpu.repeat(1.0 / l, head_dim_v_repeats, axis=0) + o_ref[...] = (o_scratch_ref[...] * l_inv).astype(o_ref.dtype) + + +def _flash_attention_kernel_mhpt( + q_ref, + k_ref, + v_ref, + m_scratch_ref, + l_scratch_ref, + o_scratch_ref, + o_ref, + *, + mask_value: float, + grid_width: int, + bq: int, + bkv: int, + bkv_compute: int, + bkv_compute_in: int, + head_dim_v: int, + q_seq_len: int, + kv_seq_len: int, + heads_per_tile: int, + use_base2_exp: bool = True, +): + float32 = jnp.float32 + head_dim_v_repeats, rem = divmod(head_dim_v, NUM_SUBLANES) + if rem != 0: + raise NotImplementedError(f"{head_dim_v=} should be a multiple of {NUM_SUBLANES}") + + _, _, j = pl.program_id(0), pl.program_id(1), pl.program_id(2) + exp = jnp.exp2 if use_base2_exp else jnp.exp + + @pl.when(j == 0) + def init(): + o_scratch_ref[...] = jnp.zeros_like(o_scratch_ref) + m_scratch_ref[...] = jnp.full_like(m_scratch_ref, mask_value) + l_scratch_ref[...] = jnp.zeros_like(l_scratch_ref) + + def compute_body(kv_compute_index, _): + base_offset = kv_compute_index * bkv_compute + slice_k = pl.ds(base_offset, bkv_compute) + + for h_local in range(heads_per_tile): + m_prev = m_scratch_ref[h_local] + l_prev = l_scratch_ref[h_local] + q = q_ref[h_local] + o_prev = o_scratch_ref[h_local] + + k_chunk = k_ref[h_local, slice_k, :] + qk = lax.dot_general(k_chunk, q, NT_DIM_NUMBERS, preferred_element_type=float32) + v_chunk = v_ref[h_local, slice_k, :] + + # --- V1 VPU REGISTER TILING --- + step = bkv_compute_in + for i in range(0, qk.shape[0], step): + qk_slice = qk[i : i + step] + + m_curr = qk_slice.max(axis=0)[None, :] + m_next = jnp.maximum(m_prev, m_curr) + s_curr = exp(qk_slice - m_next[0:1]) + l_curr = s_curr.sum(axis=0, keepdims=True) + + alpha = exp(m_prev - m_next) + l_next = l_curr + alpha * l_prev + + sv_dims = (((0,), (0,)), ((), ())) + o_curr = lax.dot_general( + v_chunk[i : i + step], + s_curr.astype(q_ref.dtype), + sv_dims, + preferred_element_type=float32, + ) + + alpha_o = alpha[0:1, ...] + o_prev = alpha_o * o_prev + o_curr + + m_prev, l_prev = m_next, l_next + # --- END V1 TILING --- + + m_scratch_ref[h_local] = m_prev + l_scratch_ref[h_local] = l_prev + o_scratch_ref[h_local] = o_prev + + def last_compute_body(kv_compute_index): + slice_k_len = kv_seq_len % bkv_compute + slice_k = pl.ds(kv_compute_index * bkv_compute, slice_k_len) + + for h_local in range(heads_per_tile): + m_prev = m_scratch_ref[h_local] + l_prev = l_scratch_ref[h_local] + q = q_ref[h_local] + o_prev = o_scratch_ref[h_local] + + k_chunk = k_ref[h_local, slice_k, :] + qk = lax.dot_general(k_chunk, q, NT_DIM_NUMBERS, preferred_element_type=float32) + v_chunk = v_ref[h_local, slice_k, :] + + # --- V1 VPU REGISTER TILING --- + step = bkv_compute_in + for i in range(0, qk.shape[0], step): + qk_slice = qk[i : i + step] + + m_curr = qk_slice.max(axis=0)[None, :] + m_next = jnp.maximum(m_prev, m_curr) + s_curr = exp(qk_slice - m_next[0:1]) + l_curr = s_curr.sum(axis=0, keepdims=True) + + alpha = exp(m_prev - m_next) + l_next = l_curr + alpha * l_prev + + sv_dims = (((0,), (0,)), ((), ())) + o_curr = lax.dot_general( + v_chunk[i : i + step], + s_curr.astype(q_ref.dtype), + sv_dims, + preferred_element_type=float32, + ) + + alpha_o = alpha[0:1, ...] + o_prev = alpha_o * o_prev + o_curr + + m_prev, l_prev = m_next, l_next + # --- END V1 TILING --- + + m_scratch_ref[h_local] = m_prev + l_scratch_ref[h_local] = l_prev + o_scratch_ref[h_local] = o_prev + + assert bkv % bkv_compute == 0 + + @pl.when(j != grid_width - 1) + def body(): + lax.fori_loop(0, (bkv // bkv_compute), compute_body, None, unroll=True) + + @pl.when(j == grid_width - 1) + def last_body(): + if kv_seq_len % bkv == 0: + iter_num = bkv // bkv_compute + lax.fori_loop(0, iter_num, compute_body, None, unroll=True) + else: + remain_kv_seq_len = kv_seq_len % bkv + iter_num = (remain_kv_seq_len + bkv_compute - 1) // bkv_compute + if remain_kv_seq_len % bkv_compute == 0: + lax.fori_loop(0, iter_num, compute_body, None, unroll=True) + else: + lax.fori_loop(0, iter_num - 1, compute_body, None, unroll=True) + last_compute_body(iter_num - 1) + + @pl.when(j == grid_width - 1) + def end(): + for h_local in range(heads_per_tile): + l = l_scratch_ref[h_local] + l_inv = pltpu.repeat(1.0 / l, head_dim_v_repeats, axis=0) + o_ref[h_local] = (o_scratch_ref[h_local] * l_inv).astype(o_ref.dtype) + + +def _splash_attention_forward( + q: jax.Array, + k: jax.Array, + v: jax.Array, + block_sizes: _BlockSizes, + bkv_compute_in: int, + q_seq_len: int | None = None, + kv_seq_len: int | None = None, + use_base2_exp: bool = True, + use_experimental_scheduler: bool = False, +): + num_q_heads, padded_q_seq_len, head_dim_qk = q.shape + head_dim_v = v.shape[-1] + bq, bkv = block_sizes.block_q, block_sizes.block_kv + bkv_compute = block_sizes.block_kv_compute + num_kv_heads = k.shape[0] + padded_kv_seq_len = k.shape[1] + + actual_q_seq_len = q_seq_len if q_seq_len is not None else padded_q_seq_len + actual_kv_seq_len = kv_seq_len if kv_seq_len is not None else padded_kv_seq_len + q_heads_per_kv_head = num_q_heads // num_kv_heads + + def q_index_map(h, i, j, *_): + return (h, i, 0) + + def out_index_map(h, i, j, *_): + return h, 0, i + + def k_index_map(h, i, j, *_): + return (h // q_heads_per_kv_head, j, 0) + + def v_index_map(h, i, j, *_): + return (h // q_heads_per_kv_head, j, 0) + + in_specs = [ + pl.BlockSpec((None, bq, head_dim_qk), q_index_map), + pl.BlockSpec((None, bkv, head_dim_qk), k_index_map), + pl.BlockSpec((None, bkv, head_dim_v), v_index_map), + ] + out_shapes = [ + jax.ShapeDtypeStruct((NUM_SUBLANES, bq), jnp.float32), + jax.ShapeDtypeStruct((NUM_SUBLANES, bq), jnp.float32), + jax.ShapeDtypeStruct((head_dim_v, bq), jnp.float32), + jax.ShapeDtypeStruct((num_q_heads, head_dim_v, actual_q_seq_len), q.dtype), + ] + out_specs = [ + pl.BlockSpec((NUM_SUBLANES, bq), lambda *_: (0, 0)), + pl.BlockSpec((NUM_SUBLANES, bq), lambda *_: (0, 0)), + pl.BlockSpec((head_dim_v, bq), lambda *_: (0, 0)), + pl.BlockSpec((None, head_dim_v, bq), out_index_map), + ] + grid_width = (actual_kv_seq_len + bkv - 1) // bkv + grid_height = (actual_q_seq_len + bq - 1) // bq + grid = (num_q_heads, grid_height, grid_width) + + all_out = pl.pallas_call( + functools.partial( + _flash_attention_kernel, + mask_value=DEFAULT_MASK_VALUE, + grid_width=grid_width, + bq=bq, + bkv=bkv, + bkv_compute=bkv_compute, + bkv_compute_in=bkv_compute_in, + head_dim_v=head_dim_v, + q_seq_len=actual_q_seq_len, + kv_seq_len=actual_kv_seq_len, + use_base2_exp=use_base2_exp, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler}, + disable_bounds_checks=True, + skip_device_barrier=True, + ), + out_shape=out_shapes, + )(q, k, v) + return all_out[-1] + + +def _splash_attention_forward_mhpt( + q: jax.Array, + k: jax.Array, + v: jax.Array, + block_sizes: _BlockSizes, + bkv_compute_in: int, + heads_per_tile: int, + q_seq_len: int | None = None, + kv_seq_len: int | None = None, + use_base2_exp: bool = True, + use_experimental_scheduler: bool = False, +): + num_q_heads, padded_q_seq_len, head_dim_qk = q.shape + head_dim_v = v.shape[-1] + bq, bkv = block_sizes.block_q, block_sizes.block_kv + bkv_compute = block_sizes.block_kv_compute + num_kv_heads = k.shape[0] + actual_q_seq_len = q_seq_len if q_seq_len is not None else padded_q_seq_len + actual_kv_seq_len = kv_seq_len if kv_seq_len is not None else k.shape[1] + hpt = heads_per_tile + + assert num_q_heads % hpt == 0, f"num_heads {num_q_heads} must be divisible by heads_per_tile {hpt}" + assert num_q_heads == num_kv_heads, "MHPT currently requires num_q_heads == num_kv_heads (no GQA)" + + def q_index_map(h, i, j, *_): + return (h, i, 0) + + def k_index_map(h, i, j, *_): + return (h, j, 0) + + def v_index_map(h, i, j, *_): + return (h, j, 0) + + def out_index_map(h, i, j, *_): + return (h, 0, i) + + in_specs = [ + pl.BlockSpec((hpt, bq, head_dim_qk), q_index_map), + pl.BlockSpec((hpt, bkv, head_dim_qk), k_index_map), + pl.BlockSpec((hpt, bkv, head_dim_v), v_index_map), + ] + out_shapes = [ + jax.ShapeDtypeStruct((hpt, NUM_SUBLANES, bq), jnp.float32), + jax.ShapeDtypeStruct((hpt, NUM_SUBLANES, bq), jnp.float32), + jax.ShapeDtypeStruct((hpt, head_dim_v, bq), jnp.float32), + jax.ShapeDtypeStruct((num_q_heads, head_dim_v, actual_q_seq_len), q.dtype), + ] + out_specs = [ + pl.BlockSpec((hpt, NUM_SUBLANES, bq), lambda *_: (0, 0, 0)), + pl.BlockSpec((hpt, NUM_SUBLANES, bq), lambda *_: (0, 0, 0)), + pl.BlockSpec((hpt, head_dim_v, bq), lambda *_: (0, 0, 0)), + pl.BlockSpec((hpt, head_dim_v, bq), out_index_map), + ] + grid_width = (actual_kv_seq_len + bkv - 1) // bkv + grid_height = (actual_q_seq_len + bq - 1) // bq + grid = (num_q_heads // hpt, grid_height, grid_width) + + all_out = pl.pallas_call( + functools.partial( + _flash_attention_kernel_mhpt, + mask_value=DEFAULT_MASK_VALUE, + grid_width=grid_width, + bq=bq, + bkv=bkv, + bkv_compute=bkv_compute, + bkv_compute_in=bkv_compute_in, + head_dim_v=head_dim_v, + q_seq_len=actual_q_seq_len, + kv_seq_len=actual_kv_seq_len, + heads_per_tile=hpt, + use_base2_exp=use_base2_exp, + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=in_specs, + out_specs=out_specs, + grid=grid, + ), + compiler_params=pltpu.CompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": use_experimental_scheduler}, + disable_bounds_checks=True, + skip_device_barrier=True, + ), + out_shape=out_shapes, + )(q, k, v) + return all_out[-1] + + +def make_splash_mha( + block_sizes: _BlockSizes, + bkv_compute_in: int = DEFAULT_BKVCOMPUTEINSIZE, + orig_q_seq_len: int | None = None, + orig_kv_seq_len: int | None = None, + heads_per_tile: int = 1, + use_base2_exp: bool = True, + use_experimental_scheduler: bool = False, +): + def _splash_attention(q, k, v): + if heads_per_tile > 1: + return _splash_attention_forward_mhpt( + q, + k, + v, + block_sizes, + bkv_compute_in, + heads_per_tile, + q_seq_len=orig_q_seq_len, + kv_seq_len=orig_kv_seq_len, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ) + return _splash_attention_forward( + q, + k, + v, + block_sizes, + bkv_compute_in, + q_seq_len=orig_q_seq_len, + kv_seq_len=orig_kv_seq_len, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ) + + return _splash_attention + + +# --------------------------------------------------------------------------- +# High-level attention function with shard_map +# --------------------------------------------------------------------------- + + +def tpu_custom_attention( + query, + key, + value, + mesh, + *, + scale=None, + block_q=None, + block_kv=None, + block_kv_compute=None, + block_kv_compute_in=None, + heads_per_tile=None, + use_base2_exp=True, + use_experimental_scheduler=False, + flash_block_sizes=None, +): + _LOG2_E = 1.44269504 + num_heads = query.shape[1] + + if flash_block_sizes is not None: + block_q = flash_block_sizes.get("block_q", block_q) + block_kv = flash_block_sizes.get("block_kv", block_kv) + block_kv_compute = flash_block_sizes.get("block_kv_compute", block_kv_compute) + block_kv_compute_in = flash_block_sizes.get("block_kv_compute_in", block_kv_compute_in) + heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile) + + block_q = block_q if block_q is not None else DEFAULT_BQSIZE + block_kv = block_kv if block_kv is not None else DEFAULT_BKVSIZE + block_kv_compute = block_kv_compute if block_kv_compute is not None else DEFAULT_BKVCOMPUTESIZE + block_kv_compute_in = block_kv_compute_in if block_kv_compute_in is not None else DEFAULT_BKVCOMPUTEINSIZE + heads_per_tile = heads_per_tile if heads_per_tile is not None else 1 + + def _attention_on_slices(q, k, v): + scale_factor = 1.0 / math.sqrt(q.shape[-1]) if scale is None else scale + if use_base2_exp: + q = q * scale_factor * _LOG2_E + else: + q = q * scale_factor + + def _pad_to_multiple(x, multiple, axis): + seq_len = x.shape[axis] + pad_len = (multiple - seq_len % multiple) % multiple + if pad_len == 0: + return x, seq_len + pad_width = [(0, 0)] * x.ndim + pad_width[axis] = (0, pad_len) + return jnp.pad(x, pad_width), seq_len + + def _kernel_3d(q_3d, k_3d, v_3d): + q_orig_len = q_3d.shape[1] + kv_orig_len = k_3d.shape[1] + + q_3d_padded, _ = _pad_to_multiple(q_3d, block_q, axis=1) + k_3d_padded, _ = _pad_to_multiple(k_3d, block_kv, axis=1) + v_3d_padded, _ = _pad_to_multiple(v_3d, block_kv, axis=1) + + padded_q_seq_len = q_3d_padded.shape[1] + padded_kv_seq_len = k_3d_padded.shape[1] + + bsizes = _BlockSizes( + block_q=min(block_q, padded_q_seq_len), + block_kv=min(block_kv, padded_kv_seq_len), + block_kv_compute=min(block_kv_compute, padded_kv_seq_len), + ) + splash_kernel = make_splash_mha( + block_sizes=bsizes, + bkv_compute_in=block_kv_compute_in, + orig_q_seq_len=q_orig_len, + orig_kv_seq_len=kv_orig_len, + heads_per_tile=heads_per_tile, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ) + out = splash_kernel( + q_3d_padded.astype(jnp.bfloat16), + k_3d_padded, + v_3d_padded, + ) + out = jnp.swapaxes(out, 1, 2) + return out[:, :q_orig_len, ...] + + return jax.vmap(_kernel_3d, in_axes=(0, 0, 0), out_axes=0)(q, k, v) + + batch_size = query.shape[0] + if num_heads < mesh.size: + q_partition_spec = P() + kv_partition_spec = P() + out_constraint = P() + else: + axis_names = mesh.axis_names + if len(axis_names) == 1: + tp_axis = axis_names[0] + q_partition_spec = P(None, tp_axis, None, None) + kv_partition_spec = P(None, tp_axis, None, None) + out_constraint = P(None, None, tp_axis, None) + elif len(axis_names) == 2: + dp_axis, tp_axis = axis_names[0], axis_names[1] + dp_size = mesh.shape[dp_axis] + if batch_size >= dp_size: + q_partition_spec = P(dp_axis, tp_axis, None, None) + kv_partition_spec = P(dp_axis, tp_axis, None, None) + out_constraint = P(dp_axis, None, tp_axis, None) + else: + all_axes = tuple(axis_names) + q_partition_spec = P(None, all_axes, None, None) + kv_partition_spec = P(None, all_axes, None, None) + out_constraint = P(None, None, all_axes, None) + else: + q_partition_spec = P(axis_names[0], axis_names[1], axis_names[2], None) + kv_partition_spec = P(axis_names[0], axis_names[1], None, None) + out_constraint = P(axis_names[0], None, (axis_names[1], axis_names[2]), None) + + sharded_fn = shard_map( + _attention_on_slices, + mesh=mesh, + in_specs=(q_partition_spec, kv_partition_spec, kv_partition_spec), + out_specs=q_partition_spec, + check_rep=False, + ) + out = sharded_fn(query, key, value) + out = jax.lax.with_sharding_constraint(out, out_constraint) + return out + + +# --------------------------------------------------------------------------- +# TorchAX SDPA wrapper +# --------------------------------------------------------------------------- + + +def make_custom_splash_sdpa(mesh, env, **kwargs): + flash_block_sizes = kwargs.get("flash_block_sizes", None) + bq = kwargs.get("block_q", DEFAULT_BQSIZE) + bkv = kwargs.get("block_kv", DEFAULT_BKVSIZE) + bkv_compute = kwargs.get("block_kv_compute", DEFAULT_BKVCOMPUTESIZE) + bkv_compute_in = kwargs.get("block_kv_compute_in", DEFAULT_BKVCOMPUTEINSIZE) + hpt = kwargs.get("heads_per_tile", 1) + use_k_smooth = kwargs.get("use_k_smooth", True) + use_base2_exp = kwargs.get("use_base2_exp", True) + use_experimental_scheduler = kwargs.get("use_experimental_scheduler", False) + + def _simple_attention(q, k, v, scale=None): + s = scale if scale is not None else 1.0 / math.sqrt(q.shape[-1]) + attn = jnp.einsum("bhsd,bhtd->bhst", q * s, k) + attn = jax.nn.softmax(attn.astype(jnp.float32), axis=-1).astype(q.dtype) + return jnp.einsum("bhst,bhtd->bhsd", attn, v) + + def _sdpa( + query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + scale=None, + enable_gqa=False, + ): + jquery, jkey, jvalue = env.t2j_iso((query, key, value)) + num_heads = jquery.shape[1] + + if num_heads <= 8: + result = _simple_attention(jquery, jkey, jvalue, scale=scale) + return env.j2t_iso(result) + + if use_k_smooth: + key_mean = jnp.mean(jkey, axis=2, keepdims=True) + jkey = jkey - key_mean + + result = tpu_custom_attention( + jquery, + jkey, + jvalue, + mesh, + scale=scale, + block_q=bq, + block_kv=bkv, + block_kv_compute=bkv_compute, + block_kv_compute_in=bkv_compute_in, + heads_per_tile=hpt, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + flash_block_sizes=flash_block_sizes, + ) + return env.j2t_iso(result) + + return _sdpa diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 53596eeaf..b795f5417 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -31,6 +31,7 @@ from einops import rearrange from .. import common_types, max_logging +from ..kernels import custom_splash_attention as custom_splash from . import quantizations from .modeling_flax_utils import get_activation @@ -521,6 +522,9 @@ def _ulysses_attention( mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, attention_mask: jax.Array = None, + use_custom_kernel: bool = False, + use_base2_exp: bool = True, + use_experimental_scheduler: bool = False, ) -> jax.Array: """Ulysses sequence-parallel attention. @@ -544,7 +548,9 @@ def _ulysses_attention( "Ulysses attention requires the number of heads to be divisible by the context shard count, " f"got heads={num_heads} and context_shards={num_shards}." ) - block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash") + + if not use_custom_kernel: + block_sizes = _select_flash_block_sizes(query, key, flash_block_sizes, dtype, "flash") q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) @@ -563,65 +569,112 @@ def wrap_ulysses_attention(query, key, value): key = jax.lax.all_to_all(key, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) value = jax.lax.all_to_all(value, axis_name=axis_name, split_axis=1, concat_axis=2, tiled=True) - # Run the same local splash kernel as standard TPU flash attention, but now - # on full-sequence / fewer-heads tensors produced by the all-to-all above. - uses_fused_kernel = block_sizes.use_fused_bwd_kernel - block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv) - block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv) - if uses_fused_kernel: - block_q_sizes += (block_sizes.block_q_dkv,) - block_kv_sizes += (block_sizes.block_kv_dkv,) - else: - block_q_sizes += (block_sizes.block_q_dq,) - block_kv_sizes += (block_sizes.block_kv_dq,) - - block_q = max(*block_q_sizes) - query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) - block_kv = max(*block_kv_sizes) - key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) - value, _, _ = _pad_data_for_flash(value, heads, block_kv) - - mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) - - q_padded_len = query.shape[2] - q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) - q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) - - kv_padded_len = key.shape[2] - kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) - kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) + if use_custom_kernel: + bq = 4864 + bkv = 1024 + bkv_compute = 1024 + bkv_compute_in = 1024 + heads_per_tile = 1 + + if flash_block_sizes is not None: + if isinstance(flash_block_sizes, dict): + bq = flash_block_sizes.get("block_q", bq) + bkv = flash_block_sizes.get("block_kv", bkv) + bkv_compute = flash_block_sizes.get("block_kv_compute", bkv_compute) + bkv_compute_in = flash_block_sizes.get("block_kv_compute_in", bkv_compute_in) + heads_per_tile = flash_block_sizes.get("heads_per_tile", heads_per_tile) + else: + bq = getattr(flash_block_sizes, "block_q", bq) + bkv = getattr(flash_block_sizes, "block_kv", bkv) + bkv_compute = getattr(flash_block_sizes, "block_kv_compute", bkv_compute) + bkv_compute_in = getattr(flash_block_sizes, "block_kv_compute_in", bkv_compute_in) + heads_per_tile = getattr(flash_block_sizes, "heads_per_tile", heads_per_tile) + + if use_base2_exp: + query_scaled = query * 1.44269504 + else: + query_scaled = query + + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, bq) + key, _, key_seq_len = _pad_data_for_flash(key, heads, bkv) + value, _, _ = _pad_data_for_flash(value, heads, bkv) + + bsizes = custom_splash._BlockSizes(block_q=bq, block_kv=bkv, block_kv_compute=bkv_compute) + + splash_kernel = custom_splash.make_splash_mha( + block_sizes=bsizes, + bkv_compute_in=bkv_compute_in, + orig_q_seq_len=query_seq_len, + orig_kv_seq_len=key_seq_len, + heads_per_tile=heads_per_tile, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ) - # Reuse the standard flash-attention masking convention by zeroing invalid - # KV positions in the segment ids passed down to splash. - if attention_mask is not None: - mask_len = min(key_seq_len, attention_mask.shape[1]) - kv_mask_for_batch = attention_mask[0, :mask_len] - if key_seq_len > mask_len: - extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) - kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) - if kv_padded_len > key_seq_len: - padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) - kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0)) + attention_output = vmapped_splash(query_scaled, key, value) + attention_output = jnp.swapaxes(attention_output, 2, 3) + attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) + else: + # Run the same local splash kernel as standard TPU flash attention, but now + # on full-sequence / fewer-heads tensors produced by the all-to-all above. + uses_fused_kernel = block_sizes.use_fused_bwd_kernel + block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv) + block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv) + if uses_fused_kernel: + block_q_sizes += (block_sizes.block_q_dkv,) + block_kv_sizes += (block_sizes.block_kv_dkv,) else: - kv_mask_padded = kv_mask_for_batch - kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) + block_q_sizes += (block_sizes.block_q_dq,) + block_kv_sizes += (block_sizes.block_kv_dq,) + + block_q = max(*block_q_sizes) + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_q) + block_kv = max(*block_kv_sizes) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_kv) + value, _, _ = _pad_data_for_flash(value, heads, block_kv) + + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + + q_padded_len = query.shape[2] + q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) + q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) + + kv_padded_len = key.shape[2] + kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) + kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) + + # Reuse the standard flash-attention masking convention by zeroing invalid + # KV positions in the segment ids passed down to splash. + if attention_mask is not None: + mask_len = min(key_seq_len, attention_mask.shape[1]) + kv_mask_for_batch = attention_mask[0, :mask_len] + if key_seq_len > mask_len: + extra_valid = jnp.ones((key_seq_len - mask_len,), dtype=jnp.int32) + kv_mask_for_batch = jnp.concatenate([kv_mask_for_batch, extra_valid], axis=0) + if kv_padded_len > key_seq_len: + padding = jnp.zeros((kv_padded_len - key_seq_len,), dtype=jnp.int32) + kv_mask_padded = jnp.concatenate([kv_mask_for_batch, padding], axis=0) + else: + kv_mask_padded = kv_mask_for_batch + kv_segment_ids = (kv_segment_ids * kv_mask_padded).astype(jnp.int32) - segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) - if not mask_padding_tokens: - segment_ids = None + segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + if not mask_padding_tokens: + segment_ids = None - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, - head_shards=1, - q_seq_shards=1, - block_sizes=block_sizes, - save_residuals=False, - residual_checkpoint_name=residual_checkpoint_name, - ) - vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) - attention_output = vmapped_splash(query, key, value, segment_ids) - attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=1, + q_seq_shards=1, + block_sizes=block_sizes, + save_residuals=False, + residual_checkpoint_name=residual_checkpoint_name, + ) + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) + attention_output = vmapped_splash(query, key, value, segment_ids) + attention_output = attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) # Restore the original layout expected by the rest of the model: # head-sharded / full-sequence -> sequence-sharded / full-heads. @@ -734,6 +787,138 @@ def _cudnn_flash_attention(query: Array, key: Array, value: Array, heads: int, m return _reshape_data_from_cudnn_flash(out) +KERNEL_REGISTRY = {} + + +def register_kernel(name: str): + def decorator(func): + KERNEL_REGISTRY[name] = func + return func + + return decorator + + +# Register existing kernels at module level with context dict +@register_kernel("dot_product") +def dot_product_kernel(q, k, v, context): + return _apply_attention_dot( + q, + k, + v, + context["dtype"], + context["heads"], + context["dim_head"], + context["scale"], + context["split_head_dim"], + context["float32_qk_product"], + context["use_memory_efficient_attention"], + ) + + +@register_kernel("ulysses_custom") +def ulysses_custom_kernel(q, k, v, context): + return _ulysses_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + use_custom_kernel=True, + use_base2_exp=context.get("use_base2_exp", True), + use_experimental_scheduler=context.get("use_experimental_scheduler", False), + ) + + +@register_kernel("ulysses") +def ulysses_kernel(q, k, v, context): + return _ulysses_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + ) + + +@register_kernel("flash") +def flash_kernel(q, k, v, context): + return _tpu_flash_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + attention_kernel="flash", + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + use_base2_exp=context["use_base2_exp"], + use_experimental_scheduler=context["use_experimental_scheduler"], + ) + + +@register_kernel("tokamax_flash") +def tokamax_flash_kernel(q, k, v, context): + return _tpu_flash_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + attention_kernel="tokamax_flash", + mask_padding_tokens=context["mask_padding_tokens"], + residual_checkpoint_name=context["residual_checkpoint_name"], + attention_mask=context["attention_mask"], + use_base2_exp=context["use_base2_exp"], + use_experimental_scheduler=context["use_experimental_scheduler"], + ) + + +@register_kernel("tokamax_ring") +def tokamax_ring_kernel(q, k, v, context): + return _tpu_flash_attention( + q, + k * context["scale"], + v, + context["heads"], + context["mesh"], + context["axis_names_q"], + context["axis_names_kv"], + context["flash_block_sizes"], + context["dtype"], + attention_kernel="tokamax_ring", + mask_padding_tokens=context["mask_padding_tokens"], + attention_mask=context["attention_mask"], + ) + + +@register_kernel("cudnn_flash_te") +def cudnn_flash_te_kernel(q, k, v, context): + return _cudnn_flash_attention(q, k, v, context["heads"], context["mesh"], context["dpa_layer"]) + + def _apply_attention( query: Array, key: Array, @@ -758,75 +943,50 @@ def _apply_attention( use_base2_exp: bool = False, use_experimental_scheduler: bool = False, ): - """Routes to different attention kernels.""" + """Routes to different attention kernels using a module-level registry.""" + _check_attention_inputs(query, key, value) seq_len_idx = 1 if query.ndim == 4: seq_len_idx = 2 - if attention_kernel in ["flash", "tokamax_flash", "ulysses"]: + + can_use_flash_attention = True + if attention_kernel in ["flash", "tokamax_flash", "ulysses", "ulysses_custom"]: can_use_flash_attention = ( query.shape[seq_len_idx] >= flash_min_seq_length and key.shape[seq_len_idx] >= flash_min_seq_length and value.shape[seq_len_idx] >= flash_min_seq_length ) - else: - can_use_flash_attention = True + + # Fallback logic + context = { + "heads": heads, + "mesh": mesh, + "axis_names_q": axis_names_q, + "axis_names_kv": axis_names_kv, + "flash_block_sizes": flash_block_sizes, + "dtype": dtype, + "mask_padding_tokens": mask_padding_tokens, + "residual_checkpoint_name": residual_checkpoint_name, + "attention_mask": attention_mask, + "scale": scale, + "use_base2_exp": use_base2_exp, + "use_experimental_scheduler": use_experimental_scheduler, + "dim_head": dim_head, + "split_head_dim": split_head_dim, + "float32_qk_product": float32_qk_product, + "use_memory_efficient_attention": use_memory_efficient_attention, + "dpa_layer": dpa_layer, + } + if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention: - return _apply_attention_dot( - query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention - ) - elif attention_kernel == "ulysses": - return _ulysses_attention( - query, - key * scale, - value, - heads, - mesh, - axis_names_q, - axis_names_kv, - flash_block_sizes, - dtype, - mask_padding_tokens=mask_padding_tokens, - residual_checkpoint_name=residual_checkpoint_name, - attention_mask=attention_mask, - ) - elif attention_kernel in ["flash", "tokamax_flash"]: - return _tpu_flash_attention( - query, - key * scale, - value, - heads, - mesh, - axis_names_q, - axis_names_kv, - flash_block_sizes, - dtype, - attention_kernel, - mask_padding_tokens=mask_padding_tokens, - residual_checkpoint_name=residual_checkpoint_name, - attention_mask=attention_mask, - use_base2_exp=use_base2_exp, - use_experimental_scheduler=use_experimental_scheduler, - ) - elif "ring" in attention_kernel: - return _tpu_flash_attention( - query, - key * scale, - value, - heads, - mesh, - axis_names_q, - axis_names_kv, - flash_block_sizes, - dtype, - attention_kernel, - mask_padding_tokens=mask_padding_tokens, - attention_mask=attention_mask, - ) - elif attention_kernel == "cudnn_flash_te": - return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) - else: - raise ValueError(f"Unexpected attention kernel {attention_kernel=}.") + return KERNEL_REGISTRY["dot_product"](query, key, value, context) + + # Module-level Registry lookup + if attention_kernel in KERNEL_REGISTRY: + return KERNEL_REGISTRY[attention_kernel](query, key, value, context) + + raise ValueError(f"Unexpected attention kernel {attention_kernel=}.") def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096): diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan_2p2.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan_2p2.py index cd41edf7c..841b115c9 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan_2p2.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan_2p2.py @@ -20,6 +20,7 @@ import jax from jax import tree_util import jax.numpy as jnp +from jax.sharding import NamedSharding, PartitionSpec as P from maxdiffusion.models.wan.autoencoder_kl_wan import AutoencoderKLWanCache, WanCausalConv3d # pylint: disable=g-importing-member from ... import common_types @@ -1266,6 +1267,7 @@ def __init__( self.temporal_upsample = temperal_downsample[::-1] self.latents_mean = latents_mean self.latents_std = latents_std + self.mesh = mesh self.patch_size = 2 self.patchify = WanPatchify(patch_size=self.patch_size) @@ -1339,16 +1341,23 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache): iter_ = 1 + (t - 1) // 4 enc_feat_map = feat_cache._enc_feat_map + spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) for i in range(iter_): enc_conv_idx = 0 if i == 0: - out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx) + chunk = x[:, :1, :, :, :] + chunk = jax.lax.with_sharding_constraint(chunk, spatial_sharding) + out, enc_feat_map, enc_conv_idx = self.encoder(chunk, feat_cache=enc_feat_map, feat_idx=enc_conv_idx) + out = jax.lax.with_sharding_constraint(out, spatial_sharding) else: + chunk = x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :] + chunk = jax.lax.with_sharding_constraint(chunk, spatial_sharding) out_, enc_feat_map, enc_conv_idx = self.encoder( - x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :], + chunk, feat_cache=enc_feat_map, feat_idx=enc_conv_idx, ) + out_ = jax.lax.with_sharding_constraint(out_, spatial_sharding) out = jnp.concatenate([out, out_], axis=1) # Update back to the wrapper object if needed, but for result we use local vars @@ -1385,17 +1394,22 @@ def _decode( dec_feat_map = feat_cache._feat_map + spatial_sharding = NamedSharding(self.mesh, P(None, None, None, "vae_spatial", None)) for i in range(iter_): conv_idx = 0 + chunk = x[:, i : i + 1, :, :, :] + chunk = jax.lax.with_sharding_constraint(chunk, spatial_sharding) if i == 0: out, dec_feat_map, conv_idx = self.decoder( - x[:, i : i + 1, :, :, :], + chunk, feat_cache=dec_feat_map, feat_idx=conv_idx, first_chunk=True, ) + out = jax.lax.with_sharding_constraint(out, spatial_sharding) else: - out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx) + out_, dec_feat_map, conv_idx = self.decoder(chunk, feat_cache=dec_feat_map, feat_idx=conv_idx) + out_ = jax.lax.with_sharding_constraint(out_, spatial_sharding) out = jnp.concatenate([out, out_], axis=1) feat_cache._feat_map = dec_feat_map diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 9488f8946..95912e436 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -17,6 +17,7 @@ from typing import List, Union, Optional from ...pyconfig import HyperParameters from functools import partial +import time from flax import nnx from flax.linen import partitioning as nn_partitioning import jax @@ -133,6 +134,9 @@ def __call__( "SenCache requires classifier-free guidance to be enabled for both transformer phases." ) + trace = {} + t_cond_start = time.perf_counter() + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( prompt, negative_prompt, @@ -147,6 +151,7 @@ def __call__( negative_prompt_embeds, vae_only, ) + trace["conditioning"] = time.perf_counter() - t_cond_start low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) @@ -166,6 +171,7 @@ def __call__( height=height, ) + t_denoise_start = time.perf_counter() with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): latents = p_run_inference( low_noise_graphdef=low_noise_graphdef, @@ -180,7 +186,14 @@ def __call__( config=self.config, ) latents = self._denormalize_latents(latents) - return self._decode_latents_to_video(latents) + latents.block_until_ready() + trace["denoise_total"] = time.perf_counter() - t_denoise_start + + t_decode_start = time.perf_counter() + video = self._decode_latents_to_video(latents) + trace["vae_decode"] = time.perf_counter() - t_decode_start + + return video, trace def run_inference_2_2(