Skip to content

Commit 23b1c55

Browse files
[Pipelines] Refactor and optimize Z-Image modulev3 pipeline
- Fix autoencoder import and image postprocessing - Absorb eager F.mul negate into compiled scheduler_step for z_image - Add batched CFG for Z-Image modulev3 pipeline - Optimize fused decode, scheduler caching, and eager reduction - Apply RoPE micro-optimizations stack-info: PR: #20, branch: byungchul-sqzb/stack/1
1 parent 15db7cf commit 23b1c55

6 files changed

Lines changed: 355 additions & 306 deletions

File tree

max/python/max/pipelines/architectures/z_image_modulev3/arch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def initialize(
4949
name="ZImagePipeline",
5050
task=PipelineTask.PIXEL_GENERATION,
5151
default_encoding="bfloat16",
52-
supported_encodings={"bfloat16"},
52+
supported_encodings={"bfloat16", "float32"},
5353
example_repo_ids=[
5454
"Tongyi-MAI/Z-Image",
55-
"Zyphra/Z-Image",
55+
"Tongyi-MAI/Z-Image-Turbo",
5656
],
5757
pipeline_model=ZImagePipeline, # type: ignore[arg-type]
5858
context_type=PixelContext,

max/python/max/pipelines/architectures/z_image_modulev3/layers/attention.py

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,49 @@
1313

1414
import math
1515

16+
from max.dtype import DType
1617
from max.experimental import functional as F
1718
from max.experimental.nn import Linear, Module
1819
from max.experimental.nn.norm import RMSNorm
19-
from max.experimental.nn.sequential import ModuleList
2020
from max.experimental.tensor import Tensor
2121
from max.nn.attention.mask_config import MHAMaskVariant
2222
from max.nn.kernels import flash_attention_gpu as _flash_attention_gpu
23-
24-
from ...flux2_modulev3.layers.embeddings import apply_rotary_emb
23+
from max.nn.kernels import (
24+
rope_ragged_with_position_ids as _rope_ragged_with_position_ids,
25+
)
2526

2627
flash_attention_gpu = F.functional(_flash_attention_gpu)
28+
rope_ragged_with_position_ids = F.functional(_rope_ragged_with_position_ids)
29+
30+
31+
def _apply_zimage_qk_rope(
32+
query: Tensor,
33+
key: Tensor,
34+
freqs_cis: Tensor,
35+
) -> tuple[Tensor, Tensor]:
36+
"""Apply RoPE using precomputed interleaved [cos, sin] frequencies."""
37+
batch_size = query.shape[0]
38+
seq_len = query.shape[1]
39+
num_heads = query.shape[2]
40+
head_dim = query.shape[3]
41+
42+
query_ragged = F.reshape(query, [batch_size * seq_len, num_heads, head_dim])
43+
key_ragged = F.reshape(key, [batch_size * seq_len, num_heads, head_dim])
44+
45+
position_ids = F.arange(0, seq_len, dtype=DType.uint32, device=query.device)
46+
position_ids = F.broadcast_to(position_ids[None, :], [batch_size, seq_len])
47+
position_ids = F.reshape(position_ids, [batch_size * seq_len])
48+
49+
query_out = rope_ragged_with_position_ids(
50+
query_ragged, freqs_cis, position_ids, interleaved=True
51+
)
52+
key_out = rope_ragged_with_position_ids(
53+
key_ragged, freqs_cis, position_ids, interleaved=True
54+
)
55+
return (
56+
F.reshape(query_out, [batch_size, seq_len, num_heads, head_dim]),
57+
F.reshape(key_out, [batch_size, seq_len, num_heads, head_dim]),
58+
)
2759

2860

2961
class ZImageAttention(Module[..., Tensor]):
@@ -45,13 +77,12 @@ def __init__(
4577
self.norm_q = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
4678
self.norm_k = RMSNorm(self.head_dim, eps=eps) if qk_norm else None
4779

48-
# Keep ModuleList naming for diffusers-compatible key loading.
49-
self.to_out = ModuleList([Linear(dim, dim, bias=False)])
80+
self.to_out = Linear(dim, dim, bias=False)
5081

5182
def forward(
5283
self,
5384
hidden_states: Tensor,
54-
freqs_cis: tuple[Tensor, Tensor],
85+
freqs_cis: Tensor,
5586
) -> Tensor:
5687
batch_size = hidden_states.shape[0]
5788
seq_len = hidden_states.shape[1]
@@ -73,22 +104,7 @@ def forward(
73104
if self.norm_k is not None:
74105
key = self.norm_k(key)
75106

76-
query = apply_rotary_emb(
77-
query,
78-
freqs_cis,
79-
use_real=True,
80-
use_real_unbind_dim=-1,
81-
sequence_dim=1,
82-
)
83-
key = apply_rotary_emb(
84-
key,
85-
freqs_cis,
86-
use_real=True,
87-
use_real_unbind_dim=-1,
88-
sequence_dim=1,
89-
)
90-
query = query.cast(value.dtype)
91-
key = key.cast(value.dtype)
107+
query, key = _apply_zimage_qk_rope(query, key, freqs_cis)
92108

93109
out = flash_attention_gpu(
94110
query,
@@ -99,4 +115,4 @@ def forward(
99115
)
100116

101117
out = F.reshape(out, [batch_size, seq_len, self.inner_dim])
102-
return self.to_out[0](out)
118+
return self.to_out(out)

max/python/max/pipelines/architectures/z_image_modulev3/layers/embeddings.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from max.experimental.nn import Linear, Module
1919
from max.experimental.tensor import Tensor
2020

21-
from ...flux2_modulev3.layers.embeddings import get_1d_rotary_pos_embed
22-
2321

2422
class TimestepEmbedder(Module[[Tensor], Tensor]):
2523
def __init__(
@@ -67,7 +65,21 @@ def forward(self, t: Tensor) -> Tensor:
6765
return t_emb
6866

6967

70-
class RopeEmbedder(Module[[Tensor], tuple[Tensor, Tensor]]):
68+
def _get_1d_rope_interleaved(
69+
dim: int,
70+
pos: Tensor,
71+
theta: float = 10000.0,
72+
) -> Tensor:
73+
"""Compute 1-D RoPE in [cos, sin] interleaved pair format."""
74+
half = dim // 2
75+
freq_exp = F.arange(0, half, dtype=DType.float32, device=pos.device) / half
76+
freq = 1.0 / (theta**freq_exp)
77+
freqs = F.outer(pos, freq)
78+
paired = F.stack([F.cos(freqs), F.sin(freqs)], axis=2)
79+
return F.reshape(paired, [freqs.shape[0], dim])
80+
81+
82+
class RopeEmbedder(Module[[Tensor], Tensor]):
7183
def __init__(
7284
self,
7385
theta: float = 256.0,
@@ -76,28 +88,15 @@ def __init__(
7688
self.theta = theta
7789
self.axes_dims = axes_dims
7890

79-
def forward(self, ids: Tensor) -> tuple[Tensor, Tensor]:
80-
if ids.rank != 2:
81-
raise ValueError(f"Expected 2D ids tensor, got rank={ids.rank}")
82-
83-
if int(ids.shape[-1]) != len(self.axes_dims):
84-
raise ValueError(
85-
"ids last dimension must match axes_dims length "
86-
f"({len(self.axes_dims)}), got {ids.shape[-1]}"
87-
)
88-
91+
def forward(self, ids: Tensor) -> Tensor:
8992
pos = ids.cast(DType.float32)
90-
cos_out = []
91-
sin_out = []
93+
parts = []
9294
for i in range(len(self.axes_dims)):
93-
cos_i, sin_i = get_1d_rotary_pos_embed(
94-
self.axes_dims[i],
95-
pos[:, i],
96-
theta=self.theta,
97-
use_real=True,
98-
repeat_interleave_real=True,
95+
parts.append(
96+
_get_1d_rope_interleaved(
97+
self.axes_dims[i],
98+
pos[:, i],
99+
theta=self.theta,
100+
)
99101
)
100-
cos_out.append(cos_i)
101-
sin_out.append(sin_i)
102-
103-
return F.concat(cos_out, axis=-1), F.concat(sin_out, axis=-1)
102+
return F.concat(parts, axis=-1)

0 commit comments

Comments
 (0)