diff --git a/README.md b/README.md index 71ed5d9c..5ddcc323 100755 --- a/README.md +++ b/README.md @@ -628,16 +628,16 @@ To generate images, run the following command: We added ring attention support for Wan models. Below are the stats for one `720p` (81 frames) video generation (with CFG DP): | Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time | | -- | -- | -- | -- | -- | -- | -| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | 264.2 | -| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | **252.4** | -| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | 212.7 | -| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | **201.7** | +| v7x-8 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context4-tp1 | **249.3** | +| v7x-8 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context4-tp1 | 252.4 | +| v7x-8 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context4-tp1 | **194.4** | +| v7x-8 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context4-tp1 | 201.7 | | Accelerator | Model | Attention Type | Inference Steps | Sharding | e2e Generation Time | | -- | -- | -- | -- | -- | -- | -| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | 146.6 | -| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | **137.2** | -| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **117.8** | +| v7x-16 | WAN 2.1 | Tokamax Flash | 50 | dp2-fsdp1-context8-tp1 | **127.1** | +| v7x-16 | WAN 2.1 | Tokamax Ring | 50 | dp2-fsdp1-context8-tp1 | 137.2 | +| v7x-16 | WAN 2.2 | Tokamax Flash | 40 | dp2-fsdp1-context8-tp1 | **106.0** | | v7x-16 | WAN 2.2 | Tokamax Ring | 40 | dp2-fsdp1-context8-tp1 | 137.5 | (* There are some known stability issues for ring attention on 16 TPUs, please use `tokamax_flash` attention instead.) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index e0658bf0..7ffb659c 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -62,6 +62,8 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring, ulysses +use_base2_exp: True +use_experimental_scheduler: True flash_min_seq_length: 0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index b1d4c23c..9e59ba9c 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -61,6 +61,8 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +use_base2_exp: True +use_experimental_scheduler: True flash_min_seq_length: 0 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index b0fb8cdc..f80c1551 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -62,6 +62,8 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +use_base2_exp: True +use_experimental_scheduler: True flash_min_seq_length: 4096 # If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens. # Else we do not pass in segment ids and on vpu bound hardware like trillium this is faster. diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index c094b445..b136c7a9 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -61,6 +61,8 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +use_base2_exp: True +use_experimental_scheduler: True flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index fcba80e5..4af01187 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -61,6 +61,8 @@ jit_initializers: True from_pt: True split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses +use_base2_exp: True +use_experimental_scheduler: True flash_min_seq_length: 4096 dropout: 0.0 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index e9de3f66..51ce4574 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -302,14 +302,7 @@ def run(config, pipeline=None, filename_prefix="", commit_hash=None): f"{'=' * 50}" ) - s0 = time.perf_counter() - if max_utils.profiler_enabled(config): - with max_utils.Profiler(config): - videos = call_pipeline(config, pipeline, prompt, negative_prompt) - generation_time_with_profiler = time.perf_counter() - s0 - max_logging.log(f"generation_time_with_profiler: {generation_time_with_profiler}") - if writer and jax.process_index() == 0: - writer.add_scalar("inference/generation_time_with_profiler", generation_time_with_profiler, global_step=0) + videos = call_pipeline(config, pipeline, prompt, negative_prompt) return saved_video_path diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 77228261..53596eea 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -272,6 +272,7 @@ def convert_to_tokamax_splash_config( attn_logits_soft_cap: float | None = None, fuse_reciprocal: bool = True, use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, max_logit_const: float | None = None, interpret: bool = False, dq_reduction_steps: int | None = None, @@ -294,6 +295,7 @@ def convert_to_tokamax_splash_config( attn_logits_soft_cap=attn_logits_soft_cap, fuse_reciprocal=fuse_reciprocal, use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, max_logit_const=max_logit_const, interpret=interpret, dq_reduction_steps=dq_reduction_steps, @@ -314,6 +316,8 @@ def _tpu_flash_attention( mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, attention_mask: jax.Array = None, + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, ) -> jax.Array: """TPU Flash Attention""" @@ -399,7 +403,12 @@ def wrap_flash_attention(query, key, value): splash_kernel = tokamax_splash_attention_kernel.make_splash_mha( mask=mask, q_seq_shards=1, # the sizes of the axis is sharding over seq_len - config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), + config=convert_to_tokamax_splash_config( + block_sizes, + residual_checkpoint_name=residual_checkpoint_name, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ), save_residuals=False, ) elif attention_kernel == "tokamax_ring": @@ -409,7 +418,12 @@ def wrap_flash_attention(query, key, value): splash_kernel = tokamax_ring_attention_kernel.make_ring_attention( mask=mask, is_mqa=False, - config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name), + config=convert_to_tokamax_splash_config( + block_sizes, + residual_checkpoint_name=residual_checkpoint_name, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, + ), save_residuals=False, ring_axis="context", rotate_segment_ids=False, # We don't rotate segment ids in tokamax ring attention because our segment ids is for padding each kv shard has same segment ids @@ -473,13 +487,13 @@ def ring_scan_body(carry, _): raise ValueError("ring attention requires context > 1") return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) - devices_in_data_context = mesh.shape["data"] * mesh.shape["context"] + devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) # This warning might show up when doing model eval for example, when calculating model flops # and that is expected. - if not (query.shape[0] / devices_in_data_context).is_integer(): + if not (query.shape[0] / devices_in_batch_sharding).is_integer(): max_logging.log( - "Warning, batch dimension should be shardable among the devices in data and context" - f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}" + "Warning, batch dimension should be shardable among the devices in data and fsdp" + f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) x = wrap_flash_attention(query, key, value) # Trim back to original sequence length after context-axis padding. @@ -614,11 +628,11 @@ def wrap_ulysses_attention(query, key, value): attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) return attention_output - devices_in_data_context = mesh.shape["data"] * num_shards - if not (query.shape[0] / devices_in_data_context).is_integer(): + devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) + if not (query.shape[0] / devices_in_batch_sharding).is_integer(): max_logging.log( - "Warning, batch dimension should be shardable among the devices in data and context" - f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}" + "Warning, batch dimension should be shardable among the devices in data and fsdp" + f" axis, batch dimension: {query.shape[0]}, devices_in_batch_sharding: {devices_in_batch_sharding}" ) x = wrap_ulysses_attention(query, key, value) x = x[:, :, :orig_q_seq_len, :] @@ -741,6 +755,8 @@ def _apply_attention( mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, attention_mask: Array = None, + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, ): """Routes to different attention kernels.""" _check_attention_inputs(query, key, value) @@ -789,6 +805,8 @@ def _apply_attention( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, attention_mask=attention_mask, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, ) elif "ring" in attention_kernel: return _tpu_flash_attention( @@ -983,8 +1001,12 @@ def __init__( quant: Quant = None, mask_padding_tokens: bool = True, residual_checkpoint_name: str | None = None, + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, ): self.dpa_layer = None + self.use_base2_exp = use_base2_exp + self.use_experimental_scheduler = use_experimental_scheduler if attention_kernel == "cudnn_flash_te": from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error @@ -1045,6 +1067,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask mask_padding_tokens=self.mask_padding_tokens, residual_checkpoint_name=self.residual_checkpoint_name, attention_mask=attention_mask, + use_base2_exp=self.use_base2_exp if hasattr(self, "use_base2_exp") else False, + use_experimental_scheduler=self.use_experimental_scheduler if hasattr(self, "use_experimental_scheduler") else False, ) @@ -1063,6 +1087,8 @@ class AttentionOp(nn.Module): flash_block_sizes: BlockSizes = None dtype: DType = jnp.float32 quant: Quant = None + use_base2_exp: bool = False + use_experimental_scheduler: bool = False def setup(self): self.dpa_layer = None @@ -1108,6 +1134,8 @@ def apply_attention(self, query: Array, key: Array, value: Array, attention_mask flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, attention_mask=attention_mask, + use_base2_exp=self.use_base2_exp, + use_experimental_scheduler=self.use_experimental_scheduler, ) @@ -1144,6 +1172,8 @@ def __init__( enable_jax_named_scopes: bool = False, added_kv_proj_dim: Optional[int] = None, # New for I2V image_seq_len: Optional[int] = None, # New for I2V + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, ): if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None: raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}") @@ -1186,6 +1216,8 @@ def __init__( quant=quant, mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name=residual_checkpoint_name, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, ) # None axes corresponds to the stacked weights across all blocks # because of the use of nnx.vmap and nnx.scan. diff --git a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py index b177afbb..8948a24a 100644 --- a/src/maxdiffusion/models/wan/autoencoder_kl_wan.py +++ b/src/maxdiffusion/models/wan/autoencoder_kl_wan.py @@ -1206,7 +1206,7 @@ def _decode( fm1, fm2, fm3, fm4 = out_chunk_1[:, 0, ...], out_chunk_1[:, 1, ...], out_chunk_1[:, 2, ...], out_chunk_1[:, 3, ...] axis = 1 if fm1.shape[0] > 1 else 0 fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]] - out_1 = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1) + out_1 = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1) out_list = [out_0, out_1] @@ -1226,7 +1226,7 @@ def scan_fn(carry, chunk_in): fm1, fm2, fm3, fm4 = out_chunk[:, 0, ...], out_chunk[:, 1, ...], out_chunk[:, 2, ...], out_chunk[:, 3, ...] axis = 1 if fm1.shape[0] > 1 else 0 fm1, fm2, fm3, fm4 = [jnp.expand_dims(f, axis=axis) for f in [fm1, fm2, fm3, fm4]] - new_chunk = jnp.concatenate([fm1, fm3, fm2, fm4], axis=1) + new_chunk = jnp.concatenate([fm1, fm2, fm3, fm4], axis=1) return next_feat_map, new_chunk diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index b3956b84..7d721773 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -291,6 +291,8 @@ def __init__( dropout: float = 0.0, mask_padding_tokens: bool = True, enable_jax_named_scopes: bool = False, + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, ): self.enable_jax_named_scopes = enable_jax_named_scopes @@ -315,6 +317,8 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="self_attn", enable_jax_named_scopes=enable_jax_named_scopes, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, ) # 1. Cross-attention @@ -339,6 +343,8 @@ def __init__( mask_padding_tokens=mask_padding_tokens, residual_checkpoint_name="cross_attn", enable_jax_named_scopes=enable_jax_named_scopes, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, ) assert cross_attn_norm is True self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True) @@ -486,6 +492,8 @@ def __init__( mask_padding_tokens: bool = True, scan_layers: bool = True, enable_jax_named_scopes: bool = False, + use_base2_exp: bool = False, + use_experimental_scheduler: bool = False, ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels @@ -547,6 +555,8 @@ def init_block(rngs): enable_jax_named_scopes=enable_jax_named_scopes, added_kv_proj_dim=added_kv_proj_dim, image_seq_len=image_seq_len, + use_base2_exp=use_base2_exp, + use_experimental_scheduler=use_experimental_scheduler, ) self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index e1b1aadd..031fe2fe 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -67,7 +67,6 @@ def cast_with_exclusion(path, x, dtype_to_cast): path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path) if any(keyword in path_str.lower() for keyword in exclusion_keywords): - print("is_norm_path: ", path) # Keep LayerNorm/GroupNorm weights and biases in full precision return x.astype(jnp.float32) else: @@ -139,6 +138,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["mask_padding_tokens"] = config.mask_padding_tokens wan_config["scan_layers"] = config.scan_layers wan_config["enable_jax_named_scopes"] = config.enable_jax_named_scopes + wan_config["use_base2_exp"] = config.use_base2_exp + wan_config["use_experimental_scheduler"] = config.use_experimental_scheduler # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index bd516450..8e859e8b 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -22,6 +22,8 @@ import jax import jax.numpy as jnp from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler +import numpy as np +from ... import max_utils class WanPipeline2_1(WanPipeline): @@ -142,6 +144,7 @@ def __call__( retention_ratio=retention_ratio, height=height, mag_ratios_base=getattr(config, "mag_ratios_base", None), + config=self.config, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -175,6 +178,7 @@ def run_inference_2_1( retention_ratio: float = 0.2, height: int = 480, mag_ratios_base: Optional[List[float]] = None, + config=None, ): """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache. @@ -253,7 +257,16 @@ def run_inference_2_1( skip_warmup = magcache_init[7] mag_ratios = magcache_init[8] + first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 + profiler_steps = config.profiler_steps if config else 0 + last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + + profiler = None for step in range(num_inference_steps): + if config and max_utils.profiler_enabled(config) and step == first_profiling_step: + profiler = max_utils.Profiler(config) + profiler.start() + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] if use_magcache and do_cfg: @@ -328,4 +341,8 @@ def run_inference_2_1( latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + if config and max_utils.profiler_enabled(config) and step == last_profiling_step: + if profiler: + profiler.stop() + return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index ac63d048..9488f894 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -23,6 +23,7 @@ import jax.numpy as jnp import numpy as np from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler +from ... import max_utils class WanPipeline2_2(WanPipeline): @@ -176,6 +177,7 @@ def __call__( latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, + config=self.config, ) latents = self._denormalize_latents(latents) return self._decode_latents_to_video(latents) @@ -200,6 +202,7 @@ def run_inference_2_2( use_cfg_cache: bool = False, use_sen_cache: bool = False, height: int = 480, + config=None, ): """Denoising loop for WAN 2.2 T2V with optional caching acceleration. @@ -451,7 +454,16 @@ def run_inference_2_2( jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds ) + first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 + profiler_steps = config.profiler_steps if config else 0 + last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + + profiler = None for step in range(num_inference_steps): + if config and max_utils.profiler_enabled(config) and step == first_profiling_step: + profiler = max_utils.Profiler(config) + profiler.start() + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] if step_uses_high[step]: @@ -487,4 +499,8 @@ def run_inference_2_2( ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + if config and max_utils.profiler_enabled(config) and step == last_profiling_step: + if profiler: + profiler.stop() return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 0ed78ee9..787f2295 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -25,6 +25,8 @@ import jax.numpy as jnp from jax.sharding import NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler +import numpy as np +from ... import max_utils class WanPipelineI2V_2_1(WanPipeline): @@ -252,6 +254,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): retention_ratio=retention_ratio, height=height, mag_ratios_base=self.config.mag_ratios_base_720p if height >= 720 else self.config.mag_ratios_base_480p, + config=self.config, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -290,6 +293,7 @@ def run_inference_2_1_i2v( retention_ratio: float = 0.2, height: int = 480, mag_ratios_base: Optional[List[float]] = None, + config=None, ): do_cfg = guidance_scale > 1.0 @@ -309,7 +313,16 @@ def run_inference_2_1_i2v( image_embeds_combined = image_embeds condition_combined = condition + first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 + profiler_steps = config.profiler_steps if config else 0 + last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + + profiler = None for step in range(num_inference_steps): + if config and max_utils.profiler_enabled(config) and step == first_profiling_step: + profiler = max_utils.Profiler(config) + profiler.start() + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] skip_blocks = False @@ -349,4 +362,8 @@ def run_inference_2_1_i2v( noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents, return_dict=False) + + if config and max_utils.profiler_enabled(config) and step == last_profiling_step: + if profiler: + profiler.stop() return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index 60431d9c..d8398f58 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -26,6 +26,7 @@ import numpy as np from jax.sharding import NamedSharding, PartitionSpec as P from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler +from ... import max_utils class WanPipelineI2V_2_2(WanPipeline): @@ -279,6 +280,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): use_cfg_cache=use_cfg_cache, use_sen_cache=use_sen_cache, height=height, + config=self.config, ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): @@ -324,6 +326,7 @@ def run_inference_2_2_i2v( use_cfg_cache: bool = False, use_sen_cache: bool = False, height: int = 480, + config=None, ): do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] @@ -602,7 +605,16 @@ def low_noise_branch(operands): image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) condition = jnp.concatenate([condition] * 2) + first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 + profiler_steps = config.profiler_steps if config else 0 + last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + + profiler = None for step in range(num_inference_steps): + if config and max_utils.profiler_enabled(config) and step == first_profiling_step: + profiler = max_utils.Profiler(config) + profiler.start() + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] latents_input = latents if do_classifier_free_guidance: @@ -616,4 +628,8 @@ def low_noise_branch(operands): ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + + if config and max_utils.profiler_enabled(config) and step == last_profiling_step: + if profiler: + profiler.stop() return latents