diff --git a/max/examples/diffusion/simple_offline_generation.py b/max/examples/diffusion/simple_offline_generation.py index f9b5a5b8a30..b0ec80b15c9 100644 --- a/max/examples/diffusion/simple_offline_generation.py +++ b/max/examples/diffusion/simple_offline_generation.py @@ -84,6 +84,11 @@ "Flux2KleinPipeline_ModuleV3", } +_Z_IMAGE_ARCH_NAMES = { + "ZImagePipeline", + "ZImagePipeline_ModuleV3", +} + def parse_args(argv: list[str] | None = None) -> argparse.Namespace: """Parse command-line arguments for the pixel generation example. @@ -419,7 +424,7 @@ async def generate_image(args: argparse.Namespace) -> None: max_length = components_config["tokenizer"]["config_dict"].get( "model_max_length", None ) - if arch.name in _FLUX2_ARCH_NAMES or arch.name == "ZImagePipeline": + if arch.name in _FLUX2_ARCH_NAMES or arch.name in _Z_IMAGE_ARCH_NAMES: max_length = 512 print(f"Using max length: {max_length} for tokenizer") diff --git a/max/python/max/pipelines/architectures/__init__.py b/max/python/max/pipelines/architectures/__init__.py index 89f574285e2..03a715b6d4c 100644 --- a/max/python/max/pipelines/architectures/__init__.py +++ b/max/python/max/pipelines/architectures/__init__.py @@ -82,7 +82,8 @@ def register_all_models() -> None: from .qwen3vl_moe import qwen3vl_arch, qwen3vl_moe_arch from .unified_eagle_llama3 import unified_eagle_llama3_arch from .unified_mtp_deepseekV3 import unified_mtp_deepseekV3_arch - from .z_image_modulev3 import z_image_arch + from .z_image import z_image_arch + from .z_image_modulev3 import z_image_modulev3_arch architectures = [ exaone_arch, @@ -138,6 +139,7 @@ def register_all_models() -> None: unified_eagle_llama3_arch, unified_mtp_deepseekV3_arch, z_image_arch, + z_image_modulev3_arch, ] for arch in architectures: diff --git a/max/python/max/pipelines/architectures/z_image/__init__.py b/max/python/max/pipelines/architectures/z_image/__init__.py new file mode 100644 index 00000000000..46710077828 --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/__init__.py @@ -0,0 +1,31 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .arch import ZImageArchConfig, z_image_arch +from .layers.attention import ZImageAttention +from .layers.embeddings import RopeEmbedder, TimestepEmbedder +from .model import ZImageTransformerModel +from .model_config import ZImageConfig, ZImageConfigBase +from .z_image import ZImageTransformer2DModel + +__all__ = [ + "RopeEmbedder", + "TimestepEmbedder", + "ZImageArchConfig", + "ZImageAttention", + "ZImageConfig", + "ZImageConfigBase", + "ZImageTransformer2DModel", + "ZImageTransformerModel", + "z_image_arch", +] diff --git a/max/python/max/pipelines/architectures/z_image/arch.py b/max/python/max/pipelines/architectures/z_image/arch.py new file mode 100644 index 00000000000..9b5f578cb93 --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/arch.py @@ -0,0 +1,62 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +from dataclasses import dataclass + +from max.graph.weights import WeightsFormat +from max.interfaces import PipelineTask +from max.pipelines.core import PixelContext +from max.pipelines.lib import PixelGenerationTokenizer, SupportedArchitecture +from max.pipelines.lib.config import MAXModelConfig, PipelineConfig +from max.pipelines.lib.interfaces import ArchConfig +from typing_extensions import Self + +from .pipeline_z_image import ZImagePipeline + + +@dataclass(kw_only=True) +class ZImageArchConfig(ArchConfig): + pipeline_config: PipelineConfig + + def get_max_seq_len(self) -> int: + return 0 + + @classmethod + def initialize( + cls, + pipeline_config: PipelineConfig, + model_config: MAXModelConfig | None = None, + ) -> Self: + model_config = model_config or pipeline_config.model + if len(model_config.device_specs) != 1: + raise ValueError("Z-Image is only supported on a single device") + return cls(pipeline_config=pipeline_config) + + +z_image_arch = SupportedArchitecture( + name="ZImagePipeline", + task=PipelineTask.PIXEL_GENERATION, + default_encoding="bfloat16", + supported_encodings={"bfloat16", "float32"}, + example_repo_ids=[ + "Tongyi-MAI/Z-Image", + "Tongyi-MAI/Z-Image-Turbo", + ], + pipeline_model=ZImagePipeline, # type: ignore[arg-type] + context_type=PixelContext, + default_weights_format=WeightsFormat.safetensors, + tokenizer=PixelGenerationTokenizer, + config=ZImageArchConfig, +) diff --git a/max/python/max/pipelines/architectures/z_image/layers/__init__.py b/max/python/max/pipelines/architectures/z_image/layers/__init__.py new file mode 100644 index 00000000000..c9c94f0d143 --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/layers/__init__.py @@ -0,0 +1,17 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .attention import ZImageAttention +from .embeddings import RopeEmbedder, TimestepEmbedder + +__all__ = ["RopeEmbedder", "TimestepEmbedder", "ZImageAttention"] diff --git a/max/python/max/pipelines/architectures/z_image/layers/attention.py b/max/python/max/pipelines/architectures/z_image/layers/attention.py new file mode 100644 index 00000000000..ae8cf623998 --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/layers/attention.py @@ -0,0 +1,158 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn.attention.mask_config import MHAMaskVariant +from max.nn.kernels import flash_attention_gpu, rope_ragged_with_position_ids +from max.nn.layer import LayerList, Module +from max.nn.linear import Linear +from max.nn.norm import RMSNorm + + +def _apply_zimage_qk_rope( + query: TensorValue, + key: TensorValue, + freqs_cis: TensorValue, +) -> tuple[TensorValue, TensorValue]: + """Apply RoPE using precomputed interleaved [cos, sin] frequencies.""" + batch_size = query.shape[0] + seq_len = query.shape[1] + num_heads = query.shape[2] + head_dim = query.shape[3] + + query_ragged = ops.reshape( + query, [batch_size * seq_len, num_heads, head_dim] + ) + key_ragged = ops.reshape(key, [batch_size * seq_len, num_heads, head_dim]) + + position_ids = ops.range( + 0, seq_len, dtype=DType.uint32, device=query.device + ) + position_ids = ops.broadcast_to( + ops.unsqueeze(position_ids, 0), [batch_size, seq_len] + ) + + query_out = rope_ragged_with_position_ids( + query_ragged, freqs_cis, position_ids, interleaved=True + ) + key_out = rope_ragged_with_position_ids( + key_ragged, freqs_cis, position_ids, interleaved=True + ) + return ( + ops.reshape(query_out, [batch_size, seq_len, num_heads, head_dim]), + ops.reshape(key_out, [batch_size, seq_len, num_heads, head_dim]), + ) + + +class ZImageAttention(Module): + def __init__( + self, + dim: int, + n_heads: int, + qk_norm: bool, + eps: float, + *, + dtype: DType, + device: DeviceRef, + ) -> None: + """Initialize ZImageAttention.""" + super().__init__() + self.head_dim = dim // n_heads + self.inner_dim = dim + self.n_heads = n_heads + + self.to_q = Linear( + in_dim=dim, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=False, + ) + self.to_k = Linear( + in_dim=dim, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=False, + ) + self.to_v = Linear( + in_dim=dim, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=False, + ) + + self.norm_q: RMSNorm | None = ( + RMSNorm(self.head_dim, dtype=dtype, eps=eps) if qk_norm else None + ) + self.norm_k: RMSNorm | None = ( + RMSNorm(self.head_dim, dtype=dtype, eps=eps) if qk_norm else None + ) + + # Keep LayerList naming for diffusers-compatible key loading. + self.to_out = LayerList( + [ + Linear( + in_dim=dim, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=False, + ) + ] + ) + + def __call__( + self, + hidden_states: TensorValue, + freqs_cis: TensorValue, + ) -> TensorValue: + """Apply self-attention with rotary position embeddings.""" + batch_size = hidden_states.shape[0] + seq_len = hidden_states.shape[1] + + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + query = ops.reshape( + query, [batch_size, seq_len, self.n_heads, self.head_dim] + ) + key = ops.reshape( + key, [batch_size, seq_len, self.n_heads, self.head_dim] + ) + value = ops.reshape( + value, [batch_size, seq_len, self.n_heads, self.head_dim] + ) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + query, key = _apply_zimage_qk_rope(query, key, freqs_cis) + + out = flash_attention_gpu( + query, + key, + value, + mask_variant=MHAMaskVariant.NULL_MASK, + scale=1.0 / (self.head_dim**0.5), + ) + + out = ops.reshape(out, [batch_size, seq_len, self.inner_dim]) + return self.to_out[0](out) diff --git a/max/python/max/pipelines/architectures/z_image/layers/embeddings.py b/max/python/max/pipelines/architectures/z_image/layers/embeddings.py new file mode 100644 index 00000000000..a3328c1524e --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/layers/embeddings.py @@ -0,0 +1,120 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +import math + +from max.dtype import DType +from max.graph import DeviceRef, TensorValue, ops +from max.nn.layer import Module +from max.nn.linear import Linear + + +class TimestepEmbedder(Module): + def __init__( + self, + out_size: int, + mid_size: int | None = None, + frequency_embedding_size: int = 256, + *, + dtype: DType, + device: DeviceRef, + ) -> None: + """Initialize TimestepEmbedder.""" + super().__init__() + if mid_size is None: + mid_size = out_size + + self.frequency_embedding_size = frequency_embedding_size + + self.linear_1 = Linear( + in_dim=frequency_embedding_size, + out_dim=mid_size, + dtype=dtype, + device=device, + has_bias=True, + ) + self.linear_2 = Linear( + in_dim=mid_size, + out_dim=out_size, + dtype=dtype, + device=device, + has_bias=True, + ) + + @staticmethod + def timestep_embedding( + t: TensorValue, + dim: int, + max_period: float = 10000.0, + ) -> TensorValue: + """Create sinusoidal timestep embeddings.""" + half = dim // 2 + freqs = ops.range(0, half, dtype=DType.float32, device=t.device) + freqs = ops.exp((-math.log(max_period) * freqs) / float(half)) + + args = ops.cast(t, DType.float32)[:, None] * freqs[None, :] + embedding = ops.concat([ops.cos(args), ops.sin(args)], axis=-1) + + if dim % 2: + zero = ops.reshape( + ops.constant(0.0, DType.float32, device=t.device), + (1, 1), + ) + zeros_col = ops.broadcast_to(zero, (embedding.shape[0], 1)) + embedding = ops.concat([embedding, zeros_col], axis=-1) + + return embedding + + def __call__(self, t: TensorValue) -> TensorValue: + """Embed timesteps.""" + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_freq = ops.cast(t_freq, self.linear_1.weight.dtype) + t_emb = self.linear_2(ops.silu(self.linear_1(t_freq))) + return t_emb + + +class RopeEmbedder(Module): + def __init__( + self, + theta: float = 256.0, + axes_dims: tuple[int, ...] = (32, 48, 48), + ) -> None: + """Initialize RopeEmbedder.""" + super().__init__() + self.theta = theta + self.axes_dims = axes_dims + + def __call__(self, ids: TensorValue) -> TensorValue: + """Compute interleaved [cos, sin] rotary position embeddings.""" + pos = ops.cast(ids, DType.float32) + parts = [] + for i in range(len(self.axes_dims)): + dim = self.axes_dims[i] + half = dim // 2 + freq_exp = ( + ops.range( + 0, + half, + dtype=DType.float32, + device=pos.device, + ) + / half + ) + freq = 1.0 / (self.theta**freq_exp) + freqs = ops.outer(pos[:, i], freq) + paired = ops.stack([ops.cos(freqs), ops.sin(freqs)], axis=2) + parts.append(ops.reshape(paired, [freqs.shape[0], dim])) + + return ops.concat(parts, axis=-1) diff --git a/max/python/max/pipelines/architectures/z_image/model.py b/max/python/max/pipelines/architectures/z_image/model.py new file mode 100644 index 00000000000..b94e7f20621 --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/model.py @@ -0,0 +1,93 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from collections.abc import Callable +from typing import Any + +from max.driver import Buffer, Device +from max.engine import InferenceSession, Model +from max.graph import Graph +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.component_model import ComponentModel + +from .model_config import ZImageConfig +from .weight_adapters import convert_z_image_transformer_state_dict +from .z_image import ZImageTransformer2DModel + + +class ZImageTransformerModel(ComponentModel): + def __init__( + self, + config: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + session: InferenceSession, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.session = session + self.config = ZImageConfig.initialize_from_config( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Callable[..., Any]: + target_dtype = self.config.dtype + raw_state_dict = {} + for key, value in self.weights.items(): + weight = value.data() + if hasattr(weight, "dtype") and hasattr(weight, "astype"): + if weight.dtype != target_dtype: + if weight.dtype.is_float() and target_dtype.is_float(): + weight = weight.astype(target_dtype) + raw_state_dict[key] = weight + state_dict = convert_z_image_transformer_state_dict(raw_state_dict) + + nn_model = ZImageTransformer2DModel(self.config) + nn_model.load_state_dict(state_dict, weight_alignment=1, strict=True) + self.state_dict = nn_model.state_dict() + + with Graph( + "z_image_transformer", + input_types=nn_model.input_types(), + ) as graph: + outputs = nn_model(*(value.tensor for value in graph.inputs)) + if isinstance(outputs, tuple): + graph.output(*outputs) + else: + graph.output(outputs) + + self.model: Model = self.session.load( + graph, + weights_registry=self.state_dict, + ) + return self.model.execute + + def __call__( + self, + hidden_states: Buffer, + encoder_hidden_states: Buffer, + timestep: Buffer, + img_ids: Buffer, + txt_ids: Buffer, + ) -> list[Buffer]: + return self.model.execute( + hidden_states, + encoder_hidden_states, + timestep, + img_ids, + txt_ids, + ) diff --git a/max/python/max/pipelines/architectures/z_image/model_config.py b/max/python/max/pipelines/architectures/z_image/model_config.py new file mode 100644 index 00000000000..dce77123e3c --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/model_config.py @@ -0,0 +1,78 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import Any + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from max.pipelines.lib.config.config_enums import supported_encoding_dtype +from pydantic import Field +from typing_extensions import Self + + +class ZImageConfig(MAXModelConfigBase): + all_patch_size: tuple[int, ...] = (2,) + all_f_patch_size: tuple[int, ...] = (1,) + in_channels: int = 16 + dim: int = 3840 + n_layers: int = 30 + n_refiner_layers: int = 2 + n_heads: int = 30 + n_kv_heads: int = 30 + norm_eps: float = 1e-5 + qk_norm: bool = True + cap_feat_dim: int = 2560 + rope_theta: float = 256.0 + t_scale: float = 1000.0 + axes_dims: tuple[int, ...] = (32, 48, 48) + axes_lens: tuple[int, ...] = (1024, 512, 512) + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + @classmethod + def initialize_from_config( + cls, + config_dict: dict[str, Any], + encoding: SupportedEncoding, + devices: list[Device], + ) -> Self: + init_dict = { + key: value + for key, value in config_dict.items() + if key in cls.model_fields + } + # Ignore omni-only fields in phase 1 (may appear in full checkpoints). + init_dict.pop("siglip_feat_dim", None) + + init_dict.update( + { + "dtype": supported_encoding_dtype(encoding), + "device": DeviceRef.from_device(devices[0]), + } + ) + return cls(**init_dict) + + def fbcache_dims(self) -> tuple[int, int]: + """(hidden_dim, output_dim) per image token for FBCache / Taylor tensors.""" + out_dim = ( + self.all_patch_size[0] + * self.all_patch_size[0] + * self.all_f_patch_size[0] + * self.in_channels + ) + return self.dim, out_dim + + +ZImageConfigBase = ZImageConfig diff --git a/max/python/max/pipelines/architectures/z_image/pipeline_z_image.py b/max/python/max/pipelines/architectures/z_image/pipeline_z_image.py new file mode 100644 index 00000000000..79aa8d5e87f --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/pipeline_z_image.py @@ -0,0 +1,1059 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""Z-Image diffusion pipeline (Graph API / ModuleV2).""" + +from __future__ import annotations + +import hashlib +from collections.abc import Callable +from dataclasses import MISSING, dataclass, field, fields +from typing import Any, cast + +import numpy as np +import numpy.typing as npt +from max.driver import CPU, Buffer, Device +from max.dtype import DType +from max.graph import TensorType, TensorValue, ops +from max.interfaces import TokenBuffer +from max.pipelines.core import PixelContext +from max.pipelines.lib.interfaces import ( + DiffusionPipeline, + DiffusionPipelineOutput, +) +from max.pipelines.lib.interfaces.diffusion_pipeline import max_compile +from max.pipelines.lib.utils import BoundedCache +from max.profiler import Tracer, traced + +from ..autoencoders import AutoencoderKLModel +from ..qwen3.text_encoder import Qwen3TextEncoderZImageModel +from .model import ZImageTransformerModel + +_DEVICE_TENSOR_FIELDS = frozenset( + { + "tokens_tensor", + "negative_tokens_tensor", + "txt_ids_tensor", + "img_ids_tensor", + "negative_txt_ids_tensor", + "negative_img_ids_tensor", + "input_image_tensor", + "latents_tensor", + "sigmas_tensor", + "h_carrier", + "w_carrier", + } +) + + +def _validate_z_image_context(context: PixelContext) -> None: + """Fail fast before device uploads.""" + if context.latents.size == 0: + raise ValueError( + "ZImagePipeline requires non-empty latents in PixelContext." + ) + for name in ("latent_image_ids", "sigmas", "timesteps"): + if not hasattr(context, name): + raise TypeError( + f"ZImagePipeline requires PixelContext with attribute {name!r}; " + f"{type(context).__name__} has no {name!r}." + ) + arr = getattr(context, name) + if not isinstance(arr, np.ndarray) or arr.size == 0: + raise ValueError( + f"ZImagePipeline requires non-empty {name} in PixelContext." + ) + + +@dataclass(kw_only=True) +class ZImageModelInputs: + """Z-Image execution inputs with device tensors and host metadata.""" + + tokens: TokenBuffer + tokens_2: TokenBuffer | None = None + negative_tokens: TokenBuffer | None = None + negative_tokens_2: TokenBuffer | None = None + timesteps: npt.NDArray[np.float32] = field( + default_factory=lambda: np.array([], dtype=np.float32) + ) + sigmas: npt.NDArray[np.float32] = field( + default_factory=lambda: np.array([], dtype=np.float32) + ) + latents: npt.NDArray[np.float32] = field( + default_factory=lambda: np.array([], dtype=np.float32) + ) + latent_image_ids: npt.NDArray[np.float32] = field( + default_factory=lambda: np.array([], dtype=np.float32) + ) + guidance: npt.NDArray[np.float32] | None = None + true_cfg_scale: float = 1.0 + num_warmup_steps: int = 0 + input_image: npt.NDArray[np.uint8] | None = None + strength: float = 0.6 + cfg_normalization: bool = False + cfg_truncation: float = 1.0 + width: int = 1024 + height: int = 1024 + guidance_scale: float = 5.0 + num_inference_steps: int = 50 + num_images_per_prompt: int = 1 + explicit_negative_prompt: bool = False + do_cfg: bool = False + tokens_tensor: Buffer + negative_tokens_tensor: Buffer | None = None + txt_ids_tensor: Buffer + img_ids_tensor: Buffer + negative_txt_ids_tensor: Buffer | None = None + negative_img_ids_tensor: Buffer | None = None + input_image_tensor: Buffer | None = None + latents_tensor: Buffer + sigmas_tensor: Buffer + h_carrier: Buffer + w_carrier: Buffer + + @classmethod + def kwargs_from_context(cls, context: PixelContext) -> dict[str, Any]: + """Build kwargs for all fields except device tensors.""" + kwargs: dict[str, Any] = {} + for dataclass_field in fields(cls): + name = dataclass_field.name + if name in _DEVICE_TENSOR_FIELDS: + continue + if not hasattr(context, name): + continue + value = getattr(context, name) + if value is None: + if dataclass_field.default is not MISSING: + kwargs[name] = dataclass_field.default + elif dataclass_field.default_factory is not MISSING: + kwargs[name] = dataclass_field.default_factory() + else: + kwargs[name] = None + else: + kwargs[name] = value + return kwargs + + def __post_init__(self) -> None: + if not isinstance(self.height, int) or self.height <= 0: + raise ValueError( + f"height must be a positive int. Got {self.height!r}" + ) + if not isinstance(self.width, int) or self.width <= 0: + raise ValueError( + f"width must be a positive int. Got {self.width!r}" + ) + if ( + not isinstance(self.num_inference_steps, int) + or self.num_inference_steps <= 0 + ): + raise ValueError( + "num_inference_steps must be a positive int. " + f"Got {self.num_inference_steps!r}" + ) + if ( + not isinstance(self.num_images_per_prompt, int) + or self.num_images_per_prompt <= 0 + ): + raise ValueError( + "num_images_per_prompt must be > 0. " + f"Got {self.num_images_per_prompt!r}" + ) + if self.sigmas.size == 0: + raise ValueError( + "ZImagePipeline requires non-empty sigmas in context." + ) + if self.latent_image_ids.size == 0: + raise ValueError( + "ZImagePipeline requires non-empty latent image ids in context." + ) + + +class ZImagePipeline(DiffusionPipeline): + """Diffusion pipeline for Z-Image generation (Graph API).""" + + unprefixed_weight_component = "transformer" + default_num_inference_steps = 50 + default_residual_threshold = 0.06 + + vae: AutoencoderKLModel + text_encoder: Qwen3TextEncoderZImageModel + transformer: ZImageTransformerModel + + components = { + "vae": AutoencoderKLModel, + "text_encoder": Qwen3TextEncoderZImageModel, + "transformer": ZImageTransformerModel, + } + + @traced(message="ZImagePipeline.init_remaining_components") + def init_remaining_components(self) -> None: + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) + if getattr(self, "vae", None) + else 8 + ) + + self.build_preprocess_latents() + self.build_prepare_scheduler() + self.build_scheduler_step() + self.build_decode_latents() + self.build_cfg_combine() + self.build_duplicate_batch() + self.build_cfg_finalize_batched() + self.build_cfg_renormalization() + self.build_postprocess_image() + self.build_pad_seq() + self.build_truncate_seq() + self.build_concat_batch() + + self._cached_text_ids: BoundedCache[str, Buffer] = BoundedCache(32) + self._cached_sigmas: BoundedCache[str, Buffer] = BoundedCache(32) + self._cached_img_ids: BoundedCache[str, Buffer] = BoundedCache(32) + self._cached_img_ids_base_np: BoundedCache[str, np.ndarray] = ( + BoundedCache(32) + ) + self._cached_shape_carriers: BoundedCache[int, Buffer] = BoundedCache( + 32 + ) + self._cached_prompt_token_tensors: BoundedCache[str, Buffer] = ( + BoundedCache(32) + ) + self._cached_prompt_padding: BoundedCache[str, Buffer] = BoundedCache( + 32 + ) + self._cached_guidance: BoundedCache[str, Buffer] = BoundedCache(32) + + @traced(message="ZImagePipeline.build_preprocess_latents") + def build_preprocess_latents(self) -> None: + device = self.transformer.devices[0] + target_dtype = self.transformer.config.dtype + + def _graph(latents: TensorValue) -> TensorValue: + batch = latents.shape[0] + c = latents.shape[1] + h = latents.shape[2] + w = latents.shape[3] + latents = ops.rebind( + latents, [batch, c, (h // 2) * 2, (w // 2) * 2] + ) + latents = ops.reshape(latents, (batch, c, h // 2, 2, w // 2, 2)) + latents = ops.permute(latents, [0, 2, 4, 3, 5, 1]) + latents = ops.reshape(latents, (batch, (h // 2) * (w // 2), c * 4)) + return ops.cast(latents, target_dtype) + + self._patchify_and_pack = cast( + Callable[[Buffer], Buffer], + max_compile( + _graph, + input_types=[ + TensorType( + DType.float32, + shape=["batch", "channels", "height", "width"], + device=device, + ), + ], + ), + ) + + @traced(message="ZImagePipeline.build_prepare_scheduler") + def build_prepare_scheduler(self) -> None: + device = self.transformer.devices[0] + + def _graph( + timesteps: TensorValue, sigmas: TensorValue + ) -> tuple[TensorValue, TensorValue]: + all_timesteps = 1.0 - timesteps + sigmas_curr = ops.slice_tensor(sigmas, [slice(0, -1)]) + sigmas_next = ops.slice_tensor(sigmas, [slice(1, None)]) + all_dt = sigmas_next - sigmas_curr + return all_timesteps, all_dt + + self._prepare_scheduler = cast( + Callable[[Buffer, Buffer], tuple[Buffer, Buffer]], + max_compile( + _graph, + input_types=[ + TensorType( + DType.float32, + shape=["num_timesteps"], + device=device, + ), + TensorType( + DType.float32, + shape=["num_sigmas"], + device=device, + ), + ], + ), + ) + + @traced(message="ZImagePipeline.build_scheduler_step") + def build_scheduler_step(self) -> None: + dtype = self.transformer.config.dtype + device = self.transformer.devices[0] + + def _graph( + latents: TensorValue, noise_pred: TensorValue, dt: TensorValue + ) -> TensorValue: + latents_dtype = latents.dtype + latents = ops.cast(latents, DType.float32) + latents = latents - dt * noise_pred + return ops.cast(latents, latents_dtype) + + self._scheduler_step = cast( + Callable[[Buffer, Buffer, Buffer], Buffer], + max_compile( + _graph, + input_types=[ + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType(DType.float32, shape=[1], device=device), + ], + ), + ) + + @traced(message="ZImagePipeline.build_decode_latents") + def build_decode_latents(self) -> None: + device = self.transformer.devices[0] + dtype = self.transformer.config.dtype + scaling = float(self.vae.config.scaling_factor) + shift = float(self.vae.config.shift_factor or 0.0) + + def _unpack( + latents: TensorValue, + h_carrier: TensorValue, + w_carrier: TensorValue, + ) -> TensorValue: + batch = latents.shape[0] + ch = latents.shape[2] + half_h = h_carrier.shape[0] + half_w = w_carrier.shape[0] + latents = ops.rebind(latents, [batch, half_h * half_w, ch]) + latents = ops.reshape(latents, (batch, half_h, half_w, ch)) + latents = ops.rebind( + latents, [batch, half_h, half_w, (ch // 4) * 4] + ) + latents = ops.reshape( + latents, (batch, half_h, half_w, 2, 2, ch // 4) + ) + latents = ops.permute(latents, [0, 5, 1, 3, 2, 4]) + latents = ops.reshape( + latents, (batch, ch // 4, half_h * 2, half_w * 2) + ) + return (latents / scaling) + shift + + self._unpack_and_postprocess = cast( + Callable[[Buffer, Buffer, Buffer], Buffer], + max_compile( + _unpack, + input_types=[ + TensorType( + dtype, + shape=["batch", "seq", "channels"], + device=device, + ), + TensorType(DType.float32, shape=["half_h"], device=device), + TensorType(DType.float32, shape=["half_w"], device=device), + ], + ), + ) + + @traced(message="ZImagePipeline.build_postprocess_image") + def build_postprocess_image(self) -> None: + device = self.transformer.devices[0] + dtype = self.transformer.config.dtype + + def _graph(image: TensorValue) -> TensorValue: + image = ops.cast(image, DType.float32) + image = image * 0.5 + 0.5 + image = ops.where(image < 0.0, 0.0, image) + image = ops.where(image > 1.0, 1.0, image) + image = ops.permute(image, [0, 2, 3, 1]) + image = image * 255.0 + return ops.cast(image, DType.uint8) + + self._postprocess_image = cast( + Callable[[Buffer], Buffer], + max_compile( + _graph, + input_types=[ + TensorType( + dtype, + shape=["batch", "channels", "height", "width"], + device=device, + ), + ], + ), + ) + + @traced(message="ZImagePipeline.build_cfg_combine") + def build_cfg_combine(self) -> None: + dtype = self.transformer.config.dtype + device = self.transformer.devices[0] + + def _graph( + pos: TensorValue, neg: TensorValue, scale: TensorValue + ) -> TensorValue: + result = pos + scale * (pos - neg) + return ops.cast(result, pos.dtype) + + self._cfg_combine = cast( + Callable[[Buffer, Buffer, Buffer], Buffer], + max_compile( + _graph, + input_types=[ + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType(DType.float32, shape=[], device=device), + ], + ), + ) + + @traced(message="ZImagePipeline.build_cfg_renormalization") + def build_cfg_renormalization(self) -> None: + dtype = self.transformer.config.dtype + device = self.transformer.devices[0] + + def _graph(pos: TensorValue, pred: TensorValue) -> TensorValue: + ori_norm = ops.sqrt( + ops.sum(ops.sum(pos * pos, axis=2), axis=1) + 1e-12 + ) + new_norm = ops.sqrt( + ops.sum(ops.sum(pred * pred, axis=2), axis=1) + 1e-12 + ) + while ori_norm.rank > 1: + ori_norm = ops.squeeze(ori_norm, -1) + while new_norm.rank > 1: + new_norm = ops.squeeze(new_norm, -1) + safe_new = ops.where(new_norm > 1e-12, new_norm, 1e-12) + ratio = ori_norm / safe_new + ratio = ops.where(new_norm > ori_norm, ratio, 1.0) + ratio = ops.unsqueeze(ops.unsqueeze(ratio, 1), 2) + return pred * ratio + + self._cfg_renormalization = cast( + Callable[[Buffer, Buffer], Buffer], + max_compile( + _graph, + input_types=[ + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + TensorType( + dtype, shape=["batch", "seq", "channels"], device=device + ), + ], + ), + ) + + @traced(message="ZImagePipeline.build_duplicate_batch") + def build_duplicate_batch(self) -> None: + dtype = self.transformer.config.dtype + device = self.transformer.devices[0] + + def _graph(x: TensorValue) -> TensorValue: + batch = x.shape[0] + seq = x.shape[1] + ch = x.shape[2] + x = ops.unsqueeze(x, 0) + x = ops.broadcast_to(x, [2, batch, seq, ch]) + return ops.reshape(x, [batch * 2, seq, ch]) + + self._duplicate_batch = cast( + Callable[[Buffer], Buffer], + max_compile( + _graph, + input_types=[ + TensorType( + dtype, + shape=["batch", "seq", "channels"], + device=device, + ), + ], + ), + ) + + @traced(message="ZImagePipeline.build_cfg_finalize_batched") + def build_cfg_finalize_batched(self) -> None: + dtype = self.transformer.config.dtype + device = self.transformer.devices[0] + input_types = [ + TensorType( + dtype, + shape=["double_batch", "seq", "channels"], + device=device, + ), + TensorType(DType.float32, shape=[], device=device), + ] + + def _no_norm( + pred_cfg: TensorValue, scale: TensorValue + ) -> tuple[TensorValue, TensorValue]: + batch2 = pred_cfg.shape[0] + batch = batch2 // 2 + seq = pred_cfg.shape[1] + ch = pred_cfg.shape[2] + pos = ops.rebind(pred_cfg[:batch], [batch, seq, ch]) + neg = ops.rebind(pred_cfg[batch:], [batch, seq, ch]) + result = ops.cast(pos + scale * (pos - neg), pos.dtype) + return pos, result + + def _with_norm( + pred_cfg: TensorValue, scale: TensorValue + ) -> tuple[TensorValue, TensorValue]: + pos, result = _no_norm(pred_cfg, scale) + ori = ops.sqrt(ops.sum(ops.sum(pos * pos, axis=2), axis=1) + 1e-12) + new = ops.sqrt( + ops.sum(ops.sum(result * result, axis=2), axis=1) + 1e-12 + ) + while ori.rank > 1: + ori = ops.squeeze(ori, -1) + while new.rank > 1: + new = ops.squeeze(new, -1) + safe = ops.where(new > 1e-12, new, 1e-12) + ratio = ori / safe + ratio = ops.where(new > ori, ratio, 1.0) + ratio = ops.unsqueeze(ops.unsqueeze(ratio, 1), 2) + return pos, result * ratio + + self._cfg_finalize_no_norm = cast( + Callable[[Buffer, Buffer], tuple[Buffer, Buffer]], + max_compile(_no_norm, input_types=input_types), + ) + self._cfg_finalize_with_norm = cast( + Callable[[Buffer, Buffer], tuple[Buffer, Buffer]], + max_compile(_with_norm, input_types=input_types), + ) + + @traced(message="ZImagePipeline.build_pad_seq") + def build_pad_seq(self) -> None: + dtype = self.transformer.config.dtype + device = self.transformer.devices[0] + + def _graph(embeds: TensorValue, pad: TensorValue) -> TensorValue: + return ops.concat([embeds, pad], axis=1) + + self._pad_seq = cast( + Callable[[Buffer, Buffer], Buffer], + max_compile( + _graph, + input_types=[ + TensorType( + dtype, + shape=["batch", "seq_a", "hidden"], + device=device, + ), + TensorType( + dtype, + shape=["batch", "seq_b", "hidden"], + device=device, + ), + ], + ), + ) + + @traced(message="ZImagePipeline.build_truncate_seq") + def build_truncate_seq(self) -> None: + dtype = self.transformer.config.dtype + device = self.transformer.devices[0] + + def _graph( + embeds: TensorValue, target_carrier: TensorValue + ) -> TensorValue: + target_len = target_carrier.shape[0] + return embeds[:, :target_len, :] + + self._truncate_seq = cast( + Callable[[Buffer, Buffer], Buffer], + max_compile( + _graph, + input_types=[ + TensorType( + dtype, + shape=["batch", "seq", "hidden"], + device=device, + ), + TensorType( + dtype, + shape=["target_len"], + device=device, + ), + ], + ), + ) + + @traced(message="ZImagePipeline.build_concat_batch") + def build_concat_batch(self) -> None: + dtype = self.transformer.config.dtype + device = self.transformer.devices[0] + + def _graph(a: TensorValue, b: TensorValue) -> TensorValue: + return ops.concat([a, b], axis=0) + + self._concat_batch = cast( + Callable[[Buffer, Buffer], Buffer], + max_compile( + _graph, + input_types=[ + TensorType( + dtype, + shape=["batch_a", "seq", "hidden"], + device=device, + ), + TensorType( + dtype, + shape=["batch_b", "seq", "hidden"], + device=device, + ), + ], + ), + ) + + def _align_prompt_embeds( + self, + neg_embeds: Buffer, + pos_embeds: Buffer, + device: Device, + ) -> Buffer: + pos_len = int(pos_embeds.shape[1]) + neg_len = int(neg_embeds.shape[1]) + hidden = int(pos_embeds.shape[2]) + + if neg_len == pos_len: + return neg_embeds + if neg_len > pos_len: + carrier = Buffer.from_dlpack( + np.empty(pos_len, dtype=np.float32) + ).to(device) + return self._truncate_seq(neg_embeds, carrier) + pad_len = pos_len - neg_len + pad = Buffer.zeros( + (1, pad_len, hidden), pos_embeds.dtype, device=device + ) + return self._pad_seq(neg_embeds, pad) + + @traced(message="ZImagePipeline.prepare_inputs") + def prepare_inputs(self, context: PixelContext) -> ZImageModelInputs: # type: ignore[override] + _validate_z_image_context(context) + kwargs = ZImageModelInputs.kwargs_from_context(context) + device = self.transformer.devices[0] + text_device = self.text_encoder.devices[0] + + kwargs["latents"] = np.asarray(context.latents) + kwargs["sigmas"] = np.asarray(context.sigmas) + kwargs["latent_image_ids"] = np.asarray(context.latent_image_ids) + + latents_np = np.ascontiguousarray(kwargs["latents"]) + latent_h = int(latents_np.shape[-2]) + latent_w = int(latents_np.shape[-1]) + packed_h = latent_h // 2 + packed_w = latent_w // 2 + image_seq_len = int(np.asarray(context.latent_image_ids).shape[-2]) + + tokens_np = self._select_tokens_for_text_encoder( + context.tokens.array, context.mask + ) + tokens_buf = self._cache_token_buffer(tokens_np, text_device) + txt_ids_buf, img_ids_buf = self._prepare_conditioning_ids( + text_seq_len=int(tokens_np.shape[0]), + image_seq_len=image_seq_len, + latent_image_ids=np.asarray(context.latent_image_ids), + height=int(context.height), + width=int(context.width), + device=device, + ) + + neg_tokens_buf: Buffer | None = None + neg_txt_ids_buf: Buffer | None = None + neg_img_ids_buf: Buffer | None = None + if context.negative_tokens is not None: + neg_np = self._select_tokens_for_text_encoder( + context.negative_tokens.array, context.negative_mask + ) + neg_tokens_buf = self._cache_token_buffer(neg_np, text_device) + if context.explicit_negative_prompt: + neg_txt_ids_buf, neg_img_ids_buf = ( + self._prepare_conditioning_ids( + text_seq_len=int(neg_np.shape[0]), + image_seq_len=image_seq_len, + latent_image_ids=np.asarray(context.latent_image_ids), + height=int(context.height), + width=int(context.width), + device=device, + ) + ) + do_cfg = ( + float(context.guidance_scale) > 0.0 and neg_tokens_buf is not None + ) + + input_image_buf: Buffer | None = None + if context.input_image is not None: + input_image_buf = self._numpy_image_to_buffer( + image=np.ascontiguousarray( + context.input_image.astype(np.uint8, copy=False) + ), + batch_size=int(context.num_images_per_prompt), + ) + + latents_buf = Buffer.from_dlpack(latents_np).to(device) + + for n in (packed_h, packed_w): + if n not in self._cached_shape_carriers: + self._cached_shape_carriers[n] = Buffer.from_dlpack( + np.ascontiguousarray(np.empty(n, dtype=np.float32)) + ).to(device) + + num_steps = int(context.num_inference_steps) + sigmas_key = f"sigmas::{num_steps}::{latent_h}x{latent_w}" + if sigmas_key in self._cached_sigmas: + sigmas_buf = self._cached_sigmas[sigmas_key] + else: + sigmas_buf = Buffer.from_dlpack( + np.ascontiguousarray(context.sigmas) + ).to(device) + self._cached_sigmas[sigmas_key] = sigmas_buf + + return ZImageModelInputs( + **kwargs, + do_cfg=do_cfg, + tokens_tensor=tokens_buf, + negative_tokens_tensor=neg_tokens_buf, + txt_ids_tensor=txt_ids_buf, + img_ids_tensor=img_ids_buf, + negative_txt_ids_tensor=neg_txt_ids_buf, + negative_img_ids_tensor=neg_img_ids_buf, + input_image_tensor=input_image_buf, + latents_tensor=latents_buf, + sigmas_tensor=sigmas_buf, + h_carrier=self._cached_shape_carriers[packed_h], + w_carrier=self._cached_shape_carriers[packed_w], + ) + + @staticmethod + def _select_tokens_for_text_encoder( + tokens: np.ndarray, + mask: np.ndarray | None, + ) -> np.ndarray: + if tokens.ndim == 2: + tokens = tokens[0] + if mask is not None: + if mask.ndim == 2: + mask = mask[0] + selected = mask.astype(np.bool_, copy=False) + if not np.any(selected): + raise ValueError("ZImage mask cannot exclude all tokens.") + if not np.all(selected): + tokens = tokens[selected] + return np.ascontiguousarray(tokens.astype(np.int64, copy=False)) + + def _cache_token_buffer(self, tokens: np.ndarray, device: Device) -> Buffer: + digest = hashlib.sha1(tokens.tobytes()).hexdigest() + key = f"tokens::{tokens.shape[0]}::{digest}::{device}" + if key in self._cached_prompt_token_tensors: + return self._cached_prompt_token_tensors[key] + buf = Buffer.from_dlpack(tokens).to(device) + self._cached_prompt_token_tensors[key] = buf + return buf + + def _prepare_conditioning_ids( + self, + text_seq_len: int, + image_seq_len: int, + latent_image_ids: np.ndarray, + height: int, + width: int, + device: Device, + ) -> tuple[Buffer, Buffer]: + text_seq_len_padded = text_seq_len + (-text_seq_len % 32) + + img_base_key = f"img_ids_base::{image_seq_len}_{height}x{width}" + if img_base_key in self._cached_img_ids_base_np: + img_ids_base = self._cached_img_ids_base_np[img_base_key] + else: + img_ids_base = np.asarray(latent_image_ids, dtype=np.int64) + if img_ids_base.ndim == 3: + img_ids_base = img_ids_base[0] + img_ids_base = np.ascontiguousarray(img_ids_base) + self._cached_img_ids_base_np[img_base_key] = img_ids_base + + img_key = ( + f"img_ids::{text_seq_len_padded}_{image_seq_len}_{height}x{width}" + ) + if img_key in self._cached_img_ids: + img_buf = self._cached_img_ids[img_key] + else: + img_np = img_ids_base.copy() + img_np[:, 0] = img_np[:, 0] + text_seq_len_padded + 1 + img_buf = Buffer.from_dlpack(np.ascontiguousarray(img_np)).to( + device + ) + self._cached_img_ids[img_key] = img_buf + + txt_key = f"text_ids::{text_seq_len}" + if txt_key in self._cached_text_ids: + txt_buf = self._cached_text_ids[txt_key] + else: + txt_ids = np.zeros((text_seq_len, 3), dtype=np.int64) + txt_ids[:, 0] = np.arange(1, text_seq_len + 1, dtype=np.int64) + txt_buf = Buffer.from_dlpack(np.ascontiguousarray(txt_ids)).to( + device + ) + self._cached_text_ids[txt_key] = txt_buf + + return txt_buf, img_buf + + def _numpy_image_to_buffer( + self, + image: npt.NDArray[np.uint8], + batch_size: int, + ) -> Buffer: + if image.ndim != 3 or image.shape[2] != 3: + raise ValueError( + f"Expected input image shape [H, W, 3], got {image.shape}." + ) + img_array = (image.astype(np.float32) / 127.5) - 1.0 + img_array = np.transpose(img_array, (2, 0, 1)) + img_array = np.expand_dims(img_array, axis=0) + if batch_size > 1: + img_array = np.tile(img_array, (batch_size, 1, 1, 1)) + img_array = np.ascontiguousarray(img_array) + return Buffer.from_dlpack(img_array).to(self.vae.devices[0]) + + def _get_cached_guidance( + self, + guidance_scale: float, + device: Device, + ) -> Buffer: + key = f"{guidance_scale:.8f}::{device}" + if key in self._cached_guidance: + return self._cached_guidance[key] + buf = Buffer.from_dlpack(np.array(guidance_scale, dtype=np.float32)).to( + device + ) + self._cached_guidance[key] = buf + return buf + + @traced(message="ZImagePipeline.prepare_prompt_embeddings") + def prepare_prompt_embeddings( + self, + tokens: Buffer, + num_images_per_prompt: int, + ) -> Buffer: + del num_images_per_prompt + return cast(Buffer, self.text_encoder(tokens)) + + @traced(message="ZImagePipeline.decode_latents") + def decode_latents( + self, + latents: Buffer, + h_carrier: Buffer, + w_carrier: Buffer, + ) -> npt.NDArray[np.uint8]: + latents = self._unpack_and_postprocess(latents, h_carrier, w_carrier) + decoded = cast(Buffer, self.vae.decode(cast(Any, latents))) + image = self._postprocess_image(decoded) + result: np.ndarray + if hasattr(image, "to"): + image = image.to(CPU()) + if hasattr(image, "__dlpack__"): + result = np.from_dlpack(image) + elif hasattr(image, "to_numpy"): + result = image.to_numpy() + else: + result = image.to_numpy() + if result.dtype != np.uint8: + result = result.astype(np.uint8, copy=False) + return cast(npt.NDArray[np.uint8], result) + + @traced(message="ZImagePipeline.preprocess_latents") + def preprocess_latents(self, latents: Buffer) -> Buffer: + return self._patchify_and_pack(latents) + + @traced(message="ZImagePipeline.execute") + def execute( # type: ignore[override] + self, + model_inputs: ZImageModelInputs, + ) -> DiffusionPipelineOutput: + with Tracer("prepare_prompt_embeddings"): + prompt_embeds = self.prepare_prompt_embeddings( + tokens=model_inputs.tokens_tensor, + num_images_per_prompt=model_inputs.num_images_per_prompt, + ) + + negative_prompt_embeds: Buffer | None = None + if ( + model_inputs.do_cfg + and model_inputs.negative_tokens_tensor is not None + ): + negative_prompt_embeds = self.prepare_prompt_embeddings( + tokens=model_inputs.negative_tokens_tensor, + num_images_per_prompt=model_inputs.num_images_per_prompt, + ) + + latents = model_inputs.latents_tensor + sigmas = model_inputs.sigmas_tensor + h_carrier = model_inputs.h_carrier + w_carrier = model_inputs.w_carrier + + timesteps: np.ndarray = model_inputs.timesteps + num_timesteps = timesteps.shape[0] + if num_timesteps < 1: + raise ValueError("No timesteps were provided for denoising.") + + device = self.transformer.devices[0] + img_ids = model_inputs.img_ids_tensor + txt_ids = model_inputs.txt_ids_tensor + latents = self.preprocess_latents(latents) + + with Tracer("prepare_scheduler"): + timesteps_np = np.ascontiguousarray( + model_inputs.timesteps.astype(np.float32, copy=False) + ) + timesteps_buf = Buffer.from_dlpack(timesteps_np).to(device) + all_timesteps, all_dt = self._prepare_scheduler( + timesteps_buf, sigmas + ) + + timesteps_seq: Any = all_timesteps + dts_seq: Any = all_dt + if hasattr(timesteps_seq, "driver_tensor"): + timesteps_seq = timesteps_seq.driver_tensor + if hasattr(dts_seq, "driver_tensor"): + dts_seq = dts_seq.driver_tensor + + cfg_cutoff_step = 0 + if model_inputs.do_cfg: + transformed_host = (1.0 - timesteps_np).astype(np.float32) + if model_inputs.cfg_truncation > 1.0: + cfg_cutoff_step = num_timesteps + else: + mask = transformed_host <= model_inputs.cfg_truncation + cfg_cutoff_step = int(np.count_nonzero(mask)) + + guidance_buf: Buffer | None = None + if model_inputs.do_cfg: + guidance_buf = self._get_cached_guidance( + model_inputs.guidance_scale, device + ) + + use_batched_cfg = bool( + model_inputs.do_cfg and not model_inputs.explicit_negative_prompt + ) + cfg_prompt_embeds: Buffer | None = None + neg_img_ids = img_ids + neg_txt_ids = txt_ids + + if model_inputs.do_cfg and negative_prompt_embeds is not None: + if model_inputs.explicit_negative_prompt: + assert model_inputs.negative_img_ids_tensor is not None + assert model_inputs.negative_txt_ids_tensor is not None + neg_img_ids = model_inputs.negative_img_ids_tensor + neg_txt_ids = model_inputs.negative_txt_ids_tensor + else: + neg_aligned = self._align_prompt_embeds( + negative_prompt_embeds, prompt_embeds, device + ) + cfg_prompt_embeds = self._concat_batch( + prompt_embeds, neg_aligned + ) + + cfg_timestep_bufs: list[Buffer] | None = None + if use_batched_cfg: + transformed = (1.0 - timesteps_np).astype(np.float32) + batch_size = int(prompt_embeds.shape[0]) + cfg_timestep_bufs = [ + Buffer.from_dlpack( + np.full((2 * batch_size,), float(t), dtype=np.float32) + ).to(device) + for t in transformed + ] + + with Tracer("denoising_loop"): + for i in range(num_timesteps): + apply_cfg = i < cfg_cutoff_step + timestep = timesteps_seq[i : i + 1] + dt = dts_seq[i : i + 1] + + with Tracer(f"denoising_step_{i}"): + if apply_cfg and use_batched_cfg: + assert cfg_prompt_embeds is not None + assert cfg_timestep_bufs is not None + with Tracer("transformer"): + latents_cfg = self._duplicate_batch(latents) + noise_pred_cfg = self.transformer( + latents_cfg, + cfg_prompt_embeds, + cfg_timestep_bufs[i], + img_ids, + txt_ids, + )[0] + assert guidance_buf is not None + if model_inputs.cfg_normalization: + _, noise_pred = self._cfg_finalize_with_norm( + noise_pred_cfg, guidance_buf + ) + else: + _, noise_pred = self._cfg_finalize_no_norm( + noise_pred_cfg, guidance_buf + ) + elif apply_cfg: + with Tracer("transformer"): + noise_pred = self.transformer( + latents, + prompt_embeds, + timestep, + img_ids, + txt_ids, + )[0] + assert negative_prompt_embeds is not None + with Tracer("cfg_transformer"): + neg_noise_pred = self.transformer( + latents, + negative_prompt_embeds, + timestep, + neg_img_ids, + neg_txt_ids, + )[0] + assert guidance_buf is not None + noise_pred = self._cfg_combine( + noise_pred, neg_noise_pred, guidance_buf + ) + if model_inputs.cfg_normalization: + noise_pred = self._cfg_renormalization( + noise_pred, + noise_pred, + ) + else: + with Tracer("transformer"): + noise_pred = self.transformer( + latents, + prompt_embeds, + timestep, + img_ids, + txt_ids, + )[0] + + with Tracer("scheduler_step"): + latents = self._scheduler_step(latents, noise_pred, dt) + + with Tracer("decode_outputs"): + images = self.decode_latents(latents, h_carrier, w_carrier) + + return DiffusionPipelineOutput(images=images) diff --git a/max/python/max/pipelines/architectures/z_image/weight_adapters.py b/max/python/max/pipelines/architectures/z_image/weight_adapters.py new file mode 100644 index 00000000000..73080b7aada --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/weight_adapters.py @@ -0,0 +1,94 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +from max.graph.weights import WeightData + + +def _replace_prefix(key: str, old: str, new: str) -> str: + if key.startswith(old): + return new + key[len(old) :] + return key + + +def convert_z_image_transformer_state_dict( + state_dict: dict[str, WeightData], +) -> dict[str, WeightData]: + converted: dict[str, WeightData] = {} + + dropped_prefixes = ( + "x_pad_token", + "cap_pad_token", + "siglip_", + ) + + for original_key, value in state_dict.items(): + key = original_key + + if key.startswith(dropped_prefixes): + continue + + key = _replace_prefix(key, "all_x_embedder.2-1.", "x_embedder.") + key = _replace_prefix(key, "all_final_layer.2-1.", "final_layer.") + key = _replace_prefix(key, "t_embedder.mlp.0.", "t_embedder.linear_1.") + key = _replace_prefix(key, "t_embedder.mlp.2.", "t_embedder.linear_2.") + key = _replace_prefix(key, "cap_embedder.0.", "cap_norm.") + key = _replace_prefix(key, "cap_embedder.1.", "cap_proj.") + key = key.replace("adaLN_modulation.0.", "adaLN_modulation.") + key = _replace_prefix( + key, + "final_layer.adaLN_modulation.1.", + "final_layer.adaLN_modulation.", + ) + + converted[key] = value + + required_prefixes = ( + "x_embedder.", + "t_embedder.", + "cap_norm.", + "cap_proj.", + "noise_refiner.0.", + "context_refiner.0.", + "layers.0.", + "final_layer.", + ) + for prefix in required_prefixes: + if not any(k.startswith(prefix) for k in converted): + raise ValueError( + f"Missing required z-image transformer weights with prefix '{prefix}'" + ) + + allowed_prefixes = ( + "x_embedder.", + "noise_refiner.", + "context_refiner.", + "t_embedder.", + "cap_norm.", + "cap_proj.", + "layers.", + "final_layer.", + ) + unexpected_keys = [ + k + for k in converted + if not any(k.startswith(p) for p in allowed_prefixes) + ] + if unexpected_keys: + sample = ", ".join(unexpected_keys[:8]) + raise ValueError( + f"Unexpected z-image transformer keys in phase-1 adapter: {sample}" + ) + + return converted diff --git a/max/python/max/pipelines/architectures/z_image/z_image.py b/max/python/max/pipelines/architectures/z_image/z_image.py new file mode 100644 index 00000000000..9a49ba1589c --- /dev/null +++ b/max/python/max/pipelines/architectures/z_image/z_image.py @@ -0,0 +1,435 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2026, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Z-Image DiT core model (Graph API / ModuleV2).""" + +from max.dtype import DType +from max.graph import DeviceRef, TensorType, TensorValue, Weight, ops +from max.nn.layer import LayerList, Module +from max.nn.linear import Linear +from max.nn.norm import RMSNorm + +from .layers.attention import ZImageAttention +from .layers.embeddings import RopeEmbedder, TimestepEmbedder +from .model_config import ZImageConfig + +ADALN_EMBED_DIM = 256 + + +class LayerNorm(Module): + """Layer normalisation with optional learned affine parameters.""" + + weight: Weight | None + bias: Weight | None + + def __init__( + self, + dim: int, + *, + dtype: DType, + device: DeviceRef, + eps: float = 1e-5, + elementwise_affine: bool = True, + use_bias: bool = True, + ) -> None: + super().__init__() + self.dim = dim + self.eps = eps + if elementwise_affine: + self.weight = Weight("weight", dtype, (dim,), device=device) + self.bias = ( + Weight("bias", dtype, (dim,), device=device) + if use_bias + else None + ) + else: + self.weight = None + self.bias = None + + def __call__(self, x: TensorValue) -> TensorValue: + if self.weight is None: + gamma = ops.broadcast_to( + ops.constant(1.0, dtype=x.dtype, device=x.device), + shape=(x.shape[-1],), + ) + else: + gamma = self.weight + + if self.bias is None: + beta = ops.broadcast_to( + ops.constant(0.0, dtype=x.dtype, device=x.device), + shape=(x.shape[-1],), + ) + else: + beta = self.bias + + return ops.layer_norm(x, gamma=gamma, beta=beta, epsilon=self.eps) + + +class FeedForward(Module): + """SwiGLU feed-forward network.""" + + def __init__( + self, + dim: int, + hidden_dim: int, + *, + dtype: DType, + device: DeviceRef, + ) -> None: + super().__init__() + self.w1 = Linear( + in_dim=dim, + out_dim=hidden_dim, + dtype=dtype, + device=device, + has_bias=False, + ) + self.w2 = Linear( + in_dim=hidden_dim, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=False, + ) + self.w3 = Linear( + in_dim=dim, + out_dim=hidden_dim, + dtype=dtype, + device=device, + has_bias=False, + ) + + def __call__(self, x: TensorValue) -> TensorValue: + return self.w2(ops.silu(self.w1(x)) * self.w3(x)) + + +class ZImageTransformerBlock(Module): + """Single transformer block with optional adaLN modulation.""" + + def __init__( + self, + dim: int, + n_heads: int, + n_kv_heads: int, + norm_eps: float, + qk_norm: bool, + *, + dtype: DType, + device: DeviceRef, + modulation: bool = True, + ) -> None: + super().__init__() + del n_kv_heads + + self.modulation = modulation + self.dim = dim + + self.attention = ZImageAttention( + dim=dim, + n_heads=n_heads, + qk_norm=qk_norm, + eps=norm_eps, + dtype=dtype, + device=device, + ) + self.feed_forward = FeedForward( + dim=dim, + hidden_dim=int(dim / 3 * 8), + dtype=dtype, + device=device, + ) + self.attention_norm1 = RMSNorm(dim, dtype=dtype, eps=norm_eps) + self.ffn_norm1 = RMSNorm(dim, dtype=dtype, eps=norm_eps) + self.attention_norm2 = RMSNorm(dim, dtype=dtype, eps=norm_eps) + self.ffn_norm2 = RMSNorm(dim, dtype=dtype, eps=norm_eps) + + self.adaLN_modulation = ( + Linear( + in_dim=min(dim, ADALN_EMBED_DIM), + out_dim=4 * dim, + dtype=dtype, + device=device, + has_bias=True, + ) + if modulation + else None + ) + + def __call__( + self, + x: TensorValue, + freqs_cis: TensorValue, + adaln_input: TensorValue | None = None, + ) -> TensorValue: + if self.modulation: + if adaln_input is None: + raise ValueError("adaln_input is required when modulation=True") + if self.adaLN_modulation is None: + raise ValueError("adaLN_modulation is not initialized") + + mod = ops.unsqueeze(self.adaLN_modulation(adaln_input), 1) + d = self.dim + scale_msa = 1.0 + mod[:, :, :d] + gate_msa = ops.tanh(mod[:, :, d : 2 * d]) + scale_mlp = 1.0 + mod[:, :, 2 * d : 3 * d] + gate_mlp = ops.tanh(mod[:, :, 3 * d :]) + + attn_out = self.attention( + self.attention_norm1(x) * scale_msa, + freqs_cis=freqs_cis, + ) + x = x + gate_msa * self.attention_norm2(attn_out) + + ffn_out = self.feed_forward(self.ffn_norm1(x) * scale_mlp) + x = x + gate_mlp * self.ffn_norm2(ffn_out) + else: + attn_out = self.attention( + self.attention_norm1(x), + freqs_cis=freqs_cis, + ) + x = x + self.attention_norm2(attn_out) + x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x))) + + return x + + +class FinalLayer(Module): + """Final projection layer with adaLN conditioning.""" + + def __init__( + self, + hidden_size: int, + out_channels: int, + *, + dtype: DType, + device: DeviceRef, + ) -> None: + super().__init__() + self.norm_final = LayerNorm( + hidden_size, + dtype=dtype, + device=device, + eps=1e-6, + elementwise_affine=False, + use_bias=False, + ) + self.linear = Linear( + in_dim=hidden_size, + out_dim=out_channels, + dtype=dtype, + device=device, + has_bias=True, + ) + self.adaLN_modulation = Linear( + in_dim=min(hidden_size, ADALN_EMBED_DIM), + out_dim=hidden_size, + dtype=dtype, + device=device, + has_bias=True, + ) + + def __call__(self, x: TensorValue, c: TensorValue) -> TensorValue: + scale = 1.0 + self.adaLN_modulation(ops.silu(c)) + x = self.norm_final(x) * ops.unsqueeze(scale, 1) + return self.linear(x) + + +class ZImageTransformer2DModel(Module): + """Z-Image diffusion transformer (DiT) model.""" + + def __init__(self, config: ZImageConfig) -> None: + super().__init__() + + dim = config.dim + n_heads = config.n_heads + norm_eps = config.norm_eps + qk_norm = config.qk_norm + cap_feat_dim = config.cap_feat_dim + n_layers = config.n_layers + n_refiner_layers = config.n_refiner_layers + axes_dims = config.axes_dims + rope_theta = config.rope_theta + dtype = config.dtype + device = config.device + + patch_size = config.all_patch_size[0] + f_patch_size = config.all_f_patch_size[0] + in_channels = ( + config.in_channels * patch_size * patch_size * f_patch_size + ) + out_channels = in_channels + + self.packed_channels = in_channels + self.max_dtype = dtype + self.max_device = device + self.cap_feat_dim = cap_feat_dim + self.t_scale = config.t_scale + self.axes_dims = axes_dims + + self.x_embedder = Linear( + in_dim=in_channels, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=True, + ) + self.final_layer = FinalLayer( + hidden_size=dim, + out_channels=out_channels, + dtype=dtype, + device=device, + ) + + self.noise_refiner = LayerList( + [ + ZImageTransformerBlock( + dim=dim, + n_heads=n_heads, + n_kv_heads=config.n_kv_heads, + norm_eps=norm_eps, + qk_norm=qk_norm, + dtype=dtype, + device=device, + modulation=True, + ) + for _ in range(n_refiner_layers) + ] + ) + self.context_refiner = LayerList( + [ + ZImageTransformerBlock( + dim=dim, + n_heads=n_heads, + n_kv_heads=config.n_kv_heads, + norm_eps=norm_eps, + qk_norm=qk_norm, + dtype=dtype, + device=device, + modulation=False, + ) + for _ in range(n_refiner_layers) + ] + ) + + self.t_embedder = TimestepEmbedder( + out_size=min(dim, ADALN_EMBED_DIM), + mid_size=1024, + dtype=dtype, + device=device, + ) + self.cap_norm = RMSNorm(cap_feat_dim, dtype=dtype, eps=norm_eps) + self.cap_proj = Linear( + in_dim=cap_feat_dim, + out_dim=dim, + dtype=dtype, + device=device, + has_bias=True, + ) + + self.layers = LayerList( + [ + ZImageTransformerBlock( + dim=dim, + n_heads=n_heads, + n_kv_heads=config.n_kv_heads, + norm_eps=norm_eps, + qk_norm=qk_norm, + dtype=dtype, + device=device, + modulation=True, + ) + for _ in range(n_layers) + ] + ) + + head_dim = dim // n_heads + if head_dim != sum(axes_dims): + raise ValueError( + f"head_dim ({head_dim}) must equal sum(axes_dims) ({sum(axes_dims)})" + ) + + self.rope_embedder = RopeEmbedder( + theta=rope_theta, + axes_dims=axes_dims, + ) + + def input_types(self) -> tuple[TensorType, ...]: + return ( + TensorType( + self.max_dtype, + shape=["batch_size", "image_seq_len", self.packed_channels], + device=self.max_device, + ), + TensorType( + self.max_dtype, + shape=["batch_size", "text_seq_len", self.cap_feat_dim], + device=self.max_device, + ), + TensorType( + DType.float32, + shape=["batch_size"], + device=self.max_device, + ), + TensorType( + DType.int64, + shape=["image_seq_len", len(self.axes_dims)], + device=self.max_device, + ), + TensorType( + DType.int64, + shape=["text_seq_len", len(self.axes_dims)], + device=self.max_device, + ), + ) + + def __call__( + self, + hidden_states: TensorValue, + encoder_hidden_states: TensorValue, + timestep: TensorValue, + img_ids: TensorValue, + txt_ids: TensorValue, + ) -> tuple[TensorValue]: + x = self.x_embedder(hidden_states) + t_emb = self.t_embedder(timestep * self.t_scale) + t_emb = ops.cast(t_emb, x.dtype) + + cap = self.cap_proj(self.cap_norm(encoder_hidden_states)) + + if img_ids.rank == 3: + img_ids = img_ids[0] + if txt_ids.rank == 3: + txt_ids = txt_ids[0] + + img_seq_len = img_ids.shape[0] + unified_ids = ops.concat([img_ids, txt_ids], axis=0) + unified_freqs = ops.cast(self.rope_embedder(unified_ids), x.dtype) + img_freqs = unified_freqs[:img_seq_len] + txt_freqs = unified_freqs[img_seq_len:] + + for block in self.noise_refiner: + x = block(x, freqs_cis=img_freqs, adaln_input=t_emb) + + for block in self.context_refiner: + cap = block(cap, freqs_cis=txt_freqs) + + img_len = x.shape[1] + x = ops.concat([x, cap], axis=1) + + for block in self.layers: + x = block(x, freqs_cis=unified_freqs, adaln_input=t_emb) + + x = x[:, :img_len, :] + x = self.final_layer(x, t_emb) + return (x,) diff --git a/max/python/max/pipelines/architectures/z_image_modulev3/__init__.py b/max/python/max/pipelines/architectures/z_image_modulev3/__init__.py index 34064d5cfd9..97c1b4dfecc 100644 --- a/max/python/max/pipelines/architectures/z_image_modulev3/__init__.py +++ b/max/python/max/pipelines/architectures/z_image_modulev3/__init__.py @@ -11,7 +11,7 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # -from .arch import z_image_arch +from .arch import z_image_modulev3_arch from .model import ZImageTransformerModel -__all__ = ["ZImageTransformerModel", "z_image_arch"] +__all__ = ["ZImageTransformerModel", "z_image_modulev3_arch"] diff --git a/max/python/max/pipelines/architectures/z_image_modulev3/arch.py b/max/python/max/pipelines/architectures/z_image_modulev3/arch.py index 9b5f578cb93..85ea0467fd6 100644 --- a/max/python/max/pipelines/architectures/z_image_modulev3/arch.py +++ b/max/python/max/pipelines/architectures/z_image_modulev3/arch.py @@ -45,8 +45,8 @@ def initialize( return cls(pipeline_config=pipeline_config) -z_image_arch = SupportedArchitecture( - name="ZImagePipeline", +z_image_modulev3_arch = SupportedArchitecture( + name="ZImagePipeline_ModuleV3", task=PipelineTask.PIXEL_GENERATION, default_encoding="bfloat16", supported_encodings={"bfloat16", "float32"}, diff --git a/max/python/max/pipelines/lib/registry.py b/max/python/max/pipelines/lib/registry.py index 849391f60a5..be396544345 100644 --- a/max/python/max/pipelines/lib/registry.py +++ b/max/python/max/pipelines/lib/registry.py @@ -717,7 +717,12 @@ def retrieve_factory( "revision": pipeline_config.model.huggingface_model_revision, "trust_remote_code": pipeline_config.model.trust_remote_code, } - if arch.name in ("Flux2Pipeline", "ZImagePipeline"): + if arch.name in ( + "Flux2Pipeline", + "Flux2Pipeline_ModuleV3", + "ZImagePipeline", + "ZImagePipeline_ModuleV3", + ): tokenizer_kwargs["max_length"] = 512 if has_tokenizer_2: