Skip to content

Commit c58e9d4

Browse files
[Pipelines] Implement Z-Image ModuleV2 pipeline
Port Z-Image to the Graph API / ModuleV2 runtime using V2 text encoder, transformer, and VAE components. Restore the current ModuleV3 feature set and behavior in the V2 path, including: - Z-Image transformer/model/config/weight adapter wiring - diffusion pipeline, arch registration, and ModuleV2/ModuleV3 selection via --prefer-module-v3 - batched CFG, CFG renormalization, and image parity with ModuleV3 - transformer-side RoPE micro-optimizations: - single unified RoPE embedder call - interleaved [cos, sin] frequency generation - rope_ragged_with_position_ids hot path - preamble dtype cast and direct modulation slicing
1 parent f76185d commit c58e9d4

15 files changed

Lines changed: 2166 additions & 7 deletions

File tree

max/examples/diffusion/simple_offline_generation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@
8484
"Flux2KleinPipeline_ModuleV3",
8585
}
8686

87+
_Z_IMAGE_ARCH_NAMES = {
88+
"ZImagePipeline",
89+
"ZImagePipeline_ModuleV3",
90+
}
91+
8792

8893
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
8994
"""Parse command-line arguments for the pixel generation example.
@@ -419,7 +424,7 @@ async def generate_image(args: argparse.Namespace) -> None:
419424
max_length = components_config["tokenizer"]["config_dict"].get(
420425
"model_max_length", None
421426
)
422-
if arch.name in _FLUX2_ARCH_NAMES or arch.name == "ZImagePipeline":
427+
if arch.name in _FLUX2_ARCH_NAMES or arch.name in _Z_IMAGE_ARCH_NAMES:
423428
max_length = 512
424429
print(f"Using max length: {max_length} for tokenizer")
425430

max/python/max/pipelines/architectures/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def register_all_models() -> None:
8282
from .qwen3vl_moe import qwen3vl_arch, qwen3vl_moe_arch
8383
from .unified_eagle_llama3 import unified_eagle_llama3_arch
8484
from .unified_mtp_deepseekV3 import unified_mtp_deepseekV3_arch
85-
from .z_image_modulev3 import z_image_arch
85+
from .z_image import z_image_arch
86+
from .z_image_modulev3 import z_image_modulev3_arch
8687

8788
architectures = [
8889
exaone_arch,
@@ -138,6 +139,7 @@ def register_all_models() -> None:
138139
unified_eagle_llama3_arch,
139140
unified_mtp_deepseekV3_arch,
140141
z_image_arch,
142+
z_image_modulev3_arch,
141143
]
142144

143145
for arch in architectures:
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from .arch import ZImageArchConfig, z_image_arch
15+
from .layers.attention import ZImageAttention
16+
from .layers.embeddings import RopeEmbedder, TimestepEmbedder
17+
from .model import ZImageTransformerModel
18+
from .model_config import ZImageConfig, ZImageConfigBase
19+
from .z_image import ZImageTransformer2DModel
20+
21+
__all__ = [
22+
"RopeEmbedder",
23+
"TimestepEmbedder",
24+
"ZImageArchConfig",
25+
"ZImageAttention",
26+
"ZImageConfig",
27+
"ZImageConfigBase",
28+
"ZImageTransformer2DModel",
29+
"ZImageTransformerModel",
30+
"z_image_arch",
31+
]
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from __future__ import annotations
15+
16+
from dataclasses import dataclass
17+
18+
from max.graph.weights import WeightsFormat
19+
from max.interfaces import PipelineTask
20+
from max.pipelines.core import PixelContext
21+
from max.pipelines.lib import PixelGenerationTokenizer, SupportedArchitecture
22+
from max.pipelines.lib.config import MAXModelConfig, PipelineConfig
23+
from max.pipelines.lib.interfaces import ArchConfig
24+
from typing_extensions import Self
25+
26+
from .pipeline_z_image import ZImagePipeline
27+
28+
29+
@dataclass(kw_only=True)
30+
class ZImageArchConfig(ArchConfig):
31+
pipeline_config: PipelineConfig
32+
33+
def get_max_seq_len(self) -> int:
34+
return 0
35+
36+
@classmethod
37+
def initialize(
38+
cls,
39+
pipeline_config: PipelineConfig,
40+
model_config: MAXModelConfig | None = None,
41+
) -> Self:
42+
model_config = model_config or pipeline_config.model
43+
if len(model_config.device_specs) != 1:
44+
raise ValueError("Z-Image is only supported on a single device")
45+
return cls(pipeline_config=pipeline_config)
46+
47+
48+
z_image_arch = SupportedArchitecture(
49+
name="ZImagePipeline",
50+
task=PipelineTask.PIXEL_GENERATION,
51+
default_encoding="bfloat16",
52+
supported_encodings={"bfloat16", "float32"},
53+
example_repo_ids=[
54+
"Tongyi-MAI/Z-Image",
55+
"Tongyi-MAI/Z-Image-Turbo",
56+
],
57+
pipeline_model=ZImagePipeline, # type: ignore[arg-type]
58+
context_type=PixelContext,
59+
default_weights_format=WeightsFormat.safetensors,
60+
tokenizer=PixelGenerationTokenizer,
61+
config=ZImageArchConfig,
62+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from .attention import ZImageAttention
15+
from .embeddings import RopeEmbedder, TimestepEmbedder
16+
17+
__all__ = ["RopeEmbedder", "TimestepEmbedder", "ZImageAttention"]
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2026, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from __future__ import annotations
15+
16+
from max.dtype import DType
17+
from max.graph import DeviceRef, TensorValue, ops
18+
from max.nn.attention.mask_config import MHAMaskVariant
19+
from max.nn.kernels import flash_attention_gpu, rope_ragged_with_position_ids
20+
from max.nn.layer import LayerList, Module
21+
from max.nn.linear import Linear
22+
from max.nn.norm import RMSNorm
23+
24+
25+
def _apply_zimage_qk_rope(
26+
query: TensorValue,
27+
key: TensorValue,
28+
freqs_cis: TensorValue,
29+
) -> tuple[TensorValue, TensorValue]:
30+
"""Apply RoPE using precomputed interleaved [cos, sin] frequencies."""
31+
batch_size = query.shape[0]
32+
seq_len = query.shape[1]
33+
num_heads = query.shape[2]
34+
head_dim = query.shape[3]
35+
36+
query_ragged = ops.reshape(
37+
query, [batch_size * seq_len, num_heads, head_dim]
38+
)
39+
key_ragged = ops.reshape(key, [batch_size * seq_len, num_heads, head_dim])
40+
41+
position_ids = ops.range(
42+
0, seq_len, dtype=DType.uint32, device=query.device
43+
)
44+
position_ids = ops.broadcast_to(
45+
ops.unsqueeze(position_ids, 0), [batch_size, seq_len]
46+
)
47+
48+
query_out = rope_ragged_with_position_ids(
49+
query_ragged, freqs_cis, position_ids, interleaved=True
50+
)
51+
key_out = rope_ragged_with_position_ids(
52+
key_ragged, freqs_cis, position_ids, interleaved=True
53+
)
54+
return (
55+
ops.reshape(query_out, [batch_size, seq_len, num_heads, head_dim]),
56+
ops.reshape(key_out, [batch_size, seq_len, num_heads, head_dim]),
57+
)
58+
59+
60+
class ZImageAttention(Module):
61+
def __init__(
62+
self,
63+
dim: int,
64+
n_heads: int,
65+
qk_norm: bool,
66+
eps: float,
67+
*,
68+
dtype: DType,
69+
device: DeviceRef,
70+
) -> None:
71+
"""Initialize ZImageAttention."""
72+
super().__init__()
73+
self.head_dim = dim // n_heads
74+
self.inner_dim = dim
75+
self.n_heads = n_heads
76+
77+
self.to_q = Linear(
78+
in_dim=dim,
79+
out_dim=dim,
80+
dtype=dtype,
81+
device=device,
82+
has_bias=False,
83+
)
84+
self.to_k = Linear(
85+
in_dim=dim,
86+
out_dim=dim,
87+
dtype=dtype,
88+
device=device,
89+
has_bias=False,
90+
)
91+
self.to_v = Linear(
92+
in_dim=dim,
93+
out_dim=dim,
94+
dtype=dtype,
95+
device=device,
96+
has_bias=False,
97+
)
98+
99+
self.norm_q: RMSNorm | None = (
100+
RMSNorm(self.head_dim, dtype=dtype, eps=eps) if qk_norm else None
101+
)
102+
self.norm_k: RMSNorm | None = (
103+
RMSNorm(self.head_dim, dtype=dtype, eps=eps) if qk_norm else None
104+
)
105+
106+
# Keep LayerList naming for diffusers-compatible key loading.
107+
self.to_out = LayerList(
108+
[
109+
Linear(
110+
in_dim=dim,
111+
out_dim=dim,
112+
dtype=dtype,
113+
device=device,
114+
has_bias=False,
115+
)
116+
]
117+
)
118+
119+
def __call__(
120+
self,
121+
hidden_states: TensorValue,
122+
freqs_cis: TensorValue,
123+
) -> TensorValue:
124+
"""Apply self-attention with rotary position embeddings."""
125+
batch_size = hidden_states.shape[0]
126+
seq_len = hidden_states.shape[1]
127+
128+
query = self.to_q(hidden_states)
129+
key = self.to_k(hidden_states)
130+
value = self.to_v(hidden_states)
131+
132+
query = ops.reshape(
133+
query, [batch_size, seq_len, self.n_heads, self.head_dim]
134+
)
135+
key = ops.reshape(
136+
key, [batch_size, seq_len, self.n_heads, self.head_dim]
137+
)
138+
value = ops.reshape(
139+
value, [batch_size, seq_len, self.n_heads, self.head_dim]
140+
)
141+
142+
if self.norm_q is not None:
143+
query = self.norm_q(query)
144+
if self.norm_k is not None:
145+
key = self.norm_k(key)
146+
147+
query, key = _apply_zimage_qk_rope(query, key, freqs_cis)
148+
149+
out = flash_attention_gpu(
150+
query,
151+
key,
152+
value,
153+
mask_variant=MHAMaskVariant.NULL_MASK,
154+
scale=1.0 / (self.head_dim**0.5),
155+
)
156+
157+
out = ops.reshape(out, [batch_size, seq_len, self.inner_dim])
158+
return self.to_out[0](out)

0 commit comments

Comments
 (0)