Skip to content

Commit a44a04b

Browse files
committed
[Pipelines] Add WanTokenizer and WanContext
Signed-off-by: jglee-sqbits <jingu.lee@squeezebits.com>
1 parent 4ebbf94 commit a44a04b

4 files changed

Lines changed: 241 additions & 10 deletions

File tree

max/python/max/interfaces/provider_options/modality/video.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,9 @@ class VideoProviderOptions(BaseModel):
6666
),
6767
gt=0,
6868
)
69+
70+
guidance_scale_2: float | None = Field(
71+
None,
72+
description="Secondary guidance scale for boundary timestep switching.",
73+
gt=0.0,
74+
)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
"""Wan-specific pixel generation context."""
14+
15+
from __future__ import annotations
16+
17+
from dataclasses import dataclass, field
18+
19+
import numpy as np
20+
import numpy.typing as npt
21+
from max.pipelines.core import PixelContext
22+
23+
24+
@dataclass(kw_only=True)
25+
class WanContext(PixelContext):
26+
"""Pixel generation context with Wan-specific video/MoE fields."""
27+
28+
num_frames: int | None = field(default=None)
29+
"""Number of frames for video generation."""
30+
31+
guidance_scale_2: float | None = field(default=None)
32+
"""Secondary guidance scale for low-noise expert (MoE models)."""
33+
34+
step_coefficients: npt.NDArray[np.float32] | None = field(default=None)
35+
"""Pre-computed scheduler step coefficients."""
36+
37+
boundary_timestep: float | None = field(default=None)
38+
"""Timestep threshold for switching between high/low noise experts."""
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
"""Wan-specific pixel generation tokenizer."""
14+
15+
from __future__ import annotations
16+
17+
import logging
18+
19+
import numpy as np
20+
import numpy.typing as npt
21+
import PIL.Image
22+
from max.interfaces.request import OpenResponsesRequest
23+
from max.pipelines.lib.pixel_tokenizer import PixelGenerationTokenizer
24+
25+
from .context import WanContext
26+
27+
logger = logging.getLogger("max.pipelines")
28+
29+
30+
class WanTokenizer(PixelGenerationTokenizer):
31+
"""Wan-specific tokenizer that produces WanContext with video/MoE fields."""
32+
33+
def __init__(self, *args, **kwargs) -> None:
34+
super().__init__(*args, **kwargs)
35+
# Override latent channel count: Wan uses out_channels (16) for noise
36+
# latents, not in_channels which may be 36 for I2V variants
37+
# (16 noise + 4 mask + 16 image).
38+
components = self.diffusers_config.get("components", {})
39+
transformer_config = components.get("transformer", {}).get(
40+
"config_dict", {}
41+
)
42+
self._num_channels_latents = transformer_config.get(
43+
"out_channels", transformer_config["in_channels"]
44+
)
45+
46+
def _select_wan_flow_shift(self, height: int, width: int) -> float:
47+
scheduler_cfg = (
48+
self.diffusers_config.get("components", {})
49+
.get("scheduler", {})
50+
.get("config_dict", {})
51+
)
52+
# Use explicit flow_shift from scheduler config if set (user override).
53+
cfg_shift = scheduler_cfg.get("flow_shift")
54+
if cfg_shift is not None and float(cfg_shift) != 1.0:
55+
return float(cfg_shift)
56+
# Default: interpolate based on pixel count.
57+
# 480p (480*832 = 399 360) → 3.0, 720p (720*1280 = 921 600) → 5.0
58+
pixels = height * width
59+
lo_px, hi_px = 399_360, 921_600
60+
lo_shift, hi_shift = 3.0, 5.0
61+
t = max(0.0, min(1.0, (pixels - lo_px) / (hi_px - lo_px)))
62+
return lo_shift + t * (hi_shift - lo_shift)
63+
64+
async def new_context(
65+
self,
66+
request: OpenResponsesRequest,
67+
input_image: PIL.Image.Image | None = None,
68+
) -> WanContext:
69+
base = await super().new_context(request, input_image=input_image)
70+
71+
video_options = request.body.provider_options.video
72+
image_options = request.body.provider_options.image
73+
74+
num_frames: int | None = (
75+
video_options.num_frames if video_options else None
76+
)
77+
guidance_scale_2: float | None = (
78+
video_options.guidance_scale_2 if video_options else None
79+
)
80+
81+
height = base.height
82+
width = base.width
83+
timesteps: npt.NDArray[np.float32] = base.timesteps
84+
sigmas: npt.NDArray[np.float32] = base.sigmas
85+
86+
if getattr(self._scheduler, "use_flow_sigmas", False):
87+
self._scheduler.flow_shift = self._select_wan_flow_shift(
88+
height, width
89+
)
90+
latent_height = 2 * (int(height) // (self._vae_scale_factor * 2))
91+
latent_width = 2 * (int(width) // (self._vae_scale_factor * 2))
92+
image_seq_len = (latent_height // 2) * (latent_width // 2)
93+
timesteps, sigmas = self._scheduler.retrieve_timesteps_and_sigmas(
94+
image_seq_len, base.num_inference_steps
95+
)
96+
97+
boundary_timestep: float | None = None
98+
boundary_ratio = self.diffusers_config.get("boundary_ratio")
99+
if boundary_ratio is not None:
100+
boundary_timestep = float(boundary_ratio) * float(
101+
getattr(self._scheduler, "num_train_timesteps", 1000)
102+
)
103+
104+
step_coefficients: npt.NDArray[np.float32] | None = None
105+
if hasattr(self._scheduler, "build_step_coefficients"):
106+
step_coefficients = self._scheduler.build_step_coefficients()
107+
108+
latents = base.latents
109+
if num_frames is not None:
110+
vae_scale_factor_temporal = 4
111+
latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1
112+
latent_height = 2 * (int(height) // (self._vae_scale_factor * 2))
113+
latent_width = 2 * (int(width) // (self._vae_scale_factor * 2))
114+
num_images = image_options.num_images if image_options else 1
115+
shape_5d = (
116+
num_images,
117+
self._num_channels_latents,
118+
latent_frames,
119+
latent_height,
120+
latent_width,
121+
)
122+
latents = self._randn_tensor(shape_5d, request.body.seed)
123+
124+
return WanContext(
125+
request_id=base.request_id,
126+
model_name=base.model_name,
127+
tokens=base.tokens,
128+
mask=base.mask,
129+
tokens_2=base.tokens_2,
130+
negative_tokens=base.negative_tokens,
131+
negative_mask=base.negative_mask,
132+
negative_tokens_2=base.negative_tokens_2,
133+
explicit_negative_prompt=base.explicit_negative_prompt,
134+
timesteps=timesteps,
135+
sigmas=sigmas,
136+
latents=latents,
137+
latent_image_ids=base.latent_image_ids,
138+
height=base.height,
139+
width=base.width,
140+
num_frames=num_frames,
141+
guidance_scale=base.guidance_scale,
142+
true_cfg_scale=base.true_cfg_scale,
143+
guidance_scale_2=guidance_scale_2,
144+
cfg_normalization=base.cfg_normalization,
145+
cfg_truncation=base.cfg_truncation,
146+
num_inference_steps=base.num_inference_steps,
147+
num_warmup_steps=base.num_warmup_steps,
148+
strength=base.strength,
149+
boundary_timestep=boundary_timestep,
150+
step_coefficients=step_coefficients,
151+
num_images_per_prompt=base.num_images_per_prompt,
152+
input_image=base.input_image,
153+
output_format=base.output_format,
154+
residual_threshold=base.residual_threshold,
155+
status=base.status,
156+
)

max/python/max/pipelines/lib/pixel_tokenizer.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ class PipelineClassName(str, Enum):
9999
FLUX2 = "Flux2Pipeline"
100100
FLUX2_KLEIN = "Flux2KleinPipeline"
101101
ZIMAGE = "ZImagePipeline"
102+
WAN = "WanPipeline"
103+
WAN_I2V = "WanImageToVideoPipeline"
102104

103105
@classmethod
104106
def from_class_name(cls, class_name: str) -> PipelineClassName:
@@ -902,9 +904,17 @@ async def new_context(
902904
" but may produce lower quality or unexpected results."
903905
)
904906

907+
# Resolve negative_prompt: prefer video options for video pipelines.
908+
video_options = request.body.provider_options.video
909+
negative_prompt_resolved = (
910+
video_options.negative_prompt
911+
if video_options and video_options.negative_prompt
912+
else None
913+
) or image_options.negative_prompt
914+
905915
if (
906916
image_options.true_cfg_scale > 1.0
907-
and image_options.negative_prompt is None
917+
and negative_prompt_resolved is None
908918
):
909919
logger.warning(
910920
f"true_cfg_scale={image_options.true_cfg_scale} is set, but no negative_prompt "
@@ -928,7 +938,7 @@ async def new_context(
928938
else:
929939
do_true_cfg = (
930940
image_options.true_cfg_scale > 1.0
931-
and image_options.negative_prompt is not None
941+
and negative_prompt_resolved is not None
932942
)
933943

934944
# 1. Tokenize prompts
@@ -953,7 +963,7 @@ async def new_context(
953963
) = await self._generate_tokens_ids(
954964
prompt,
955965
image_options.secondary_prompt,
956-
image_options.negative_prompt,
966+
negative_prompt_resolved,
957967
image_options.secondary_negative_prompt,
958968
do_true_cfg or do_zimage_cfg,
959969
images=images_for_tokenization,
@@ -992,28 +1002,49 @@ async def new_context(
9921002
self._pipeline_class_name != PipelineClassName.ZIMAGE
9931003
),
9941004
)
995-
height = image_options.height or preprocessed_image.height
996-
width = image_options.width or preprocessed_image.width
1005+
height = (
1006+
(video_options and video_options.height)
1007+
or image_options.height
1008+
or preprocessed_image.height
1009+
)
1010+
width = (
1011+
(video_options and video_options.width)
1012+
or image_options.width
1013+
or preprocessed_image.width
1014+
)
9971015
preprocessed_image_array = np.array(
9981016
preprocessed_image, dtype=np.uint8
9991017
).copy()
10001018
else:
10011019
height = (
1002-
image_options.height or default_sample_size * vae_scale_factor
1020+
(video_options and video_options.height)
1021+
or image_options.height
1022+
or default_sample_size * vae_scale_factor
10031023
)
10041024
width = (
1005-
image_options.width or default_sample_size * vae_scale_factor
1025+
(video_options and video_options.width)
1026+
or image_options.width
1027+
or default_sample_size * vae_scale_factor
10061028
)
10071029

10081030
# 3. Resolve image dimensions using cached static values
10091031
latent_height = 2 * (int(height) // (self._vae_scale_factor * 2))
10101032
latent_width = 2 * (int(width) // (self._vae_scale_factor * 2))
10111033
image_seq_len = (latent_height // 2) * (latent_width // 2)
10121034

1035+
video_steps = (
1036+
video_options.steps
1037+
if video_options and video_options.steps is not None
1038+
else None
1039+
)
10131040
num_inference_steps = (
1014-
image_options.steps
1015-
if "steps" in image_options.model_fields_set
1016-
else self._default_num_inference_steps
1041+
video_steps
1042+
if video_steps is not None
1043+
else (
1044+
image_options.steps
1045+
if "steps" in image_options.model_fields_set
1046+
else self._default_num_inference_steps
1047+
)
10171048
)
10181049
sigma_min = (
10191050
0.0

0 commit comments

Comments
 (0)