Skip to content

Commit 67233e9

Browse files
committed
fix lint issues
1 parent bca0d0e commit 67233e9

3 files changed

Lines changed: 31 additions & 51 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan_animate.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -735,38 +735,32 @@ def __call__(
735735
# frame-wise reshape so each query batch element still corresponds to a full frame.
736736
query = jax.lax.with_sharding_constraint(
737737
query,
738-
nn.logical_to_mesh_axes(
739-
(
740-
common_types.BATCH,
741-
None,
742-
common_types.CROSS_ATTN_HEAD,
743-
common_types.D_KV,
744-
)
745-
),
738+
nn.logical_to_mesh_axes((
739+
common_types.BATCH,
740+
None,
741+
common_types.CROSS_ATTN_HEAD,
742+
common_types.D_KV,
743+
)),
746744
)
747745
key = jax.lax.with_sharding_constraint(
748746
key,
749-
nn.logical_to_mesh_axes(
750-
(
751-
common_types.BATCH,
752-
None,
753-
None,
754-
common_types.CROSS_ATTN_HEAD,
755-
common_types.D_KV,
756-
)
757-
),
747+
nn.logical_to_mesh_axes((
748+
common_types.BATCH,
749+
None,
750+
None,
751+
common_types.CROSS_ATTN_HEAD,
752+
common_types.D_KV,
753+
)),
758754
)
759755
value = jax.lax.with_sharding_constraint(
760756
value,
761-
nn.logical_to_mesh_axes(
762-
(
763-
common_types.BATCH,
764-
None,
765-
None,
766-
common_types.CROSS_ATTN_HEAD,
767-
common_types.D_KV,
768-
)
769-
),
757+
nn.logical_to_mesh_axes((
758+
common_types.BATCH,
759+
None,
760+
None,
761+
common_types.CROSS_ATTN_HEAD,
762+
common_types.D_KV,
763+
)),
770764
)
771765

772766
query_S = query.shape[1]
@@ -803,13 +797,11 @@ def __call__(
803797
hidden_states = self.to_out(attn_output)
804798
hidden_states = jax.lax.with_sharding_constraint(
805799
hidden_states,
806-
nn.logical_to_mesh_axes(
807-
(
808-
common_types.BATCH,
809-
common_types.LENGTH,
810-
common_types.EMBED,
811-
)
812-
),
800+
nn.logical_to_mesh_axes((
801+
common_types.BATCH,
802+
common_types.LENGTH,
803+
common_types.EMBED,
804+
)),
813805
)
814806

815807
if attention_mask is not None:

src/maxdiffusion/pipelines/wan/wan_pipeline_animate.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -610,9 +610,7 @@ def prepare_prev_segment_cond_latents(
610610
remaining_frames = segment_frame_length - prev_segment_cond_frames
611611
remaining = jnp.zeros((batch_size, 3, remaining_frames, height, width), dtype=vae_dtype)
612612

613-
full_cond_video = jnp.concatenate(
614-
[prev_segment_cond_video.astype(vae_dtype), remaining], axis=2
615-
) # (B, C, T_seg, H, W)
613+
full_cond_video = jnp.concatenate([prev_segment_cond_video.astype(vae_dtype), remaining], axis=2) # (B, C, T_seg, H, W)
616614

617615
cond_latents = self._encode_video_to_latents(full_cond_video, dtype)
618616
# (B, T_lat, H_lat, W_lat, z_dim)
@@ -785,9 +783,7 @@ def __call__(
785783
f"`segment_frame_length - 1` must be divisible by {self.vae_scale_factor_temporal}. "
786784
f"Rounding {segment_frame_length}."
787785
)
788-
segment_frame_length = (
789-
segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
790-
)
786+
segment_frame_length = segment_frame_length // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
791787
segment_frame_length = max(segment_frame_length, 1)
792788

793789
do_classifier_free_guidance = guidance_scale > 1.0
@@ -994,9 +990,7 @@ def __call__(
994990
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
995991

996992
noise_pred = noise_pred.astype(seg_latents.dtype)
997-
seg_latents, scheduler_state = self.scheduler.step(
998-
scheduler_state, noise_pred, t, seg_latents, return_dict=False
999-
)
993+
seg_latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, seg_latents, return_dict=False)
1000994

1001995
# Decode this segment (skip reference frame at index 0).
1002996
out_frames_cf = self._decode_segment_to_pixels(seg_latents[:, 1:, :, :, :])

src/maxdiffusion/tests/wan_animate_diffusers_parity_test.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -503,9 +503,7 @@ def fake_encode(video, dtype):
503503
latent_t = (video.shape[2] - 1) // self.max_pipeline.vae_scale_factor_temporal + 1
504504
latent_h = video.shape[3] // self.max_pipeline.vae_scale_factor_spatial
505505
latent_w = video.shape[4] // self.max_pipeline.vae_scale_factor_spatial
506-
return jnp.zeros(
507-
(video.shape[0], latent_t, latent_h, latent_w, self.max_pipeline.vae.z_dim), dtype=jnp.float32
508-
)
506+
return jnp.zeros((video.shape[0], latent_t, latent_h, latent_w, self.max_pipeline.vae.z_dim), dtype=jnp.float32)
509507

510508
self.max_pipeline._encode_video_to_latents = fake_encode
511509

@@ -722,9 +720,7 @@ def __call__(
722720
np.testing.assert_allclose(to_numpy(capture["pose_hidden_states"]), to_numpy(expected_pose), atol=0.0, rtol=0.0)
723721
np.testing.assert_allclose(to_numpy(capture["face_pixel_values"]), to_numpy(face_video), atol=0.0, rtol=0.0)
724722
np.testing.assert_allclose(to_numpy(capture["encoder_hidden_states"]), to_numpy(prompt_embeds), atol=0.0, rtol=0.0)
725-
np.testing.assert_allclose(
726-
to_numpy(capture["encoder_hidden_states_image"]), to_numpy(image_embeds), atol=0.0, rtol=0.0
727-
)
723+
np.testing.assert_allclose(to_numpy(capture["encoder_hidden_states_image"]), to_numpy(image_embeds), atol=0.0, rtol=0.0)
728724
self.assertEqual(capture["motion_encode_batch_size"], 7)
729725
self.assertFalse(capture["return_dict"])
730726
np.testing.assert_allclose(to_numpy(noise_pred), to_numpy(latents), atol=0.0, rtol=0.0)
@@ -872,9 +868,7 @@ def test_flax_unipc_flow_sigmas_match_diffusers(self):
872868
max_model_output = jnp.array(to_numpy(hf_model_output))
873869

874870
hf_sample = hf_scheduler.step(hf_model_output, int(timestep), hf_sample, return_dict=False)[0]
875-
max_sample, max_state = max_scheduler.step(
876-
max_state, max_model_output, int(timestep), max_sample, return_dict=False
877-
)
871+
max_sample, max_state = max_scheduler.step(max_state, max_model_output, int(timestep), max_sample, return_dict=False)
878872

879873
np.testing.assert_allclose(to_numpy(max_sample), to_numpy(hf_sample), atol=1e-4, rtol=1e-5)
880874

0 commit comments

Comments
 (0)