Skip to content

Commit 73687fe

Browse files
ninatumartinarroyo
andcommitted
Add unit test to initialize and run Wan-VACE pipeline
Introduce a unit test to verify VaceWanPipeline2_1 initialization and end-to-end execution. The test loads the full-size VAE and a 2-layer 1.3B Transformer using in-memory random weights, and executes a 2-step inference pass with dummy PIL Images to validate the entire code path (preprocessing, sharding, Flash Attention compilation, and VAE encoding/decoding). Co-authored-by: martinarroyo <martinarroyo@google.com>
1 parent bb3b0c6 commit 73687fe

1 file changed

Lines changed: 282 additions & 0 deletions

File tree

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
"""
2+
Copyright 2026 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import PIL.Image
19+
20+
# Enable JAX's CPU interpreter mode for Pallas custom kernels.
21+
# Needed because the physical CPU backend does not support Pallas/Splash attention compilation.
22+
os.environ["PALLAS_INTERPRET"] = "1"
23+
import unittest
24+
from unittest.mock import MagicMock, patch
25+
26+
import flax
27+
import jax
28+
import jax._src.config as jax_config
29+
30+
# Force the CPU Pallas interpreter globally. JAX 0.10.0+ uses this internal config manager
31+
# and ignores the standard environment variable during eager pipeline execution.
32+
jax_config.pallas_tpu_interpret_mode_context_manager.set_global(True)
33+
import jax.numpy as jnp
34+
35+
from maxdiffusion import pyconfig
36+
37+
from maxdiffusion.pipelines.wan.wan_vace_pipeline_2_1 import VaceWanPipeline2_1
38+
from maxdiffusion.schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler
39+
40+
THIS_DIR = os.path.dirname(os.path.abspath(__file__))
41+
42+
43+
class WanVacePipelineTest(unittest.TestCase):
44+
45+
def setUp(self):
46+
# Initialize pyconfig with base_wan_1_3b.yml and overrides some parameters for speed
47+
pyconfig.initialize(
48+
[
49+
None,
50+
os.path.join(THIS_DIR, "..", "configs", "base_wan_1_3b.yml"),
51+
# For completeness, all configs and weights are mocked in this test
52+
"pretrained_model_name_or_path=Wan-AI/Wan2.1-VACE-1.3B-Diffusers",
53+
"num_inference_steps=2", # Reduced steps for speed
54+
"height=240", # Reduced resolution for speed (divisible by 16)
55+
"width=416", # Reduced resolution for speed (divisible by 16)
56+
"num_frames=9", # Reduced num_frames for speed
57+
"attention=flash",
58+
"scan_layers=False", # Explicitly disable scan for VACE
59+
"jit_initializers=False", # Disable JIT for faster setup & better CPU debugging
60+
"skip_jax_distributed_system=True",
61+
],
62+
unittest=True,
63+
)
64+
self.config = pyconfig.config
65+
66+
@patch("maxdiffusion.pipelines.wan.wan_vace_pipeline_2_1.WanVACEModel.load_config")
67+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.AutoencoderKLWan.load_config")
68+
@patch("maxdiffusion.pipelines.wan.wan_vace_pipeline_2_1.load_wan_transformer")
69+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.load_wan_vae")
70+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_tokenizer")
71+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_text_encoder")
72+
@patch("maxdiffusion.pipelines.wan.wan_pipeline.WanPipeline.load_scheduler")
73+
# pylint: disable=too-many-positional-arguments
74+
def test_pipeline_load_and_inference(
75+
self,
76+
mock_load_scheduler_fn,
77+
mock_load_text_encoder_fn,
78+
mock_load_tokenizer_fn,
79+
mock_load_wan_vae_fn,
80+
mock_load_wan_transformer_fn,
81+
mock_vae_load_config_fn,
82+
mock_transformer_load_config_fn,
83+
):
84+
# Mock configs to represent a 1.3B model but with only 2 layers
85+
# Reference real config: https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers/blob/main/transformer/config.json
86+
# pylint: disable=unused-argument
87+
def mock_transformer_load_config(pretrained_model_name_or_path, return_unused_kwargs=False, **kwargs):
88+
config_dict = {
89+
"added_kv_proj_dim": None,
90+
"attention_head_dim": 128,
91+
"cross_attn_norm": True,
92+
"eps": 1e-06,
93+
"ffn_dim": 8960,
94+
"freq_dim": 256,
95+
"image_dim": None,
96+
"in_channels": 16,
97+
"num_attention_heads": 12,
98+
"num_layers": 2, # Overridden to 2 layers for speed
99+
"out_channels": 16,
100+
"patch_size": [1, 2, 2],
101+
"pos_embed_seq_len": None,
102+
"qk_norm": "rms_norm_across_heads",
103+
"rope_max_seq_len": 1024,
104+
"text_dim": 4096,
105+
"vace_in_channels": 96,
106+
"vace_layers": [0, 1], # VACE conditioning on both layers
107+
}
108+
if return_unused_kwargs:
109+
return config_dict, kwargs
110+
return config_dict
111+
112+
mock_transformer_load_config_fn.side_effect = mock_transformer_load_config
113+
114+
# Full-size VAE config
115+
# Reference real config: https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers/blob/main/vae/config.json
116+
# pylint: disable=unused-argument
117+
def mock_vae_load_config(pretrained_model_name_or_path, return_unused_kwargs=False, **kwargs):
118+
config_dict = {
119+
"attn_scales": [],
120+
"base_dim": 96,
121+
"dim_mult": [1, 2, 4, 4],
122+
"dropout": 0.0,
123+
"latents_mean": [
124+
-0.7571,
125+
-0.7089,
126+
-0.9113,
127+
0.1075,
128+
-0.1745,
129+
0.9653,
130+
-0.1517,
131+
1.5508,
132+
0.4134,
133+
-0.0715,
134+
0.5517,
135+
-0.3632,
136+
-0.1922,
137+
-0.9497,
138+
0.2503,
139+
-0.2921,
140+
],
141+
"latents_std": [
142+
2.8184,
143+
1.4541,
144+
2.3275,
145+
2.6558,
146+
1.2196,
147+
1.7708,
148+
2.6052,
149+
2.0743,
150+
3.2687,
151+
2.1526,
152+
2.8652,
153+
1.5579,
154+
1.6382,
155+
1.1253,
156+
2.8251,
157+
1.916,
158+
],
159+
"num_res_blocks": 2,
160+
"temperal_downsample": [False, True, True],
161+
"z_dim": 16,
162+
}
163+
if return_unused_kwargs:
164+
return config_dict, kwargs
165+
return config_dict
166+
167+
mock_vae_load_config_fn.side_effect = mock_vae_load_config
168+
169+
# Mock weight loaders to generate random weights in memory
170+
# pylint: disable=unused-argument
171+
def mock_load_wan_transformer(pretrained_model_name_or_path, eval_shapes, *args, **kwargs):
172+
cpu = jax.local_devices(backend="cpu")[0]
173+
flat_shapes = flax.traverse_util.flatten_dict(eval_shapes)
174+
flat_params = {}
175+
# Use a static seed to ensure deterministic random weights
176+
key = jax.random.key(42)
177+
for k, shape_struct in flat_shapes.items():
178+
dtype = shape_struct.dtype
179+
shape = shape_struct.shape
180+
key, subkey = jax.random.split(key)
181+
val = jax.random.normal(subkey, shape, dtype=dtype)
182+
flat_params[k] = jax.device_put(val, device=cpu)
183+
return flax.traverse_util.unflatten_dict(flat_params)
184+
185+
mock_load_wan_transformer_fn.side_effect = mock_load_wan_transformer
186+
187+
# pylint: disable=unused-argument
188+
def mock_load_wan_vae(pretrained_model_name_or_path, eval_shapes, *args, **kwargs):
189+
cpu = jax.local_devices(backend="cpu")[0]
190+
flat_shapes = flax.traverse_util.flatten_dict(eval_shapes)
191+
flat_params = {}
192+
key = jax.random.key(42)
193+
for k, shape_struct in flat_shapes.items():
194+
dtype = shape_struct.dtype
195+
shape = shape_struct.shape
196+
key, subkey = jax.random.split(key)
197+
val = jax.random.normal(subkey, shape, dtype=dtype)
198+
flat_params[k] = jax.device_put(val, device=cpu)
199+
return flax.traverse_util.unflatten_dict(flat_params)
200+
201+
mock_load_wan_vae_fn.side_effect = mock_load_wan_vae
202+
203+
# Mock scheduler to load from local config dictionary
204+
# Reference real config: https://huggingface.co/Wan-AI/Wan2.1-VACE-1.3B-diffusers/blob/main/scheduler/scheduler_config.json # pylint: disable=line-too-long
205+
def mock_load_scheduler(config):
206+
scheduler = FlaxUniPCMultistepScheduler.from_config({
207+
"beta_end": 0.02,
208+
"beta_schedule": "linear",
209+
"beta_start": 0.0001,
210+
"disable_corrector": [],
211+
"dynamic_thresholding_ratio": 0.995,
212+
"final_sigmas_type": "zero",
213+
"flow_shift": config.flow_shift,
214+
"lower_order_final": True,
215+
"num_train_timesteps": 1000,
216+
"predict_x0": True,
217+
"prediction_type": "flow_prediction",
218+
"rescale_zero_terminal_snr": False,
219+
"sample_max_value": 1.0,
220+
"solver_order": 2,
221+
"solver_p": None,
222+
"solver_type": "bh2",
223+
"steps_offset": 0,
224+
"thresholding": False,
225+
"timestep_spacing": "linspace",
226+
"trained_betas": None,
227+
"use_beta_sigmas": False,
228+
"use_exponential_sigmas": False,
229+
"use_flow_sigmas": True,
230+
"use_karras_sigmas": False,
231+
})
232+
state = scheduler.create_state()
233+
return scheduler, state
234+
235+
mock_load_scheduler_fn.side_effect = mock_load_scheduler
236+
237+
# Mock tokenizer and text encoder to avoid Hugging Face downloads
238+
mock_load_tokenizer_fn.return_value = MagicMock()
239+
mock_load_text_encoder_fn.return_value = MagicMock()
240+
241+
pipeline = VaceWanPipeline2_1.from_pretrained(self.config)
242+
243+
# Prepare dummy inputs
244+
batch_size = 1
245+
246+
height = self.config.height
247+
width = self.config.width
248+
num_frames = self.config.num_frames
249+
250+
# Pre-computed dummy text embeddings matching T5 text dimension (4096)
251+
# Bypasses the actual text encoder
252+
prompt_embeds = jnp.zeros((batch_size, 512, 4096), dtype=self.config.weights_dtype)
253+
negative_prompt_embeds = jnp.zeros((batch_size, 512, 4096), dtype=self.config.weights_dtype)
254+
255+
# Create dummy PIL images for conditioning inputs
256+
dummy_image_rgb = PIL.Image.new("RGB", (width, height), color="white")
257+
dummy_image_l = PIL.Image.new("L", (width, height), color="white")
258+
259+
video_input = [dummy_image_rgb] * num_frames
260+
mask_input = [dummy_image_l] * num_frames
261+
ref_images_input = [dummy_image_rgb]
262+
263+
video = pipeline(
264+
video=video_input,
265+
mask=mask_input,
266+
reference_images=ref_images_input,
267+
prompt=None,
268+
prompt_embeds=prompt_embeds,
269+
negative_prompt=None,
270+
negative_prompt_embeds=negative_prompt_embeds,
271+
height=height,
272+
width=width,
273+
num_frames=num_frames,
274+
num_inference_steps=self.config.num_inference_steps,
275+
)
276+
277+
self.assertEqual(len(video), batch_size)
278+
self.assertEqual(video[0].shape, (num_frames, height, width, 3))
279+
280+
281+
if __name__ == "__main__":
282+
unittest.main()

0 commit comments

Comments
 (0)