Skip to content

[MAX] Add UniPC multistep scheduler for Wan diffusion#13

Draft
jglee-sqbits wants to merge 1 commit into
mainfrom
jglee-sqbits/stack/1
Draft

[MAX] Add UniPC multistep scheduler for Wan diffusion#13
jglee-sqbits wants to merge 1 commit into
mainfrom
jglee-sqbits/stack/1

Conversation

@jglee-sqbits
Copy link
Copy Markdown
Collaborator

@jglee-sqbits jglee-sqbits commented Apr 1, 2026

Stacked PRs:


[MAX] Add UniPC multistep scheduler for Wan diffusion

Summary

Add a numpy-only UniPC multistep scheduler for Wan diffusion pipelines.

Description

  • Implements the UniPC-BH2 algorithm with corrector and predictor steps
  • Supports flow-matching sigma schedules (used by Wan 2.1/2.2)
  • Provides build_step_coefficients() to precompute per-step coefficient matrices on the host, enabling on-device scheduler steps without Python-side numpy calls during denoising
  • Registers UniPCMultistepScheduler in the diffusion scheduler factory

This is a numpy-only port of the diffusers UniPCMultistepScheduler, specialized for the Wan pipeline configuration.

Dependencies

None — can be merged independently.

Checklist

  • PR is small and focused
  • I ran ./bazelw run format to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

## Summary

Add a numpy-only UniPC multistep scheduler for Wan diffusion pipelines.

## Description

- Implements the UniPC-BH2 algorithm with corrector and predictor steps
- Supports flow-matching sigma schedules (used by Wan 2.1/2.2)
- Provides `build_step_coefficients()` to precompute per-step coefficient matrices on the host, enabling on-device scheduler steps without Python-side numpy calls during denoising
- Registers `UniPCMultistepScheduler` in the diffusion scheduler factory

This is a numpy-only port of the diffusers `UniPCMultistepScheduler`, specialized for the Wan pipeline configuration.

## Dependencies

None — can be merged independently.

## Checklist

- [x] PR is small and focused
- [x] I ran `./bazelw run format` to format my changes

Assisted-by: Claude Code

Assisted-by: Claude Code

stack-info: PR: #13, branch: jglee-sqbits/stack/1
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the UniPCMultistepScheduler, a NumPy-only implementation of the UniPC-BH2 algorithm designed for fast sampling in diffusion models like the Wan 2.2 T2V pipeline. The scheduler includes support for flow-matching and pre-computes step coefficients for optimized inference. Review feedback highlights several issues in the coefficient pre-computation methods, specifically that _predictor_coefficients and _corrector_coefficients hardcode parameters such as solver_type and predict_x0, and lack the necessary logic to handle solver orders greater than two.

Comment on lines +617 to +620
b_h = float(np.expm1(-h))
sample_scale = float(sigma_t / sigma_s0)
m0_scale = float(-alpha_t * b_h)
m1_scale = 0.0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The _predictor_coefficients method hardcodes the calculation for solver_type="bh2" and assumes predict_x0=True. It should respect self.solver_type and self.predict_x0 to ensure consistency with the step() method. Specifically, hh should be determined by predict_x0, and the scale factor for m0 should be alpha_t or sigma_t depending on the prediction type.

Comment on lines +643 to +645
hh = -h
h_phi_1 = float(np.expm1(hh))
b_h = float(np.expm1(hh))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In _corrector_coefficients, b_h is hardcoded for solver_type="bh2". It should use self.solver_type to decide between hh and expm1(hh), matching the logic in multistep_uni_c_bh_update.

Comment on lines +622 to +628
if order == 2:
sigma_si_raw = float(self.sigmas[step_index - 1])
lambda_si = self._lambda_from_sigma(sigma_si_raw)
rk = (lambda_si - lambda_s0) / h
m1_scale = float(-alpha_t * b_h * 0.5 / rk)
m0_scale -= m1_scale

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

_predictor_coefficients only supports up to order=2. If solver_order is set to a higher value, this method will return incorrect coefficients (effectively falling back to order 1 for the predictor part). It should be generalized to handle arbitrary orders using a linear system solver, similar to the implementation in multistep_uni_p_bh_update.

Comment on lines +670 to +674
m1_scale = float(-alpha_t * b_h * rhos_c[0] / rk)
m0_scale = float(
-alpha_t * h_phi_1 + alpha_t * b_h * (rhos_c[0] / rk + rhos_c[-1])
)
mt_scale = float(-alpha_t * b_h * rhos_c[-1])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The coefficient calculation for m0_scale and m1_scale in _corrector_coefficients only accounts for a single rk value (rhos_c[0] / rk). For order > 2, there are multiple history terms with different rk values that must be accounted for in the summation. This will lead to incorrect results when using higher solver orders.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant