forked from modular/modular
-
Notifications
You must be signed in to change notification settings - Fork 0
[Pipelines] Implement Z-Image ModuleV2 pipeline #21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
byungchul-sqzb
wants to merge
1
commit into
byungchul-sqzb/stack/1
Choose a base branch
from
byungchul-sqzb/stack/2
base: byungchul-sqzb/stack/1
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
31 changes: 31 additions & 0 deletions
31
max/python/max/pipelines/architectures/z_image/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from .arch import ZImageArchConfig, z_image_arch | ||
| from .layers.attention import ZImageAttention | ||
| from .layers.embeddings import RopeEmbedder, TimestepEmbedder | ||
| from .model import ZImageTransformerModel | ||
| from .model_config import ZImageConfig, ZImageConfigBase | ||
| from .z_image import ZImageTransformer2DModel | ||
|
|
||
| __all__ = [ | ||
| "RopeEmbedder", | ||
| "TimestepEmbedder", | ||
| "ZImageArchConfig", | ||
| "ZImageAttention", | ||
| "ZImageConfig", | ||
| "ZImageConfigBase", | ||
| "ZImageTransformer2DModel", | ||
| "ZImageTransformerModel", | ||
| "z_image_arch", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from dataclasses import dataclass | ||
|
|
||
| from max.graph.weights import WeightsFormat | ||
| from max.interfaces import PipelineTask | ||
| from max.pipelines.core import PixelContext | ||
| from max.pipelines.lib import PixelGenerationTokenizer, SupportedArchitecture | ||
| from max.pipelines.lib.config import MAXModelConfig, PipelineConfig | ||
| from max.pipelines.lib.interfaces import ArchConfig | ||
| from typing_extensions import Self | ||
|
|
||
| from .pipeline_z_image import ZImagePipeline | ||
|
|
||
|
|
||
| @dataclass(kw_only=True) | ||
| class ZImageArchConfig(ArchConfig): | ||
| pipeline_config: PipelineConfig | ||
|
|
||
| def get_max_seq_len(self) -> int: | ||
| return 0 | ||
|
|
||
| @classmethod | ||
| def initialize( | ||
| cls, | ||
| pipeline_config: PipelineConfig, | ||
| model_config: MAXModelConfig | None = None, | ||
| ) -> Self: | ||
| model_config = model_config or pipeline_config.model | ||
| if len(model_config.device_specs) != 1: | ||
| raise ValueError("Z-Image is only supported on a single device") | ||
| return cls(pipeline_config=pipeline_config) | ||
|
|
||
|
|
||
| z_image_arch = SupportedArchitecture( | ||
| name="ZImagePipeline", | ||
| task=PipelineTask.PIXEL_GENERATION, | ||
| default_encoding="bfloat16", | ||
| supported_encodings={"bfloat16", "float32"}, | ||
| example_repo_ids=[ | ||
| "Tongyi-MAI/Z-Image", | ||
| "Tongyi-MAI/Z-Image-Turbo", | ||
| ], | ||
| pipeline_model=ZImagePipeline, # type: ignore[arg-type] | ||
| context_type=PixelContext, | ||
| default_weights_format=WeightsFormat.safetensors, | ||
| tokenizer=PixelGenerationTokenizer, | ||
| config=ZImageArchConfig, | ||
| ) |
17 changes: 17 additions & 0 deletions
17
max/python/max/pipelines/architectures/z_image/layers/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from .attention import ZImageAttention | ||
| from .embeddings import RopeEmbedder, TimestepEmbedder | ||
|
|
||
| __all__ = ["RopeEmbedder", "TimestepEmbedder", "ZImageAttention"] |
158 changes: 158 additions & 0 deletions
158
max/python/max/pipelines/architectures/z_image/layers/attention.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,158 @@ | ||
| # ===----------------------------------------------------------------------=== # | ||
| # Copyright (c) 2026, Modular Inc. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions: | ||
| # https://llvm.org/LICENSE.txt | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ===----------------------------------------------------------------------=== # | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from max.dtype import DType | ||
| from max.graph import DeviceRef, TensorValue, ops | ||
| from max.nn.attention.mask_config import MHAMaskVariant | ||
| from max.nn.kernels import flash_attention_gpu, rope_ragged_with_position_ids | ||
| from max.nn.layer import LayerList, Module | ||
| from max.nn.linear import Linear | ||
| from max.nn.norm import RMSNorm | ||
|
|
||
|
|
||
| def _apply_zimage_qk_rope( | ||
| query: TensorValue, | ||
| key: TensorValue, | ||
| freqs_cis: TensorValue, | ||
| ) -> tuple[TensorValue, TensorValue]: | ||
| """Apply RoPE using precomputed interleaved [cos, sin] frequencies.""" | ||
| batch_size = query.shape[0] | ||
| seq_len = query.shape[1] | ||
| num_heads = query.shape[2] | ||
| head_dim = query.shape[3] | ||
|
|
||
| query_ragged = ops.reshape( | ||
| query, [batch_size * seq_len, num_heads, head_dim] | ||
| ) | ||
| key_ragged = ops.reshape(key, [batch_size * seq_len, num_heads, head_dim]) | ||
|
|
||
| position_ids = ops.range( | ||
| 0, seq_len, dtype=DType.uint32, device=query.device | ||
| ) | ||
| position_ids = ops.broadcast_to( | ||
| ops.unsqueeze(position_ids, 0), [batch_size, seq_len] | ||
| ) | ||
|
|
||
| query_out = rope_ragged_with_position_ids( | ||
| query_ragged, freqs_cis, position_ids, interleaved=True | ||
| ) | ||
| key_out = rope_ragged_with_position_ids( | ||
| key_ragged, freqs_cis, position_ids, interleaved=True | ||
| ) | ||
| return ( | ||
| ops.reshape(query_out, [batch_size, seq_len, num_heads, head_dim]), | ||
| ops.reshape(key_out, [batch_size, seq_len, num_heads, head_dim]), | ||
| ) | ||
|
|
||
|
|
||
| class ZImageAttention(Module): | ||
| def __init__( | ||
| self, | ||
| dim: int, | ||
| n_heads: int, | ||
| qk_norm: bool, | ||
| eps: float, | ||
| *, | ||
| dtype: DType, | ||
| device: DeviceRef, | ||
| ) -> None: | ||
| """Initialize ZImageAttention.""" | ||
| super().__init__() | ||
| self.head_dim = dim // n_heads | ||
| self.inner_dim = dim | ||
| self.n_heads = n_heads | ||
|
|
||
| self.to_q = Linear( | ||
| in_dim=dim, | ||
| out_dim=dim, | ||
| dtype=dtype, | ||
| device=device, | ||
| has_bias=False, | ||
| ) | ||
| self.to_k = Linear( | ||
| in_dim=dim, | ||
| out_dim=dim, | ||
| dtype=dtype, | ||
| device=device, | ||
| has_bias=False, | ||
| ) | ||
| self.to_v = Linear( | ||
| in_dim=dim, | ||
| out_dim=dim, | ||
| dtype=dtype, | ||
| device=device, | ||
| has_bias=False, | ||
| ) | ||
|
|
||
| self.norm_q: RMSNorm | None = ( | ||
| RMSNorm(self.head_dim, dtype=dtype, eps=eps) if qk_norm else None | ||
| ) | ||
| self.norm_k: RMSNorm | None = ( | ||
| RMSNorm(self.head_dim, dtype=dtype, eps=eps) if qk_norm else None | ||
| ) | ||
|
|
||
| # Keep LayerList naming for diffusers-compatible key loading. | ||
| self.to_out = LayerList( | ||
| [ | ||
| Linear( | ||
| in_dim=dim, | ||
| out_dim=dim, | ||
| dtype=dtype, | ||
| device=device, | ||
| has_bias=False, | ||
| ) | ||
| ] | ||
| ) | ||
|
|
||
| def __call__( | ||
| self, | ||
| hidden_states: TensorValue, | ||
| freqs_cis: TensorValue, | ||
| ) -> TensorValue: | ||
| """Apply self-attention with rotary position embeddings.""" | ||
| batch_size = hidden_states.shape[0] | ||
| seq_len = hidden_states.shape[1] | ||
|
|
||
| query = self.to_q(hidden_states) | ||
| key = self.to_k(hidden_states) | ||
| value = self.to_v(hidden_states) | ||
|
|
||
| query = ops.reshape( | ||
| query, [batch_size, seq_len, self.n_heads, self.head_dim] | ||
| ) | ||
| key = ops.reshape( | ||
| key, [batch_size, seq_len, self.n_heads, self.head_dim] | ||
| ) | ||
| value = ops.reshape( | ||
| value, [batch_size, seq_len, self.n_heads, self.head_dim] | ||
| ) | ||
|
|
||
| if self.norm_q is not None: | ||
| query = self.norm_q(query) | ||
| if self.norm_k is not None: | ||
| key = self.norm_k(key) | ||
|
|
||
| query, key = _apply_zimage_qk_rope(query, key, freqs_cis) | ||
|
|
||
| out = flash_attention_gpu( | ||
| query, | ||
| key, | ||
| value, | ||
| mask_variant=MHAMaskVariant.NULL_MASK, | ||
| scale=1.0 / (self.head_dim**0.5), | ||
| ) | ||
|
|
||
| out = ops.reshape(out, [batch_size, seq_len, self.inner_dim]) | ||
| return self.to_out[0](out) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
position_idstensor must be flattened to 1D to match the first dimension of the ragged input tensors (query_ragged,key_ragged) passed to therope_ragged_with_position_idskernel. Currently, it is a 2D tensor of shape[batch_size, seq_len], which will cause a shape mismatch or incorrect indexing in the kernel.