Skip to content

Commit 4ea5caa

Browse files
committed
Fix tests for Flux, WAN, SDXL, and LTX-Video
- Resolve execution and environment issues - Fix dimension mismatch in ControlNet and add tearDown for GC in SDXL tests - Enable durations profiling and fix formatting/lint issues
1 parent 51c34e6 commit 4ea5caa

30 files changed

Lines changed: 292 additions & 87 deletions

.github/workflows/UnitTests.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ jobs:
5959
python --version
6060
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
6161
- name: PyTest
62+
env:
63+
HF_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
6264
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
6365
export LIBTPU_INIT_ARGS='--xla_tpu_scoped_vmem_limit_kib=65536'
64-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
66+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --ignore=src/maxdiffusion/kernels/ -x --durations=0 -W ignore::DeprecationWarning -W ignore::UserWarning -W ignore::RuntimeWarning
6567
# add_pull_ready
6668
# if: github.ref != 'refs/heads/main'
6769
# permissions:

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ ici_tensor_parallelism: 1
9292
allow_split_physical_axes: False
9393
learning_rate_schedule_steps: -1
9494
max_train_steps: 500
95-
pretrained_model_name_or_path: ''
95+
pretrained_model_name_or_path: 'Lightricks/LTX-Video'
9696
unet_checkpoint: ''
9797
dataset_name: 'diffusers/pokemon-gpt4-captions'
9898
train_split: 'train'

src/maxdiffusion/controlnet/generate_controlnet_sdxl_replicated.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def create_key(seed=0):
3636
def run(config):
3737
rng = jax.random.PRNGKey(config.seed)
3838

39+
devices_array = max_utils.create_device_mesh(config)
40+
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
41+
3942
prompts = config.prompt
4043
negative_prompts = config.negative_prompt
4144
controlnet_conditioning_scale = config.controlnet_conditioning_scale
@@ -48,13 +51,14 @@ def run(config):
4851
image = np.concatenate([image, image, image], axis=2)
4952
image = Image.fromarray(image)
5053

51-
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
52-
config.controlnet_model_name_or_path, from_pt=config.controlnet_from_pt, dtype=config.activations_dtype
53-
)
54+
with mesh:
55+
controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
56+
config.controlnet_model_name_or_path, from_pt=config.controlnet_from_pt, dtype=config.activations_dtype
57+
)
5458

55-
pipe, params = FlaxStableDiffusionXLControlNetPipeline.from_pretrained(
56-
config.pretrained_model_name_or_path, controlnet=controlnet, revision=config.revision, dtype=config.activations_dtype
57-
)
59+
pipe, params = FlaxStableDiffusionXLControlNetPipeline.from_pretrained(
60+
config.pretrained_model_name_or_path, controlnet=controlnet, revision=config.revision, dtype=config.activations_dtype
61+
)
5862

5963
scheduler_state = params.pop("scheduler")
6064
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
@@ -68,21 +72,23 @@ def run(config):
6872
prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
6973
negative_prompt_ids = pipe.prepare_text_inputs([negative_prompts] * num_samples)
7074
processed_image = pipe.prepare_image_inputs([image] * num_samples)
71-
p_params = replicate(params)
72-
prompt_ids = shard(prompt_ids)
73-
negative_prompt_ids = shard(negative_prompt_ids)
74-
processed_image = shard(processed_image)
75-
76-
output = pipe(
77-
prompt_ids=prompt_ids,
78-
image=processed_image,
79-
params=p_params,
80-
prng_seed=rng,
81-
num_inference_steps=config.num_inference_steps,
82-
neg_prompt_ids=negative_prompt_ids,
83-
controlnet_conditioning_scale=controlnet_conditioning_scale,
84-
jit=True,
85-
).images
75+
76+
with mesh:
77+
p_params = replicate(params)
78+
prompt_ids = shard(prompt_ids)
79+
negative_prompt_ids = shard(negative_prompt_ids)
80+
processed_image = shard(processed_image)
81+
82+
output = pipe(
83+
prompt_ids=prompt_ids,
84+
image=processed_image,
85+
params=p_params,
86+
prng_seed=rng,
87+
num_inference_steps=config.num_inference_steps,
88+
neg_prompt_ids=negative_prompt_ids,
89+
controlnet_conditioning_scale=controlnet_conditioning_scale,
90+
jit=True,
91+
).images
8692

8793
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
8894
output_images[0].save("generated_image.png")

src/maxdiffusion/data_preprocessing/wan_txt2vid_data_preprocessing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020
"""
2121

2222
import os
23-
import functools
2423
from absl import app
2524
from typing import Sequence, Union, List
2625
from datasets import load_dataset
2726
import numpy as np
2827
import jax
28+
from flax import nnx
2929
import jax.numpy as jnp
3030
from jax.sharding import Mesh
3131
from maxdiffusion import pyconfig, max_utils
@@ -110,8 +110,9 @@ def generate_dataset(config, pipeline):
110110
vae_scale_factor_spatial = 2 ** len(pipeline.vae.temperal_downsample)
111111
video_processor = VideoProcessor(vae_scale_factor=vae_scale_factor_spatial)
112112

113-
# jit vae fun.
114-
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))
113+
@nnx.jit
114+
def p_vae_encode(video, rng, vae, vae_cache):
115+
return vae_encode(video, rng, vae, vae_cache)
115116

116117
# Load dataset
117118
ds = load_dataset(config.dataset_name, split="train")
@@ -126,7 +127,7 @@ def generate_dataset(config, pipeline):
126127
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width) for video in videos]
127128
video = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
128129
with mesh:
129-
latents = p_vae_encode(video=video, rng=new_rng)
130+
latents = p_vae_encode(video=video, rng=new_rng, vae=pipeline.vae, vae_cache=pipeline.vae_cache)
130131
encoder_hidden_states = text_encode(pipeline, text)
131132
for latent, encoder_hidden_state in zip(latents, encoder_hidden_states):
132133
writer.write(create_example(latent, encoder_hidden_state))

src/maxdiffusion/generate_sdxl.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,18 @@ def tokenize(prompt, pipeline):
115115
return inputs
116116

117117

118-
def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
118+
def get_unet_inputs(pipeline, scheduler_params, states, config, rng, mesh, batch_size):
119119
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
120120

121121
vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1)
122122
prompt_ids = [config.prompt] * batch_size
123123
prompt_ids = tokenize(prompt_ids, pipeline)
124+
prompt_ids = jax.lax.with_sharding_constraint(prompt_ids, jax.sharding.NamedSharding(mesh, P("data", None, None)))
124125
negative_prompt_ids = [config.negative_prompt] * batch_size
125126
negative_prompt_ids = tokenize(negative_prompt_ids, pipeline)
127+
negative_prompt_ids = jax.lax.with_sharding_constraint(
128+
negative_prompt_ids, jax.sharding.NamedSharding(mesh, P("data", None, None))
129+
)
126130
guidance_scale = config.guidance_scale
127131
guidance_rescale = config.guidance_rescale
128132
num_inference_steps = config.num_inference_steps
@@ -133,6 +137,8 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
133137
"text_encoder_2": states["text_encoder_2_state"].params,
134138
}
135139
prompt_embeds, pooled_embeds = get_embeddings(prompt_ids, pipeline, text_encoder_params)
140+
prompt_embeds = jax.lax.with_sharding_constraint(prompt_embeds, jax.sharding.NamedSharding(mesh, P("data", None, None)))
141+
pooled_embeds = jax.lax.with_sharding_constraint(pooled_embeds, jax.sharding.NamedSharding(mesh, P("data", None)))
136142

137143
batch_size = prompt_embeds.shape[0]
138144
add_time_ids = get_add_time_ids(
@@ -148,6 +154,9 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
148154

149155
prompt_embeds = jnp.concatenate([negative_prompt_embeds, prompt_embeds], axis=0)
150156
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
157+
prompt_embeds = jax.lax.with_sharding_constraint(prompt_embeds, jax.sharding.NamedSharding(mesh, P("data", None, None)))
158+
add_text_embeds = jax.lax.with_sharding_constraint(add_text_embeds, jax.sharding.NamedSharding(mesh, P("data", None)))
159+
151160
add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)
152161

153162
else:
@@ -166,8 +175,11 @@ def get_unet_inputs(pipeline, params, states, config, rng, mesh, batch_size):
166175

167176
latents = jax.random.normal(rng, shape=latents_shape, dtype=jnp.float32)
168177

178+
if isinstance(scheduler_params, dict) and "scheduler" in scheduler_params:
179+
scheduler_params = scheduler_params["scheduler"]
180+
169181
scheduler_state = pipeline.scheduler.set_timesteps(
170-
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
182+
scheduler_params, num_inference_steps=num_inference_steps, shape=latents.shape
171183
)
172184

173185
latents = latents * scheduler_state.init_noise_sigma
@@ -217,9 +229,11 @@ def run_inference(states, pipeline, params, config, rng, mesh, batch_size):
217229
def run(config):
218230
checkpoint_loader = GenerateSDXL(config)
219231
mesh = checkpoint_loader.mesh
220-
with mesh:
221-
pipeline, params = checkpoint_loader.load_checkpoint()
232+
# NOTE: load_checkpoint() is called outside the mesh context intentionally.
233+
# If checkpoint loading requires mesh-aware sharding, move this back inside `with mesh:`.
234+
pipeline, params = checkpoint_loader.load_checkpoint()
222235

236+
with mesh:
223237
noise_scheduler, noise_scheduler_state = create_scheduler(pipeline.scheduler.config, config)
224238

225239
weights_init_fn = functools.partial(pipeline.unet.init_weights, rng=checkpoint_loader.rng)
@@ -303,11 +317,13 @@ def run(config):
303317
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
304318
p_run_inference(states).block_until_ready()
305319
print("compile time: ", (time.time() - s))
320+
306321
s = time.time()
307322
with ExitStack() as stack:
308323
_ = [stack.enter_context(nn.intercept_methods(interceptor)) for interceptor in lora_interceptors]
309324
images = p_run_inference(states).block_until_ready()
310325
print("inference time: ", (time.time() - s))
326+
311327
images = jax.experimental.multihost_utils.process_allgather(images, tiled=True)
312328
numpy_images = np.array(images)
313329
images = VaeImageProcessor.numpy_to_pil(numpy_images)

src/maxdiffusion/tests/data_processing_test.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717
import os
1818
import pytest
19-
import functools
2019
import jax
2120
import jax.numpy as jnp
21+
from flax import nnx
2222
from flax.linen import partitioning as nn_partitioning
2323
from jax.sharding import Mesh
2424
from .. import pyconfig
@@ -81,11 +81,14 @@ def test_wan_vae_encode_normalization(self):
8181
video = load_video(video_path)
8282
videos = [video_processor.preprocess_video([video], height=config.height, width=config.width)]
8383
videos = jnp.array(np.squeeze(np.array(videos), axis=1), dtype=config.weights_dtype)
84-
p_vae_encode = jax.jit(functools.partial(vae_encode, vae=pipeline.vae, vae_cache=pipeline.vae_cache))
84+
85+
@nnx.jit
86+
def p_vae_encode(video, rng, vae, vae_cache):
87+
return vae_encode(video, rng, vae, vae_cache)
8588

8689
rng = jax.random.key(config.seed)
8790
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
88-
latents = p_vae_encode(videos, rng=rng)
91+
latents = p_vae_encode(videos, rng=rng, vae=pipeline.vae, vae_cache=pipeline.vae_cache)
8992
# 1. Verify Channel Count (Wan 2.1 requires 16)
9093
self.assertEqual(latents.shape[1], 16, f"Expected 16 channels, got {latents.shape[1]}")
9194

src/maxdiffusion/tests/generate_ltx2_smoke_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,10 @@ def setUpClass(cls):
5858
)
5959
cls.config = pyconfig.config
6060
checkpoint_loader = LTX2Checkpointer(config=cls.config)
61-
# Load pipeline without upsampler for simplicity in smoke test
6261
cls.pipeline, _, _ = checkpoint_loader.load_checkpoint(load_upsampler=False)
6362

64-
cls.prompt = [cls.config.prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
65-
cls.negative_prompt = [cls.config.negative_prompt] * getattr(cls.config, "global_batch_size_to_train_on", 1)
63+
cls.prompt = [cls.config.prompt]
64+
cls.negative_prompt = [cls.config.negative_prompt]
6665

6766
def test_ltx2_inference(self):
6867
"""Test that LTX2 pipeline can run inference and produce output."""
@@ -92,9 +91,6 @@ def test_ltx2_inference(self):
9291
# Check that we got frames
9392
self.assertGreater(len(videos), 0)
9493

95-
# LTX2 might also produce audio, check if it's there if expected
96-
# The config doesn't explicitly say if it's T2AV or just T2V, but the pipeline seems to handle audio.
97-
# We can just log if audio is present.
9894
if audios is not None:
9995
print(f"Audio produced with shape: {audios[0].shape}")
10096
self.assertGreater(len(audios), 0)

src/maxdiffusion/tests/generate_sdxl_smoke_test.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,15 @@
3636
class Generate(unittest.TestCase):
3737
"""Smoke test."""
3838

39+
def tearDown(self):
40+
super().tearDown()
41+
import gc
42+
43+
gc.collect()
44+
import jax
45+
46+
jax.clear_caches()
47+
3948
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
4049
def test_hyper_sdxl_lora(self):
4150
img_url = os.path.join(THIS_DIR, "images", "test_hyper_sdxl.png")
@@ -53,6 +62,7 @@ def test_hyper_sdxl_lora(self):
5362
'diffusion_scheduler_config={"_class_name" : "FlaxDDIMScheduler", "timestep_spacing" : "trailing"}',
5463
'lora_config={"lora_model_name_or_path" : ["ByteDance/Hyper-SD"], "weight_name" : ["Hyper-SDXL-2steps-lora.safetensors"], "adapter_name" : ["hyper-sdxl"], "scale": [0.7], "from_pt": ["true"]}',
5564
f"jax_cache_dir={JAX_CACHE_DIR}",
65+
"jit_initializers=False",
5666
],
5767
unittest=True,
5868
)
@@ -84,6 +94,7 @@ def test_sdxl_config(self):
8494
"run_name=sdxl-inference-test",
8595
"split_head_dim=False",
8696
f"jax_cache_dir={JAX_CACHE_DIR}",
97+
"jit_initializers=False",
8798
],
8899
unittest=True,
89100
)
@@ -116,6 +127,7 @@ def test_sdxl_from_gcs(self):
116127
"run_name=sdxl-inference-test",
117128
"split_head_dim=False",
118129
f"jax_cache_dir={JAX_CACHE_DIR}",
130+
"jit_initializers=False",
119131
],
120132
unittest=True,
121133
)
@@ -139,14 +151,16 @@ def test_controlnet_sdxl(self):
139151
"activations_dtype=bfloat16",
140152
"weights_dtype=bfloat16",
141153
f"jax_cache_dir={JAX_CACHE_DIR}",
154+
"controlnet_image=" + os.path.join(THIS_DIR, "images", "cnet_test.png"),
155+
"jit_initializers=False",
142156
],
143157
unittest=True,
144158
)
145159
images = generate_run_sdxl_controlnet(pyconfig.config)
146160
test_image = np.array(images[0]).astype(np.uint8)
147161
ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255)
148162
assert base_image.shape == test_image.shape
149-
assert ssim_compare >= 0.70
163+
assert ssim_compare >= 0.80
150164

151165
@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="Don't run smoke tests on Github Actions")
152166
def test_sdxl_lightning(self):
@@ -158,14 +172,15 @@ def test_sdxl_lightning(self):
158172
os.path.join(THIS_DIR, "..", "configs", "base_xl_lightning.yml"),
159173
"run_name=sdxl-lightning-test",
160174
f"jax_cache_dir={JAX_CACHE_DIR}",
175+
"jit_initializers=False",
161176
],
162177
unittest=True,
163178
)
164179
images = generate_run_xl(pyconfig.config)
165180
test_image = np.array(images[0]).astype(np.uint8)
166181
ssim_compare = ssim(base_image, test_image, multichannel=True, channel_axis=-1, data_range=255)
167182
assert base_image.shape == test_image.shape
168-
assert ssim_compare >= 0.70
183+
assert ssim_compare >= 0.80
169184

170185

171186
if __name__ == "__main__":

0 commit comments

Comments
 (0)