|
| 1 | +""" |
| 2 | +Copyright 2026 Google LLC |
| 3 | +
|
| 4 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +you may not use this file except in compliance with the License. |
| 6 | +You may obtain a copy of the License at |
| 7 | +
|
| 8 | + https://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +
|
| 10 | +Unless required by applicable law or agreed to in writing, software |
| 11 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +See the License for the specific language governing permissions and |
| 14 | +limitations under the License. |
| 15 | +""" |
| 16 | + |
| 17 | +import os |
| 18 | +import PIL.Image |
| 19 | + |
| 20 | +# Enable JAX's CPU interpreter mode for Pallas custom kernels. |
| 21 | +# Needed because the physical CPU backend does not support Pallas/Splash attention compilation. |
| 22 | +os.environ["PALLAS_INTERPRET"] = "1" |
| 23 | +import unittest |
| 24 | +from unittest.mock import MagicMock, patch |
| 25 | + |
| 26 | +import flax |
| 27 | +import jax |
| 28 | +import jax._src.config as jax_config |
| 29 | + |
| 30 | +# Force the CPU Pallas interpreter globally. JAX 0.10.0+ uses this internal config manager |
| 31 | +# and ignores the standard environment variable during eager pipeline execution. |
| 32 | +jax_config.pallas_tpu_interpret_mode_context_manager.set_global(True) |
| 33 | +import jax.numpy as jnp |
| 34 | + |
| 35 | +from maxdiffusion import pyconfig |
| 36 | + |
| 37 | +from maxdiffusion.pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1 |
| 38 | +from maxdiffusion.schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler |
| 39 | + |
| 40 | +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) |
| 41 | + |
| 42 | + |
| 43 | +class WanVacePipelineTest(unittest.TestCase): |
| 44 | + |
| 45 | + def setUp(self): |
| 46 | + # Initialize pyconfig with base_wan_1_3b.yml and overrides some parameters for speed |
| 47 | + pyconfig.initialize( |
| 48 | + [ |
| 49 | + None, |
| 50 | + os.path.join(THIS_DIR, "..", "configs", "base_wan_1_3b.yml"), |
| 51 | + # For completeness, all configs and weights are mocked in this test |
| 52 | + "pretrained_model_name_or_path=Wan-AI/Wan2.1-VACE-1.3B-Diffusers", |
| 53 | + "num_inference_steps=2", # Reduced steps for speed |
| 54 | + "height=240", # Reduced resolution for speed (divisible by 16) |
| 55 | + "width=416", # Reduced resolution for speed (divisible by 16) |
| 56 | + "num_frames=9", # Reduced num_frames for speed |
| 57 | + "attention=flash", |
| 58 | + "scan_layers=False", # Explicitly disable scan for VACE |
| 59 | + "jit_initializers=False", # Disable JIT for faster setup & better CPU debugging |
| 60 | + "skip_jax_distributed_system=True", |
| 61 | + ], |
| 62 | + unittest=True, |
| 63 | + ) |
| 64 | + self.config = pyconfig.config |
| 65 | + |
| 66 | + @patch("maxdiffusion.pipelines.wan.wan_vace_pipeline_2_1.WanVACEModel.load_config") |
| 67 | + @patch("maxdiffusion.pipelines.wan.wan_pipeline.AutoencoderKLWan.load_config") |
| 68 | + @patch("maxdiffusion.pipelines.wan.wan_vace_pipeline_2_1.load_wan_transformer") |
| 69 | + @patch("maxdiffusion.pipelines.wan.wan_pipeline.load_wan_vae") |
| 70 | + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_tokenizer") |
| 71 | + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_text_encoder") |
| 72 | + @patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_scheduler") |
| 73 | + # pylint: disable=too-many-positional-arguments |
| 74 | + def test_pipeline_load_and_inference( |
| 75 | + self, |
| 76 | + mock_load_scheduler_fn, |
| 77 | + mock_load_text_encoder_fn, |
| 78 | + mock_load_tokenizer_fn, |
| 79 | + mock_load_wan_vae_fn, |
| 80 | + mock_load_wan_transformer_fn, |
| 81 | + mock_vae_load_config_fn, |
| 82 | + mock_transformer_load_config_fn, |
| 83 | + ): |
| 84 | + # Mock configs to represent a 1.3B model but with only 2 layers |
| 85 | + # Reference real config: https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers/blob/main/transformer/config.json |
| 86 | + # pylint: disable=unused-argument |
| 87 | + def mock_transformer_load_config(pretrained_model_name_or_path, return_unused_kwargs=False, **kwargs): |
| 88 | + config_dict = { |
| 89 | + "added_kv_proj_dim": None, |
| 90 | + "attention_head_dim": 128, |
| 91 | + "cross_attn_norm": True, |
| 92 | + "eps": 1e-06, |
| 93 | + "ffn_dim": 8960, |
| 94 | + "freq_dim": 256, |
| 95 | + "image_dim": None, |
| 96 | + "in_channels": 16, |
| 97 | + "num_attention_heads": 12, |
| 98 | + "num_layers": 2, # Overridden to 2 layers for speed |
| 99 | + "out_channels": 16, |
| 100 | + "patch_size": [1, 2, 2], |
| 101 | + "pos_embed_seq_len": None, |
| 102 | + "qk_norm": "rms_norm_across_heads", |
| 103 | + "rope_max_seq_len": 1024, |
| 104 | + "text_dim": 4096, |
| 105 | + "vace_in_channels": 96, |
| 106 | + "vace_layers": [0, 1], # VACE conditioning on both layers |
| 107 | + } |
| 108 | + if return_unused_kwargs: |
| 109 | + return config_dict, kwargs |
| 110 | + return config_dict |
| 111 | + |
| 112 | + mock_transformer_load_config_fn.side_effect = mock_transformer_load_config |
| 113 | + |
| 114 | + # Full-size VAE config |
| 115 | + # Reference real config: https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers/blob/main/vae/config.json |
| 116 | + # pylint: disable=unused-argument |
| 117 | + def mock_vae_load_config(pretrained_model_name_or_path, return_unused_kwargs=False, **kwargs): |
| 118 | + config_dict = { |
| 119 | + "attn_scales": [], |
| 120 | + "base_dim": 96, |
| 121 | + "dim_mult": [1, 2, 4, 4], |
| 122 | + "dropout": 0.0, |
| 123 | + "latents_mean": [ |
| 124 | + -0.7571, |
| 125 | + -0.7089, |
| 126 | + -0.9113, |
| 127 | + 0.1075, |
| 128 | + -0.1745, |
| 129 | + 0.9653, |
| 130 | + -0.1517, |
| 131 | + 1.5508, |
| 132 | + 0.4134, |
| 133 | + -0.0715, |
| 134 | + 0.5517, |
| 135 | + -0.3632, |
| 136 | + -0.1922, |
| 137 | + -0.9497, |
| 138 | + 0.2503, |
| 139 | + -0.2921, |
| 140 | + ], |
| 141 | + "latents_std": [ |
| 142 | + 2.8184, |
| 143 | + 1.4541, |
| 144 | + 2.3275, |
| 145 | + 2.6558, |
| 146 | + 1.2196, |
| 147 | + 1.7708, |
| 148 | + 2.6052, |
| 149 | + 2.0743, |
| 150 | + 3.2687, |
| 151 | + 2.1526, |
| 152 | + 2.8652, |
| 153 | + 1.5579, |
| 154 | + 1.6382, |
| 155 | + 1.1253, |
| 156 | + 2.8251, |
| 157 | + 1.916, |
| 158 | + ], |
| 159 | + "num_res_blocks": 2, |
| 160 | + "temperal_downsample": [False, True, True], |
| 161 | + "z_dim": 16, |
| 162 | + } |
| 163 | + if return_unused_kwargs: |
| 164 | + return config_dict, kwargs |
| 165 | + return config_dict |
| 166 | + |
| 167 | + mock_vae_load_config_fn.side_effect = mock_vae_load_config |
| 168 | + |
| 169 | + # Mock weight loaders to generate random weights in memory |
| 170 | + # pylint: disable=unused-argument |
| 171 | + def mock_load_wan_transformer(pretrained_model_name_or_path, eval_shapes, *args, **kwargs): |
| 172 | + cpu = jax.local_devices(backend="cpu")[0] |
| 173 | + flat_shapes = flax.traverse_util.flatten_dict(eval_shapes) |
| 174 | + flat_params = {} |
| 175 | + # Use a static seed to ensure deterministic random weights |
| 176 | + key = jax.random.key(42) |
| 177 | + for k, shape_struct in flat_shapes.items(): |
| 178 | + dtype = shape_struct.dtype |
| 179 | + shape = shape_struct.shape |
| 180 | + key, subkey = jax.random.split(key) |
| 181 | + val = jax.random.normal(subkey, shape, dtype=dtype) |
| 182 | + flat_params[k] = jax.device_put(val, device=cpu) |
| 183 | + return flax.traverse_util.unflatten_dict(flat_params) |
| 184 | + |
| 185 | + mock_load_wan_transformer_fn.side_effect = mock_load_wan_transformer |
| 186 | + |
| 187 | + # pylint: disable=unused-argument |
| 188 | + def mock_load_wan_vae(pretrained_model_name_or_path, eval_shapes, *args, **kwargs): |
| 189 | + cpu = jax.local_devices(backend="cpu")[0] |
| 190 | + flat_shapes = flax.traverse_util.flatten_dict(eval_shapes) |
| 191 | + flat_params = {} |
| 192 | + key = jax.random.key(42) |
| 193 | + for k, shape_struct in flat_shapes.items(): |
| 194 | + dtype = shape_struct.dtype |
| 195 | + shape = shape_struct.shape |
| 196 | + key, subkey = jax.random.split(key) |
| 197 | + val = jax.random.normal(subkey, shape, dtype=dtype) |
| 198 | + flat_params[k] = jax.device_put(val, device=cpu) |
| 199 | + return flax.traverse_util.unflatten_dict(flat_params) |
| 200 | + |
| 201 | + mock_load_wan_vae_fn.side_effect = mock_load_wan_vae |
| 202 | + |
| 203 | + # Mock scheduler to load from local config dictionary |
| 204 | + # Reference real config: https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers/blob/main/scheduler/scheduler_config.json # pylint: disable=line-too-long |
| 205 | + def mock_load_scheduler(config): |
| 206 | + scheduler = FlaxUniPCMultistepScheduler.from_config({ |
| 207 | + "beta_end": 0.02, |
| 208 | + "beta_schedule": "linear", |
| 209 | + "beta_start": 0.0001, |
| 210 | + "disable_corrector": [], |
| 211 | + "dynamic_thresholding_ratio": 0.995, |
| 212 | + "final_sigmas_type": "zero", |
| 213 | + "flow_shift": config.flow_shift, |
| 214 | + "lower_order_final": True, |
| 215 | + "num_train_timesteps": 1000, |
| 216 | + "predict_x0": True, |
| 217 | + "prediction_type": "flow_prediction", |
| 218 | + "rescale_zero_terminal_snr": False, |
| 219 | + "sample_max_value": 1.0, |
| 220 | + "solver_order": 2, |
| 221 | + "solver_p": None, |
| 222 | + "solver_type": "bh2", |
| 223 | + "steps_offset": 0, |
| 224 | + "thresholding": False, |
| 225 | + "timestep_spacing": "linspace", |
| 226 | + "trained_betas": None, |
| 227 | + "use_beta_sigmas": False, |
| 228 | + "use_exponential_sigmas": False, |
| 229 | + "use_flow_sigmas": True, |
| 230 | + "use_karras_sigmas": False, |
| 231 | + }) |
| 232 | + state = scheduler.create_state() |
| 233 | + return scheduler, state |
| 234 | + |
| 235 | + mock_load_scheduler_fn.side_effect = mock_load_scheduler |
| 236 | + |
| 237 | + # Mock tokenizer and text encoder to avoid Hugging Face downloads |
| 238 | + mock_load_tokenizer_fn.return_value = MagicMock() |
| 239 | + mock_load_text_encoder_fn.return_value = MagicMock() |
| 240 | + |
| 241 | + pipeline = VaceWanPipeline2_1.from_pretrained(self.config) |
| 242 | + |
| 243 | + # Prepare dummy inputs |
| 244 | + batch_size = 1 |
| 245 | + |
| 246 | + height = self.config.height |
| 247 | + width = self.config.width |
| 248 | + num_frames = self.config.num_frames |
| 249 | + |
| 250 | + # Pre-computed dummy text embeddings matching T5 text dimension (4096) |
| 251 | + # Bypasses the actual text encoder |
| 252 | + prompt_embeds = jnp.zeros((batch_size, 512, 4096), dtype=self.config.weights_dtype) |
| 253 | + negative_prompt_embeds = jnp.zeros((batch_size, 512, 4096), dtype=self.config.weights_dtype) |
| 254 | + |
| 255 | + # Create dummy PIL images for conditioning inputs |
| 256 | + dummy_image_rgb = PIL.Image.new("RGB", (width, height), color="white") |
| 257 | + dummy_image_l = PIL.Image.new("L", (width, height), color="white") |
| 258 | + |
| 259 | + video_input = [dummy_image_rgb] * num_frames |
| 260 | + mask_input = [dummy_image_l] * num_frames |
| 261 | + ref_images_input = [dummy_image_rgb] |
| 262 | + |
| 263 | + video = pipeline( |
| 264 | + video=video_input, |
| 265 | + mask=mask_input, |
| 266 | + reference_images=ref_images_input, |
| 267 | + prompt=None, |
| 268 | + prompt_embeds=prompt_embeds, |
| 269 | + negative_prompt=None, |
| 270 | + negative_prompt_embeds=negative_prompt_embeds, |
| 271 | + height=height, |
| 272 | + width=width, |
| 273 | + num_frames=num_frames, |
| 274 | + num_inference_steps=self.config.num_inference_steps, |
| 275 | + ) |
| 276 | + |
| 277 | + self.assertEqual(len(video), batch_size) |
| 278 | + self.assertEqual(video[0].shape, (num_frames, height, width, 3)) |
| 279 | + |
| 280 | + |
| 281 | +if __name__ == "__main__": |
| 282 | + unittest.main() |
0 commit comments