Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion max/examples/diffusion/simple_offline_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@
"Flux2KleinPipeline_ModuleV3",
}

_Z_IMAGE_ARCH_NAMES = {
"ZImagePipeline",
"ZImagePipeline_ModuleV3",
}


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
"""Parse command-line arguments for the pixel generation example.
Expand Down Expand Up @@ -419,7 +424,7 @@ async def generate_image(args: argparse.Namespace) -> None:
max_length = components_config["tokenizer"]["config_dict"].get(
"model_max_length", None
)
if arch.name in _FLUX2_ARCH_NAMES or arch.name == "ZImagePipeline":
if arch.name in _FLUX2_ARCH_NAMES or arch.name in _Z_IMAGE_ARCH_NAMES:
max_length = 512
print(f"Using max length: {max_length} for tokenizer")

Expand Down
4 changes: 3 additions & 1 deletion max/python/max/pipelines/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def register_all_models() -> None:
from .qwen3vl_moe import qwen3vl_arch, qwen3vl_moe_arch
from .unified_eagle_llama3 import unified_eagle_llama3_arch
from .unified_mtp_deepseekV3 import unified_mtp_deepseekV3_arch
from .z_image_modulev3 import z_image_arch
from .z_image import z_image_arch
from .z_image_modulev3 import z_image_modulev3_arch

architectures = [
exaone_arch,
Expand Down Expand Up @@ -138,6 +139,7 @@ def register_all_models() -> None:
unified_eagle_llama3_arch,
unified_mtp_deepseekV3_arch,
z_image_arch,
z_image_modulev3_arch,
]

for arch in architectures:
Expand Down
31 changes: 31 additions & 0 deletions max/python/max/pipelines/architectures/z_image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from .arch import ZImageArchConfig, z_image_arch
from .layers.attention import ZImageAttention
from .layers.embeddings import RopeEmbedder, TimestepEmbedder
from .model import ZImageTransformerModel
from .model_config import ZImageConfig, ZImageConfigBase
from .z_image import ZImageTransformer2DModel

__all__ = [
"RopeEmbedder",
"TimestepEmbedder",
"ZImageArchConfig",
"ZImageAttention",
"ZImageConfig",
"ZImageConfigBase",
"ZImageTransformer2DModel",
"ZImageTransformerModel",
"z_image_arch",
]
62 changes: 62 additions & 0 deletions max/python/max/pipelines/architectures/z_image/arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from __future__ import annotations

from dataclasses import dataclass

from max.graph.weights import WeightsFormat
from max.interfaces import PipelineTask
from max.pipelines.core import PixelContext
from max.pipelines.lib import PixelGenerationTokenizer, SupportedArchitecture
from max.pipelines.lib.config import MAXModelConfig, PipelineConfig
from max.pipelines.lib.interfaces import ArchConfig
from typing_extensions import Self

from .pipeline_z_image import ZImagePipeline


@dataclass(kw_only=True)
class ZImageArchConfig(ArchConfig):
pipeline_config: PipelineConfig

def get_max_seq_len(self) -> int:
return 0

@classmethod
def initialize(
cls,
pipeline_config: PipelineConfig,
model_config: MAXModelConfig | None = None,
) -> Self:
model_config = model_config or pipeline_config.model
if len(model_config.device_specs) != 1:
raise ValueError("Z-Image is only supported on a single device")
return cls(pipeline_config=pipeline_config)


z_image_arch = SupportedArchitecture(
name="ZImagePipeline",
task=PipelineTask.PIXEL_GENERATION,
default_encoding="bfloat16",
supported_encodings={"bfloat16", "float32"},
example_repo_ids=[
"Tongyi-MAI/Z-Image",
"Tongyi-MAI/Z-Image-Turbo",
],
pipeline_model=ZImagePipeline, # type: ignore[arg-type]
context_type=PixelContext,
default_weights_format=WeightsFormat.safetensors,
tokenizer=PixelGenerationTokenizer,
config=ZImageArchConfig,
)
17 changes: 17 additions & 0 deletions max/python/max/pipelines/architectures/z_image/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from .attention import ZImageAttention
from .embeddings import RopeEmbedder, TimestepEmbedder

__all__ = ["RopeEmbedder", "TimestepEmbedder", "ZImageAttention"]
158 changes: 158 additions & 0 deletions max/python/max/pipelines/architectures/z_image/layers/attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2026, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from __future__ import annotations

from max.dtype import DType
from max.graph import DeviceRef, TensorValue, ops
from max.nn.attention.mask_config import MHAMaskVariant
from max.nn.kernels import flash_attention_gpu, rope_ragged_with_position_ids
from max.nn.layer import LayerList, Module
from max.nn.linear import Linear
from max.nn.norm import RMSNorm


def _apply_zimage_qk_rope(
query: TensorValue,
key: TensorValue,
freqs_cis: TensorValue,
) -> tuple[TensorValue, TensorValue]:
"""Apply RoPE using precomputed interleaved [cos, sin] frequencies."""
batch_size = query.shape[0]
seq_len = query.shape[1]
num_heads = query.shape[2]
head_dim = query.shape[3]

query_ragged = ops.reshape(
query, [batch_size * seq_len, num_heads, head_dim]
)
key_ragged = ops.reshape(key, [batch_size * seq_len, num_heads, head_dim])

position_ids = ops.range(
0, seq_len, dtype=DType.uint32, device=query.device
)
position_ids = ops.broadcast_to(
ops.unsqueeze(position_ids, 0), [batch_size, seq_len]
)
Comment on lines +41 to +46
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The position_ids tensor must be flattened to 1D to match the first dimension of the ragged input tensors (query_ragged, key_ragged) passed to the rope_ragged_with_position_ids kernel. Currently, it is a 2D tensor of shape [batch_size, seq_len], which will cause a shape mismatch or incorrect indexing in the kernel.

Suggested change
position_ids = ops.range(
0, seq_len, dtype=DType.uint32, device=query.device
)
position_ids = ops.broadcast_to(
ops.unsqueeze(position_ids, 0), [batch_size, seq_len]
)
position_ids = ops.range(
0, seq_len, dtype=DType.uint32, device=query.device
)
position_ids = ops.broadcast_to(
ops.unsqueeze(position_ids, 0), [batch_size, seq_len]
)
position_ids = ops.reshape(position_ids, [batch_size * seq_len])


query_out = rope_ragged_with_position_ids(
query_ragged, freqs_cis, position_ids, interleaved=True
)
key_out = rope_ragged_with_position_ids(
key_ragged, freqs_cis, position_ids, interleaved=True
)
return (
ops.reshape(query_out, [batch_size, seq_len, num_heads, head_dim]),
ops.reshape(key_out, [batch_size, seq_len, num_heads, head_dim]),
)


class ZImageAttention(Module):
def __init__(
self,
dim: int,
n_heads: int,
qk_norm: bool,
eps: float,
*,
dtype: DType,
device: DeviceRef,
) -> None:
"""Initialize ZImageAttention."""
super().__init__()
self.head_dim = dim // n_heads
self.inner_dim = dim
self.n_heads = n_heads

self.to_q = Linear(
in_dim=dim,
out_dim=dim,
dtype=dtype,
device=device,
has_bias=False,
)
self.to_k = Linear(
in_dim=dim,
out_dim=dim,
dtype=dtype,
device=device,
has_bias=False,
)
self.to_v = Linear(
in_dim=dim,
out_dim=dim,
dtype=dtype,
device=device,
has_bias=False,
)

self.norm_q: RMSNorm | None = (
RMSNorm(self.head_dim, dtype=dtype, eps=eps) if qk_norm else None
)
self.norm_k: RMSNorm | None = (
RMSNorm(self.head_dim, dtype=dtype, eps=eps) if qk_norm else None
)

# Keep LayerList naming for diffusers-compatible key loading.
self.to_out = LayerList(
[
Linear(
in_dim=dim,
out_dim=dim,
dtype=dtype,
device=device,
has_bias=False,
)
]
)

def __call__(
self,
hidden_states: TensorValue,
freqs_cis: TensorValue,
) -> TensorValue:
"""Apply self-attention with rotary position embeddings."""
batch_size = hidden_states.shape[0]
seq_len = hidden_states.shape[1]

query = self.to_q(hidden_states)
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)

query = ops.reshape(
query, [batch_size, seq_len, self.n_heads, self.head_dim]
)
key = ops.reshape(
key, [batch_size, seq_len, self.n_heads, self.head_dim]
)
value = ops.reshape(
value, [batch_size, seq_len, self.n_heads, self.head_dim]
)

if self.norm_q is not None:
query = self.norm_q(query)
if self.norm_k is not None:
key = self.norm_k(key)

query, key = _apply_zimage_qk_rope(query, key, freqs_cis)

out = flash_attention_gpu(
query,
key,
value,
mask_variant=MHAMaskVariant.NULL_MASK,
scale=1.0 / (self.head_dim**0.5),
)

out = ops.reshape(out, [batch_size, seq_len, self.inner_dim])
return self.to_out[0](out)
Loading
Loading