From 97b3f3a8394e8b88723053e7182b9908e4f199fe Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Wed, 21 Jan 2026 04:40:47 +0000 Subject: [PATCH 01/10] feat: flux1 entrypoint --- max/examples/diffusion/offline_generation.py | 55 +++++++++ max/python/max/entrypoints/BUILD.bazel | 36 ++++++ max/python/max/entrypoints/cli/generate.py | 39 ++++++ max/python/max/entrypoints/diffusion.py | 61 ++++++++++ max/python/max/entrypoints/pipelines.py | 111 ++++++++++++++++++ .../max/entrypoints/pipelines_diffusion.py | 27 +++++ 6 files changed, 329 insertions(+) create mode 100644 max/examples/diffusion/offline_generation.py create mode 100644 max/python/max/entrypoints/diffusion.py create mode 100644 max/python/max/entrypoints/pipelines_diffusion.py diff --git a/max/examples/diffusion/offline_generation.py b/max/examples/diffusion/offline_generation.py new file mode 100644 index 00000000000..12b5cebd735 --- /dev/null +++ b/max/examples/diffusion/offline_generation.py @@ -0,0 +1,55 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import argparse +from pathlib import Path + +from max.entrypoints.diffusion import DiffusionPipeline +from max.experimental.realization_context import set_seed +from max.pipelines import PipelineConfig + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--model-path", type=str, default="black-forest-labs/FLUX.1-dev" + ) + parser.add_argument("--seed", type=int, default=42) + args = parser.parse_args() + + model_path = args.model_path + set_seed(args.seed) + pipeline_config = PipelineConfig(model_path=model_path) + pipe = DiffusionPipeline(pipeline_config) + + prompt = "A cat holding a sign that says hello world" + print(f"Prompt: {prompt}") + + result = pipe( + prompt=prompt, + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=3.5, + ) + + images = result.images + + output_path = Path("output.png") + output_path.parent.mkdir(parents=True, exist_ok=True) + images[0].save(output_path) + print(f"Image saved to: {output_path}") + + +if __name__ == "__main__": + main() diff --git a/max/python/max/entrypoints/BUILD.bazel b/max/python/max/entrypoints/BUILD.bazel index 6bef0b817c7..caf523b6598 100644 --- a/max/python/max/entrypoints/BUILD.bazel +++ b/max/python/max/entrypoints/BUILD.bazel @@ -93,6 +93,42 @@ modular_py_binary( ], ) +modular_py_binary( + name = "pipelines_diffusion", + srcs = [ + "pipelines_diffusion.py", + ], + data = [ + "@nvshmem_prebuilt//:host", + ], + env = { + "OTEL_EXPORTER_OTLP_METRICS_DEFAULT_HISTOGRAM_AGGREGATION": "base2_exponential_bucket_histogram", + "MODULAR_SHMEM_LIB_DIR": "../+http_archive+nvshmem_prebuilt", + }, + mojo_deps = select({ + "//:emit_mojo_enabled": PROD_MOJOPKGS, + "//conditions:default": [], + }), + deps = [ + # Provides the `max.entrypoints.pipelines` module for the wrapper to import. + ":_pipelines", + ":entrypoints", + "//max/python/max:_core", + "//max/python/max/benchmark:benchmark_serving_lib", + "//max/python/max/interfaces", + "//max/python/max/pipelines", + "//max/python/max/serve:config", + "//max/python/max/serve/telemetry", + requirement("typing-extensions"), + requirement("click"), + ] + select({ + "//:nvidia_gpu": [ + requirement("torch"), + ], + "//conditions:default": [], + }), +) + modular_py_binary( name = "replay_recording", srcs = ["replay_recording.py"], diff --git a/max/python/max/entrypoints/cli/generate.py b/max/python/max/entrypoints/cli/generate.py index 0cf5adc4cbc..1bcd5aed3a1 100644 --- a/max/python/max/entrypoints/cli/generate.py +++ b/max/python/max/entrypoints/cli/generate.py @@ -19,6 +19,7 @@ import dataclasses import logging from collections.abc import Iterable +from pathlib import Path from typing import Any import requests @@ -158,3 +159,41 @@ def generate_text_for_pipeline( print_tokens=True, ) ) + + +def generate_image( + pipeline_config: PipelineConfig, + prompt: str, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + num_images_per_prompt: int, + output: Path, +) -> None: + from ..diffusion import DiffusionPipeline + + pipeline = DiffusionPipeline(pipeline_config) + result = pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + ) + + images = result.images + assert images, "Expected at least one generated image." + + output.parent.mkdir(parents=True, exist_ok=True) + if num_images_per_prompt == 1: + images[0].save(output) + logger.info(f"Image saved to: {output}") + else: + stem = output.stem + suffix = output.suffix + for i, image in enumerate(images): + numbered_path = output.parent / f"{stem}_{i + 1}{suffix}" + image.save(numbered_path) + logger.info(f"{len(images)} images saved to: {output.parent}") diff --git a/max/python/max/entrypoints/diffusion.py b/max/python/max/entrypoints/diffusion.py new file mode 100644 index 00000000000..ec85979310c --- /dev/null +++ b/max/python/max/entrypoints/diffusion.py @@ -0,0 +1,61 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.interfaces import ( + ImageGenerationInputs, + ImageGenerationOutput, + PipelineTask, +) +from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig + + +class DiffusionPipeline: + """Entrypoint for image-generation diffusion pipelines.""" + + def __init__(self, pipeline_config: PipelineConfig) -> None: + # NOTE: Currently, this entrypoint is implemented minimally + # for offline image generation. + # It will be developed further to support serving as well. + self.pipeline_config = pipeline_config + _, model_factory = PIPELINE_REGISTRY.retrieve_factory( + pipeline_config, + task=PipelineTask.IMAGE_GENERATION, + ) + self.pipeline = model_factory() + + def __call__( + self, + prompt: str, + negative_prompt: str | None = None, + true_cfg_scale: float = 1.0, + height: int = 1024, + width: int = 1024, + num_inference_steps: int = 50, + guidance_scale: float = 3.5, + num_images_per_prompt: int = 1, + ) -> ImageGenerationOutput: + """Generate images from a prompt with the configured pipeline.""" + # TODO: consider all possible diffusion tasks, + # e.g. T2I, I2I, T2V, I2V, V2V. + inputs = ImageGenerationInputs( + prompt=prompt, + negative_prompt=negative_prompt, + true_cfg_scale=true_cfg_scale, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + ) + pipeline_output: ImageGenerationOutput = self.pipeline.execute(inputs) + return pipeline_output diff --git a/max/python/max/entrypoints/pipelines.py b/max/python/max/entrypoints/pipelines.py index c8dcada7415..a99c8c6fb18 100644 --- a/max/python/max/entrypoints/pipelines.py +++ b/max/python/max/entrypoints/pipelines.py @@ -18,6 +18,7 @@ import os import sys from collections.abc import Callable, Sequence +from pathlib import Path from typing import Any, TypeVar import click @@ -384,6 +385,116 @@ def cli_pipeline( ) +@main.group(name="diffusion", cls=ModelGroup) +def diffusion_group() -> None: + """Commands for diffusion-based image/video generation pipelines.""" + + +@diffusion_group.command(name="generate", cls=WithLazyPipelineOptions) +@click.option( + "--prompt", + type=str, + default="A cat holding a sign that says hello world", + help="The text prompt to use for image generation.", +) +@click.option( + "--height", + type=click.IntRange(min=64), + default=1024, + show_default=True, + help="Generated image height in pixels.", +) +@click.option( + "--width", + type=click.IntRange(min=64), + default=1024, + show_default=True, + help="Generated image width in pixels.", +) +@click.option( + "--num-inference-steps", + type=click.IntRange(min=1), + default=50, + show_default=True, + help="Number of denoising steps to run.", +) +@click.option( + "--guidance-scale", + type=float, + default=3.5, + show_default=True, + help="Classifier-free guidance scale.", +) +@click.option( + "--num-images-per-prompt", + type=click.IntRange(min=1), + default=1, + show_default=True, + help="Number of images to generate for a single prompt.", +) +@click.option( + "--output", + type=click.Path(dir_okay=False, path_type=Path), + default="output.png", + show_default=True, + help="Output image path (numbered if multiple images are generated).", +) +@click.option( + "--use-torch-randn/--no-use-torch-randn", + default=False, + show_default=True, + help=( + "Use torch-based random latents (set USE_TORCH_RANDN and SEED env vars)." + ), +) +@click.option( + "--seed", + type=int, + default=42, + show_default=True, + help="Random seed for torch-based latent initialization.", +) +def diffusion_generate( + prompt: str, + height: int, + width: int, + num_inference_steps: int, + guidance_scale: float, + num_images_per_prompt: int, + output: Path, + use_torch_randn: bool, + seed: int, + **config_kwargs: Any, +) -> None: + """Generate images using a diffusion pipeline.""" + from max.entrypoints.cli.generate import generate_image + from max.experimental.realization_context import set_seed + from max.pipelines import PipelineConfig + + set_seed(seed) + pipeline_config = PipelineConfig(**config_kwargs) + pipeline_config.log_basic_config() + + try: + generate_image( + pipeline_config=pipeline_config, + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + output=output, + ) + except Exception as exc: + logger.exception( + "Diffusion generation failed for model %s with prompt %r", + pipeline_config.model.model_path, + prompt, + ) + raise click.ClickException("Diffusion generation failed.") from exc + + @main.command(name="encode", cls=WithLazyPipelineOptions) @click.option( "--prompt", diff --git a/max/python/max/entrypoints/pipelines_diffusion.py b/max/python/max/entrypoints/pipelines_diffusion.py new file mode 100644 index 00000000000..52863bf27f4 --- /dev/null +++ b/max/python/max/entrypoints/pipelines_diffusion.py @@ -0,0 +1,27 @@ +"""Diffusion-only CLI wrapper. + +This exists so Bazel can keep `//max/python/max/entrypoints:pipelines` lean, +while allowing `//max/python/max/entrypoints:pipelines_diffusion` to pull in +extra runtime deps. +""" + +from __future__ import annotations + +import sys + + +def main() -> None: + # Import the main pipelines CLI and dispatch into the `diffusion` group. + # + # NOTE: `max.entrypoints.pipelines.main` is a click command object. Calling it + # with `args=[...]` is equivalent to invoking the CLI with those argv tokens. + import max.entrypoints.pipelines as pipelines_cli + + pipelines_cli.main( + prog_name="pipelines", + args=["diffusion", *sys.argv[1:]], + ) + + +if __name__ == "__main__": + main() From 5ef581b1b7dec3eec47fd69c192698f430c8b405 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Wed, 21 Jan 2026 04:41:50 +0000 Subject: [PATCH 02/10] feat: flux1 interface --- max/python/max/interfaces/__init__.py | 4 + .../interfaces/pipeline_variants/__init__.py | 6 + .../pipeline_variants/image_generation.py | 42 ++++ max/python/max/interfaces/task.py | 2 + .../lib/pipeline_variants/__init__.py | 1 + .../lib/pipeline_variants/image_generation.py | 219 ++++++++++++++++++ max/python/max/pipelines/lib/registry.py | 190 +++++++++------ 7 files changed, 388 insertions(+), 76 deletions(-) create mode 100644 max/python/max/interfaces/pipeline_variants/image_generation.py create mode 100644 max/python/max/pipelines/lib/pipeline_variants/image_generation.py diff --git a/max/python/max/interfaces/__init__.py b/max/python/max/interfaces/__init__.py index adfcafb72cc..d6fe4c46cb6 100644 --- a/max/python/max/interfaces/__init__.py +++ b/max/python/max/interfaces/__init__.py @@ -49,6 +49,8 @@ EmbeddingsGenerationInputs, EmbeddingsGenerationOutput, ImageContentPart, + ImageGenerationInputs, + ImageGenerationOutput, ImageMetadata, TextContentPart, TextGenerationContext, @@ -109,6 +111,8 @@ def create_text_pipeline() -> Pipeline[TextGenerationInputs, TextGenerationOutpu "EmbeddingsGenerationOutput", "GenerationStatus", "ImageContentPart", + "ImageGenerationInputs", + "ImageGenerationOutput", "ImageMetadata", "LoRAOperation", "LoRARequest", diff --git a/max/python/max/interfaces/pipeline_variants/__init__.py b/max/python/max/interfaces/pipeline_variants/__init__.py index d5a7b10d1f3..073a7ad7a76 100644 --- a/max/python/max/interfaces/pipeline_variants/__init__.py +++ b/max/python/max/interfaces/pipeline_variants/__init__.py @@ -24,6 +24,10 @@ EmbeddingsGenerationInputs, EmbeddingsGenerationOutput, ) +from .image_generation import ( + ImageGenerationInputs, + ImageGenerationOutput, +) from .text_generation import ( BatchType, ImageContentPart, @@ -54,6 +58,8 @@ "EmbeddingsGenerationInputs", "EmbeddingsGenerationOutput", "ImageContentPart", + "ImageGenerationInputs", + "ImageGenerationOutput", "ImageMetadata", "TextContentPart", "TextGenerationContext", diff --git a/max/python/max/interfaces/pipeline_variants/image_generation.py b/max/python/max/interfaces/pipeline_variants/image_generation.py new file mode 100644 index 00000000000..d761298a7d6 --- /dev/null +++ b/max/python/max/interfaces/pipeline_variants/image_generation.py @@ -0,0 +1,42 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from dataclasses import dataclass + +from max.interfaces.pipeline import PipelineInputs +from PIL.Image import Image + + +@dataclass(eq=True) +class ImageGenerationInputs(PipelineInputs): + """Inputs for image-generation pipelines.""" + + # NOTE: Current implementation only considers offline generation without + # request scheduling. `ImageGenerationContext` should be used once + # request scheduling is implemented. + prompt: str + negative_prompt: str | None + true_cfg_scale: float + height: int + width: int + num_inference_steps: int + guidance_scale: float + num_images_per_prompt: int + + +@dataclass(kw_only=True) +class ImageGenerationOutput: + """Output container for generated images.""" + + images: list[Image] + """List of generated images.""" diff --git a/max/python/max/interfaces/task.py b/max/python/max/interfaces/task.py index 477b77451e7..0422ebf657c 100644 --- a/max/python/max/interfaces/task.py +++ b/max/python/max/interfaces/task.py @@ -58,6 +58,8 @@ class PipelineTask(str, Enum): """Task for generating audio.""" SPEECH_TOKEN_GENERATION = "speech_token_generation" """Task for generating speech tokens.""" + IMAGE_GENERATION = "image_generation" + """Task for generating images.""" @property def output_type( diff --git a/max/python/max/pipelines/lib/pipeline_variants/__init__.py b/max/python/max/pipelines/lib/pipeline_variants/__init__.py index 991ed73eea4..2c04acdbcdd 100644 --- a/max/python/max/pipelines/lib/pipeline_variants/__init__.py +++ b/max/python/max/pipelines/lib/pipeline_variants/__init__.py @@ -11,4 +11,5 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # +from .image_generation import ImageGenerationPipeline from .text_generation import TextGenerationPipeline diff --git a/max/python/max/pipelines/lib/pipeline_variants/image_generation.py b/max/python/max/pipelines/lib/pipeline_variants/image_generation.py new file mode 100644 index 00000000000..5841263e36f --- /dev/null +++ b/max/python/max/pipelines/lib/pipeline_variants/image_generation.py @@ -0,0 +1,219 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +import fnmatch +import logging +import os +import re +from pathlib import Path +from typing import TYPE_CHECKING + +import huggingface_hub +import requests +from huggingface_hub.utils import EntryNotFoundError, OfflineModeIsEnabled +from max.config import load_config +from max.interfaces import ( + ImageGenerationInputs, + ImageGenerationOutput, + Pipeline, + RequestID, +) +from requests.exceptions import HTTPError + +from ..config_enums import RepoType +from ..interfaces import DiffusionPipeline + +if TYPE_CHECKING: + from ..config import PipelineConfig + +logger = logging.getLogger(__name__) + + +class ImageGenerationPipeline( + Pipeline[ImageGenerationInputs, ImageGenerationOutput], +): + """Pipeline wrapper for diffusion image generation.""" + + def __init__( + self, + pipeline_config: PipelineConfig, + diffusion_pipeline: type[DiffusionPipeline], + ) -> None: + # Download checkpoints if required + # NOTE: Unlike TextGenerationPipeline where each file, + # such as configs and weights, are downloaded individually, + # DiffusionPipeline downloads the entire snapshot at once, + # since it normally contains multiple components. + pretrained_model_name_or_path = ( + pipeline_config.model.huggingface_model_repo.repo_id + ) + if ( + pipeline_config.model.huggingface_model_repo.repo_type + == RepoType.online + ): + cached_folder = self.download( + pretrained_model_name_or_path, + config_name=diffusion_pipeline.config_name, + force_download=pipeline_config.model.force_download, + revision=pipeline_config.model.huggingface_model_revision, + ) + else: + cached_folder = pretrained_model_name_or_path + + self._diffusion_pipeline = diffusion_pipeline( + pipeline_config, cached_folder + ) + + def download( + self, + pretrained_model_name: str | os.PathLike, + config_name: str | None, + force_download: bool = False, + revision: str | None = None, + ) -> str: + """Download the pipeline components from the Hugging Face Hub. + + Args: + pretrained_model_name: Model identifier. + config_name: Pipeline config filename in the repo. + force_download: Whether to force download. + revision: Model revision. + + Returns: + Path to the downloaded model folder. + """ + try: + info = huggingface_hub.model_info( + pretrained_model_name, revision=revision + ) + except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e: + logger.warning( + f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache." + ) + model_info_call_error = ( + e # save error to reraise it if model is not cached locally + ) + + if config_name is None: + raise ValueError( + f"config_name for {pretrained_model_name} pipeline is not set. " + "Please set proper config file name from huggingface hub." + ) + try: + config_file = huggingface_hub.hf_hub_download( + pretrained_model_name, + config_name, + revision=revision, + force_download=force_download, + ) + except EntryNotFoundError as e: + raise ValueError( + f"config file {config_name} not found for {pretrained_model_name} pipeline. " + "Please check if the config file name is correct." + ) from e + + config_dict = load_config(config_file) + ignore_filenames = config_dict.pop("_ignore_files", []) + + filenames = {sibling.rfilename for sibling in info.siblings} + filenames = set(filenames) - set(ignore_filenames) + + ignore_patterns = [ + "*.bin", + "*.msgpack", + "*.onnx", + "*.pb", + "*.bin.index.*json", + "*.msgpack.index.*json", + "*.onnx.index.*json", + "*.pb.index.*json", + ] + + allow_patterns = ["*/*"] + allow_patterns += [ + "scheduler_config.json", + "config.json", + config_name, + ] + re_ignore_pattern = [ + re.compile(fnmatch.translate(p)) for p in ignore_patterns + ] + re_allow_pattern = [ + re.compile(fnmatch.translate(p)) for p in allow_patterns + ] + + expected_files = [ + f + for f in filenames + if not any(p.match(f) for p in re_ignore_pattern) + ] + expected_files = [ + f + for f in expected_files + if any(p.match(f) for p in re_allow_pattern) + ] + + snapshot_folder = Path(config_file).parent + pipeline_is_cached = all( + (snapshot_folder / f).is_file() for f in expected_files + ) + + if pipeline_is_cached and not force_download: + # if the pipeline is cached, we can directly return it + # else call snapshot_download + return snapshot_folder + + # download all allow_patterns - ignore_patterns + try: + cached_folder = huggingface_hub.snapshot_download( + pretrained_model_name, + revision=revision, + allow_patterns=allow_patterns, + ignore_patterns=ignore_patterns, + ) + + return cached_folder + + except FileNotFoundError: + # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache. + # This can happen in two cases: + # 1. If the user passed `local_files_only=True` => we raise the error directly + # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error + if model_info_call_error is None: + # 1. user passed `local_files_only=True` + raise + else: + # 2. we forced `local_files_only=True` when `model_info` failed + raise OSError( + f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred" + " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace" + " above." + ) from model_info_call_error + + def execute(self, inputs: ImageGenerationInputs) -> ImageGenerationOutput: + outputs = self._diffusion_pipeline( + prompt=inputs.prompt, + negative_prompt=inputs.negative_prompt, + true_cfg_scale=inputs.true_cfg_scale, + height=inputs.height, + width=inputs.width, + num_inference_steps=inputs.num_inference_steps, + guidance_scale=inputs.guidance_scale, + num_images_per_prompt=inputs.num_images_per_prompt, + ) + return ImageGenerationOutput(images=outputs.images) + + def release(self, request_id: RequestID) -> None: + pass diff --git a/max/python/max/pipelines/lib/registry.py b/max/python/max/pipelines/lib/registry.py index 122973df2db..7246e5bb32c 100644 --- a/max/python/max/pipelines/lib/registry.py +++ b/max/python/max/pipelines/lib/registry.py @@ -16,6 +16,7 @@ from __future__ import annotations import functools +import json import logging from collections.abc import Callable from dataclasses import dataclass, field @@ -38,6 +39,7 @@ from transformers import ( AutoConfig, AutoTokenizer, + PretrainedConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, ) @@ -49,8 +51,9 @@ from .audio_generator_pipeline import AudioGeneratorPipeline from .config_enums import RopeType, SupportedEncoding from .embeddings_pipeline import EmbeddingsPipeline -from .hf_utils import HuggingFaceRepo +from .hf_utils import HuggingFaceRepo, get_model_index_path_for_diffusers from .interfaces import PipelineModel +from .pipeline_variants.image_generation import ImageGenerationPipeline from .pipeline_variants.text_generation import TextGenerationPipeline from .speculative_decoding import ( EAGLESpeculativeDecodingPipeline, @@ -74,6 +77,7 @@ def get_pipeline_for_task( | type[StandaloneSpeculativeDecodingPipeline] | type[SpeechTokenGenerationPipeline] | type[EAGLESpeculativeDecodingPipeline] + | type[ImageGenerationPipeline] ): if task == PipelineTask.TEXT_GENERATION: if pipeline_config._speculative is not None: @@ -100,6 +104,8 @@ def get_pipeline_for_task( return AudioGeneratorPipeline elif task == PipelineTask.SPEECH_TOKEN_GENERATION: return SpeechTokenGenerationPipeline + elif task == PipelineTask.IMAGE_GENERATION: + return ImageGenerationPipeline @dataclass(frozen=False) @@ -290,7 +296,7 @@ def retrieve_architecture( def get_active_huggingface_config( self, huggingface_repo: HuggingFaceRepo - ) -> AutoConfig: + ) -> AutoConfig | PretrainedConfig: """Retrieves or creates a cached HuggingFace AutoConfig for the given model configuration. @@ -311,7 +317,22 @@ def get_active_huggingface_config( Returns: AutoConfig: The HuggingFace configuration object for the model. """ - if huggingface_repo not in self._cached_huggingface_configs: + model_index_path = get_model_index_path_for_diffusers(huggingface_repo) + + if model_index_path is not None: + with open(model_index_path, encoding="utf-8") as f: + model_index = json.load(f) + + class_name = model_index.get("_class_name") + if not class_name or not isinstance(class_name, str): + raise ValueError( + f"Diffusers-style repository '{huggingface_repo.repo_id}' is missing a valid '_class_name' in model_index.json" + ) + + self._cached_huggingface_configs[huggingface_repo] = ( + PretrainedConfig(architectures=[class_name]) + ) + else: self._cached_huggingface_configs[huggingface_repo] = ( AutoConfig.from_pretrained( huggingface_repo.repo_id, @@ -444,87 +465,104 @@ def retrieve_factory( assert arch is not None devices = load_devices(pipeline_config.model.device_specs) - max_length = arch.pipeline_model.calculate_max_seq_len( - pipeline_config, huggingface_config=huggingface_config - ) - - # Old Mistral model like Mistral-7B-Instruct-v0.3 uses LlamaTokenizer - # and suffers from the whitespace decoding bug. So, we enable the fix - # for only MistralModel in order to avoid any issues with performance - # for rest of the models. This can be applied more generically once - # we have more time verifying this for all the models. - # More information: - # https://linear.app/modularml/issue/AIPIPE-197/add-support-for-mistral-7b-instruct-v03 - # TODO: remove this pipeline_model.__name__ check - if ( - arch.pipeline_model.__name__ in ("MistralModel", "Phi3Model") - and arch.tokenizer is TextTokenizer - ): - text_tokenizer = cast(type[TextTokenizer], arch.tokenizer) - tokenizer = text_tokenizer( - pipeline_config.model.model_path, - pipeline_config=pipeline_config, - revision=pipeline_config.model.huggingface_model_revision, - max_length=max_length, - trust_remote_code=pipeline_config.model.trust_remote_code, - enable_llama_whitespace_fix=True, - chat_template=pipeline_config.retrieve_chat_template(), - context_validators=arch.context_validators, + if task != PipelineTask.IMAGE_GENERATION: + max_length = arch.pipeline_model.calculate_max_seq_len( + pipeline_config, huggingface_config=huggingface_config ) - else: - tokenizer = arch.tokenizer( - model_path=pipeline_config.model.model_path, - pipeline_config=pipeline_config, - revision=pipeline_config.model.huggingface_model_revision, - max_length=max_length, - trust_remote_code=pipeline_config.model.trust_remote_code, - chat_template=pipeline_config.retrieve_chat_template(), - context_validators=arch.context_validators, + + # Old Mistral model like Mistral-7B-Instruct-v0.3 uses LlamaTokenizer + # and suffers from the whitespace decoding bug. So, we enable the fix + # for only MistralModel in order to avoid any issues with performance + # for rest of the models. This can be applied more generically once + # we have more time verifying this for all the models. + # More information: + # https://linear.app/modularml/issue/AIPIPE-197/add-support-for-mistral-7b-instruct-v03 + # TODO: remove this pipeline_model.__name__ check + if ( + arch.pipeline_model.__name__ in ("MistralModel", "Phi3Model") + and arch.tokenizer is TextTokenizer + ): + text_tokenizer = cast(type[TextTokenizer], arch.tokenizer) + tokenizer = text_tokenizer( + pipeline_config.model.model_path, + pipeline_config=pipeline_config, + revision=pipeline_config.model.huggingface_model_revision, + max_length=max_length, + trust_remote_code=pipeline_config.model.trust_remote_code, + enable_llama_whitespace_fix=True, + chat_template=pipeline_config.retrieve_chat_template(), + context_validators=arch.context_validators, + ) + else: + tokenizer = arch.tokenizer( + model_path=pipeline_config.model.model_path, + pipeline_config=pipeline_config, + revision=pipeline_config.model.huggingface_model_revision, + max_length=max_length, + trust_remote_code=pipeline_config.model.trust_remote_code, + chat_template=pipeline_config.retrieve_chat_template(), + context_validators=arch.context_validators, + ) + # Cast tokenizer to the proper type for text generation pipeline compatibility + typed_tokenizer = cast( + PipelineTokenizer[ + Any, npt.NDArray[np.integer[Any]], TextGenerationRequest + ], + tokenizer, ) - # Cast tokenizer to the proper type for text generation pipeline compatibility - typed_tokenizer = cast( - PipelineTokenizer[ - Any, npt.NDArray[np.integer[Any]], TextGenerationRequest - ], - tokenizer, - ) - # For speculative decoding, retrieve draft model's architecture - factory_kwargs: dict[str, Any] = { - "pipeline_config": pipeline_config, - "pipeline_model": arch.pipeline_model, - "eos_token_id": tokenizer.eos, - "weight_adapters": arch.weight_adapters, - "tokenizer": typed_tokenizer, - } - - # If using speculative decoding, add draft model-specific parameters - if pipeline_config.draft_model is not None: - draft_arch = self.retrieve_architecture( - huggingface_repo=pipeline_config.draft_model.huggingface_weight_repo, - use_module_v3=pipeline_config.use_module_v3, + # For speculative decoding, retrieve draft model's architecture + factory_kwargs: dict[str, Any] = { + "pipeline_config": pipeline_config, + "pipeline_model": arch.pipeline_model, + "eos_token_id": tokenizer.eos, + "weight_adapters": arch.weight_adapters, + "tokenizer": typed_tokenizer, + } + + # If using speculative decoding, add draft model-specific parameters + if pipeline_config.draft_model is not None: + draft_arch = self.retrieve_architecture( + huggingface_repo=pipeline_config.draft_model.huggingface_weight_repo, + use_module_v3=pipeline_config.use_module_v3, + ) + if draft_arch is None: + raise ValueError( + f"MAX-Optimized architecture not found for draft model " + f"'{pipeline_config.draft_model.model_path}'" + ) + factory_kwargs["draft_pipeline_model"] = ( + draft_arch.pipeline_model + ) + factory_kwargs["draft_weight_adapters"] = ( + draft_arch.weight_adapters + ) + + pipeline_factory = cast( + Callable[[], PipelineTypes], + functools.partial( # type: ignore + pipeline_class, **factory_kwargs + ), ) - if draft_arch is None: + + if tokenizer.eos is None: raise ValueError( - f"MAX-Optimized architecture not found for draft model " - f"'{pipeline_config.draft_model.model_path}'" + "tokenizer.eos value is None, tokenizer configuration is incomplete." ) - factory_kwargs["draft_pipeline_model"] = draft_arch.pipeline_model - factory_kwargs["draft_weight_adapters"] = draft_arch.weight_adapters - - pipeline_factory = cast( - Callable[[], PipelineTypes], - functools.partial( # type: ignore - pipeline_class, **factory_kwargs - ), - ) - if tokenizer.eos is None: - raise ValueError( - "tokenizer.eos value is None, tokenizer configuration is incomplete." + return tokenizer, pipeline_factory + else: + factory_kwargs = { + "pipeline_config": pipeline_config, + "diffusion_pipeline": arch.pipeline_model, + } + pipeline_factory = cast( + Callable[[], PipelineTypes], + functools.partial( # type: ignore + pipeline_class, **factory_kwargs + ), ) - - return tokenizer, pipeline_factory + return None, pipeline_factory def retrieve_context_type( self, pipeline_config: PipelineConfig From d81e142aeeb6759495c55a766573f8a00a9b20fb Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Wed, 21 Jan 2026 04:42:39 +0000 Subject: [PATCH 03/10] feat: flux1 pipeline --- max/python/max/config/__init__.py | 22 +- max/python/max/experimental/BUILD.bazel | 1 + max/python/max/experimental/compile_utils.py | 97 ++ .../architectures/flux1/pipeline_flux.py | 765 ++++++++++++++++ max/python/max/pipelines/lib/config.py | 20 +- .../lib/diffusion_schedulers/__init__.py | 16 + .../scheduling_flow_match_euler_discrete.py | 852 ++++++++++++++++++ max/python/max/pipelines/lib/hf_utils.py | 49 +- .../max/pipelines/lib/image_processor.py | 226 +++++ .../max/pipelines/lib/interfaces/__init__.py | 2 + .../lib/interfaces/diffusion_pipeline.py | 177 ++++ max/python/max/pipelines/lib/model_config.py | 6 +- 12 files changed, 2211 insertions(+), 22 deletions(-) create mode 100644 max/python/max/experimental/compile_utils.py create mode 100644 max/python/max/pipelines/architectures/flux1/pipeline_flux.py create mode 100644 max/python/max/pipelines/lib/diffusion_schedulers/__init__.py create mode 100644 max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py create mode 100644 max/python/max/pipelines/lib/image_processor.py create mode 100644 max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py diff --git a/max/python/max/config/__init__.py b/max/python/max/config/__init__.py index 4b10a5e89d1..440822a3478 100644 --- a/max/python/max/config/__init__.py +++ b/max/python/max/config/__init__.py @@ -16,7 +16,9 @@ import argparse import enum +import json import logging +import os import types from abc import abstractmethod from collections.abc import Mapping @@ -506,6 +508,7 @@ def _extract_max_config_data( config_dict: The loaded YAML configuration dictionary. config_class: The config class we're extracting data for. section_name: Optional specific section name to look for. + config_file_path: Path to the config file for resolving inheritance. Returns: Configuration data for the specific config class. @@ -854,9 +857,9 @@ def _add_field_as_argument( ): # For enums, use the string value as default but we'll need to convert back arg_kwargs = { - "default": field_value.value - if field_value - else field_obj.default + "default": ( + field_value.value if field_value else field_obj.default + ) } else: arg_kwargs = {"default": field_value} @@ -1071,6 +1074,19 @@ def parse_args( # type: ignore[override] # noqa: ANN202 return MAXConfigArgumentParser(parser, self) +def load_config(config_path: str | os.PathLike) -> dict: + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + try: + with open(config_path, encoding="utf-8") as f: + config_dict = json.loads(f.read()) + except Exception as e: + raise ValueError( + f"Failed to load configuration from {config_path}: {e}" + ) from e + return config_dict + + all = [ "MAXBaseModel", "ConfigFileModel", diff --git a/max/python/max/experimental/BUILD.bazel b/max/python/max/experimental/BUILD.bazel index 9c95184007c..946a73ca810 100644 --- a/max/python/max/experimental/BUILD.bazel +++ b/max/python/max/experimental/BUILD.bazel @@ -9,6 +9,7 @@ modular_py_library( "_passes.py", "_tensor_repr.py", "functional.py", + "compile_utils.py", "random.py", "realization_context.py", "support.py", diff --git a/max/python/max/experimental/compile_utils.py b/max/python/max/experimental/compile_utils.py new file mode 100644 index 00000000000..13c10b93d83 --- /dev/null +++ b/max/python/max/experimental/compile_utils.py @@ -0,0 +1,97 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from collections.abc import Callable, Iterable +from typing import Any + +from max.driver import CPU, Accelerator +from max.engine import InferenceSession +from max.graph import Graph, TensorType +from max.nn.module_v3 import Module + + +class CompileWrapper: + def __init__( + self, + compile_target: Callable | Module, + input_types: Iterable[TensorType] | None = None, + ) -> None: + """Initialize the CompileWrapper. + + Args: + compile_target: The function or module to be compiled. + input_types: A list of input types (TensorTypes) required for compilation. + + Raises: + ValueError: If input_types is not provided. + """ + if input_types is None: + raise ValueError( + f"input_types must be provided for compilation of {compile_target.__name__}." + ) + + self.is_module = False + if isinstance(compile_target, Module): + self.is_module = True + self.session = compile_target.compile(input_types) + return + + with Graph(compile_target.__name__, input_types=input_types) as graph: + output = compile_target(*graph.inputs) + graph.output(output) + compiled_graph = graph + + if any(input_type.device.is_gpu() for input_type in input_types): + device = Accelerator() + else: + device = CPU() + session = InferenceSession([device]) + loaded_session = session.load(compiled_graph) + self.session = loaded_session + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Execute the compiled session with the given arguments. + + Args: + *args: Positional arguments to pass to the session. + **kwargs: Keyword arguments to pass to the session. + + Returns: + The result of the session execution. + """ + if self.is_module: + return self.session(*args, **kwargs) + return self.session.execute(*args, **kwargs) + + +def max_compile( + compile_target: Callable | Module | None = None, + input_types: Iterable[TensorType] | None = None, +) -> Callable[[Callable | Module], CompileWrapper] | CompileWrapper: + """Decorator or function to compile a target with specified input types. + + Args: + compile_target: The function or module to compile. If None, returns a decorator. + input_types: The input types for the compilation. + + Returns: + A CompileWrapper instance if compile_target is provided, otherwise a decorator. + """ + if compile_target is None: + + def decorator(f: Callable | Module) -> CompileWrapper: + return CompileWrapper(f, input_types) + + return decorator + + return CompileWrapper(compile_target, input_types) diff --git a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py new file mode 100644 index 00000000000..44a450eed2d --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py @@ -0,0 +1,765 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import inspect +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +import numpy as np +import PIL.Image +from max.driver import Buffer as Tensor +from max.dtype import DType +from max.experimental import Tensor as Tensor_v3 +from max.experimental import functional as F +from max.experimental import random +from max.graph import DeviceRef +from max.pipelines.lib.diffusion_schedulers import ( + FlowMatchEulerDiscreteScheduler, +) +from max.pipelines.lib.image_processor import ( + PipelineImageInput, + VaeImageProcessor, +) +from max.pipelines.lib.interfaces.diffusion_pipeline import ( + DiffusionPipeline, +) +from tqdm import tqdm +from transformers import ( + CLIPTokenizer, + T5TokenizerFast, +) + +from ..autoencoder_kl import AutoencoderKLModel +from ..clip import ClipModel +from ..t5 import T5Model +from .model import Flux1Model + + +def retrieve_timesteps( + scheduler: Any, + num_inference_steps: int | None = None, + device: str | DeviceRef | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs: Any, +) -> tuple[np.ndarray, int]: + r"""Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. + + Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `DeviceRef`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`list[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`list[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + **kwargs (`Any`, *optional*): + Additional arguments to pass to the scheduler's `set_timesteps` method. + + Returns: + `tuple[Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = int(timesteps.shape[0]) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = int(timesteps.shape[0]) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +def calculate_shift( + image_seq_len: int, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.15, +) -> float: + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +@dataclass +class FluxPipelineOutput: + """Output class for Flux image generation pipelines. + + Args: + images (`list[PIL.Image.Image]` or `np.ndarray` or `Tensor`) + List of denoised PIL images of length `batch_size` or numpy array or Max tensor of shape `(batch_size, + height, width, num_channels)`. PIL images or numpy array present the denoised images of the diffusion + pipeline. Max tensors can represent either the denoised images or the intermediate latents ready to be + passed to the decoder. + """ + + images: list[PIL.Image.Image] | np.ndarray | Tensor + + +class FluxPipeline(DiffusionPipeline): + config_name = "model_index.json" + + components = { + "scheduler": FlowMatchEulerDiscreteScheduler, + "vae": AutoencoderKLModel, + "text_encoder": ClipModel, + "tokenizer": CLIPTokenizer, + "text_encoder_2": T5Model, + "tokenizer_2": T5TokenizerFast, + "transformer": Flux1Model, + } + + def init_remaining_components(self) -> None: + image_processor_class = self.components.get( + "image_processor", VaeImageProcessor + ) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) + if getattr(self, "vae", None) + else 8 + ) + image_processor = image_processor_class( + vae_scale_factor=self.vae_scale_factor * 2 + ) + self.image_processor = image_processor + + def encode_prompt( + self, + prompt: str | list[str], + prompt_2: str | list[str], + device: DeviceRef | None = None, + num_images_per_prompt: int = 1, + prompt_embeds: Tensor | None = None, + pooled_prompt_embeds: Tensor | None = None, + max_sequence_length: int = 512, + lora_scale: float | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + r"""Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in all text-encoders + device: (`DeviceRef`): + Max device + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + prompt_embeds (`Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + max_sequence_length (`int`, defaults to 512): Maximum sequence length to use with the `prompt`. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + if lora_scale is not None and isinstance(self, FluxPipeline): + self._lora_scale = lora_scale + + if self.text_encoder is not None and hasattr( + self.text_encoder, "set_lora_scale" + ): + self.text_encoder.set_lora_scale(lora_scale) + if self.text_encoder_2 is not None and hasattr( + self.text_encoder_2, "set_lora_scale" + ): + self.text_encoder_2.set_lora_scale(lora_scale) + + prompt = [prompt] if isinstance(prompt, str) else prompt + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=min( + max_sequence_length, self.tokenizer.model_max_length + ), + truncation=True, + return_length=False, + return_overflowing_tokens=False, + ) + text_input_ids = Tensor_v3.constant( + text_inputs.input_ids, device=device, dtype=DType.int64 + ) + + text_encoder_outputs = self.text_encoder(text_input_ids) + prompt_embeds = text_encoder_outputs[0] + pooled_prompt_embeds = text_encoder_outputs[1] + + prompt_2 = prompt_2 or prompt + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + if self.text_encoder_2 is not None: + text_inputs_2 = self.tokenizer_2( + prompt_2, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_length=False, + return_overflowing_tokens=False, + ) + text_input_ids_2 = Tensor_v3.constant( + text_inputs_2.input_ids, device=device, dtype=DType.int64 + ) + + prompt_embeds_2 = self.text_encoder_2(text_input_ids_2) + else: + prompt_embeds_2 = None + + if prompt_embeds_2 is not None: + prompt_embeds = prompt_embeds_2 + + text_ids = Tensor_v3.zeros( + (prompt_embeds.shape[1], 3), + device=device, + dtype=prompt_embeds.dtype, + ) + + bs_embed, seq_len, _ = prompt_embeds.shape + + prompt_embeds = F.tile(prompt_embeds, (1, num_images_per_prompt, 1)) + prompt_embeds = prompt_embeds.reshape( + (bs_embed.dim * num_images_per_prompt, seq_len, -1) + ) + + pooled_prompt_embeds = F.tile( + pooled_prompt_embeds, (1, num_images_per_prompt) + ) + pooled_prompt_embeds = pooled_prompt_embeds.reshape( + (bs_embed.dim * num_images_per_prompt, -1) + ) + + return prompt_embeds, pooled_prompt_embeds, text_ids + + @staticmethod + def _prepare_latent_image_ids( + batch_size: int, + height: int, + width: int, + device: DeviceRef, + dtype: DType, + ) -> Tensor_v3: + latent_image_ids = np.stack( + [ + np.zeros((height, width)), + np.broadcast_to(np.arange(height)[:, None], (height, width)), + np.broadcast_to(np.arange(width)[None, :], (height, width)), + ], + axis=-1, + ) + + ( + latent_image_id_height, + latent_image_id_width, + latent_image_id_channels, + ) = latent_image_ids.shape + + latent_image_ids = np.reshape( + latent_image_ids, + ( + latent_image_id_height * latent_image_id_width, + latent_image_id_channels, + ), + ) + latent_image_ids = ( + Tensor_v3.from_dlpack(latent_image_ids).to(device).cast(dtype) + ) + + return latent_image_ids + + @staticmethod + def _pack_latents( + latents: Tensor_v3, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + ) -> Tensor_v3: + latents = F.reshape( + latents, + (batch_size, num_channels_latents, height // 2, 2, width // 2, 2), + ) + latents = F.permute(latents, (0, 2, 4, 1, 3, 5)) + latents = F.reshape( + latents, + ( + batch_size, + (height // 2) * (width // 2), + num_channels_latents * 4, + ), + ) + + return latents + + @staticmethod + def _unpack_latents( + latents: Tensor_v3, + height: int, + width: int, + vae_scale_factor: int, + ) -> Tensor_v3: + # TODO: should compile this function for speed up. + batch_size, _, channels = latents.shape + + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (height // (vae_scale_factor * 2)) + width = 2 * (width // (vae_scale_factor * 2)) + + latents = F.reshape( + latents, + (batch_size.dim, height // 2, width // 2, channels.dim // 4, 2, 2), + ) + latents = F.permute(latents, (0, 3, 1, 4, 2, 5)) + + latents = F.reshape( + latents, (batch_size.dim, channels.dim // (2 * 2), height, width) + ) + + return latents + + def prepare_latents( + self, + batch_size: int, + num_channels_latents: int, + height: int, + width: int, + dtype: DType, + device: DeviceRef, + latents: Tensor_v3 | None = None, + ) -> tuple[Tensor_v3, Tensor_v3]: + """Prepare latents for the Flux pipeline. + + Args: + batch_size: The number of images to generate. + num_channels_latents: The number of latent channels. + height: The height of the generated image. + width: The width of the generated image. + dtype: The data type for the latents. + device: The device to run on. + latents: Pre-generated latents. + + Returns: + Tuple of latents and latent image ids. + """ + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) + + shape = (batch_size, num_channels_latents, height, width) + + if latents is not None: + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // 2, width // 2, device, dtype + ) + return latents.to(device).cast(dtype), latent_image_ids + + latents = random.normal(shape, device=device, dtype=dtype) + latents = self._pack_latents( + latents, batch_size, num_channels_latents, height, width + ) + + latent_image_ids = self._prepare_latent_image_ids( + batch_size, height // 2, width // 2, device, dtype + ) + + return latents, latent_image_ids + + def __call__( + self, + prompt: str | list[str] | None = None, + prompt_2: str | list[str] | None = None, + negative_prompt: str | list[str] | None = None, + negative_prompt_2: str | list[str] | None = None, + true_cfg_scale: float = 1.0, + height: int | None = None, + width: int | None = None, + num_inference_steps: int = 28, + sigmas: list[float] | None = None, + guidance_scale: float = 3.5, + num_images_per_prompt: int | None = 1, + latents: Tensor | None = None, + prompt_embeds: Tensor | None = None, + pooled_prompt_embeds: Tensor | None = None, + ip_adapter_image: PipelineImageInput | None = None, + ip_adapter_image_embeds: list[Tensor] | None = None, + negative_ip_adapter_image: PipelineImageInput | None = None, + negative_ip_adapter_image_embeds: list[Tensor] | None = None, + negative_prompt_embeds: Tensor | None = None, + negative_pooled_prompt_embeds: Tensor | None = None, + output_type: str | None = "pil", + return_dict: bool = True, + joint_attention_kwargs: dict[str, Any] | None = None, + callback_on_step_end: Callable[[int, int, dict], None] | None = None, + callback_on_step_end_tensor_inputs: list[str] | None = None, + max_sequence_length: int = 512, + ): + r"""Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + will be used instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is + not greater than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. + true_cfg_scale (`float`, *optional*, defaults to 1.0): + True classifier-free guidance (guidance scale) is enabled when `true_cfg_scale` > 1 and + `negative_prompt` is provided. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. This is set to 1024 by default for the best results. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. This is set to 1024 by default for the best results. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 3.5): + Embedded guiddance scale is enabled by setting `guidance_scale` > 1. Higher `guidance_scale` encourages + a model to generate images more aligned with `prompt` at the expense of lower image quality. + + Guidance-distilled models approximates true classifer-free guidance for `guidance_scale` > 1. Refer to + the [paper](https://huggingface.co/papers/2210.03142) to learn more. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + latents (`Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_ip_adapter_image: + (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + negative_ip_adapter_image_embeds (`List[Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. + negative_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + negative_pooled_prompt_embeds (`Tensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. + + Returns: + [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` + is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated + images. + """ + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + self._guidance_scale = guidance_scale + self._joint_attention_kwargs = joint_attention_kwargs + self._current_timestep = None + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device() + + lora_scale = ( + self._joint_attention_kwargs.get("scale", None) + if self._joint_attention_kwargs is not None + else None + ) + has_neg_prompt = negative_prompt is not None or ( + negative_prompt_embeds is not None + and negative_pooled_prompt_embeds is not None + ) + do_true_cfg = true_cfg_scale > 1 and has_neg_prompt + ( + prompt_embeds, + pooled_prompt_embeds, + text_ids, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + if do_true_cfg: + ( + negative_prompt_embeds, + negative_pooled_prompt_embeds, + negative_text_ids, + ) = self.encode_prompt( + prompt=negative_prompt, + prompt_2=negative_prompt_2, + prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=negative_pooled_prompt_embeds, + device=device, + num_images_per_prompt=num_images_per_prompt, + max_sequence_length=max_sequence_length, + lora_scale=lora_scale, + ) + + # 4. Prepare latent variables + num_channels_latents = self.transformer.config.in_channels // 4 + latents, latent_image_ids = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + latents, + ) + + # 5. Prepare timesteps + sigmas = ( + np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + if sigmas is None + else sigmas + ) + if ( + hasattr(self.scheduler, "use_flow_sigmas") + and self.scheduler.use_flow_sigmas + ): + sigmas = None + image_seq_len = latents.shape[1].dim + mu = calculate_shift( + image_seq_len, + self.scheduler.base_image_seq_len, + self.scheduler.max_image_seq_len, + self.scheduler.base_shift, + self.scheduler.max_shift, + ) + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + sigmas=sigmas, + mu=mu, + ) + + self._num_timesteps = timesteps.shape[0] + + # handle guidance + if self.transformer.config.guidance_embeds: + guidance = Tensor_v3.full( + [latents.shape[0].dim], + guidance_scale, + device=device, + dtype=prompt_embeds.dtype, + ) + else: + guidance = Tensor_v3.zeros( + [latents.shape[0].dim], + device=device, + dtype=prompt_embeds.dtype, + ) + + if ( + ip_adapter_image is not None + or ip_adapter_image_embeds is not None + or negative_ip_adapter_image is not None + or negative_ip_adapter_image_embeds is not None + ): + raise NotImplementedError( + "IP adapter is not supported for Max yet." + ) + + if self._joint_attention_kwargs is None: + self._joint_attention_kwargs = {} + + image_embeds = None + negative_image_embeds = None + + # 6. Denoising loop + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + self.scheduler.set_begin_index(0) + batch_size = latents.shape[0].dim + for i in tqdm(range(self._num_timesteps), desc="Denoising"): + if self._interrupt: + continue + + t = timesteps[i] + self._current_timestep = t + if image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( + image_embeds + ) + + # NOTE: Convert timesteps to a Max Tensor before denoising loop, + # as in the original implementation, results in a significant slow down. + # As a workaround, we keep timesteps as a numpy array and convert it + # to a Max Tensor here. This might require a more efficient way to handle this. + # Converting to a Max module V3 Tensor also results in a significant slow down. + timestep = np.full((batch_size,), t) / 1000.0 + timestep = Tensor.from_dlpack(timestep).to(prompt_embeds.device) + + noise_pred = self.transformer( + latents, + prompt_embeds, + pooled_prompt_embeds, + timestep, + latent_image_ids, + text_ids, + guidance, + )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = ( + negative_image_embeds + ) + + neg_noise_pred = self.transformer( + latents, + negative_prompt_embeds, + negative_pooled_prompt_embeds, + timestep, + latent_image_ids, + negative_text_ids, + guidance, + )[0] + # TODO: negative prompt path is very slow, need to optimize. + noise_pred = neg_noise_pred + true_cfg_scale * ( + noise_pred - neg_noise_pred + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents_dtype = latents.dtype + latents = self.scheduler.step( + noise_pred, t, latents, return_dict=False + )[0] + + if latents.dtype != latents_dtype: + latents = latents.to(latents_dtype) + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end( + self, i, t, callback_kwargs + ) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop( + "prompt_embeds", prompt_embeds + ) + + self._current_timestep = None + + if output_type == "latent": + image = latents + else: + latents = Tensor_v3.from_dlpack(latents) # V2 Tensor to V3 Tensor + latents = self._unpack_latents( + latents, height, width, self.vae_scale_factor + ) + latents = ( + latents / self.vae.config.scaling_factor + ) + self.vae.config.shift_factor + image = self.vae.decode(latents) + + image = self.image_processor.postprocess( + image, output_type=output_type + ) + + if not return_dict: + return (image,) + + return FluxPipelineOutput(images=image) diff --git a/max/python/max/pipelines/lib/config.py b/max/python/max/pipelines/lib/config.py index be2b956bdad..4110da6cfb3 100644 --- a/max/python/max/pipelines/lib/config.py +++ b/max/python/max/pipelines/lib/config.py @@ -28,6 +28,7 @@ from max.driver import DeviceSpec, load_devices from max.engine import InferenceSession from max.graph.quantization import QuantizationEncoding +from max.interfaces import PipelineTask from max.serve.queue.zmq_queue import generate_zmq_ipc_path from pydantic import ( Field, @@ -909,6 +910,11 @@ def _validate_and_resolve_remaining_pipeline_config( # memory estimations. arch.pipeline_model.finalize_pipeline_config(self) + if arch.task == PipelineTask.IMAGE_GENERATION: + # diffusion pipeline does not use KV cache, + # so we can skip profile run. + return + MemoryEstimator.estimate_memory_footprint( self, arch.pipeline_model, @@ -1139,12 +1145,14 @@ def log_basic_config(self) -> None: pipeline_class = get_pipeline_for_task(task, self) # Get reserved memory info from KVCache config - kv_config = self.model._kv_cache - if kv_config._available_cache_memory is None: - raise ValueError( - "KVCache config is not available after config resolution." - ) - memory_str = to_human_readable_bytes(kv_config._available_cache_memory) + memory_str = None + if task != PipelineTask.IMAGE_GENERATION: + kv_config = self.model._kv_cache + if kv_config._available_cache_memory is None: + raise ValueError( + "KVCache config is not available after config resolution." + ) + memory_str = to_human_readable_bytes(kv_config._available_cache_memory) devices_str = ", ".join( f"{d.device_type}[{d.id}]" for d in self.model.device_specs diff --git a/max/python/max/pipelines/lib/diffusion_schedulers/__init__.py b/max/python/max/pipelines/lib/diffusion_schedulers/__init__.py new file mode 100644 index 00000000000..df278d92447 --- /dev/null +++ b/max/python/max/pipelines/lib/diffusion_schedulers/__init__.py @@ -0,0 +1,16 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler, +) diff --git a/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py new file mode 100644 index 00000000000..1eb39c90b33 --- /dev/null +++ b/max/python/max/pipelines/lib/diffusion_schedulers/scheduling_flow_match_euler_discrete.py @@ -0,0 +1,852 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import logging +import math +from dataclasses import dataclass + +import numpy as np +from max.driver import CPU, Accelerator, Device +from max.dtype import DType +from max.engine import InferenceSession +from max.experimental import Tensor, random +from max.graph import DeviceRef, Graph, TensorType + +try: + import scipy.stats + + is_scipy_available = True +except ImportError: + is_scipy_available = False + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class FlowMatchEulerDiscreteSchedulerOutput: + """Output class for the scheduler's `step` function output. + + Args: + prev_sample (`Tensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: Tensor + + +class FlowMatchEulerDiscreteScheduler: + """Euler scheduler. + + Native Modular implementation (ported from diffusers). + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + """ + + config_name = "scheduler_config.json" + + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting: bool = False, + base_shift: float | None = 0.5, + max_shift: float | None = 1.15, + base_image_seq_len: int | None = 256, + max_image_seq_len: int | None = 4096, + invert_sigmas: bool = False, + shift_terminal: float | None = None, + use_karras_sigmas: bool | None = False, + use_exponential_sigmas: bool | None = False, + use_beta_sigmas: bool | None = False, + time_shift_type: str = "exponential", + stochastic_sampling: bool = False, + device: DeviceRef = DeviceRef.CPU(), + dtype: DType = DType.float32, + **kwargs, + ): + """Initialize the scheduler. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + use_dynamic_shifting (`bool`, defaults to False): + Whether to apply timestep shifting on-the-fly based on the image resolution. + base_shift (`float`, defaults to 0.5): + Value to stabilize image generation. Increasing `base_shift` reduces variation and image is more consistent + with desired output. + max_shift (`float`, defaults to 1.15): + Value change allowed to latent vectors. Increasing `max_shift` encourages more variation and image may be + more exaggerated or stylized. + base_image_seq_len (`int`, defaults to 256): + The base image sequence length. + max_image_seq_len (`int`, defaults to 4096): + The maximum image sequence length. + invert_sigmas (`bool`, defaults to False): + Whether to invert the sigmas. + shift_terminal (`float`, defaults to None): + The end value of the shifted timestep schedule. + use_karras_sigmas (`bool`, defaults to False): + Whether to use Karras sigmas for step sizes in the noise schedule during sampling. + use_exponential_sigmas (`bool`, defaults to False): + Whether to use exponential sigmas for step sizes in the noise schedule during sampling. + use_beta_sigmas (`bool`, defaults to False): + Whether to use beta sigmas for step sizes in the noise schedule during sampling. + time_shift_type (`str`, defaults to "exponential"): + The type of dynamic resolution-dependent timestep shifting to apply. Either "exponential" or "linear". + stochastic_sampling (`bool`, defaults to False): + Whether to use stochastic sampling. + device (`DeviceRef`, defaults to `DeviceRef.CPU()`): + The device to use. + dtype (`DType`, defaults to `DType.float32`): + The dtype to use. + """ + self.num_train_timesteps = num_train_timesteps + self._shift = shift + self.use_dynamic_shifting = use_dynamic_shifting + self.base_shift = base_shift + self.max_shift = max_shift + self.base_image_seq_len = base_image_seq_len + self.max_image_seq_len = max_image_seq_len + self.invert_sigmas = invert_sigmas + self.shift_terminal = shift_terminal + self.use_karras_sigmas = use_karras_sigmas + self.use_exponential_sigmas = use_exponential_sigmas + self.use_beta_sigmas = use_beta_sigmas + self.time_shift_type = time_shift_type + self.stochastic_sampling = stochastic_sampling + self.device = device + self.dtype = dtype + + if self.use_beta_sigmas and not is_scipy_available: + raise ImportError( + "Make sure to install scipy if you want to use beta sigmas." + ) + if ( + sum( + [ + self.use_beta_sigmas, + self.use_exponential_sigmas, + self.use_karras_sigmas, + ] + ) + > 1 + ): + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) + if time_shift_type not in {"exponential", "linear"}: + raise ValueError( + "`time_shift_type` must either be 'exponential' or 'linear'." + ) + + timesteps = np.linspace( + 1, num_train_timesteps, num_train_timesteps, dtype=np.float32 + )[::-1].copy() + + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + + self.timesteps = sigmas * num_train_timesteps + + self._step_index = None + self._begin_index = None + + self._shift = shift + + self.sigmas = sigmas + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.load_model() + + @property + def shift(self) -> float: + """The value used for shifting.""" + return self._shift + + @property + def step_index(self) -> int: + """The index counter for current timestep. It will increase 1 after each scheduler step.""" + return self._step_index + + @property + def begin_index(self) -> int: + """The index for the first timestep. It should be set from pipeline with `set_begin_index` method.""" + return self._begin_index + + def set_begin_index(self, begin_index: int = 0) -> None: + """Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`, defaults to `0`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + def set_shift(self, shift: float) -> None: + """Set the shift value.""" + self._shift = shift + + def scale_noise( + self, + sample: Tensor, + timestep: float | Tensor, + noise: Tensor | None = None, + ) -> Tensor: + """Forward process in flow-matching. + + Args: + sample (`Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + noise (`Tensor`, *optional*): + The noise tensor. + + Returns: + `Tensor`: + A scaled input sample. + """ + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=sample.device).cast(sample.dtype) + + if sample.device.type == "mps": + # mps does not support float64 + schedule_timesteps = self.timesteps.to(sample.device).cast( + DType.float32 + ) + timestep = timestep.to(sample.device).cast(DType.float32) + else: + schedule_timesteps = self.timesteps.to(sample.device) + timestep = timestep.to(sample.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) for t in timestep + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timestep.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timestep.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(sample.shape): + sigma = sigma.unsqueeze(-1) + + sample = sigma * noise + (1.0 - sigma) * sample + + return sample + + def _sigma_to_t(self, sigma: float) -> float: + """Converts sigma to timestep.""" + return sigma * self.num_train_timesteps + + def time_shift(self, mu: float, sigma: float, t: Tensor) -> Tensor: + """Apply time shifting to the timesteps. + + Args: + mu (`float`): + The mu parameter for time shifting. + sigma (`float`): + The sigma parameter for time shifting. + t (`Tensor`): + The timesteps to shift. + + Returns: + `Tensor`: + The shifted timesteps. + """ + if self.time_shift_type == "exponential": + return self._time_shift_exponential(mu, sigma, t) + elif self.time_shift_type == "linear": + return self._time_shift_linear(mu, sigma, t) + + def stretch_shift_to_terminal(self, t: Tensor) -> Tensor: + r"""Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal`. + + Reference: + https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51 + + Args: + t (`Tensor`): + A tensor of timesteps to be stretched and shifted. + + Returns: + `Tensor`: + A tensor of adjusted timesteps such that the final value equals `self.shift_terminal`. + """ + one_minus_z = 1 - t + scale_factor = one_minus_z[-1] / (1 - self.shift_terminal) + stretched_t = 1 - (one_minus_z / scale_factor) + return stretched_t + + def set_timesteps( + self, + num_inference_steps: int | None = None, + device: str | Device | None = None, + sigmas: list[float] | None = None, + mu: float | None = None, + timesteps: list[float] | None = None, + ) -> None: + """Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`, *optional*): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `Device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + sigmas (`List[float]`, *optional*): + Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed + automatically. + mu (`float`, *optional*): + Determines the amount of shifting applied to sigmas when performing resolution-dependent timestep + shifting. + timesteps (`List[float]`, *optional*): + Custom values for timesteps to be used for each diffusion step. If `None`, the timesteps are computed + automatically. + """ + if self.use_dynamic_shifting and mu is None: + raise ValueError( + "`mu` must be passed when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is not None and timesteps is not None: + if len(sigmas) != len(timesteps): + raise ValueError( + "`sigmas` and `timesteps` should have the same length" + ) + + if num_inference_steps is not None: + if (sigmas is not None and len(sigmas) != num_inference_steps) or ( + timesteps is not None and len(timesteps) != num_inference_steps + ): + raise ValueError( + "`sigmas` and `timesteps` should have the same length as num_inference_steps, if `num_inference_steps` is provided" + ) + else: + num_inference_steps = ( + len(sigmas) if sigmas is not None else len(timesteps) + ) + + self.num_inference_steps = num_inference_steps + + # 1. Prepare default sigmas + is_timesteps_provided = timesteps is not None + + if is_timesteps_provided: + timesteps = np.array(timesteps).astype(np.float32) + + if sigmas is None: + if timesteps is None: + timesteps = np.linspace( + self._sigma_to_t(self.sigma_max), + self._sigma_to_t(self.sigma_min), + num_inference_steps, + ) + sigmas = timesteps / self.num_train_timesteps + else: + sigmas = np.array(sigmas).astype(np.float32) + num_inference_steps = len(sigmas) + + # 2. Perform timestep shifting. Either no shifting is applied, or resolution-dependent shifting of + # "exponential" or "linear" type is applied + if self.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + # 3. If required, stretch the sigmas schedule to terminate at the configured `shift_terminal` value + if self.shift_terminal: + sigmas = self.stretch_shift_to_terminal(sigmas) + + # 4. If required, convert sigmas to one of karras, exponential, or beta sigma schedules + if self.use_karras_sigmas: + sigmas = self._convert_to_karras( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + elif self.use_exponential_sigmas: + sigmas = self._convert_to_exponential( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + elif self.use_beta_sigmas: + sigmas = self._convert_to_beta( + in_sigmas=sigmas, num_inference_steps=num_inference_steps + ) + + if not is_timesteps_provided: + timesteps = sigmas * self.num_train_timesteps + + # 5. Append the terminal sigma value. + # If a model requires inverted sigma schedule for denoising but timesteps without inversion, the + # `invert_sigmas` flag can be set to `True`. This case is only required in Mochi + if self.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.num_train_timesteps + sigmas = np.concatenate([sigmas, np.ones((1,), dtype=sigmas.dtype)]) + else: + sigmas = np.concatenate( + [ + sigmas, + np.zeros((1,), dtype=sigmas.dtype), + ] + ) + + # 6. Convert sigmas and timesteps to tensors and move to specified device + sigmas = ( + Tensor.from_dlpack(sigmas).to(device=device).cast(DType.float32) + ) + + self.timesteps = timesteps + self.sigmas = sigmas + self._step_index = None + self._begin_index = None + + def index_for_timestep( + self, timestep: Tensor, schedule_timesteps: Tensor | None = None + ) -> int: + """Returns the index for a given timestep. + + Args: + timestep (`Tensor`): + The timestep to find the index for. + schedule_timesteps (`Tensor`, *optional*): + The schedule timesteps to search in. If `None`, defaults to `self.timesteps`. + + Returns: + `int`: + The index of the timestep. + """ + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep: Tensor) -> None: + """Initialize the step index based on the given timestep. + + Args: + timestep (`Tensor`): + The current timestep. + """ + if self.begin_index is None: + if isinstance(timestep, Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def _step( + self, + model_output: Tensor, + timestep: float | Tensor, + sample: Tensor, + sigmas: Tensor | None = None, + step_index: Tensor | None = None, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + per_token_timesteps: Tensor | None = None, + return_dict: bool = True, + ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple: + """Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`Tensor`): + The direct output from learned diffusion model. + timestep (`float` or `Tensor`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + sigmas (`Tensor`, *optional*): + The sigmas tensor. + step_index (`Tensor`, *optional*): + The step index. + s_churn (`float`): + Churn parameter. + s_tmin (`float`): + Min churn timestep. + s_tmax (`float`): + Max churn timestep. + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + per_token_timesteps (`Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.cast(DType.float32) + + if per_token_timesteps is not None: + per_token_sigmas = per_token_timesteps / self.num_train_timesteps + + sigmas = sigmas[:, None, None] + lower_mask = sigmas < per_token_sigmas[None] - 1e-6 + lower_sigmas = lower_mask * sigmas + lower_sigmas, _ = lower_sigmas.max(axis=0) + + current_sigma = per_token_sigmas[..., None] + next_sigma = lower_sigmas[..., None] + dt = current_sigma - next_sigma + else: + sigma = sigmas[step_index] + sigma_next = sigmas[step_index + 1] + + current_sigma = sigma + next_sigma = sigma_next + dt = sigma_next - sigma + + if self.stochastic_sampling: + x0 = sample - current_sigma * model_output + noise = random.normal(sample) + prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise + else: + prev_sample = sample + dt * model_output + + # upon completion increase step index by one + self._step_index += 1 + if per_token_timesteps is None: + # Cast sample back to model compatible dtype + prev_sample = prev_sample.cast(model_output.dtype) + + if not return_dict: + return prev_sample + + return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) + + def _convert_to_karras( + self, in_sigmas: Tensor, num_inference_steps: int + ) -> Tensor: + """Construct the noise schedule as proposed in [Elucidating the Design Space of Diffusion-Based Generative Models](https://huggingface.co/papers/2206.00364). + + Args: + in_sigmas (`Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `Tensor`: + The converted sigma values following the Karras noise schedule. + """ + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self, "sigma_min"): + sigma_min = self.sigma_min + else: + sigma_min = None + + if hasattr(self, "sigma_max"): + sigma_max = self.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + rho = 7.0 # 7.0 is the value used in the paper + ramp = np.linspace(0, 1, num_inference_steps) + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + def _convert_to_exponential( + self, in_sigmas: Tensor, num_inference_steps: int + ) -> Tensor: + """Construct an exponential noise schedule. + + Args: + in_sigmas (`Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + + Returns: + `Tensor`: + The converted sigma values following an exponential schedule. + """ + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self, "sigma_min"): + sigma_min = self.sigma_min + else: + sigma_min = None + + if hasattr(self, "sigma_max"): + sigma_max = self.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.exp( + np.linspace( + math.log(sigma_max), math.log(sigma_min), num_inference_steps + ) + ) + return sigmas + + def _convert_to_beta( + self, + in_sigmas: Tensor, + num_inference_steps: int, + alpha: float = 0.6, + beta: float = 0.6, + ) -> Tensor: + """Construct a beta noise schedule as proposed in [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173). + + Args: + in_sigmas (`Tensor`): + The input sigma values to be converted. + num_inference_steps (`int`): + The number of inference steps to generate the noise schedule for. + alpha (`float`, *optional*, defaults to `0.6`): + The alpha parameter for the beta distribution. + beta (`float`, *optional*, defaults to `0.6`): + The beta parameter for the beta distribution. + + Returns: + `Tensor`: + The converted sigma values following a beta distribution schedule. + """ + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self, "sigma_min"): + sigma_min = self.sigma_min + else: + sigma_min = None + + if hasattr(self, "sigma_max"): + sigma_max = self.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = np.array( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + + def _time_shift_exponential( + self, mu: float, sigma: float, t: Tensor + ) -> Tensor: + """Apply exponential time shifting. + + Args: + mu (`float`): + The mu parameter. + sigma (`float`): + The sigma parameter. + t (`Tensor`): + The timesteps. + + Returns: + `Tensor`: + The shifted timesteps. + """ + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _time_shift_linear(self, mu: float, sigma: float, t: Tensor) -> Tensor: + """Apply linear time shifting. + + Args: + mu (`float`): + The mu parameter. + sigma (`float`): + The sigma parameter. + t (`Tensor`): + The timesteps. + + Returns: + `Tensor`: + The shifted timesteps. + """ + return mu / (mu + (1 / t - 1) ** sigma) + + def __len__(self) -> int: + """Returns the number of train timesteps.""" + return self.num_train_timesteps + + def step_input_types(self) -> tuple[TensorType, ...]: + """Return the input types for the step function.""" + return ( + TensorType( + self.dtype, + shape=["batch_size", "image_seq_len", "channel"], + device=self.device, + ), + TensorType( + DType.float32, + shape=[], + device=self.device, + ), + TensorType( + self.dtype, + shape=["batch_size", "image_seq_len", "channel"], + device=self.device, + ), + TensorType( + DType.float32, + shape=["num_inference_steps"], + device=self.device, + ), + TensorType( + DType.int64, + shape=[], + device=DeviceRef.CPU(), + ), + ) + + def load_model(self) -> None: + """Load the model.""" + if self.device.is_cpu(): + session = InferenceSession([CPU()]) + else: + session = InferenceSession([Accelerator()]) + + self.set_begin_index(0) + with Graph( + "scheduler_step", input_types=self.step_input_types() + ) as graph: + outputs = self._step( + *graph.inputs, + return_dict=False, + ) + graph.output(outputs) + compiled_graph = graph + self.session = session.load(compiled_graph) + + def step( + self, + model_output: Tensor, + timestep: float | Tensor, + sample: Tensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + per_token_timesteps: Tensor | None = None, + return_dict: bool = True, + ) -> FlowMatchEulerDiscreteSchedulerOutput | tuple: + """Predict the sample from the previous timestep by reversing the SDE. + + Args: + model_output (`Tensor`): + The direct output from learned diffusion model. + timestep (`float` or `Tensor`): + The current discrete timestep in the diffusion chain. + sample (`Tensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + Churn parameter. + s_tmin (`float`): + Min churn timestep. + s_tmax (`float`): + Max churn timestep. + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + per_token_timesteps (`Tensor`, *optional*): + The timesteps for each token in the sample. + return_dict (`bool`): + Whether or not to return a + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or tuple. + + Returns: + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, + [`~schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + if self.step_index is None: + self._init_step_index(timestep) + schedule_output = self.session.execute( + model_output, + timestep, + sample, + self.sigmas, + self.step_index, + )[0] + self._step_index += 1 + + if not return_dict: + return (schedule_output,) + return FlowMatchEulerDiscreteSchedulerOutput( + prev_sample=schedule_output + ) diff --git a/max/python/max/pipelines/lib/hf_utils.py b/max/python/max/pipelines/lib/hf_utils.py index 7e1e8ef9a8a..235b51a785e 100644 --- a/max/python/max/pipelines/lib/hf_utils.py +++ b/max/python/max/pipelines/lib/hf_utils.py @@ -310,7 +310,7 @@ def _repo_exists_with_retry(repo_id: str, revision: str) -> bool: ) time.sleep(delay_in_seconds) - assert False, ( # noqa: B011 + raise AssertionError( "This should never be reached due to the raise in the last attempt" ) @@ -385,20 +385,18 @@ def info(self) -> huggingface_hub.ModelInfo: @cached_property def weight_files(self) -> dict[WeightsFormat, list[str]]: - safetensor_search_pattern = "*.safetensors" - gguf_search_pattern = "*.gguf" - pytorch_search_pattern = "*.bin" + safetensor_search_pattern = "**/*.safetensors" + gguf_search_pattern = "**/*.gguf" weight_files = {} if self.repo_type == RepoType.local: safetensor_paths = glob.glob( - os.path.join(self.repo_id, safetensor_search_pattern) + os.path.join(self.repo_id, safetensor_search_pattern), + recursive=True, ) gguf_paths = glob.glob( - os.path.join(self.repo_id, gguf_search_pattern) - ) - pytorch_paths = glob.glob( - os.path.join(self.repo_id, pytorch_search_pattern) + os.path.join(self.repo_id, gguf_search_pattern), + recursive=True, ) elif self.repo_type == RepoType.online: fs = huggingface_hub.HfFileSystem() @@ -409,9 +407,6 @@ def weight_files(self) -> dict[WeightsFormat, list[str]]: gguf_paths = cast( list[str], fs.glob(f"{self.repo_id}/{gguf_search_pattern}") ) - pytorch_paths = cast( - list[str], fs.glob(f"{self.repo_id}/{pytorch_search_pattern}") - ) else: raise ValueError(f"Unsupported repo type: {self.repo_type}") @@ -657,3 +652,33 @@ def generate_local_model_path(repo_id: str, revision: str) -> str: f"Model path does not exist: HF cache for '{repo_id}' " f"(revision: {revision}) not found." ) from e + + +def get_model_index_path_for_diffusers( + huggingface_repo: HuggingFaceRepo, +) -> str | None: + model_index_path: str | None = None + + if huggingface_repo.repo_type == RepoType.local: + local_index = Path(huggingface_repo.repo_id) / "model_index.json" + if local_index.exists(): + model_index_path = str(local_index) + else: + raise ValueError( + f"Failed to find model_index.json in {huggingface_repo.repo_id}." + ) + else: + try: + if huggingface_hub.file_exists( + huggingface_repo.repo_id, + "model_index.json", + revision=huggingface_repo.revision, + ): + model_index_path = huggingface_hub.hf_hub_download( + huggingface_repo.repo_id, + "model_index.json", + revision=huggingface_repo.revision, + ) + except Exception: + model_index_path = None + return model_index_path diff --git a/max/python/max/pipelines/lib/image_processor.py b/max/python/max/pipelines/lib/image_processor.py new file mode 100644 index 00000000000..3449c6cb47d --- /dev/null +++ b/max/python/max/pipelines/lib/image_processor.py @@ -0,0 +1,226 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import logging + +import numpy as np +import PIL.Image +from max.driver import Buffer as DTensor +from max.dtype import DType +from max.experimental import Tensor +from max.experimental import functional as F +from max.experimental.compile_utils import max_compile +from max.graph import DeviceRef, TensorType, TensorValue, ops +from PIL import Image + +logger = logging.getLogger(__name__) + + +PipelineImageInput = ( + PIL.Image.Image + | np.ndarray + | Tensor + | list[PIL.Image.Image] + | list[np.ndarray] + | list[Tensor] +) + + +class VaeImageProcessor: + config_name = "config.json" + + def __init__( + self, + do_resize: bool = True, + vae_scale_factor: int = 8, + vae_latent_channels: int = 4, + resample: str = "lanczos", + reducing_gap: int | None = None, + do_normalize: bool = True, + do_binarize: bool = False, + do_convert_rgb: bool = False, + do_convert_grayscale: bool = False, + device: DeviceRef = DeviceRef.GPU(), + dtype: DType = DType.bfloat16, + ): + """Initialize the VaeImageProcessor. + + Args: + do_resize (bool, optional): Whether to resize images. Defaults to True. + vae_scale_factor (int, optional): The VAE scale factor. Defaults to 8. + vae_latent_channels (int, optional): The number of latent channels for the VAE. Defaults to 4. + resample (str, optional): The resampling mode for resizing. Defaults to "lanczos". + reducing_gap (int, optional): A reduction gap parameter for resampling. Defaults to None. + do_normalize (bool, optional): Whether to normalize images to [-1, 1]. Defaults to True. + do_binarize (bool, optional): Whether to binarize images. Defaults to False. + do_convert_rgb (bool, optional): Whether to convert images to RGB. Defaults to False. + do_convert_grayscale (bool, optional): Whether to convert images to grayscale. Defaults to False. + device (DeviceRef, optional): The device to use for the image processor. Defaults to DeviceRef.GPU(). + dtype (DType, optional): The data type to use for the image processor. Defaults to DType.bfloat16. + + Raises: + ValueError: If both do_convert_rgb and do_convert_grayscale are set to True. + """ + super().__init__() + if do_convert_rgb and do_convert_grayscale: + raise ValueError( + "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," + " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", + " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", + ) + + self.do_normalize = do_normalize + self.device = device + self.dtype = dtype + self._denormalize_conditionally = max_compile( + self._denormalize_conditionally, + input_types=self._denormalize_conditionally_input_types(), + ) + + @staticmethod + def denormalize(images: np.ndarray | Tensor) -> np.ndarray | Tensor: + r"""Denormalize an image array to [0,1]. + + Args: + images (`np.ndarray` or `Tensor`): + The image array to denormalize. + + Returns: + `np.ndarray` or `Tensor`: + The denormalized image array. + """ + if isinstance(images, (Tensor, TensorValue)): + images = images * 0.5 + 0.5 + images = F.min( + images, + Tensor.constant(1.0, dtype=images.dtype, device=images.device), + ) + images = F.max( + images, + Tensor.constant(0.0, dtype=images.dtype, device=images.device), + ) + return images + return np.clip(images * 0.5 + 0.5, 0, 1) + + def _denormalize_conditionally( + self, + images: np.ndarray | Tensor, + ) -> np.ndarray: + r"""Denormalize a batch of images based on a condition list. + + Args: + images (`np.ndarray` or `Tensor`): + The input image tensor. + """ + images = self.denormalize(images) if self.do_normalize else images + images = ops.cast(images, DType.float32) + return images + + @staticmethod + def max_to_numpy(images: Tensor) -> np.ndarray: + r"""Convert a Max tensor to a NumPy image. + + Args: + images (`Tensor`): + The Max tensor to convert to NumPy format. + + Returns: + `np.ndarray`: + A NumPy array representation of the images. + """ + images = DTensor.to_numpy(images) + images = np.transpose(images, (0, 2, 3, 1)) + return images + + @staticmethod + def numpy_to_pil(images: np.ndarray) -> list[PIL.Image.Image]: + r"""Convert a numpy image or a batch of images to a PIL image. + + Args: + images (`np.ndarray`): + The image array to convert to PIL format. + + Returns: + `list[PIL.Image.Image]`: + A list of PIL images. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [ + Image.fromarray(image.squeeze(), mode="L") for image in images + ] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + def _denormalize_conditionally_input_types(self) -> list[TensorType]: + return [ + TensorType( + shape=("batch_size", "num_channels", "height", "width"), + device=self.device, + dtype=self.dtype, + ), + ] + + def postprocess( + self, + image: Tensor, + output_type: str = "pil", + do_denormalize: list[bool] | None = None, + ) -> PIL.Image.Image | np.ndarray | Tensor: + """Postprocess the image output from tensor to `output_type`. + + Args: + image (`Tensor`): + The image input, should be a Max tensor with shape `B x C x H x W`. + output_type (`str`, *optional*, defaults to `pil`): + The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. + do_denormalize (`list[bool]`, *optional*, defaults to `None`): + Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the + `VaeImageProcessor` config. + + Returns: + `PIL.Image.Image`, `np.ndarray` or `Tensor`: + The postprocessed image. + """ + if not isinstance(image, Tensor) and not isinstance(image, TensorValue): + raise ValueError( + f"Input for postprocessing is in incorrect format: {type(image)}. We only support Max tensor" + ) + if output_type not in ["latent", "max", "np", "pil"]: + deprecation_message = ( + f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: " + "`pil`, `np`, `max`, `latent`" + ) + logger.warning(deprecation_message) + output_type = "np" + + if output_type == "latent": + return image + + image = self._denormalize_conditionally(image) + + if output_type == "max": + return image[0] + + image = self.max_to_numpy(image[0]) + + if output_type == "np": + return image + + if output_type == "pil": + return self.numpy_to_pil(image) diff --git a/max/python/max/pipelines/lib/interfaces/__init__.py b/max/python/max/pipelines/lib/interfaces/__init__.py index db7ab9885c5..a6592f356d3 100644 --- a/max/python/max/pipelines/lib/interfaces/__init__.py +++ b/max/python/max/pipelines/lib/interfaces/__init__.py @@ -12,6 +12,7 @@ # ===----------------------------------------------------------------------=== # """Interfaces for MAX pipelines.""" +from .diffusion_pipeline import DiffusionPipeline from .generate import GenerateMixin from .kv_cache import KVCacheMixin, get_paged_manager from .pipeline_model import ( @@ -23,6 +24,7 @@ __all__ = [ "AlwaysSignalBuffersMixin", + "DiffusionPipeline", "GenerateMixin", "KVCacheMixin", "ModelInputs", diff --git a/max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py b/max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py new file mode 100644 index 00000000000..64835ee3c71 --- /dev/null +++ b/max/python/max/pipelines/lib/interfaces/diffusion_pipeline.py @@ -0,0 +1,177 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Pipeline utilities for MAX-optimized pipelines.""" + +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from max.config import load_config +from max.driver import load_devices +from max.graph import DeviceRef +from max.graph.weights import load_weights +from max.pipelines.lib.interfaces.max_model import MaxModel +from tqdm import tqdm + +if TYPE_CHECKING: + from ..config import PipelineConfig + from ..diffusion_schedulers import FlowMatchEulerDiscreteScheduler + + +class DiffusionPipeline(ABC): + config_name: str | None = None + """ + The name of the config file of the pipeline. + + It can be found in the downloaded path or HuggingFace hub. + It's usually "model_index.json" or "config.json" for Diffusion models. + """ + + components: ( + dict[str, type[MaxModel] | type[FlowMatchEulerDiscreteScheduler]] | None + ) = None + """The components of the pipeline. + + It can be found in the downloaded path or HuggingFace hub. + It's usually contains text_encoder, tokenizer, transformer, vae, etc. + """ + + def __init__( + self, + pipeline_config: PipelineConfig, + cached_folder: str, + **kwargs: Any, + ) -> DiffusionPipeline: + """Load a pipeline from a pretrained model. + + Args: + pipeline_config: Pipeline configuration for model and runtime setup. + cached_folder: Local path to the downloaded model snapshot. + **kwargs: Additional pipeline-specific arguments. + """ + self.pipeline_config = pipeline_config + self.devices = load_devices(pipeline_config.model.device_specs) + + # Load sub models + loaded_sub_models = self.load_sub_models(cached_folder) + for name, model in loaded_sub_models.items(): + setattr(self, name, model) + + self.init_remaining_components() + + @abstractmethod + def init_remaining_components(self) -> None: + pass + + def load_sub_models( + self, + pretrained_model_name_or_path: str | os.PathLike, + ) -> dict: + """Load sub-models for the pipeline. + + Args: + pretrained_model_name_or_path: Path to pretrained model. + + Returns: + Dictionary containing the loaded sub-models. + """ + loaded_sub_models = {} + if self.components is None: + raise ValueError( + f"`components` for {self.__class__.__name__} pipeline is not set. " + "Please set proper components based on its sub-directories in the downloaded path." + ) + for name, component_class in tqdm( + self.components.items(), desc="Loading sub models" + ): + component_path = os.path.join(pretrained_model_name_or_path, name) + if "tokenizer" in name: + # NOTE: Currently, we are using tokenizers from transformers. + # TODO(minkyu): Check if we can use Tokenizer in Max, + # and remove this conditional path. + loaded_sub_models[name] = component_class.from_pretrained( + component_path + ) + continue + + if ( + not hasattr(component_class, "config_name") + or component_class.config_name is None + ): + raise ValueError( + f"`config_name` for {component_class.__name__} is not set. " + "Please set proper config file name in the downloaded path." + ) + config = load_config( + f"{component_path}/{component_class.config_name}" + ) + if issubclass(component_class, MaxModel): + weight_paths = [ + Path(pretrained_model_name_or_path) / weight_path + for weight_path in self.pipeline_config.model.weight_path + if weight_path.split("/")[0] == name + ] + loaded_sub_models[name] = component_class( + config=config, + encoding=self.pipeline_config.model.quantization_encoding, + devices=self.devices, + weights=load_weights(weight_paths), + ) + else: + loaded_sub_models[name] = component_class( + **config, + device=DeviceRef.from_device(self.devices[0]), + dtype=self.pipeline_config.model.quantization_encoding.dtype, + ) + + return loaded_sub_models + + def finalize_pipeline_config(self) -> None: + return + + def _execution_device(self) -> DeviceRef: + r"""Returns the device on which the pipeline's models will be executed. + + This property checks pipeline components to determine the execution device. + It supports MAX models (with DeviceRef device attribute). + Similar structure to diffusers' _execution_device but returns DeviceRef instead of DeviceRef. + + Returns: + DeviceRef: The execution device (GPU if available, otherwise CPU). + """ + # Check MAX models - prioritize GPU + # Similar to diffusers' _execution_device but for MAX models (not torch.nn.Module) + sub_models = {k: getattr(self, k) for k in self.components} + for name, model in sub_models.items(): + exclude_from_cpu_offload = getattr( + self, "_exclude_from_cpu_offload", set() + ) + if name in exclude_from_cpu_offload: + continue + + if hasattr(model, "device") and isinstance(model.device, DeviceRef): + return model.device + + if hasattr(self, "device"): + try: + device = self.device + if isinstance(device, DeviceRef): + return device + except Exception: + pass + + return DeviceRef.CPU() diff --git a/max/python/max/pipelines/lib/model_config.py b/max/python/max/pipelines/lib/model_config.py index 608c61ed778..05aab8eb707 100644 --- a/max/python/max/pipelines/lib/model_config.py +++ b/max/python/max/pipelines/lib/model_config.py @@ -509,7 +509,6 @@ def validate_and_resolve_quantization_encoding_weight_path( are consistent. Args: - weight_path: The path to the weight file. default_encoding: The default encoding to use if no encoding is provided. """ @@ -582,6 +581,7 @@ def _validate_and_resolve_dtype_casting( Note: We currently only support float32 to bfloat16 weight type casting. Args: + from_encoding: The source encoding to cast from. to_encoding: The desired encoding to cast to. Raises: @@ -810,6 +810,9 @@ def _resolve_weight_path( encoding=self._applied_dtype_cast_from ) + if not weight_files: + weight_files = self.huggingface_weight_repo.weight_files + if default_weight_files := weight_files.get( default_weights_format, [] ): @@ -948,6 +951,7 @@ def _local_weight_path(self, relative_path: Path) -> str | None: # NOTE(bduke): do this even for online repositories, because upstream # code originating from `huggingface_hub.hf_hub_download` returns # absolute paths for cached files. + relative_path = Path(relative_path) if relative_path.exists() and relative_path.is_file(): return str(relative_path.resolve()) From 2b80ab7856833bd97f82b693b5b37c73450c4814 Mon Sep 17 00:00:00 2001 From: byungchul-sqzb Date: Wed, 21 Jan 2026 07:39:26 +0000 Subject: [PATCH 04/10] [MAX] Add Conv2d and GroupNorm to module_v3 --- max/python/max/nn/module_v3/__init__.py | 6 + max/python/max/nn/module_v3/conv.py | 217 ++++++++++++++++++ max/python/max/nn/module_v3/norm/__init__.py | 2 + .../max/nn/module_v3/norm/group_norm.py | 171 ++++++++++++++ 4 files changed, 396 insertions(+) create mode 100644 max/python/max/nn/module_v3/conv.py create mode 100644 max/python/max/nn/module_v3/norm/group_norm.py diff --git a/max/python/max/nn/module_v3/__init__.py b/max/python/max/nn/module_v3/__init__.py index 56267745b84..a9575e755e4 100644 --- a/max/python/max/nn/module_v3/__init__.py +++ b/max/python/max/nn/module_v3/__init__.py @@ -12,15 +12,21 @@ # ===----------------------------------------------------------------------=== # """Module implementation using eager tensors.""" +from .conv import Conv2d from .embedding import Embedding from .linear import Linear from .module import Module, module_dataclass +from .norm import GemmaRMSNorm, GroupNorm, RMSNorm from .sequential import Sequential __all__ = [ + "Conv2d", "Embedding", + "GemmaRMSNorm", + "GroupNorm", "Linear", "Module", + "RMSNorm", "Sequential", "module_dataclass", ] diff --git a/max/python/max/nn/module_v3/conv.py b/max/python/max/nn/module_v3/conv.py new file mode 100644 index 00000000000..863873fd2e3 --- /dev/null +++ b/max/python/max/nn/module_v3/conv.py @@ -0,0 +1,217 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""A Module for convolutional layers.""" + +from __future__ import annotations + +from typing import Literal + +from ...driver import Accelerator, accelerator_api +from ...dtype import DType +from ...experimental import functional as F +from ...experimental import random +from ...experimental.tensor import Tensor +from ...graph import DeviceRef +from ...graph.type import FilterLayout +from .module import Module + + +class Conv2d(Module[[Tensor], Tensor]): + """A 2D convolution layer using module_v3. + + This is a module_v3-compatible version of Conv2d that uses Tensor + instead of Weight objects. + + Example: + .. code-block:: python + + from max.nn.module_v3 import Conv2d + from max.experimental.tensor import Tensor + + conv = Conv2d( + kernel_size=3, + in_channels=3, + out_channels=64, + has_bias=True, + permute=True, + ) + + x = Tensor.ones([1, 3, 32, 32]) + result = conv(x) + """ + + weight: Tensor + """The weight tensor with shape [out_channels, in_channels // num_groups, kernel_height, kernel_width].""" + + bias: Tensor | Literal[0] + """The bias tensor with shape [out_channels] (or 0 if bias is disabled).""" + + def __init__( + self, + kernel_size: int | tuple[int, int], + in_channels: int, + out_channels: int, + dtype: DType | None = None, + stride: int | tuple[int, int] = 1, + padding: int | tuple[int, int] | tuple[int, int, int, int] = 0, + dilation: int | tuple[int, int] = 1, + num_groups: int = 1, + device: DeviceRef | None = None, + has_bias: bool = False, + permute: bool = False, + name: str | None = None, + ): + """Initialize Conv2d layer. + + Args: + kernel_size: Size of the convolving kernel. Can be a single int (square kernel) or tuple (height, width). + in_channels: Number of channels in the input image. + out_channels: Number of channels produced by the convolution. + dtype: The data type for both weights and bias. In v3, this is optional as Tensor manages dtype automatically. + stride: Stride of the convolution for height and width dimensions. + Can be int (applied to both dimensions) or tuple (stride_h, stride_w). Default: 1 + padding: Padding added to input. Can be int (applied to all sides), + tuple of 2 ints (pad_h, pad_w), or tuple of 4 ints (pad_top, pad_bottom, pad_left, pad_right) to support asymmetric padding. Default: 0 + dilation: Spacing between kernel elements for height and width dimensions. + Can be int (applied to both dimensions) or tuple (dilation_h, dilation_w). Default: 1 + num_groups: Number of blocked connections from input channels to output channels. + Input channels and output channels are divided into groups. Default: 1 + device: The target device for computation. In v3, this is optional as Tensor manages device automatically. + has_bias: If true, adds a learnable bias vector to the layer. + Defaults to :obj:`False`. + permute: If true, permutes weights from PyTorch format to MAX format. + PyTorch order: (out_channels, in_channels / num_groups, height, width). + MAX API order: (height, width, in_channels / num_groups, out_channels). + Defaults to :obj:`False`. + name: Base name for weights. In v3, this is stored but not used for Weight naming. + Defaults to :obj:`None`. + """ + # Store configuration for easy reconstruction + self.in_channels = in_channels + self.out_channels = out_channels + self.dtype = dtype + self.device = device + self.permute = permute + self.num_groups = num_groups + self.has_bias = has_bias + self.name = name + + # Handle kernel_size as int or tuple + if isinstance(kernel_size, int): + kernel_height = kernel_width = kernel_size + self.kernel_size = (kernel_size, kernel_size) + else: + kernel_height, kernel_width = kernel_size + self.kernel_size = kernel_size + + self.weight = random.normal( + [ + out_channels, + in_channels // num_groups, + kernel_height, + kernel_width, + ] + if self.permute + else [ + kernel_height, + kernel_width, + in_channels // num_groups, + out_channels, + ], + dtype=self.dtype, + device=self.device.to_device() if self.device is not None else None, + ) + + if has_bias: + self.bias = random.normal( + [out_channels], + dtype=self.dtype, + device=self.device.to_device() + if self.device is not None + else None, + ) + else: + self.bias = 0 + + # Convert scalar parameters to tuples as needed + self.stride = (stride, stride) if isinstance(stride, int) else stride + + if isinstance(padding, int): + padding = (padding, padding, padding, padding) + elif len(padding) == 2: + # Convert (pad_h, pad_w) to (pad_top, pad_bottom, pad_left, pad_right) + pad_h, pad_w = padding + padding = (pad_h, pad_h, pad_w, pad_w) + + self.padding = padding + + if isinstance(dilation, int): + dilation = (dilation, dilation) + self.dilation = dilation + + if ( + isinstance(self.weight, Tensor) + and hasattr(self.weight, "quantization_encoding") + and self.weight.quantization_encoding is not None + ): + raise ValueError("Conv2d not implemented with weight quantization.") + + def forward(self, x: Tensor) -> Tensor: + """Apply 2D convolution to input. + + Args: + x: Input tensor. Shape depends on `permute`: + - If permute=True: [batch_size, in_channels, height, width] + - If permute=False: [batch_size, height, width, in_channels] + + Returns: + Output tensor. Shape depends on `permute`: + - If permute=True: [batch_size, out_channels, new_height, new_width] + - If permute=False: [batch_size, new_height, new_width, out_channels] + """ + # Move weight to same device as input + weight = self.weight.to(x.device) + + is_nvidia_gpu = ( + isinstance(x.device, Accelerator) and accelerator_api() == "cuda" + ) + + if self.permute: + # Input: [batch_size, in_channels, height, width] -> [batch_size, height, width, in_channels] + x = F.permute(x, [0, 2, 3, 1]) + + # GPU supports FCRS but CPU doesn't. On CPU, permute from + # FCRS to RSCF format. + if not is_nvidia_gpu: + # Permute weight from [out_channels, in_channels // num_groups, height, width] + # to [height, width, in_channels // num_groups, out_channels] (RSCF) + weight = F.permute(weight, [2, 3, 1, 0]) + + output = F.conv2d( + x, + weight, + self.stride, + self.dilation, + self.padding, + self.num_groups, + self.bias if isinstance(self.bias, Tensor) else None, + filter_layout=FilterLayout.FCRS + if (self.permute and is_nvidia_gpu) + else FilterLayout.RSCF, + ) + + if self.permute: + # Output: [batch_size, new_height, new_width, out_channels] -> [batch_size, out_channels, new_height, new_width] + output = F.permute(output, [0, 3, 1, 2]) + + return output diff --git a/max/python/max/nn/module_v3/norm/__init__.py b/max/python/max/nn/module_v3/norm/__init__.py index 1d78f1db418..0cbfe00389d 100644 --- a/max/python/max/nn/module_v3/norm/__init__.py +++ b/max/python/max/nn/module_v3/norm/__init__.py @@ -11,9 +11,11 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # +from .group_norm import GroupNorm from .rms_norm import GemmaRMSNorm, RMSNorm __all__ = [ "GemmaRMSNorm", + "GroupNorm", "RMSNorm", ] diff --git a/max/python/max/nn/module_v3/norm/group_norm.py b/max/python/max/nn/module_v3/norm/group_norm.py new file mode 100644 index 00000000000..9bbd6d63039 --- /dev/null +++ b/max/python/max/nn/module_v3/norm/group_norm.py @@ -0,0 +1,171 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Group Normalization""" + +from __future__ import annotations + +from ....driver import CPU +from ....dtype import DType +from ....experimental import functional as F +from ....experimental.tensor import Tensor +from ..module import Module + + +def group_norm( + x: Tensor, + weight: Tensor, + bias: Tensor, + num_groups: int, + eps: float, +) -> Tensor: + """Applies Group Normalization to an input tensor. + + Group Normalization divides the channels into groups and computes + normalization statistics within each group. This is useful for small + batch sizes where batch normalization is unstable. + + Args: + x: Input tensor of shape [N, C, *] where C is number of channels + weight: Weight tensor of shape [C] + bias: Bias tensor of shape [C] + num_groups: Number of groups to separate the channels into + eps: Small constant added to denominator for numerical stability + + Returns: + Normalized tensor of same shape as input + """ + if len(x.shape) < 2: + raise ValueError( + f"Expected input tensor with >=2 dimensions, got shape {x.shape}" + ) + + return F.custom( + "group_norm", + x.device, + [ + x, + weight.to(x.device), + bias.to(x.device), + F.constant(eps, dtype=x.dtype, device=CPU()), + F.constant(num_groups, dtype=DType.int32, device=CPU()), + ], + [x.type], + )[0] + + +class GroupNorm(Module[[Tensor], Tensor]): + """Group normalization block using module_v3. + + Divides channels into groups and computes normalization stats per group. + Follows the implementation pattern from PyTorch's group_norm. + + This module_v3 implementation uses Tensor instead of Weight, which + automatically handles dtype matching with input tensors, eliminating + the need for dtype workarounds. + + Example: + .. code-block:: python + + from max.nn.module_v3 import GroupNorm + from max.experimental.tensor import Tensor + + norm = GroupNorm(num_groups=32, num_channels=128) + x = Tensor.ones([1, 128, 32, 32]) + result = norm(x) + """ + + weight: Tensor | None + """The weight tensor with shape [num_channels] (None if affine=False).""" + bias: Tensor | None + """The bias tensor with shape [num_channels] (None if affine=False).""" + num_groups: int + """Number of groups to separate the channels into.""" + num_channels: int + """Number of input channels.""" + eps: float + """Small constant added to denominator for numerical stability.""" + + def __init__( + self, + num_groups: int, + num_channels: int, + eps: float = 1e-5, + affine: bool = True, + ) -> None: + """Initialize GroupNorm module. + + Args: + num_groups: Number of groups to separate the channels into + num_channels: Number of input channels + eps: Small constant added to denominator for numerical stability. + Default: 1e-5 + affine: If True, apply learnable affine transform parameters. + Default: True + """ + if num_channels % num_groups != 0: + raise ValueError( + f"num_channels({num_channels}) should be divisible by " + f"num_groups({num_groups})" + ) + + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + + if self.affine: + self.weight = Tensor.ones([num_channels]) + self.bias = Tensor.zeros([num_channels]) + else: + self.weight = None + self.bias = None + + def __rich_repr__(self): + """Rich representation for debugging.""" + yield "num_groups", self.num_groups + yield "num_channels", self.num_channels + yield "eps", self.eps, 1e-5 + yield "affine", self.affine, True + + def forward(self, x: Tensor) -> Tensor: + """Apply group normalization to input tensor. + + Args: + x: Input tensor of shape [N, C, *] where C is number of channels + + Returns: + Normalized tensor of same shape as input + """ + if len(x.shape) < 2: + raise ValueError( + f"Expected input tensor with >=2 dimensions, got shape {x.shape}" + ) + if x.shape[1] != self.num_channels: + raise ValueError( + f"Expected {self.num_channels} channels, got shape {x.shape}" + ) + + if self.affine: + weight = self.weight + bias = self.bias + else: + # Create temporary tensors of ones and zeros when affine=False + weight = Tensor.ones( + [self.num_channels], dtype=x.dtype, device=x.device + ) + bias = Tensor.zeros( + [self.num_channels], dtype=x.dtype, device=x.device + ) + + return group_norm(x, weight, bias, self.num_groups, self.eps) From ee0f9a418a3543098147cc966931b68bf1536dc8 Mon Sep 17 00:00:00 2001 From: byungchul-sqzb Date: Wed, 21 Jan 2026 07:46:28 +0000 Subject: [PATCH 05/10] [Pipelines] Add AutoencoderKL VAE decoder implementation for Flux.1 --- .../architectures/autoencoders/__init__.py | 14 + .../autoencoders/autoencoder_kl.py | 100 ++++ .../autoencoders/layers/__init__.py | 16 + .../autoencoders/layers/attention.py | 141 ++++++ .../autoencoders/layers/resnet.py | 155 ++++++ .../autoencoders/layers/upsampling.py | 167 ++++++ .../architectures/autoencoders/model.py | 104 ++++ .../autoencoders/model_config.py | 66 +++ .../architectures/autoencoders/vae.py | 479 ++++++++++++++++++ 9 files changed, 1242 insertions(+) create mode 100644 max/python/max/pipelines/architectures/autoencoders/__init__.py create mode 100644 max/python/max/pipelines/architectures/autoencoders/autoencoder_kl.py create mode 100644 max/python/max/pipelines/architectures/autoencoders/layers/__init__.py create mode 100644 max/python/max/pipelines/architectures/autoencoders/layers/attention.py create mode 100644 max/python/max/pipelines/architectures/autoencoders/layers/resnet.py create mode 100644 max/python/max/pipelines/architectures/autoencoders/layers/upsampling.py create mode 100644 max/python/max/pipelines/architectures/autoencoders/model.py create mode 100644 max/python/max/pipelines/architectures/autoencoders/model_config.py create mode 100644 max/python/max/pipelines/architectures/autoencoders/vae.py diff --git a/max/python/max/pipelines/architectures/autoencoders/__init__.py b/max/python/max/pipelines/architectures/autoencoders/__init__.py new file mode 100644 index 00000000000..1b8f3f4929d --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .autoencoder_kl import AutoencoderKLModel diff --git a/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl.py b/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl.py new file mode 100644 index 00000000000..70b122372a7 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/autoencoder_kl.py @@ -0,0 +1,100 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.experimental.tensor import Tensor +from max.graph.weights import Weights +from max.nn.module_v3 import Module +from max.pipelines.lib import SupportedEncoding + +from .model import BaseAutoencoderModel +from .model_config import AutoencoderKLConfig +from .vae import Decoder + + +class AutoencoderKL(Module[[Tensor], Tensor]): + r"""A VAE model with KL loss for encoding images into latents and decoding latent representations into images using module_v3.""" + + def __init__( + self, + config: AutoencoderKLConfig, + ) -> None: + """Initialize VAE AutoencoderKL model. + + Args: + config: Autoencoder configuration containing channel sizes, block + structure, normalization settings, and device/dtype information. + """ + super().__init__() + self.decoder = Decoder( + in_channels=config.latent_channels, + out_channels=config.out_channels, + up_block_types=config.up_block_types, + block_out_channels=config.block_out_channels, + layers_per_block=config.layers_per_block, + norm_num_groups=config.norm_num_groups, + act_fn=config.act_fn, + norm_type="group", + mid_block_add_attention=config.mid_block_add_attention, + use_post_quant_conv=config.use_post_quant_conv, + device=config.device, + dtype=config.dtype, + ) + + def forward(self, z: Tensor, temb: Tensor | None = None) -> Tensor: + """Apply AutoencoderKL forward pass (decoding only). + + Args: + z: Input latent tensor of shape [N, C_latent, H_latent, W_latent]. + temb: Optional time embedding tensor. + + Returns: + Decoded image tensor of shape [N, C_out, H, W]. + """ + return self.decoder(z, temb) + + +class AutoencoderKLModel(BaseAutoencoderModel): + """MaxModel wrapper for AutoencoderKL. + + This class provides the MaxModel interface for AutoencoderKL, handling + configuration, weight loading, and model compilation. + """ + + config_name: ClassVar[str] = AutoencoderKLConfig.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + """Initialize AutoencoderKLModel. + + Args: + config: Model configuration dictionary. + encoding: Supported encoding for the model. + devices: List of devices to use. + weights: Model weights. + """ + super().__init__( + config=config, + encoding=encoding, + devices=devices, + weights=weights, + config_class=AutoencoderKLConfig, + autoencoder_class=AutoencoderKL, + ) diff --git a/max/python/max/pipelines/architectures/autoencoders/layers/__init__.py b/max/python/max/pipelines/architectures/autoencoders/layers/__init__.py new file mode 100644 index 00000000000..1b9a27486e9 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/layers/__init__.py @@ -0,0 +1,16 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .attention import VAEAttention +from .resnet import ResnetBlock2D +from .upsampling import Upsample2D diff --git a/max/python/max/pipelines/architectures/autoencoders/layers/attention.py b/max/python/max/pipelines/architectures/autoencoders/layers/attention.py new file mode 100644 index 00000000000..60e7a791d0c --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/layers/attention.py @@ -0,0 +1,141 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import DeviceRef +from max.nn.module_v3 import GroupNorm, Linear, Module +from max.nn.module_v3.sequential import ModuleList + + +class VAEAttention(Module[[Tensor], Tensor]): + """Spatial attention module for VAE models using module_v3. + + This module performs self-attention on 2D spatial features by: + 1. Converting [N, C, H, W] to [N, H*W, C] sequence format + 2. Applying scaled dot-product attention (optimized for small sequences) + 3. Converting back to [N, C, H, W] format + + Note: Manual attention is used instead of flash_attention_gpu because + VAE attention typically has small sequence lengths (H*W) where flash + attention overhead outweighs benefits. + """ + + def __init__( + self, + query_dim: int, + heads: int, + dim_head: int, + num_groups: int = 32, + eps: float = 1e-6, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize VAE attention module. + + Args: + query_dim: Dimension of query (number of channels). + heads: Number of attention heads. + dim_head: Dimension of each attention head. + num_groups: Number of groups for GroupNorm. + eps: Epsilon value for GroupNorm. + device: Device reference. + dtype: Data type. + """ + super().__init__() + self.query_dim = query_dim + self.heads = heads + self.dim_head = dim_head + self.inner_dim = heads * dim_head + + self.group_norm = GroupNorm( + num_groups=num_groups, + num_channels=query_dim, + eps=eps, + affine=True, + ) + + self.to_q = Linear( + in_dim=query_dim, + out_dim=self.inner_dim, + bias=True, + ) + self.to_k = Linear( + in_dim=query_dim, + out_dim=self.inner_dim, + bias=True, + ) + self.to_v = Linear( + in_dim=query_dim, + out_dim=self.inner_dim, + bias=True, + ) + # Use ModuleList to match original weights format (to_out.0.*) + self.to_out = ModuleList( + [ + Linear( + in_dim=self.inner_dim, + out_dim=query_dim, + bias=True, + ) + ] + ) + + self.scale = 1.0 / math.sqrt(dim_head) + + def forward(self, x: Tensor) -> Tensor: + """Apply spatial attention to 2D image tensor. + + Args: + x: Input tensor of shape [N, C, H, W]. + + Returns: + Output tensor of shape [N, C, H, W] with residual connection. + """ + residual = x + + x = self.group_norm(x) + + n, c, h, w = x.shape + seq_len = h * w + + x = F.reshape(x, [n, c, seq_len]) + x = F.permute(x, [0, 2, 1]) + + q = self.to_q(x) + k = self.to_k(x) + v = self.to_v(x) + + q = F.reshape(q, [n, seq_len, self.heads, self.dim_head]) + q = F.permute(q, [0, 2, 1, 3]) + k = F.reshape(k, [n, seq_len, self.heads, self.dim_head]) + k = F.permute(k, [0, 2, 1, 3]) + v = F.reshape(v, [n, seq_len, self.heads, self.dim_head]) + v = F.permute(v, [0, 2, 1, 3]) + + attn = q @ F.permute(k, [0, 1, 3, 2]) * self.scale + attn = F.softmax(attn, axis=-1) + out = attn @ v + + out = F.permute(out, [0, 2, 1, 3]) + out = F.reshape(out, [n, seq_len, self.inner_dim]) + + out = self.to_out[0](out) + + out = F.permute(out, [0, 2, 1]) + out = F.reshape(out, [n, c, h, w]) + + return residual + out diff --git a/max/python/max/pipelines/architectures/autoencoders/layers/resnet.py b/max/python/max/pipelines/architectures/autoencoders/layers/resnet.py new file mode 100644 index 00000000000..ac6e52dfbf5 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/layers/resnet.py @@ -0,0 +1,155 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import DeviceRef +from max.nn.module_v3 import Conv2d, GroupNorm, Module + + +class ResnetBlock2D(Module[[Tensor], Tensor]): + """Residual block for 2D VAE decoder using module_v3. + + This module implements a residual block with two convolutional layers, + group normalization, and optional shortcut connection. It supports + time embedding conditioning and configurable activation functions. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int | None, + groups: int, + groups_out: int, + eps: float = 1e-6, + non_linearity: str = "silu", + use_conv_shortcut: bool = False, + conv_shortcut_bias: bool = True, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize ResnetBlock2D module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + temb_channels: Number of time embedding channels (None if not used). + groups: Number of groups for first GroupNorm. + groups_out: Number of groups for second GroupNorm. + eps: Epsilon value for GroupNorm layers. + non_linearity: Activation function name (e.g., "silu"). + use_conv_shortcut: Whether to use convolutional shortcut. + conv_shortcut_bias: Whether to use bias in shortcut convolution. + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.use_conv_shortcut = use_conv_shortcut + + self.norm1 = GroupNorm( + num_groups=groups, + num_channels=in_channels, + eps=eps, + affine=True, + ) + + self.conv1 = Conv2d( + kernel_size=3, + in_channels=in_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + self.norm2 = GroupNorm( + num_groups=groups_out, + num_channels=out_channels, + eps=eps, + affine=True, + ) + + self.conv2 = Conv2d( + kernel_size=3, + in_channels=out_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + self.conv_shortcut: Conv2d | None = None + if self.use_conv_shortcut: + self.conv_shortcut = Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=0, + dilation=1, + num_groups=1, + has_bias=conv_shortcut_bias, + device=device, + permute=True, + ) + elif in_channels != out_channels: + self.conv_shortcut = Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=0, + dilation=1, + num_groups=1, + has_bias=conv_shortcut_bias, + device=device, + permute=True, + ) + + def forward(self, x: Tensor, temb: Tensor | None = None) -> Tensor: + """Apply ResnetBlock2D forward pass. + + Args: + x: Input tensor of shape [N, C, H, W]. + temb: Optional time embedding tensor (currently unused). + + Returns: + Output tensor of shape [N, C_out, H, W] with residual connection. + """ + shortcut = ( + self.conv_shortcut(x) if self.conv_shortcut is not None else x + ) + + h = F.silu(self.norm1(x)) + h = self.conv1(h) + + h = F.silu(self.norm2(h)) + h = self.conv2(h) + + return h + shortcut diff --git a/max/python/max/pipelines/architectures/autoencoders/layers/upsampling.py b/max/python/max/pipelines/architectures/autoencoders/layers/upsampling.py new file mode 100644 index 00000000000..a169e08007b --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/layers/upsampling.py @@ -0,0 +1,167 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Upsampling utilities for MAX framework.""" + +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import DeviceRef, TensorValue, TensorValueLike +from max.nn.module_v3 import Conv2d, Module + + +def interpolate_2d_nearest( + x: TensorValueLike, + scale_factor: int = 2, +) -> TensorValue: + """Upsamples a 2D tensor using nearest-neighbor interpolation. + + This is a workaround implementation because MAX framework's ops.resize + does not support NEAREST mode (only BICUBIC is currently supported). + The workaround uses reshape and broadcast operations to achieve + nearest-neighbor upsampling by a factor of 2. + + This function works in both Graph context and eager execution contexts, + compatible with module_v3 style. + + Note: + This workaround can be removed once ops.resize supports NEAREST mode. + + Args: + x: Input tensor of shape [N, C, H, W] in NCHW format. + Can be Tensor or TensorValue. + scale_factor: Upsampling factor. Currently only 2 is supported. + Default: 2 + + Returns: + Upsampled tensor of shape [N, C, H*scale_factor, W*scale_factor]. + + Raises: + ValueError: If input tensor doesn't have rank 4. + NotImplementedError: If scale_factor is not 2. + """ + x = TensorValue(x) + + if x.rank != 4: + raise ValueError(f"Input tensor must have rank 4, got {x.rank}") + + if scale_factor != 2: + raise NotImplementedError( + f"Only scale_factor=2 is currently supported, got {scale_factor}" + ) + + n, c, h, w = x.shape + target_shape = [n, c, h * scale_factor, w * scale_factor] + + # Reshape: [N, C, H, W] -> [N, C, H, 1, W, 1] + x_reshaped = F.reshape(x, [n, c, h, 1, w, 1]) + + ones_scalar = F.constant(1.0, dtype=x.dtype, device=x.device) + ones = F.broadcast_to( + ones_scalar, + [1, 1, 1, scale_factor, 1, scale_factor], + ) + + # Broadcast: [N, C, H, 1, W, 1] * [1, 1, 1, 2, 1, 2] -> [N, C, H, 2, W, 2] + x_expanded = F.mul(x_reshaped, ones) + + # Reshape: [N, C, H, 2, W, 2] -> [N, C, H*2, W*2] + return F.reshape(x_expanded, target_shape) + + +class Upsample2D(Module[[Tensor], Tensor]): + """2D upsampling module with optional convolution using module_v3. + + This module performs 2D upsampling using nearest-neighbor interpolation + (via interpolate_2d_nearest function) followed by an optional convolution layer. + + This is a module_v3-compatible version that uses Tensor instead of TensorValue + """ + + conv: Conv2d | None + """Optional Conv2d layer applied after upsampling.""" + + def __init__( + self, + channels: int, + use_conv: bool = False, + use_conv_transpose: bool = False, + out_channels: int | None = None, + name: str = "conv", + kernel_size: int | None = None, + padding: int = 1, + bias: bool = True, + interpolate: bool = True, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize 2D upsampling module. + + Args: + channels: Number of input channels. + use_conv: Whether to apply a convolution after upsampling. + use_conv_transpose: Whether to use transposed convolution (not supported yet). + out_channels: Number of output channels. If None, uses channels. + name: Name for the convolution layer (unused, kept for compatibility). + kernel_size: Kernel size for the convolution. + padding: Padding for the convolution. + bias: Whether to use bias in the convolution. + interpolate: Whether to perform interpolation upsampling. + device: Device reference (optional in module_v3). + dtype: Data type (optional in module_v3). + """ + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.interpolate = interpolate + self.device = device + self.dtype = dtype + + if use_conv_transpose: + raise NotImplementedError( + "Upsample2D does not support use_conv_transpose=True yet." + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + self.conv = Conv2d( + kernel_size=kernel_size, + in_channels=self.channels, + out_channels=self.out_channels, + dtype=dtype, + stride=1, + padding=padding, + has_bias=bias, + device=device, + permute=True, + ) + else: + self.conv = None + + def forward(self, x: Tensor) -> Tensor: + """Apply 2D upsampling with optional convolution. + + Args: + x: Input tensor of shape [N, C, H, W]. + + Returns: + Upsampled tensor, optionally convolved. + """ + if self.interpolate: + x = interpolate_2d_nearest(x, scale_factor=2) + + if self.use_conv and self.conv is not None: + x = self.conv(x) + + return x diff --git a/max/python/max/pipelines/architectures/autoencoders/model.py b/max/python/max/pipelines/architectures/autoencoders/model.py new file mode 100644 index 00000000000..999c2538260 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/model.py @@ -0,0 +1,104 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import TypeVar + +from max.driver import Device +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .model_config import AutoencoderKLConfigBase + +TConfig = TypeVar("TConfig", bound=AutoencoderKLConfigBase) + + +class BaseAutoencoderModel(MaxModel): + """Base class for autoencoder models with shared logic. + + This base class provides common functionality for loading and running + autoencoder decoders. Subclasses should specify the config and autoencoder + classes to use. + """ + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + config_class: type[TConfig], + autoencoder_class: type, + ) -> None: + """Initialize base autoencoder model. + + Args: + config: Model configuration dictionary. + encoding: Supported encoding for the model. + devices: List of devices to use. + weights: Model weights. + config_class: Configuration class to use (e.g., AutoencoderKLConfig). + autoencoder_class: Autoencoder class to use (e.g., AutoencoderKL). + """ + super().__init__(config, encoding, devices, weights) + self.config = config_class.generate(config, encoding, devices) + self.autoencoder_class = autoencoder_class + self.load_model() + + def load_model(self) -> None: + """Load and compile the decoder model. + + Extracts decoder weights from the full model weights and compiles + the decoder for inference. + """ + state_dict = { + key.removeprefix("decoder."): value.data() + for key, value in self.weights.items() + if not key.startswith("encoder.") + } + with F.lazy(): + autoencoder = self.autoencoder_class(self.config) + autoencoder.decoder.to(self.devices[0]) + + self.model = autoencoder.decoder.compile( + *autoencoder.decoder.input_types(), weights=state_dict + ) + + def decode(self, *args, **kwargs) -> Tensor: + """Decode latents to images using module_v3 compiled decoder. + + Args: + *args: Input arguments (typically latents as Tensor). + **kwargs: Additional keyword arguments. + + Returns: + Tensor: Decoded image tensor (module_v3 Tensor, V3). + """ + return self.model(*args, **kwargs) + + def __call__(self, *args, **kwargs) -> Tensor: + """Call the decoder model to decode latents to images. + + This method provides a consistent interface with other MaxModel + implementations. It is an alias for decode(). + + Args: + *args: Input arguments (typically latents as Tensor). + **kwargs: Additional keyword arguments. + + Returns: + Tensor: Decoded image tensor (module_v3 Tensor, V3). + """ + return self.decode(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/autoencoders/model_config.py b/max/python/max/pipelines/architectures/autoencoders/model_config.py new file mode 100644 index 00000000000..e85d05b7e49 --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/model_config.py @@ -0,0 +1,66 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class AutoencoderKLConfigBase(MAXModelConfigBase): + in_channels: int = 3 + out_channels: int = 3 + down_block_types: list[str] = Field(default_factory=list, max_length=4) + up_block_types: list[str] = Field(default_factory=list, max_length=4) + block_out_channels: list[int] = Field(default_factory=list, max_length=4) + layers_per_block: int = 1 + act_fn: str = "silu" + latent_channels: int = 4 + norm_num_groups: int = 32 + sample_size: int = 32 + scaling_factor: float = 0.18215 + shift_factor: float | None = None + latents_mean: tuple[float] | None = None + latents_std: tuple[float] | None = None + force_upcast: bool = True + use_quant_conv: bool = True + use_post_quant_conv: bool = True + mid_block_add_attention: bool = True + device: DeviceRef = Field(default_factory=DeviceRef.CPU) + dtype: DType = DType.bfloat16 + + +class AutoencoderKLConfig(AutoencoderKLConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> "AutoencoderKLConfig": + init_dict = { + key: value + for key, value in config_dict.items() + if key in AutoencoderKLConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return AutoencoderKLConfig(**init_dict) diff --git a/max/python/max/pipelines/architectures/autoencoders/vae.py b/max/python/max/pipelines/architectures/autoencoders/vae.py new file mode 100644 index 00000000000..cb57498f6cc --- /dev/null +++ b/max/python/max/pipelines/architectures/autoencoders/vae.py @@ -0,0 +1,479 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from dataclasses import dataclass + +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import DeviceRef, TensorType +from max.nn.module_v3 import Conv2d, GroupNorm, Module +from max.nn.module_v3.sequential import ModuleList + +from .layers import ResnetBlock2D, Upsample2D, VAEAttention + + +class UpDecoderBlock2D(Module[[Tensor], Tensor]): + """Upsampling decoder block for 2D VAE. + + This module consists of multiple ResNet blocks followed by an optional + upsampling layer. It progressively increases spatial resolution while + processing features through residual connections. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: int | None = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: int | None = None, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize UpDecoderBlock2D module. + + Args: + in_channels: Number of input channels. + out_channels: Number of output channels. + resolution_idx: Optional resolution index for tracking. + dropout: Dropout rate (currently unused). + num_layers: Number of ResNet blocks in this decoder block. + resnet_eps: Epsilon value for ResNet GroupNorm layers. + resnet_time_scale_shift: Time embedding scale/shift mode. + resnet_act_fn: Activation function for ResNet blocks. + resnet_groups: Number of groups for ResNet GroupNorm. + resnet_pre_norm: Whether to apply normalization before ResNet. + output_scale_factor: Scaling factor for output (currently unused). + add_upsample: Whether to add upsampling layer after ResNet blocks. + temb_channels: Number of time embedding channels (None if not used). + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + resnets_list = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnet = ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + groups=resnet_groups, + groups_out=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + use_conv_shortcut=False, + conv_shortcut_bias=True, + device=device, + dtype=dtype, + ) + resnets_list.append(resnet) + self.resnets = ModuleList(resnets_list) + + if add_upsample: + upsampler = Upsample2D( + channels=out_channels, + use_conv=True, + out_channels=out_channels, + name="conv", + kernel_size=3, + padding=1, + bias=True, + interpolate=True, + device=device, + dtype=dtype, + ) + self.upsamplers = ModuleList([upsampler]) + else: + self.upsamplers = None + + def forward( + self, hidden_states: Tensor, temb: Tensor | None = None + ) -> Tensor: + """Apply UpDecoderBlock2D forward pass. + + Args: + hidden_states: Input tensor of shape [N, C_in, H, W]. + temb: Optional time embedding tensor. + + Returns: + Output tensor of shape [N, C_out, H*2, W*2] (if upsampling) or + [N, C_out, H, W] (if no upsampling). + """ + # Process through all resnet blocks + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb) + + # Apply upsampling if configured (compile-time decision) + if self.upsamplers is not None: + hidden_states = self.upsamplers[0](hidden_states) + + return hidden_states + + +class MidBlock2D(Module[[Tensor], Tensor]): + """Middle block for 2D VAE using module_v3. + + This module processes features at the middle of the VAE architecture, + applying ResNet blocks with optional spatial attention mechanisms. + It maintains spatial dimensions while processing features through + residual connections and self-attention. + """ + + def __init__( + self, + in_channels: int, + temb_channels: int | None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize MidBlock2D module. + + Args: + in_channels: Number of input channels. + temb_channels: Number of time embedding channels (None if not used). + dropout: Dropout rate (currently unused). + num_layers: Number of ResNet/attention layer pairs. + resnet_eps: Epsilon value for ResNet GroupNorm layers. + resnet_time_scale_shift: Time embedding scale/shift mode. + resnet_act_fn: Activation function for ResNet blocks. + resnet_groups: Number of groups for ResNet GroupNorm. + resnet_pre_norm: Whether to apply normalization before ResNet. + add_attention: Whether to add attention layers between ResNet blocks. + attention_head_dim: Dimension of each attention head. + output_scale_factor: Scaling factor for output (currently unused). + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + resnets_list = [] + attentions_list = [] + + resnet = ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, + groups_out=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + use_conv_shortcut=False, + conv_shortcut_bias=True, + device=device, + dtype=dtype, + ) + resnets_list.append(resnet) + + for _i in range(num_layers): + if add_attention: + attn = VAEAttention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + num_groups=resnet_groups, + eps=resnet_eps, + device=device, + dtype=dtype, + ) + attentions_list.append(attn) + else: + attentions_list.append(None) + + resnet = ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + groups=resnet_groups, + groups_out=resnet_groups, + eps=resnet_eps, + non_linearity=resnet_act_fn, + use_conv_shortcut=False, + conv_shortcut_bias=True, + device=device, + dtype=dtype, + ) + resnets_list.append(resnet) + + self.resnets = ModuleList(resnets_list) + + if attentions_list: + non_none_attentions = [ + attn for attn in attentions_list if attn is not None + ] + if non_none_attentions: + self.attentions = ModuleList(non_none_attentions) + self.attention_indices = { + i + for i, attn in enumerate(attentions_list) + if attn is not None + } + else: + self.attentions = None + self.attention_indices = set() + else: + self.attentions = None + self.attention_indices = set() + + def forward( + self, hidden_states: Tensor, temb: Tensor | None = None + ) -> Tensor: + """Apply MidBlock2D forward pass. + + Args: + hidden_states: Input tensor of shape [N, C, H, W]. + temb: Optional time embedding tensor. + + Returns: + Output tensor of shape [N, C, H, W] with same spatial dimensions. + """ + hidden_states = self.resnets[0](hidden_states, temb) + + attention_idx = 0 + for i in range(len(self.resnets) - 1): + if self.attentions is not None and i in self.attention_indices: + hidden_states = self.attentions[attention_idx](hidden_states) + attention_idx += 1 + hidden_states = self.resnets[i + 1](hidden_states, temb) + + return hidden_states + + +@dataclass +class DecoderOutput: + r"""Output of decoding method. + + Args: + sample (`Tensor` of shape `(batch_size, num_channels, height, width)`): + The decoded output sample from the last layer of the model. + """ + + sample: Tensor + commit_loss: Tensor | None = None + + +class Decoder(Module[[Tensor], Tensor]): + """VAE decoder for generating images from latent representations using module_v3. + + This decoder progressively upsamples latent features through multiple + decoder blocks, applying ResNet layers and attention mechanisms to + reconstruct high-resolution images from compressed latent codes. + """ + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: tuple[str, ...] = ("UpDecoderBlock2D",), + block_out_channels: tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", + mid_block_add_attention: bool = True, + use_post_quant_conv: bool = True, + device: DeviceRef | None = None, + dtype: DType | None = None, + ) -> None: + """Initialize Decoder module. + + Args: + in_channels: Number of input channels (latent channels). + out_channels: Number of output channels (image channels). + up_block_types: Tuple of upsampling block types. + block_out_channels: Tuple of channel counts for each decoder block. + layers_per_block: Number of ResNet layers per decoder block. + norm_num_groups: Number of groups for GroupNorm layers. + act_fn: Activation function name (e.g., "silu"). + norm_type: Normalization type ("group" or "spatial"). + mid_block_add_attention: Whether to add attention in middle block. + use_post_quant_conv: Whether to use post-quantization convolution. + device: Device reference for module placement. + dtype: Data type for module parameters. + """ + super().__init__() + self.layers_per_block = layers_per_block + self.session = None + self.in_channels = in_channels + self.device = device + self.dtype = dtype + + self.post_quant_conv: Conv2d | None = None + if use_post_quant_conv: + self.post_quant_conv = Conv2d( + kernel_size=1, + in_channels=in_channels, + out_channels=in_channels, + dtype=dtype, + stride=1, + padding=0, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + self.conv_in = Conv2d( + kernel_size=3, + in_channels=in_channels, + out_channels=block_out_channels[-1], + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + temb_channels = in_channels if norm_type == "spatial" else None + self.mid_block = MidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=temb_channels, + dropout=0.0, + num_layers=1, + resnet_eps=1e-6, + resnet_time_scale_shift=( + "default" if norm_type == "group" else norm_type + ), + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + add_attention=mid_block_add_attention, + attention_head_dim=block_out_channels[-1], + output_scale_factor=1.0, + device=device, + dtype=dtype, + ) + + up_blocks_list = [] + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + if up_block_type == "UpDecoderBlock2D": + up_block = UpDecoderBlock2D( + in_channels=prev_output_channel, + out_channels=output_channel, + resolution_idx=i, + dropout=0.0, + num_layers=self.layers_per_block + 1, + resnet_eps=1e-6, + resnet_time_scale_shift=norm_type, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_pre_norm=True, + output_scale_factor=1.0, + add_upsample=not is_final_block, + temb_channels=temb_channels, + device=device, + dtype=dtype, + ) + up_blocks_list.append(up_block) + else: + raise ValueError(f"Unsupported up_block_type: {up_block_type}") + + prev_output_channel = output_channel + + self.up_blocks = ModuleList(up_blocks_list) + + if norm_type == "spatial": + raise NotImplementedError("SpatialNorm not implemented in MAX VAE") + else: + self.conv_norm_out = GroupNorm( + num_groups=norm_num_groups, + num_channels=block_out_channels[0], + eps=1e-6, + affine=True, + ) + + self.conv_out = Conv2d( + kernel_size=3, + in_channels=block_out_channels[0], + out_channels=out_channels, + dtype=dtype, + stride=1, + padding=1, + dilation=1, + num_groups=1, + has_bias=True, + device=device, + permute=True, + ) + + def forward(self, z: Tensor, temb: Tensor | None = None) -> Tensor: + """Apply Decoder forward pass. + + Args: + z: Input latent tensor of shape [N, C_latent, H_latent, W_latent]. + temb: Optional time embedding tensor. + + Returns: + Decoded image tensor of shape [N, C_out, H, W] where H and W are + upsampled from H_latent and W_latent. + """ + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + sample = self.conv_in(z) + sample = self.mid_block(sample, temb) + + for up_block in self.up_blocks: + sample = up_block(sample, temb) + + sample = self.conv_norm_out(sample) + sample = F.silu(sample) + sample = self.conv_out(sample) + + return sample + + def input_types(self) -> tuple[TensorType, ...]: + """Define input tensor types for the decoder model. + + Returns: + Tuple of TensorType specifications for decoder input. + """ + latent_type = TensorType( + self.dtype, + shape=[ + "batch_size", + self.in_channels, + "latent_height", + "latent_width", + ], + device=self.device, + ) + + return (latent_type,) From 8cb21470b9ab220107013db4b89d36e2d5e5c579 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Wed, 21 Jan 2026 08:31:42 +0000 Subject: [PATCH 06/10] feat: add clip for flux1 pipeline support --- max/python/max/dtype/__init__.py | 1 + max/python/max/dtype/dtype_extension.py | 56 ++ max/python/max/nn/module_v3/norm/__init__.py | 2 + .../max/nn/module_v3/norm/layer_norm.py | 115 ++++ .../pipelines/architectures/clip/__init__.py | 14 + .../max/pipelines/architectures/clip/clip.py | 514 ++++++++++++++++++ .../max/pipelines/architectures/clip/model.py | 57 ++ .../architectures/clip/model_config.py | 63 +++ .../max/pipelines/lib/interfaces/max_model.py | 45 ++ 9 files changed, 867 insertions(+) create mode 100644 max/python/max/dtype/dtype_extension.py create mode 100644 max/python/max/nn/module_v3/norm/layer_norm.py create mode 100644 max/python/max/pipelines/architectures/clip/__init__.py create mode 100644 max/python/max/pipelines/architectures/clip/clip.py create mode 100644 max/python/max/pipelines/architectures/clip/model.py create mode 100644 max/python/max/pipelines/architectures/clip/model_config.py create mode 100644 max/python/max/pipelines/lib/interfaces/max_model.py diff --git a/max/python/max/dtype/__init__.py b/max/python/max/dtype/__init__.py index 864514236b8..ad702907d8c 100644 --- a/max/python/max/dtype/__init__.py +++ b/max/python/max/dtype/__init__.py @@ -11,4 +11,5 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # +from . import dtype_extension from .dtype import DType diff --git a/max/python/max/dtype/dtype_extension.py b/max/python/max/dtype/dtype_extension.py new file mode 100644 index 00000000000..fba3de7e83f --- /dev/null +++ b/max/python/max/dtype/dtype_extension.py @@ -0,0 +1,56 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Extension for max.dtype to support additional attributes.""" + +from numpy import finfo as np_finfo + +from .dtype import DType + + +class finfo: + """A numerical properties of a floating point max.dtype.DType. + + This class mimics torch.finfo behavior without torch dependency, + including support for bfloat16. + + NOTE: Currently, it's applied through patching. + This extension is better to be implemented in dtype library itself. + """ + + def __init__(self, dtype: DType): + """Initialize finfo for a given max.dtype.DType. + + Args: + dtype: The data type to get limits for. + """ + if dtype == DType.bfloat16: + self.min = -3.38953e38 + self.max = 3.38953e38 + self.bits = 16 + self.eps = 0.0078125 + self.resolution = 0.01 + self.tiny = 1.17549e-38 + self.dtype = "bfloat16" + else: + np_finfo_obj = np_finfo(dtype.to_numpy()) + self.min = float(np_finfo_obj.min) + self.max = float(np_finfo_obj.max) + self.bits = np_finfo_obj.bits + self.eps = float(np_finfo_obj.eps) + self.resolution = float(np_finfo_obj.resolution) + self.tiny = float(np_finfo_obj.tiny) + self.dtype = str(np_finfo_obj.dtype) + + +DType.finfo = finfo diff --git a/max/python/max/nn/module_v3/norm/__init__.py b/max/python/max/nn/module_v3/norm/__init__.py index 1d78f1db418..3a7444d5df2 100644 --- a/max/python/max/nn/module_v3/norm/__init__.py +++ b/max/python/max/nn/module_v3/norm/__init__.py @@ -11,9 +11,11 @@ # limitations under the License. # ===----------------------------------------------------------------------=== # +from .layer_norm import LayerNorm from .rms_norm import GemmaRMSNorm, RMSNorm __all__ = [ "GemmaRMSNorm", + "LayerNorm", "RMSNorm", ] diff --git a/max/python/max/nn/module_v3/norm/layer_norm.py b/max/python/max/nn/module_v3/norm/layer_norm.py new file mode 100644 index 00000000000..a5470aa8021 --- /dev/null +++ b/max/python/max/nn/module_v3/norm/layer_norm.py @@ -0,0 +1,115 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +"""Layer normalization for module_v3.""" + +from __future__ import annotations + +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor + +from ..module import Module + + +def layer_norm( + x: Tensor, + gamma: Tensor, + beta: Tensor, + eps: float, + keep_dtype: bool, +) -> Tensor: + """Applies Layer Normalization to an input tensor. + + Args: + x: Input tensor to normalize. + gamma: Scale tensor for elementwise affine transform. + beta: Bias tensor for elementwise affine transform. + eps: Numerical stability constant. + keep_dtype: Whether to preserve input dtype in computation. + + Returns: + A layer-normalized tensor with the same shape and type as `x`. + """ + if keep_dtype: + return F.layer_norm(x, gamma=gamma, beta=beta, epsilon=eps) + output = F.layer_norm( + F.cast(x, DType.float32), + gamma=F.cast(gamma, DType.float32), + beta=F.cast(beta, DType.float32), + epsilon=eps, + ) + return F.cast(output, x.dtype) + + +class LayerNorm(Module[[Tensor], Tensor]): + """Layer normalization over the last dimension.""" + + def __init__( + self, + dim: int, + eps: float = 1e-5, + *, + keep_dtype: bool = True, + elementwise_affine: bool = True, + use_bias: bool = True, + ) -> None: + """Initialize LayerNorm. + + Args: + dim: Size of the last dimension to normalize. + eps: Numerical stability constant. + keep_dtype: Whether to preserve input dtype in computation. + elementwise_affine: Whether to apply learned scale and bias. + use_bias: Whether to learn an additive bias term. + """ + super().__init__() + self.dim = dim + self.eps = eps + self.keep_dtype = keep_dtype + self.elementwise_affine = elementwise_affine + self.use_bias = use_bias + if elementwise_affine: + self.weight = Tensor.ones([dim]) + self.bias = Tensor.zeros([dim]) if use_bias else None + else: + self.weight = None + self.bias = None + + def __rich_repr__(self): + """Repr matching the Linear constructor.""" + yield "dim", self.dim + yield "eps", self.eps, 1e-6 + + def _affine_params(self, x: Tensor) -> tuple[Tensor, Tensor]: + if self.weight is None: + gamma = F.broadcast_to( + F.constant(1.0, dtype=x.dtype, device=x.device), + shape=(x.shape[-1],), + ) + else: + gamma = self.weight + + if self.bias is None: + beta = F.broadcast_to( + F.constant(0.0, dtype=x.dtype, device=x.device), + shape=(x.shape[-1],), + ) + else: + beta = self.bias + + return gamma, beta + + def forward(self, x: Tensor) -> Tensor: + gamma, beta = self._affine_params(x) + return layer_norm(x, gamma, beta, self.eps, self.keep_dtype) diff --git a/max/python/max/pipelines/architectures/clip/__init__.py b/max/python/max/pipelines/architectures/clip/__init__.py new file mode 100644 index 00000000000..32cb1def84e --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import ClipModel diff --git a/max/python/max/pipelines/architectures/clip/clip.py b/max/python/max/pipelines/architectures/clip/clip.py new file mode 100644 index 00000000000..42d832386cc --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/clip.py @@ -0,0 +1,514 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from functools import partial + +from max.driver import Device +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import TensorType +from max.nn.module_v3 import Embedding, Linear, Module +from max.nn.module_v3.norm import LayerNorm +from max.nn.module_v3.sequential import ModuleList + +from .model_config import ClipConfig + + +class CLIPTextEmbeddings(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP text embeddings. + + Args: + config: CLIP configuration for embedding dimensions and device/dtype. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.position_embedding = Embedding( + config.max_position_embeddings, + dim=self.embed_dim, + ) + self.token_embedding = Embedding( + config.vocab_size, + dim=self.embed_dim, + ) + + def forward( + self, + input_ids: Tensor | None = None, + position_ids: Tensor | None = None, + inputs_embeds: Tensor | None = None, + ) -> Tensor: + """Apply embeddings to input tokens. + + Args: + input_ids: Input token IDs. + position_ids: Position IDs. + inputs_embeds: Pre-computed input embeddings. + + Returns: + Combined embeddings. + """ + if input_ids is None and inputs_embeds is None: + raise ValueError( + "You have to specify either input_ids or inputs_embeds" + ) + + if input_ids is not None: + seq_length = input_ids.shape[-1] + else: + seq_length = inputs_embeds.shape[-2] + + if position_ids is None: + device = ( + input_ids.device + if input_ids is not None + else inputs_embeds.device + ) + position_ids = F.arange( + 0, seq_length, step=1, dtype=DType.int32, device=device + ) + position_ids = F.unsqueeze(position_ids, 0) + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class CLIPAttention(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP attention module. + + Args: + config: CLIP configuration for attention dimensions and device/dtype. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = Linear( + self.embed_dim, + self.embed_dim, + bias=True, + ) + self.v_proj = Linear( + self.embed_dim, + self.embed_dim, + bias=True, + ) + self.q_proj = Linear( + self.embed_dim, + self.embed_dim, + bias=True, + ) + self.out_proj = Linear( + self.embed_dim, + self.embed_dim, + bias=True, + ) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor | None = None, + causal_attention_mask: Tensor | None = None, + ) -> Tensor: + """Apply multi-head attention. + + Args: + hidden_states: Input hidden states. + attention_mask: Attention mask. + causal_attention_mask: Causal attention mask. + + Returns: + Attention output. + """ + batch_size, seq_length, embed_dim = hidden_states.shape + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = F.reshape( + query, (batch_size, seq_length, self.num_heads, self.head_dim) + ) + query = F.transpose(query, 1, 2) + + key = F.reshape( + key, (batch_size, seq_length, self.num_heads, self.head_dim) + ) + key = F.transpose(key, 1, 2) + + value = F.reshape( + value, (batch_size, seq_length, self.num_heads, self.head_dim) + ) + value = F.transpose(value, 1, 2) + + if attention_mask is not None and causal_attention_mask is not None: + attention_mask = attention_mask + causal_attention_mask + elif causal_attention_mask is not None: + attention_mask = causal_attention_mask + + attn_weights = F.matmul(query, F.transpose(key, -1, -2)) * self.scale + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(F.cast(attn_weights, DType.float32), axis=-1) + attn_weights = F.cast(attn_weights, hidden_states.dtype) + + attn_output = F.matmul(attn_weights, value) + attn_output = F.transpose(attn_output, 1, 2) + attn_output = F.reshape( + attn_output, (batch_size, seq_length, embed_dim) + ) + + attn_output = self.out_proj(attn_output) + + return attn_output + + +class CLIPMLP(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP MLP. + + Args: + config: CLIP configuration for MLP dimensions and device/dtype. + """ + super().__init__() + self.config = config + self.fc1 = Linear( + config.hidden_size, + config.intermediate_size, + bias=True, + ) + self.fc2 = Linear( + config.intermediate_size, + config.hidden_size, + bias=True, + ) + self.act_fn = partial(F.gelu, approximate="quick") + + def forward(self, hidden_states: Tensor) -> Tensor: + """Apply MLP block. + + Args: + hidden_states: Input hidden states. + + Returns: + Output hidden states. + """ + hidden_states = self.fc1(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +class CLIPEncoderLayer(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP encoder layer. + + Args: + config: CLIP configuration for encoder layer structure. + """ + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = CLIPAttention(config) + self.layer_norm1 = LayerNorm( + self.embed_dim, + eps=config.layer_norm_eps, + keep_dtype=True, + ) + self.mlp = CLIPMLP(config) + self.layer_norm2 = LayerNorm( + self.embed_dim, + eps=config.layer_norm_eps, + keep_dtype=True, + ) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor | None, + causal_attention_mask: Tensor | None, + ) -> Tensor: + """Apply encoder layer. + + Args: + hidden_states: Input hidden states. + attention_mask: Attention mask. + causal_attention_mask: Causal attention mask. + + Returns: + Output hidden states. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class CLIPEncoder(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP encoder. + + Args: + config: CLIP configuration for encoder depth and dimensions. + """ + super().__init__() + self.layers = ModuleList( + [CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) + + def forward( + self, + inputs_embeds: Tensor, + attention_mask: Tensor | None = None, + causal_attention_mask: Tensor | None = None, + ) -> Tensor: + """Apply encoder (stack of layers). + + Args: + inputs_embeds: Input embeddings. + attention_mask: Attention mask. + causal_attention_mask: Causal attention mask. + + Returns: + Encoded hidden states. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + hidden_states = encoder_layer( + hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + return hidden_states + + +class CLIPTextTransformer(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP text transformer. + + Args: + config: CLIP configuration for embeddings, encoder, and device/dtype. + """ + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.embeddings = CLIPTextEmbeddings(config) + self.encoder = CLIPEncoder(config) + self.final_layer_norm = LayerNorm( + self.embed_dim, + eps=config.layer_norm_eps, + keep_dtype=True, + ) + self.eos_token_id = config.eos_token_id + + def _create_causal_mask( + self, input_shape: tuple[int, int], device: Device + ) -> Tensor: + """Create causal mask for the transformer. + + Args: + input_shape: Shape of the input tensor. + + Returns: + Causal mask tensor. + """ + _, seq_length = input_shape + + rows = F.arange(0, seq_length, step=1, dtype=DType.int32, device=device) + rows = F.unsqueeze(rows, 1) + cols = F.arange(0, seq_length, step=1, dtype=DType.int32, device=device) + cols = F.unsqueeze(cols, 0) + mask = F.greater(cols, rows) + mask_float = F.cast(mask, self.config.dtype) + + min_val = DType.finfo(self.config.dtype).min + min_val_tensor = F.constant( + min_val, dtype=self.config.dtype, device=device + ) + + causal_mask = mask_float * min_val_tensor + causal_mask = F.unsqueeze(causal_mask, 0) + causal_mask = F.unsqueeze(causal_mask, 1) + return causal_mask + + def forward( + self, + input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + ) -> Tensor: + """Apply text transformer. + + Args: + input_ids: Input token IDs. + attention_mask: Attention mask. + position_ids: Position IDs. + + Returns: + Tuple of (last_hidden_state, pooled_output). + """ + if input_ids is None: + raise ValueError("You have to specify input_ids") + + hidden_states = self.embeddings( + input_ids=input_ids, position_ids=position_ids + ) + + input_shape = input_ids.shape + causal_attention_mask = self._create_causal_mask( + input_shape, input_ids.device + ) + + if attention_mask is not None: + mask_multiplier = F.constant( + DType.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + inverted_mask = ( + F.constant( + 1.0, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + - F.cast(attention_mask, hidden_states.dtype) + ) * mask_multiplier + attention_mask = F.unsqueeze(inverted_mask, 1) + attention_mask = F.unsqueeze(attention_mask, 1) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + causal_attention_mask=causal_attention_mask, + ) + + last_hidden_state = self.final_layer_norm(encoder_outputs) + + if self.eos_token_id == 2: + eos_token_indices = F.cast( + F.argmax(input_ids, axis=-1), DType.int32 + ) + else: + eos_token_indices = F.cast( + F.argmax( + F.cast(F.equal(input_ids, self.eos_token_id), DType.int32), + axis=-1, + ), + DType.int32, + ) + + pooled_output = F.gather_nd( + last_hidden_state, eos_token_indices, batch_dims=1 + ) + + return last_hidden_state, pooled_output + + +class CLIPTextModel(Module): + def __init__( + self, + config: ClipConfig, + ): + """Initialize CLIP text model with MAX. + + Args: + config: CLIP configuration for vocabulary size, dimensions, and + device/dtype settings. + """ + super().__init__() + self.text_model = CLIPTextTransformer(config) + self.device = config.device + + def input_types(self) -> tuple[TensorType, ...]: + """Define input tensor types for the model. + + Returns: + Tuple of TensorType specifications for model inputs. + """ + return ( + TensorType( + DType.int64, + shape=["batch_size", "sequence_length"], + device=self.device, + ), + ) + + def forward( + self, + input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + position_ids: Tensor | None = None, + ) -> tuple[Tensor, Tensor]: + """Apply CLIP text model forward pass. + + Args: + input_ids: Input token IDs. + attention_mask: Attention mask. + position_ids: Position IDs. + + Returns: + Tuple of (last_hidden_state, pooled_output). + """ + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + ) diff --git a/max/python/max/pipelines/architectures/clip/model.py b/max/python/max/pipelines/architectures/clip/model.py new file mode 100644 index 00000000000..c149258afc3 --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/model.py @@ -0,0 +1,57 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import Device +from max.engine import Model +from max.experimental import functional as F +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .clip import CLIPTextModel +from .model_config import ClipConfig + + +class ClipModel(MaxModel): + config_name = ClipConfig.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__( + config, + encoding, + devices, + weights, + ) + self.config = ClipConfig.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Model: + state_dict = {key: value.data() for key, value in self.weights.items()} + with F.lazy(): + clip = CLIPTextModel(self.config) + clip.to(self.devices[0]) + self.model = clip.compile(*clip.input_types(), weights=state_dict) + return self.model + + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/clip/model_config.py b/max/python/max/pipelines/architectures/clip/model_config.py new file mode 100644 index 00000000000..ccf9ba35083 --- /dev/null +++ b/max/python/max/pipelines/architectures/clip/model_config.py @@ -0,0 +1,63 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class ClipConfigBase(MAXModelConfigBase): + vocab_size: int = 49408 + hidden_size: int = 512 + intermediate_size: int = 2048 + projection_dim: int = 512 + num_hidden_layers: int = 12 + num_attention_heads: int = 8 + max_position_embeddings: int = 77 + hidden_act: str = "quick_gelu" + layer_norm_eps: float = 1e-5 + attention_dropout: float = 0.0 + initializer_range: float = 0.02 + initializer_factor: float = 1.0 + pad_token_id: int = 1 + bos_token_id: int = 49406 + eos_token_id: int = 49407 + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + +class ClipConfig(ClipConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> ClipConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in ClipConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return ClipConfigBase(**init_dict) diff --git a/max/python/max/pipelines/lib/interfaces/max_model.py b/max/python/max/pipelines/lib/interfaces/max_model.py new file mode 100644 index 00000000000..3f323d81ad8 --- /dev/null +++ b/max/python/max/pipelines/lib/interfaces/max_model.py @@ -0,0 +1,45 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from max.driver import Device +from max.engine import Model +from max.graph.weights import Weights + +if TYPE_CHECKING: + from max.pipelines.lib import SupportedEncoding + + +class MaxModel(ABC): + """Base interface for pipeline models with weight-backed execution.""" + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + self.config = config + self.encoding = encoding + self.devices = devices + self.weights = weights + + @abstractmethod + def load_model(self) -> Model: + """Load and return a runtime model instance.""" + ... From dc452f937c4c864446cf840e3e0a7580bf0252ab Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Wed, 21 Jan 2026 09:53:19 +0000 Subject: [PATCH 07/10] feat: add t5 for flux1 pipeline support --- .../pipelines/architectures/t5/__init__.py | 14 + .../max/pipelines/architectures/t5/model.py | 54 ++ .../architectures/t5/model_config.py | 69 ++ .../max/pipelines/architectures/t5/t5.py | 821 ++++++++++++++++++ .../architectures/t5/weight_adapters.py | 61 ++ 5 files changed, 1019 insertions(+) create mode 100644 max/python/max/pipelines/architectures/t5/__init__.py create mode 100644 max/python/max/pipelines/architectures/t5/model.py create mode 100644 max/python/max/pipelines/architectures/t5/model_config.py create mode 100644 max/python/max/pipelines/architectures/t5/t5.py create mode 100644 max/python/max/pipelines/architectures/t5/weight_adapters.py diff --git a/max/python/max/pipelines/architectures/t5/__init__.py b/max/python/max/pipelines/architectures/t5/__init__.py new file mode 100644 index 00000000000..ad108912d49 --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .model import T5Model diff --git a/max/python/max/pipelines/architectures/t5/model.py b/max/python/max/pipelines/architectures/t5/model.py new file mode 100644 index 00000000000..bfbb5e7659c --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/model.py @@ -0,0 +1,54 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import Device +from max.engine import Model +from max.experimental import functional as F +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .model_config import T5Config +from .t5 import T5EncoderModel +from .weight_adapters import convert_safetensor_state_dict + + +class T5Model(MaxModel): + config_name = T5Config.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__(config, encoding, devices, weights) + self.config = T5Config.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Model: + state_dict = {key: value.data() for key, value in self.weights.items()} + state_dict = convert_safetensor_state_dict(state_dict) + with F.lazy(): + t5 = T5EncoderModel(self.config) + t5.to(self.devices[0]) + self.model = t5.compile(*t5.input_types(), weights=state_dict) + return self.model + + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/t5/model_config.py b/max/python/max/pipelines/architectures/t5/model_config.py new file mode 100644 index 00000000000..ab28ce4b2a8 --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/model_config.py @@ -0,0 +1,69 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class T5ConfigBase(MAXModelConfigBase): + vocab_size: int = 32128 + d_model: int = 512 + d_kv: int = 64 + d_ff: int = 2048 + num_layers: int = 6 + num_decoder_layers: int | None = None + num_heads: int = 8 + relative_attention_num_buckets: int = 32 + relative_attention_max_distance: int = 128 + dropout_rate: float = 0.1 + layer_norm_epsilon: float = 1e-6 + initializer_factor: float = 1.0 + feed_forward_proj: str = "relu" + dense_act_fn: str | None = Field(default=None, exclude=True) + is_gated_act: bool = Field(default=False, exclude=True) + is_decoder: bool = Field(default=False, exclude=True) + is_encoder_decoder: bool = True + use_cache: bool = True + pad_token_id: int = 0 + eos_token_id: int = 1 + classifier_dropout: float = 0.0 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + dtype: DType = DType.bfloat16 + + +class T5Config(T5ConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> T5ConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in T5ConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return T5ConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/t5/t5.py b/max/python/max/pipelines/architectures/t5/t5.py new file mode 100644 index 00000000000..43491949829 --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/t5.py @@ -0,0 +1,821 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +from max.driver import Device +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import TensorType +from max.nn.module_v3 import Embedding, Linear, Module +from max.nn.module_v3.sequential import ModuleList + +from .model_config import T5Config + + +class T5LayerNorm(Module): + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + dtype: DType = DType.float32, + ): + """Construct a layernorm module in the T5 style. No bias and no subtraction of mean. + + Args: + hidden_size: Hidden size. + eps: Epsilon. + dtype: Data type for the module. + """ + super().__init__() + self.weight = Tensor.ones([hidden_size]) + self.variance_epsilon = eps + self.dtype = dtype + + def forward(self, hidden_states: Tensor) -> Tensor: + """Process hidden states through the T5 layer norm. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean + # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated + # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for + # half-precision inputs is done in fp32 + + hidden_states_f32 = F.cast(hidden_states, DType.float32) + variance = F.mean(F.pow(hidden_states_f32, 2), axis=-1) + hidden_states = hidden_states * F.rsqrt( + variance + self.variance_epsilon + ) + + # convert into half-precision if necessary + if self.dtype in [DType.float16, DType.bfloat16]: + hidden_states = F.cast(hidden_states, self.dtype) + + return self.weight * hidden_states + + +class T5DenseActDense(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a dense-activation-dense module. + + Args: + config: T5 configuration for feed-forward dimensions and dtype. + """ + super().__init__() + self.wi = Linear( + config.d_model, + config.d_ff, + bias=False, + ) + self.wo = Linear( + config.d_ff, + config.d_model, + bias=False, + ) + self.act_fn = ( + lambda x: 0.5 + * x + * ( + 1.0 + + F.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * F.pow(x, 3.0)) + ) + ) + ) + + def forward(self, hidden_states: Tensor) -> Tensor: + """Process hidden states through the dense-activation-dense block. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + hidden_states = self.wi(hidden_states) + hidden_states = self.act_fn(hidden_states) + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5DenseGatedActDense(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a dense-gated-activation-dense module. + + Args: + config: T5 configuration for feed-forward dimensions and dtype. + """ + super().__init__() + self.wi_0 = Linear( + config.d_model, + config.d_ff, + bias=False, + ) + self.wi_1 = Linear( + config.d_model, + config.d_ff, + bias=False, + ) + self.wo = Linear( + config.d_ff, + config.d_model, + bias=False, + ) + self.act_fn = ( + lambda x: 0.5 + * x + * ( + 1.0 + + F.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * F.pow(x, 3.0)) + ) + ) + ) + + def forward(self, hidden_states: Tensor) -> Tensor: + """Process hidden states through the dense-gated-activation-dense block. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + hidden_gelu = self.act_fn(self.wi_0(hidden_states)) + hidden_linear = self.wi_1(hidden_states) + hidden_states = hidden_gelu * hidden_linear + hidden_states = self.wo(hidden_states) + return hidden_states + + +class T5LayerFF(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a feed-forward layer. + + Args: + config: T5 configuration for gating, dimensions, and dtype. + """ + super().__init__() + if config.is_gated_act: + self.DenseReluDense = T5DenseGatedActDense(config) + else: + self.DenseReluDense = T5DenseActDense(config) + + self.layer_norm = T5LayerNorm( + config.d_model, + eps=config.layer_norm_epsilon, + dtype=config.dtype, + ) + + def forward(self, hidden_states: Tensor) -> Tensor: + """Process hidden states through the feed-forward layer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = hidden_states + forwarded_states + return hidden_states + + +class T5Attention(Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool = False, + layer_idx: int | None = None, + ): + """Construct an attention layer. + + Args: + config: T5 configuration. + has_relative_attention_bias: Whether to use relative attention bias. + layer_idx: Index of the layer. + """ + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = ( + config.relative_attention_num_buckets + ) + self.relative_attention_max_distance = ( + config.relative_attention_max_distance + ) + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.device = config.device + self.dtype = config.dtype + + self.q = Linear( + self.d_model, + self.inner_dim, + bias=False, + ) + self.k = Linear( + self.d_model, + self.inner_dim, + bias=False, + ) + self.v = Linear( + self.d_model, + self.inner_dim, + bias=False, + ) + self.o = Linear( + self.inner_dim, + self.d_model, + bias=False, + ) + + if self.has_relative_attention_bias: + self.relative_attention_bias = Embedding( + self.relative_attention_num_buckets, + dim=self.n_heads, + ) + + def _relative_position_bucket( + self, + relative_position: Tensor, + bidirectional: bool = True, + num_buckets: int = 32, + max_distance: int = 128, + ) -> Tensor: + """Compute relative position bucket. + + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Args: + relative_position: Tensor with relative positions. + bidirectional: Whether the attention is bidirectional. + num_buckets: Number of buckets. + max_distance: Maximum distance for relative positions. + + Returns: + TensorValue: Relative position buckets. + """ + relative_buckets = Tensor.constant( + 0, dtype=DType.int32, device=relative_position.device + ) + + if bidirectional: + num_buckets = num_buckets // 2 + is_positive = F.greater(relative_position, 0) + relative_buckets = relative_buckets + ( + F.cast(is_positive, DType.int32) * num_buckets + ) + relative_position = F.abs(relative_position) + else: + relative_position = -F.min(relative_position, 0) + + max_exact = num_buckets // 2 + is_small = F.greater(max_exact, relative_position) + + scale = (num_buckets - max_exact) / math.log(max_distance / max_exact) + rel_pos_float = F.cast(relative_position, DType.float32) + val_log = F.log(rel_pos_float / float(max_exact)) + relative_position_if_large = max_exact + F.cast( + val_log * scale, DType.int32 + ) + relative_position_if_large = F.min( + relative_position_if_large, num_buckets - 1 + ) + return relative_buckets + F.where( + is_small, relative_position, relative_position_if_large + ) + + def compute_bias( + self, query_length: int, key_length: int, device: Device + ) -> Tensor: + """Compute relative attention bias. + + Args: + query_length: Length of the query sequence. + key_length: Length of the key sequence. + + Returns: + TensorValue: Relative attention bias tensor. + """ + context_position = F.arange( + 0, query_length, step=1, dtype=DType.int32, device=device + ) + context_position = F.unsqueeze(context_position, 1) + + memory_position = F.arange( + 0, key_length, step=1, dtype=DType.int32, device=device + ) + memory_position = F.unsqueeze(memory_position, 0) + + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) + values = F.permute(values, (2, 0, 1)) + values = F.unsqueeze(values, 0) + return values + + def forward( + self, + hidden_states: Tensor, + mask: Tensor | None = None, + key_value_states: Tensor | None = None, + position_bias: Tensor | None = None, + past_key_values: Tensor | None = None, + layer_head_mask: Tensor | None = None, + query_length: int | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Tensor | None = None, + ) -> tuple[Tensor, Tensor]: + """Process hidden states through the attention layer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + mask: Attention mask. + key_value_states: Key-value states for cross-attention. + position_bias: Position bias tensor. + past_key_values: Past key values for caching (not implemented). + layer_head_mask: Mask for attention heads. + query_length: Length of the query sequence. + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + Tuple[TensorValue, TensorValue]: Output tensor and position bias. + """ + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + if is_cross_attention: + raise NotImplementedError( + "T5 CrossAttention is not implemented yet." + ) + if past_key_values is not None: + raise NotImplementedError( + "T5 auto regressive model is not implemented yet." + ) + + query = self.q(hidden_states) + key = self.k(hidden_states) + value = self.v(hidden_states) + + # Reshape to (batch, seq, heads, head_dim) + query = F.reshape( + query, + (batch_size, seq_length, self.n_heads, self.key_value_proj_dim), + ) + key = F.reshape( + key, (batch_size, seq_length, self.n_heads, self.key_value_proj_dim) + ) + value = F.reshape( + value, + (batch_size, seq_length, self.n_heads, self.key_value_proj_dim), + ) + + # Transpose to (batch, heads, seq, head_dim) + query = F.permute(query, (0, 2, 1, 3)) + key = F.permute(key, (0, 2, 1, 3)) + value = F.permute(value, (0, 2, 1, 3)) + + scores = F.matmul(query, F.permute(key, (0, 1, 3, 2))) + + if position_bias is None and self.has_relative_attention_bias: + position_bias = self.compute_bias( + seq_length, seq_length, hidden_states.device + ) + + if position_bias is not None: + scores = scores + position_bias + + if mask is not None: + scores = scores + mask + + attn_weights = F.softmax(F.cast(scores, DType.float32), axis=-1) + attn_weights = F.cast(attn_weights, self.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = F.matmul(attn_weights, value) + attn_output = F.permute(attn_output, (0, 2, 1, 3)) + attn_output = F.reshape( + attn_output, (batch_size, seq_length, self.inner_dim) + ) + attn_output = self.o(attn_output) + + outputs = (attn_output, position_bias) + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class T5LayerSelfAttention(Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool = False, + layer_idx: int | None = None, + ): + """Construct a self-attention layer. + + Args: + config: T5 configuration. + has_relative_attention_bias: Whether to use relative attention bias. + layer_idx: Index of the layer. + """ + super().__init__() + self.SelfAttention = T5Attention( + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, + ) + self.layer_norm = T5LayerNorm( + config.d_model, + eps=config.layer_norm_epsilon, + dtype=config.dtype, + ) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor | None = None, + position_bias: Tensor | None = None, + layer_head_mask: Tensor | None = None, + past_key_values: Tensor | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Tensor | None = None, + ) -> Tensor: + """Process hidden states through the self-attention layer. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + attention_mask: Attention mask. + position_bias: Position bias tensor. + layer_head_mask: Mask for attention heads. + past_key_values: Past key values for caching (not implemented). + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + attention_output[0] + outputs = (hidden_states,) + attention_output[1:] + return outputs + + +class T5Block(Module): + def __init__( + self, + config: T5Config, + has_relative_attention_bias: bool = False, + layer_idx: int | None = None, + ): + """Construct a T5 block. + + Args: + config: T5 configuration. + has_relative_attention_bias: Whether to use relative attention bias. + layer_idx: Index of the layer. + """ + super().__init__() + layers = list() + self.is_decoder = config.is_decoder + if self.is_decoder: + raise NotImplementedError( + "T5 LayerCrossAttention is not implemented yet." + ) + + layers.append( + T5LayerSelfAttention( + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, + ) + ) + layers.append(T5LayerFF(config)) + self.layer = ModuleList(layers) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor | None = None, + position_bias: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + encoder_decoder_position_bias: Tensor | None = None, + cross_attn_layer_head_mask: Tensor | None = None, + layer_head_mask: Tensor | None = None, + past_key_values: Tensor | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Tensor | None = None, + ) -> tuple[Tensor, Tensor]: + """Process hidden states through the T5 block. + + Args: + hidden_states: Input tensor of shape (batch_size, seq_length, hidden_size). + attention_mask: Attention mask. + position_bias: Position bias tensor. + encoder_hidden_states: Encoder hidden states (not implemented). + encoder_attention_mask: Encoder attention mask (not implemented). + encoder_decoder_position_bias: Encoder-decoder position bias (not implemented). + cross_attn_layer_head_mask: Cross attention layer head mask (not implemented). + layer_head_mask: Mask for attention heads. + past_key_values: Past key values for caching (not implemented). + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + Tuple[TensorValue, TensorValue]: Output tensor and position bias. + """ + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = self_attention_outputs[0] + attention_outputs = self_attention_outputs[1:] + + if hidden_states.dtype == DType.float16: + clamp_value = DType.finfo(hidden_states.dtype).max - 1000 + hidden_states = hidden_states.clip( + min=-clamp_value, max=clamp_value + ) + + do_cross_attention = ( + self.is_decoder and encoder_hidden_states is not None + ) + if do_cross_attention: + raise NotImplementedError( + "T5 CrossAttention is not implemented yet." + ) + + hidden_states = self.layer[-1](hidden_states) + if hidden_states.dtype == DType.float16: + clamp_value = DType.finfo(hidden_states.dtype).max - 1000 + hidden_states = hidden_states.clip( + min=-clamp_value, max=clamp_value + ) + + outputs = (hidden_states,) + return outputs + attention_outputs + + +class T5Stack(Module): + def __init__( + self, + config: T5Config, + embed_tokens: Embedding | None = None, + ): + """Construct a T5 stack. + + Args: + config: T5 configuration. + embed_tokens: Embedding module. + """ + super().__init__() + self.config = config + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = ModuleList( + [ + T5Block( + config, + has_relative_attention_bias=bool(i == 0), + layer_idx=i, + ) + for i in range(config.num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm( + config.d_model, + eps=config.layer_norm_epsilon, + dtype=config.dtype, + ) + self.dropout = config.dropout_rate + self.device = config.device + self.dtype = config.dtype + + def forward( + self, + input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + inputs_embeds: Tensor | None = None, + encoder_hidden_states: Tensor | None = None, + encoder_attention_mask: Tensor | None = None, + encoder_decoder_position_bias: Tensor | None = None, + cross_attn_layer_head_mask: Tensor | None = None, + layer_head_mask: Tensor | None = None, + past_key_values: Tensor | None = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Tensor | None = None, + ) -> Tensor: + """Process input through the T5 stack. + + Args: + input_ids: Input IDs tensor of shape (batch_size, seq_length). + attention_mask: Attention mask tensor of shape (batch_size, seq_length). + inputs_embeds: Input embeddings tensor of shape (batch_size, seq_length, hidden_size). + encoder_hidden_states: Encoder hidden states (not implemented). + encoder_attention_mask: Encoder attention mask (not implemented). + encoder_decoder_position_bias: Encoder-decoder position bias (not implemented). + cross_attn_layer_head_mask: Cross attention layer head mask (not implemented). + layer_head_mask: Mask for attention heads. + past_key_values: Past key values for caching (not implemented). + use_cache: Whether to use cache (not implemented). + output_attentions: Whether to return attention weights. + cache_position: Cache position. + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + if input_ids is not None: + if input_ids.rank == 1: + input_ids = F.unsqueeze(input_ids, 0) + inputs_embeds = self.embed_tokens(input_ids) + elif inputs_embeds is None: + raise ValueError( + "You have to specify either input_ids or inputs_embeds" + ) + elif inputs_embeds.rank == 2: + inputs_embeds = F.unsqueeze(inputs_embeds, 0) + + if self.is_decoder or use_cache: + raise NotImplementedError("T5 decoder is not implemented yet.") + + hidden_states = inputs_embeds + + if attention_mask is not None: + if attention_mask.rank == 1: + attention_mask = F.unsqueeze(attention_mask, 0) + mask_multiplier = F.constant( + DType.finfo(hidden_states.dtype).min, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + causal_mask = ( + F.constant( + 1.0, + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + - F.cast(attention_mask, hidden_states.dtype) + ) * mask_multiplier + causal_mask = F.unsqueeze(causal_mask, 1) + causal_mask = F.unsqueeze(causal_mask, 1) + else: + causal_mask = None + encoder_extended_attention_mask = None + + position_bias = None + for layer_module in self.block: + layer_outputs = layer_module( + hidden_states, + attention_mask=causal_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + layer_head_mask=layer_head_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = layer_outputs[0] + position_bias = layer_outputs[1] + if output_attentions: + all_attentions = all_attentions + (layer_outputs[2],) + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class T5EncoderModel(Module): + def __init__( + self, + config: T5Config, + ): + """Construct a T5 encoder model. + + Args: + config: T5 configuration for vocabulary size, layer counts, and + device/dtype settings. + """ + super().__init__() + act_info = config.feed_forward_proj.split("-") + config.dense_act_fn = act_info[-1] + config.is_gated_act = act_info[0] == "gated" + + self.shared = Embedding( + config.vocab_size, + dim=config.d_model, + ) + + encoder_config = config + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + + self.encoder = T5Stack(encoder_config, self.shared) + self.device = config.device + self.dtype = config.dtype + + def input_types(self) -> tuple[TensorType, ...]: + """Get input types for the model. + + Returns: + tuple[TensorType, ...]: Input types. + """ + return ( + TensorType( + DType.int64, + shape=["batch_size", "sequence_length"], + device=self.device, + ), + ) + + def forward( + self, + input_ids: Tensor | None = None, + attention_mask: Tensor | None = None, + ) -> Tensor: + """Process input through the T5 encoder model. + + Args: + input_ids: Input IDs tensor of shape (batch_size, seq_length). + attention_mask: Attention mask tensor of shape (batch_size, seq_length). + + Returns: + TensorValue: Output tensor of shape (batch_size, seq_length, hidden_size). + """ + return self.encoder(input_ids=input_ids, attention_mask=attention_mask) diff --git a/max/python/max/pipelines/architectures/t5/weight_adapters.py b/max/python/max/pipelines/architectures/t5/weight_adapters.py new file mode 100644 index 00000000000..32bdb0c5aa8 --- /dev/null +++ b/max/python/max/pipelines/architectures/t5/weight_adapters.py @@ -0,0 +1,61 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # +"""Weight adapters for T5 models.""" + +from __future__ import annotations + +from collections.abc import Mapping + +from max.graph.weights import WeightData + + +def _clone_weight(weight: WeightData, new_name: str) -> WeightData: + return WeightData( + data=weight.data, + name=new_name, + dtype=weight.dtype, + shape=weight.shape, + quantization_encoding=weight.quantization_encoding, + ) + + +def convert_safetensor_state_dict( + state_dict: Mapping[str, WeightData], +) -> dict[str, WeightData]: + """Ensure shared T5 embeddings are available under MAX expected names. + + MAX's T5 encoder graph registers both `shared.weight` and + `encoder.embed_tokens.weight`. Some checkpoints only contain one of them, so + this adapter duplicates the available embedding to the missing name. + """ + new_state_dict = dict(state_dict) + shared_weight = new_state_dict.get("shared.weight") + encoder_weight = new_state_dict.get("encoder.embed_tokens.weight") + + if shared_weight is None and encoder_weight is None: + raise ValueError( + "Missing T5 embedding weights. Expected one of " + "`shared.weight` or `encoder.embed_tokens.weight` to be present." + ) + + if shared_weight is None and encoder_weight is not None: + new_state_dict["shared.weight"] = _clone_weight( + encoder_weight, "shared.weight" + ) + + if encoder_weight is None and shared_weight is not None: + new_state_dict["encoder.embed_tokens.weight"] = _clone_weight( + shared_weight, "encoder.embed_tokens.weight" + ) + + return new_state_dict From 74b220a9f816434733ea03a8ab6390fd6468aa75 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Wed, 21 Jan 2026 11:01:32 +0000 Subject: [PATCH 08/10] feat: add flux1 model for flux1 pipeline support --- .../pipelines/architectures/flux1/flux1.py | 500 ++++++++++++++++++ .../architectures/flux1/layers/__init__.py | 12 + .../architectures/flux1/layers/activations.py | 50 ++ .../architectures/flux1/layers/embeddings.py | 430 +++++++++++++++ .../flux1/layers/flux_attention.py | 427 +++++++++++++++ .../flux1/layers/normalizations.py | 232 ++++++++ .../pipelines/architectures/flux1/model.py | 59 +++ .../architectures/flux1/model_config.py | 59 +++ .../architectures/flux1/weight_adapters.py | 52 ++ 9 files changed, 1821 insertions(+) create mode 100644 max/python/max/pipelines/architectures/flux1/flux1.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/__init__.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/activations.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/embeddings.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/flux_attention.py create mode 100644 max/python/max/pipelines/architectures/flux1/layers/normalizations.py create mode 100644 max/python/max/pipelines/architectures/flux1/model.py create mode 100644 max/python/max/pipelines/architectures/flux1/model_config.py create mode 100644 max/python/max/pipelines/architectures/flux1/weight_adapters.py diff --git a/max/python/max/pipelines/architectures/flux1/flux1.py b/max/python/max/pipelines/architectures/flux1/flux1.py new file mode 100644 index 00000000000..9b912c036d3 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/flux1.py @@ -0,0 +1,500 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import logging +import os +from os import PathLike +from typing import Any + +from max.driver import DLPackArray +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import TensorType +from max.graph.weights import SafetensorWeights +from max.nn.module_v3 import Linear, Module +from max.nn.module_v3.norm import LayerNorm +from max.nn.module_v3.sequential import ModuleList + +from .layers.embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, +) +from .layers.flux_attention import FeedForward, FluxAttention, FluxPosEmbed +from .layers.normalizations import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, +) +from .model_config import FluxConfig + +logger = logging.getLogger(__name__) + + +def get_weight_registry_from_diffusers( + safe_tensor_folder: PathLike, +) -> dict[str, DLPackArray]: + weight_files = [ + os.path.join(safe_tensor_folder, f) + for f in os.listdir(safe_tensor_folder) + if f.endswith(".safetensors") + ] + weights = SafetensorWeights(weight_files) + return {name: weight.data().data for name, weight in weights.items()} + + +class FluxSingleTransformerBlock(Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + dtype: DType = DType.bfloat16, + ): + """Initialize Flux single transformer block. + + Args: + dim: Dimension of the input/output. + num_attention_heads: Number of attention heads. + attention_head_dim: Dimension of each attention head. + mlp_ratio: Ratio for MLP hidden dimension. + dtype: Data type for the module. + """ + super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim, dtype=dtype) + self.proj_mlp = Linear(dim, self.mlp_hidden_dim, bias=True) + self.act_mlp = F.gelu + self.proj_out = Linear( + dim + self.mlp_hidden_dim, + dim, + bias=True, + ) + self.attn = FluxAttention( + query_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, + eps=1e-6, + pre_only=True, + dtype=dtype, + ) + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + temb: Tensor, + image_rotary_emb: tuple[Tensor, Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[Tensor, Tensor]: + """Apply single transformer block with attention and MLP. + + Args: + hidden_states: Input hidden states. + encoder_hidden_states: Encoder hidden states for cross-attention. + temb: Time embedding. + image_rotary_emb: Optional rotary position embeddings. + joint_attention_kwargs: Optional attention kwargs. + + Returns: + Tuple of (encoder_hidden_states, hidden_states). + """ + text_seq_len = encoder_hidden_states.shape[1] + hidden_states = F.concat([encoder_hidden_states, hidden_states], axis=1) + + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp( + self.proj_mlp(norm_hidden_states), approximate="tanh" + ) + joint_attention_kwargs = joint_attention_kwargs or {} + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + hidden_states = F.concat([attn_output, mlp_hidden_states], axis=2) + gate = F.unsqueeze(gate, 1) + hidden_states = gate * self.proj_out(hidden_states) + hidden_states = residual + hidden_states + if hidden_states.dtype == DType.float16: + hidden_states = hidden_states.clip(min=-65504, max=65504) + + encoder_hidden_states, hidden_states = ( + hidden_states[:, :text_seq_len], + hidden_states[:, text_seq_len:], + ) + return encoder_hidden_states, hidden_states + + +class FluxTransformerBlock(Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, + dtype: DType = DType.bfloat16, + ): + """Initialize Flux transformer block. + + Args: + dim: Dimension of the input/output. + num_attention_heads: Number of attention heads. + attention_head_dim: Dimension of each attention head. + qk_norm: Type of normalization for query and key ("rms_norm"). + eps: Epsilon for normalization layers. + dtype: Data type for the module. + """ + super().__init__() + + self.norm1 = AdaLayerNormZero(dim, dtype=dtype) + self.norm1_context = AdaLayerNormZero(dim, dtype=dtype) + + self.attn = FluxAttention( + query_dim=dim, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, + eps=eps, + dtype=dtype, + ) + + self.norm2 = LayerNorm( + dim, + eps=1e-6, + keep_dtype=True, + elementwise_affine=False, + use_bias=False, + ) + self.ff = FeedForward( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + ) + + self.norm2_context = LayerNorm( + dim, + eps=1e-6, + keep_dtype=True, + elementwise_affine=False, + use_bias=False, + ) + self.ff_context = FeedForward( + dim=dim, + dim_out=dim, + activation_fn="gelu-approximate", + ) + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor, + temb: Tensor, + image_rotary_emb: tuple[Tensor, Tensor] | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + ) -> tuple[Tensor, Tensor]: + """Apply transformer block with dual-stream attention and feedforward. + + Args: + hidden_states: Input hidden states. + encoder_hidden_states: Encoder hidden states for cross-attention. + temb: Time embedding. + image_rotary_emb: Optional rotary position embeddings. + joint_attention_kwargs: Optional attention kwargs. + + Returns: + Tuple of (encoder_hidden_states, hidden_states). + """ + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.norm1(hidden_states, emb=temb) + ) + + ( + norm_encoder_hidden_states, + c_gate_msa, + c_shift_mlp, + c_scale_mlp, + c_gate_mlp, + ) = self.norm1_context(encoder_hidden_states, emb=temb) + joint_attention_kwargs = joint_attention_kwargs or {} + + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, + **joint_attention_kwargs, + ) + + attn_output, context_attn_output = attention_outputs + + # Process attention outputs for the `hidden_states`. + attn_output = F.unsqueeze(gate_msa, 1) * attn_output + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) + + ff_output = self.ff(norm_hidden_states) + ff_output = F.unsqueeze(gate_mlp, 1) * ff_output + + hidden_states = hidden_states + ff_output + + # Process attention outputs for the `encoder_hidden_states`. + context_attn_output = F.unsqueeze(c_gate_msa, 1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + + c_shift_mlp[:, None] + ) + + context_ff_output = self.ff_context(norm_encoder_hidden_states) + encoder_hidden_states = ( + encoder_hidden_states + + F.unsqueeze(c_gate_mlp, 1) * context_ff_output + ) + if encoder_hidden_states.dtype == DType.float16: + encoder_hidden_states = encoder_hidden_states.clip( + min=-65504, max=65504 + ) + + return encoder_hidden_states, hidden_states + + +class FluxTransformer2DModel(Module): + def __init__( + self, + config: FluxConfig, + ): + """Initialize Flux Transformer 2D model. + + Args: + config: Flux configuration containing model dimensions, attention + settings, and device/dtype information. + """ + super().__init__() + patch_size = config.patch_size + in_channels = config.in_channels + out_channels = config.out_channels + num_layers = config.num_layers + num_single_layers = config.num_single_layers + attention_head_dim = config.attention_head_dim + num_attention_heads = config.num_attention_heads + joint_attention_dim = config.joint_attention_dim + pooled_projection_dim = config.pooled_projection_dim + guidance_embeds = config.guidance_embeds + axes_dims_rope = config.axes_dims_rope + device = config.device + dtype = config.dtype + self.out_channels = out_channels or in_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + self.guidance_embeds = guidance_embeds + + text_time_guidance_cls = ( + CombinedTimestepGuidanceTextProjEmbeddings + if guidance_embeds + else CombinedTimestepTextProjEmbeddings + ) + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, + pooled_projection_dim=pooled_projection_dim, + dtype=dtype, + ) + self.context_embedder = Linear( + joint_attention_dim, + self.inner_dim, + bias=True, + ) + self.x_embedder = Linear( + in_channels, + self.inner_dim, + bias=True, + ) + + self.transformer_blocks = ModuleList( + [ + FluxTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dtype=dtype, + ) + for _ in range(num_layers) + ] + ) + + self.single_transformer_blocks = ModuleList( + [ + FluxSingleTransformerBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dtype=dtype, + ) + for _ in range(num_single_layers) + ] + ) + + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, eps=1e-6, dtype=dtype + ) + self.proj_out = Linear( + self.inner_dim, + patch_size * patch_size * self.out_channels, + bias=True, + ) + + self.gradient_checkpointing = False + + self.max_device = device + self.max_dtype = dtype + self.in_channels = in_channels + self.joint_attention_dim = joint_attention_dim + self.pooled_projection_dim = pooled_projection_dim + + def input_types(self) -> tuple[TensorType, ...]: + """Define input tensor types for the model. + + Returns: + Tuple of TensorType specifications for all model inputs. + """ + hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "image_seq_len", self.in_channels], + device=self.max_device, + ) + encoder_hidden_states_type = TensorType( + self.max_dtype, + shape=["batch_size", "text_seq_len", self.joint_attention_dim], + device=self.max_device, + ) + pooled_projections_type = TensorType( + self.max_dtype, + shape=["batch_size", self.pooled_projection_dim], + device=self.max_device, + ) + timestep_type = TensorType( + DType.float32, shape=["batch_size"], device=self.max_device + ) + img_ids_type = TensorType( + self.max_dtype, shape=["image_seq_len", 3], device=self.max_device + ) + txt_ids_type = TensorType( + self.max_dtype, shape=["text_seq_len", 3], device=self.max_device + ) + guidance_type = TensorType( + self.max_dtype, shape=["batch_size"], device=self.max_device + ) + + return ( + hidden_states_type, + encoder_hidden_states_type, + pooled_projections_type, + timestep_type, + img_ids_type, + txt_ids_type, + guidance_type, + ) + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor | None = None, + pooled_projections: Tensor | None = None, + timestep: Tensor | None = None, + img_ids: Tensor | None = None, + txt_ids: Tensor | None = None, + guidance: Tensor | None = None, + joint_attention_kwargs: dict[str, Any] | None = None, + controlnet_block_samples: Any | None = None, + controlnet_single_block_samples: Any | None = None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> tuple[Tensor]: + """Apply Flux Transformer 2D model forward pass. + + Args: + hidden_states: Input latent hidden states. + encoder_hidden_states: Text encoder hidden states. + pooled_projections: Pooled text embeddings. + timestep: Diffusion timestep. + img_ids: Image position IDs. + txt_ids: Text position IDs. + guidance: Guidance scale values. + joint_attention_kwargs: Additional attention arguments. + controlnet_block_samples: Optional controlnet block samples. + controlnet_single_block_samples: Optional controlnet single block samples. + return_dict: Whether to return as dictionary. + controlnet_blocks_repeat: Whether to repeat controlnet blocks. + + Returns: + Tuple containing output tensor. + """ + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + + hidden_states = self.x_embedder(hidden_states) + + timestep = F.cast(timestep, hidden_states.dtype) + timestep = timestep * 1000.0 + if guidance is not None: + guidance = F.cast(guidance, hidden_states.dtype) * 1000.0 + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if not self.guidance_embeds + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + ids = F.concat((txt_ids, img_ids), axis=0) + image_rotary_emb = self.pos_embed(ids) + + for block in self.transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + for block in self.single_transformer_blocks: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + return (output,) diff --git a/max/python/max/pipelines/architectures/flux1/layers/__init__.py b/max/python/max/pipelines/architectures/flux1/layers/__init__.py new file mode 100644 index 00000000000..75c4f824f20 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/__init__.py @@ -0,0 +1,12 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # diff --git a/max/python/max/pipelines/architectures/flux1/layers/activations.py b/max/python/max/pipelines/architectures/flux1/layers/activations.py new file mode 100644 index 00000000000..c1a723ed78d --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/activations.py @@ -0,0 +1,50 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.nn.module_v3 import Linear, Module + + +class GELU(Module): + def __init__( + self, + dim_in: int, + dim_out: int, + approximate: str = "none", + bias: bool = True, + ): + """Initialize GELU activation layer with linear projection. + + Args: + dim_in: Input dimension. + dim_out: Output dimension. + approximate: Approximation type for GELU ("none" or "tanh"). + bias: Whether to include bias in the linear projection. + """ + super().__init__() + self.proj = Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def forward(self, hidden_states: Tensor) -> Tensor: + """Apply GELU activation to the input. + + Args: + hidden_states: Input tensor. + + Returns: + Output tensor after linear projection and GELU activation. + """ + hidden_states = self.proj(hidden_states) + hidden_states = F.gelu(hidden_states, approximate=self.approximate) + return hidden_states diff --git a/max/python/max/pipelines/architectures/flux1/layers/embeddings.py b/max/python/max/pipelines/architectures/flux1/layers/embeddings.py new file mode 100644 index 00000000000..dccbc0e0bd3 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/embeddings.py @@ -0,0 +1,430 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.nn.module_v3 import Linear, Module + + +def apply_rotary_emb( + x: Tensor, + freqs_cis: tuple[Tensor, Tensor], + sequence_dim: int = 2, +) -> Tensor: + """Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency + tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped + for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. + + Args: + x: Query or key tensor to apply rotary embeddings. Shape depends on + caller; the last dimension is split into complex pairs. + freqs_cis: Precomputed cosine/sine frequency tensors for complex + exponentials. Shape ([S, D], [S, D]). + sequence_dim: Dimension representing the sequence (1 or 2). + + Returns: + Tensor: Tensor with rotary embeddings applied. + """ + cos, sin = freqs_cis # [S, D] + if sequence_dim == 2: + cos = cos[None, None, :, :] + sin = sin[None, None, :, :] + elif sequence_dim == 1: + cos = cos[None, :, None, :] + sin = sin[None, :, None, :] + else: + raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.") + + cos, sin = cos.to(x.device), sin.to(x.device) + + # Used for flux, cogvideox, hunyuan-dit + half_last_dim = x.shape[-1] // 2 + x_reshaped = F.reshape(x, list(x.shape[:-1]) + [half_last_dim, 2]) + chunks = F.split(x_reshaped, 1, axis=-1) + x_real = F.squeeze(chunks[0], axis=-1) + x_imag = F.squeeze(chunks[1], axis=-1) + # Stack and flatten: [B, S, H, D//2] -> [B, S, H, D//2, 2] -> [B, S, H, D] + x_rotated_stacked = F.stack([-x_imag, x_real], axis=-1) + batch_sz = x_rotated_stacked.shape[0] + seq_len = x_rotated_stacked.shape[1] + heads = x_rotated_stacked.shape[2] + flattened_last_dim = x_rotated_stacked.shape[3] * x_rotated_stacked.shape[4] + x_rotated = F.reshape( + x_rotated_stacked, (batch_sz, seq_len, heads, flattened_last_dim) + ) + + out = ( + F.cast(x, DType.float32) * cos + F.cast(x_rotated, DType.float32) * sin + ).cast(x.dtype) + + return out + + +def get_timestep_embedding( + timesteps: Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +) -> Tensor: + """Create sinusoidal timestep embeddings. + + Matches the implementation in Diffusers/DDPM. + """ + half_dim = embedding_dim // 2 + + # Create exponent: -math.log(max_period) * arange(0, half_dim) + # ops.range creates a sequence tensor + exponent = F.arange( + 0, half_dim, step=1, dtype=DType.float32, device=timesteps.device + ) + exponent = exponent * (-math.log(max_period)) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = F.exp(exponent) + + timesteps_f32 = F.cast(timesteps, DType.float32) + timesteps_dim = timesteps_f32.shape[0] + emb_dim = emb.shape[0] + emb = timesteps_f32.reshape((timesteps_dim, 1)) * emb.reshape((1, emb_dim)) + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = F.concat([F.sin(emb), F.cos(emb)], axis=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = F.concat([emb[:, half_dim:], emb[:, :half_dim]], axis=-1) + + # zero pad if embedding_dim is odd (rare case) + if embedding_dim % 2 == 1: + # Pad with one zero column at the end + zeros = Tensor.zeros( + (emb.shape[0], 1), dtype=emb.dtype, device=timesteps.device + ) + emb = F.concat([emb, zeros], axis=-1) + + return emb + + +class Timesteps(Module): + def __init__( + self, + num_channels: int, + flip_sin_to_cos: bool, + downscale_freq_shift: float, + scale: int = 1, + ): + """Initialize Timesteps embedding module. + + Args: + num_channels: Number of channels in the embedding. + flip_sin_to_cos: Whether to flip sine and cosine embeddings. + downscale_freq_shift: Frequency downscaling shift parameter. + scale: Scaling factor for embeddings. + """ + super().__init__() + self.num_channels = num_channels + self.flip_sin_to_cos = flip_sin_to_cos + self.downscale_freq_shift = downscale_freq_shift + self.scale = scale + + def forward(self, timesteps: Tensor) -> Tensor: + """Generate timestep embeddings. + + Args: + timesteps: Input timestep values. + + Returns: + Timestep embeddings. + """ + t_emb = get_timestep_embedding( + timesteps, + self.num_channels, + flip_sin_to_cos=self.flip_sin_to_cos, + downscale_freq_shift=self.downscale_freq_shift, + scale=self.scale, + ) + return t_emb + + +class TimestepEmbedding(Module): + def __init__( + self, + in_channels: int, + time_embed_dim: int, + act_fn: str = "silu", + out_dim: int | None = None, + post_act_fn: str | None = None, + cond_proj_dim: int | None = None, + sample_proj_bias: bool = True, + ): + """Initialize TimestepEmbedding module. + + Args: + in_channels: Number of input channels. + time_embed_dim: Dimension of the time embedding. + act_fn: Activation function to use ("silu", "swish", or "gelu"). + out_dim: Optional output dimension. Defaults to time_embed_dim if None. + post_act_fn: Optional post-activation function. + cond_proj_dim: Optional conditional projection dimension. + sample_proj_bias: Whether to use bias in projection layers. + """ + super().__init__() + + self.linear_1 = Linear( + in_channels, + time_embed_dim, + bias=sample_proj_bias, + ) + + if cond_proj_dim is not None: + self.cond_proj = Linear( + cond_proj_dim, + in_channels, + bias=False, + ) + else: + self.cond_proj = None + if act_fn == "silu" or act_fn == "swish": + self.act_fn = F.silu + elif act_fn == "gelu": + self.act_fn = F.gelu + else: + raise ValueError(f"Invalid activation function: {act_fn}") + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + + self.linear_2 = Linear( + time_embed_dim, + time_embed_dim_out, + bias=sample_proj_bias, + ) + + if post_act_fn is None: + self.post_act_fn = None + elif post_act_fn == "silu" or post_act_fn == "swish": + self.post_act_fn = F.silu + elif post_act_fn == "gelu": + self.post_act_fn = F.gelu + else: + raise ValueError(f"Invalid post activation function: {post_act_fn}") + + def forward( + self, sample: Tensor, condition: Tensor | None = None + ) -> Tensor: + """Generate timestep embeddings with optional conditioning. + + Args: + sample: Input sample tensor. + condition: Optional conditioning tensor. + + Returns: + Timestep embeddings. + """ + if condition is not None and self.cond_proj is not None: + sample = sample + self.cond_proj(condition) + + sample = self.linear_1(sample) + + sample = self.act_fn(sample) + + sample = self.linear_2(sample) + + if self.post_act_fn is not None: + sample = self.post_act_fn(sample) + + return sample + + +class PixArtAlphaTextProjection(Module): + """Projects caption embeddings. Also handles dropout for classifier-free guidance.""" + + def __init__( + self, + in_features: int, + hidden_size: int, + out_features: int | None = None, + act_fn: str = "gelu_tanh", + ): + """Initialize PixArtAlpha text projection module. + + Args: + in_features: Number of input features. + hidden_size: Size of the hidden layer. + out_features: Number of output features. Defaults to hidden_size if None. + act_fn: Activation function to use ("gelu_tanh" or "silu"). + """ + super().__init__() + if out_features is None: + out_features = hidden_size + self.linear_1 = Linear(in_features, hidden_size, bias=True) + self.linear_2 = Linear(hidden_size, out_features, bias=True) + if act_fn == "gelu_tanh": + self.act_fn = lambda x: F.gelu(x, approximate="tanh") + elif act_fn == "silu": + self.act_fn = F.silu + else: + raise ValueError(f"Invalid activation function: {act_fn}") + + def forward(self, caption: Tensor) -> Tensor: + """Project caption embeddings. + + Args: + caption: Input caption embeddings. + + Returns: + Projected caption embeddings. + """ + hidden_states = self.linear_1(caption) + + hidden_states = self.act_fn(hidden_states) + + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class CombinedTimestepTextProjEmbeddings(Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + dtype: DType = DType.bfloat16, + ): + """Initialize combined timestep and text projection embeddings module. + + Args: + embedding_dim: Dimension of the embedding. + pooled_projection_dim: Dimension of the pooled projection. + dtype: Data type for the module. + """ + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + ) + self.text_embedder = PixArtAlphaTextProjection( + pooled_projection_dim, + embedding_dim, + act_fn="silu", + ) + + def forward(self, timestep: Tensor, pooled_projection: Tensor) -> Tensor: + """Combine timestep and text embeddings. + + Args: + timestep: Input timestep values. + pooled_projection: Pooled text projection. + + Returns: + Combined conditioning embeddings. + """ + # Timestep projection and embedding + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + F.cast(timesteps_proj, pooled_projection.dtype) + ) + + # Text projection + pooled_projections = self.text_embedder(pooled_projection) + + # Combine + conditioning = timesteps_emb + pooled_projections + + return conditioning + + +class CombinedTimestepGuidanceTextProjEmbeddings(Module): + def __init__( + self, + embedding_dim: int, + pooled_projection_dim: int, + dtype: DType = DType.bfloat16, + ): + """Initialize combined timestep, guidance, and text projection embeddings module. + + Args: + embedding_dim: Dimension of the embedding. + pooled_projection_dim: Dimension of the pooled projection. + dtype: Data type for the module. + """ + super().__init__() + + self.time_proj = Timesteps( + num_channels=256, + flip_sin_to_cos=True, + downscale_freq_shift=0, + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + ) + self.guidance_embedder = TimestepEmbedding( + in_channels=256, + time_embed_dim=embedding_dim, + ) + self.text_embedder = PixArtAlphaTextProjection( + pooled_projection_dim, + embedding_dim, + act_fn="silu", + ) + + def forward( + self, + timestep: Tensor, + guidance: Tensor, + pooled_projection: Tensor, + ) -> Tensor: + """Combine timestep, guidance, and text embeddings. + + Args: + timestep: Input timestep values. + guidance: Guidance values. + pooled_projection: Pooled text projection. + + Returns: + Combined conditioning embeddings. + """ + timesteps_proj = self.time_proj(timestep) + timesteps_emb = self.timestep_embedder( + F.cast(timesteps_proj, pooled_projection.dtype) + ) + + guidance_proj = self.time_proj(guidance) + guidance_emb = self.guidance_embedder( + F.cast(guidance_proj, pooled_projection.dtype) + ) + + time_guidance_emb = timesteps_emb + guidance_emb + + pooled_projections = self.text_embedder(pooled_projection) + conditioning = time_guidance_emb + pooled_projections + + return conditioning diff --git a/max/python/max/pipelines/architectures/flux1/layers/flux_attention.py b/max/python/max/pipelines/architectures/flux1/layers/flux_attention.py new file mode 100644 index 00000000000..5f295599422 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/flux_attention.py @@ -0,0 +1,427 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import math + +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.graph import DeviceRef +from max.nn.attention.mask_config import MHAMaskVariant +from max.nn.kernels import flash_attention_gpu +from max.nn.module_v3 import Linear, Module +from max.nn.module_v3.norm import RMSNorm +from max.nn.module_v3.sequential import ModuleList + +from .activations import GELU +from .embeddings import apply_rotary_emb + + +class FluxAttention(Module): + """Flux attention mechanism with QK normalization and optional dual stream.""" + + def __init__( + self, + query_dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + added_kv_proj_dim: int | None = None, + added_proj_bias: bool | None = True, + out_bias: bool = True, + eps: float = 1e-5, + out_dim: int | None = None, + context_pre_only: bool | None = None, + pre_only: bool = False, + elementwise_affine: bool = True, + dtype: DType = DType.bfloat16, + ): + """Initialize Flux attention module. + + Args: + query_dim: Dimension of query vectors. + heads: Number of attention heads. + dim_head: Dimension of each attention head. + dropout: Dropout probability. + bias: Whether to use bias in projections. + added_kv_proj_dim: Optional dimension for additional key/value projections. + added_proj_bias: Whether to use bias in additional projections. + out_bias: Whether to use bias in output projection. + eps: Epsilon for normalization layers. + out_dim: Optional output dimension. + context_pre_only: Whether to use context pre-processing only. + pre_only: Whether to use pre-processing only. + elementwise_affine: Whether to use elementwise affine in normalization. + dtype: Data type for the module. + """ + super().__init__() + + self.head_dim = dim_head + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.dropout = dropout + self.out_dim = out_dim if out_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.heads = out_dim // dim_head if out_dim is not None else heads + self.added_kv_proj_dim = added_kv_proj_dim + self.added_proj_bias = added_proj_bias + self.dtype = dtype + + self.norm_q = RMSNorm(dim_head, eps=eps) + self.norm_k = RMSNorm(dim_head, eps=eps) + self.to_q = Linear( + query_dim, + self.inner_dim, + bias=bias, + ) + self.to_k = Linear( + query_dim, + self.inner_dim, + bias=bias, + ) + self.to_v = Linear( + query_dim, + self.inner_dim, + bias=bias, + ) + + if not self.pre_only: + self.to_out = Linear( + self.inner_dim, + self.out_dim, + bias=out_bias, + ) + + if added_kv_proj_dim is not None: + self.norm_added_q = RMSNorm(dim_head, eps=eps) + self.norm_added_k = RMSNorm(dim_head, eps=eps) + self.add_q_proj = Linear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + ) + self.add_k_proj = Linear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + ) + self.add_v_proj = Linear( + added_kv_proj_dim, + self.inner_dim, + bias=added_proj_bias, + ) + self.to_add_out = Linear( + self.inner_dim, + query_dim, + bias=out_bias, + ) + + def forward( + self, + hidden_states: Tensor, + encoder_hidden_states: Tensor | None = None, + attention_mask: Tensor | None = None, + image_rotary_emb: tuple[Tensor, Tensor] | None = None, + ) -> Tensor | tuple[Tensor, Tensor]: + """Apply Flux attention to hidden states. + + Args: + hidden_states: Input hidden states. + encoder_hidden_states: Optional encoder hidden states for cross-attention. + attention_mask: Optional attention mask. + image_rotary_emb: Optional rotary embeddings for position encoding. + + Returns: + Output hidden states after attention, or tuple of (hidden_states, encoder_hidden_states) if encoder states provided. + """ + batch_size = hidden_states.shape[0] + + # get qkv projections + query = self.to_q(hidden_states) + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + + seq_len = query.shape[1] + query = F.reshape( + query, (batch_size, seq_len, self.heads, self.head_dim) + ) + key = F.reshape(key, (batch_size, seq_len, self.heads, self.head_dim)) + value = F.reshape( + value, (batch_size, seq_len, self.heads, self.head_dim) + ) + + query = self.norm_q(query) + key = self.norm_k(key) + + encoder_query = encoder_key = encoder_value = None + if ( + encoder_hidden_states is not None + and self.added_kv_proj_dim is not None + ): + encoder_query = self.add_q_proj(encoder_hidden_states) + encoder_key = self.add_k_proj(encoder_hidden_states) + encoder_value = self.add_v_proj(encoder_hidden_states) + + query = self.norm_q(query) + key = self.norm_k(key) + + if ( + encoder_hidden_states is not None + and self.added_kv_proj_dim is not None + ): + encoder_seq_len = encoder_query.shape[1] + encoder_query = F.reshape( + encoder_query, + (batch_size, encoder_seq_len, self.heads, self.head_dim), + ) + encoder_key = F.reshape( + encoder_key, + (batch_size, encoder_seq_len, self.heads, self.head_dim), + ) + encoder_value = F.reshape( + encoder_value, + (batch_size, encoder_seq_len, self.heads, self.head_dim), + ) + + encoder_query = self.norm_added_q(encoder_query) + encoder_key = self.norm_added_k(encoder_key) + + query = F.concat([encoder_query, query], axis=1) + key = F.concat([encoder_key, key], axis=1) + value = F.concat([encoder_value, value], axis=1) + + if image_rotary_emb is not None: + query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) + key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) + + hidden_states = flash_attention_gpu( + query, + key, + value, + mask_variant=MHAMaskVariant.NULL_MASK, + scale=math.sqrt(1.0 / self.head_dim), + ) + + total_seq_len = hidden_states.shape[1] + hidden_states = F.reshape( + hidden_states, + (batch_size, total_seq_len, self.heads * self.head_dim), + ) + + if encoder_hidden_states is not None: + encoder_seq_len = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :encoder_seq_len, :] + hidden_states = hidden_states[:, encoder_seq_len:, :] + + hidden_states = self.to_out(hidden_states) + encoder_hidden_states = self.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + return hidden_states + + +class FeedForward(Module): + def __init__( + self, + dim: int, + dim_out: int | None = None, + mult: int = 4, + activation_fn: str = "geglu", + inner_dim: int | None = None, + bias: bool = True, + ): + """Initialize FeedForward module. + + Args: + dim: Input dimension. + dim_out: Optional output dimension. Defaults to dim if None. + mult: Multiplier for hidden dimension. + activation_fn: Activation function to use ("gelu" or "gelu-approximate"). + inner_dim: Optional inner dimension. Computed as dim * mult if None. + bias: Whether to use bias in linear layers. + """ + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU( + dim, + inner_dim, + approximate="tanh", + bias=bias, + ) + else: + raise NotImplementedError( + f"Activation function {activation_fn} is not implemented" + ) + + self.net = ModuleList( + [ + act_fn, + Linear( + inner_dim, + dim_out, + bias=bias, + ), + ] + ) + + def forward(self, hidden_states: Tensor, *args, **kwargs) -> Tensor: + """Apply feedforward network to hidden states. + + Args: + hidden_states: Input hidden states. + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments (unused). + + Returns: + Output hidden states after feedforward network. + """ + for layer in self.net: + hidden_states = layer(hidden_states) + return hidden_states + + +class FluxPosEmbed(Module): + """Flux Position Embedding module for 3D rotary position embeddings. + + This module computes separate rotary embeddings for each spatial dimension + (typically time, height, width) and concatenates them. + + Args: + theta: Base value for frequency computation (typically 10000) + axes_dim: List of dimensions for each axis (e.g., [16, 56, 56] for time, height, width) + """ + + def __init__( + self, theta: int = 10000, axes_dim: tuple[int, int, int] = (16, 56, 56) + ): + """Initialize Flux position embedding module. + + Args: + theta: Base value for frequency computation (typically 10000). + axes_dim: Dimensions for each axis (e.g., [16, 56, 56] for time, height, width). + """ + super().__init__() + self.theta = float(theta) + self.axes_dim = list(axes_dim) + + def _get_1d_rotary_pos_embed( + self, dim: int, pos: Tensor, device: DeviceRef + ) -> tuple[Tensor, Tensor]: + """Compute 1D rotary position embeddings for a single axis. + + Args: + dim: Dimension of the embedding (should be even) + pos: Position indices, shape [batch_size] + device: Device to compute on + + Returns: + Tuple of (freqs_cos, freqs_sin), each with shape [batch_size, dim] + """ + # Ensure dim is even + assert dim % 2 == 0, f"dim must be even, got {dim}" + + # Cast position to float32 for computation + pos = F.cast(pos, DType.float32) + + # Compute frequencies: 1.0 / (theta ** (arange(0, dim, 2) / dim)) + # Shape: [dim/2] + arange_vals = F.arange( + 0, dim, step=2, dtype=DType.float32, device=device + ) + exponents = arange_vals / float(dim) + + # theta ** exponents + theta_tensor = F.constant(self.theta, DType.float32, device=device) + theta_powered = F.pow(theta_tensor, exponents) + + # 1.0 / theta_powered + freqs = 1.0 / theta_powered # Shape: [dim/2] + + # Outer product: pos [batch_size] x freqs [dim/2] = [batch_size, dim/2] + freqs_outer = F.outer(pos, freqs) + + # Compute cos and sin + freqs_cos_half = F.cos(freqs_outer) # [batch_size, dim/2] + freqs_sin_half = F.sin(freqs_outer) # [batch_size, dim/2] + + # Repeat interleave to get full dimension + # repeat_interleave(2, dim=1): [a, b, c] -> [a, a, b, b, c, c] + # Since repeat_interleave is not supported on GPU, we use reshape + tile + + # 1. Unsqueeze: [batch_size, dim/2] -> [batch_size, dim/2, 1] + freqs_cos_expanded = F.unsqueeze(freqs_cos_half, axis=2) + freqs_sin_expanded = F.unsqueeze(freqs_sin_half, axis=2) + + # 2. Concat to duplicate: [batch_size, dim/2, 1] -> [batch_size, dim/2, 2] + freqs_cos_tiled = F.concat( + [freqs_cos_expanded, freqs_cos_expanded], axis=2 + ) + freqs_sin_tiled = F.concat( + [freqs_sin_expanded, freqs_sin_expanded], axis=2 + ) + + # 3. Reshape to flatten: [batch_size, dim/2, 2] -> [batch_size, dim] + flattened_dim = freqs_cos_tiled.shape[1] * freqs_cos_tiled.shape[2] + freqs_cos = F.reshape( + freqs_cos_tiled, (freqs_cos_tiled.shape[0], flattened_dim) + ) + freqs_sin = F.reshape( + freqs_sin_tiled, (freqs_sin_tiled.shape[0], flattened_dim) + ) + + return freqs_cos, freqs_sin + + def forward(self, ids: Tensor) -> tuple[Tensor, Tensor]: + """Forward pass to compute rotary position embeddings. + + Args: + ids: Position indices tensor with shape [batch_size, n_axes] + where n_axes is the number of spatial dimensions (e.g., 3 for time/height/width) + + Returns: + Tuple of (freqs_cos, freqs_sin) with concatenated embeddings from all axes + """ + # Get number of axes from the last dimension + n_axes = ids.shape[-1] + device = ids.device + + cos_out = [] + sin_out = [] + + # Compute embeddings for each axis + for i in range(int(n_axes)): + # Extract position for this axis: ids[:, i] + pos = ids[:, i] + + # Compute 1D rotary embeddings for this axis + cos_embed, sin_embed = self._get_1d_rotary_pos_embed( + dim=self.axes_dim[i], pos=pos, device=device + ) + + cos_out.append(cos_embed) + sin_out.append(sin_embed) + + # Concatenate embeddings from all axes along the last dimension + freqs_cos = F.concat(cos_out, axis=-1) # [batch_size, sum(axes_dim)] + freqs_sin = F.concat(sin_out, axis=-1) # [batch_size, sum(axes_dim)] + + return freqs_cos, freqs_sin diff --git a/max/python/max/pipelines/architectures/flux1/layers/normalizations.py b/max/python/max/pipelines/architectures/flux1/layers/normalizations.py new file mode 100644 index 00000000000..451809fa88a --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/layers/normalizations.py @@ -0,0 +1,232 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + + +from max.dtype import DType +from max.experimental import functional as F +from max.experimental.tensor import Tensor +from max.nn.module_v3 import Linear, Module +from max.nn.module_v3.norm import LayerNorm, RMSNorm + + +class AdaLayerNormZeroSingle(Module): + def __init__( + self, + embedding_dim: int, + norm_type: str = "layer_norm", + bias: bool = True, + dtype: DType = DType.bfloat16, + ): + """Initialize adaptive layer normalization zero single module. + + Args: + embedding_dim: Size of each embedding vector. + norm_type: Type of normalization to use ("layer_norm"). + bias: Whether to use bias in linear projection. + device: Device to place the module on. + dtype: Data type for the module. + """ + super().__init__() + self.linear = Linear( + embedding_dim, + 3 * embedding_dim, + bias=bias, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm( + embedding_dim, + use_bias=False, + eps=1e-6, + keep_dtype=True, + elementwise_affine=False, + ) + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward(self, x: Tensor, emb: Tensor | None = None) -> Tensor: + """Apply adaptive layer normalization. + + Args: + x: Input tensor. + emb: Optional embedding tensor for conditioning. + + Returns: + Tuple of normalized tensor and gate values. + """ + emb = self.linear(F.silu(emb)) + shift_msa, scale_msa, gate_msa = F.chunk(emb, 3, axis=1) + x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa + + +class AdaLayerNormZero(Module): + r"""Norm layer adaptive layer norm zero (adaLN-Zero). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + num_embeddings (`int`): The size of the embeddings dictionary. + """ + + def __init__( + self, + embedding_dim: int, + num_embeddings: int | None = None, + norm_type: str = "layer_norm", + bias: bool = True, + dtype: DType = DType.bfloat16, + ): + """Initialize adaptive layer normalization zero module. + + Args: + embedding_dim: Size of each embedding vector. + num_embeddings: Optional size of the embeddings dictionary. + norm_type: Type of normalization to use ("layer_norm" or "fp32_layer_norm"). + bias: Whether to use bias in linear projection. + dtype: Data type for the module. + """ + super().__init__() + if num_embeddings is not None: + # self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) + raise NotImplementedError( + "CombinedTimestepLabelEmbeddings is not implemented" + ) + else: + self.emb = None + + self.linear = Linear( + embedding_dim, + 6 * embedding_dim, + bias=bias, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm( + embedding_dim, + use_bias=False, + eps=1e-6, + keep_dtype=True, + elementwise_affine=False, + ) + elif norm_type == "fp32_layer_norm": + # self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False) + raise NotImplementedError("FP32LayerNorm is not implemented") + else: + raise ValueError( + f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'." + ) + + def forward( + self, + x: Tensor, + timestep: Tensor | None = None, + class_labels: Tensor | None = None, + hidden_dtype: DType | None = None, + emb: Tensor | None = None, + ) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + """Apply adaptive layer normalization with gate values for attention and MLP. + + Args: + x: Input tensor. + timestep: Optional timestep tensor. + class_labels: Optional class label tensor. + hidden_dtype: Optional hidden data type. + emb: Optional embedding tensor for conditioning. + + Returns: + Tuple of (normalized tensor, gate_msa, shift_mlp, scale_mlp, gate_mlp). + """ + if self.emb is not None: + emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype) + emb = self.linear(F.silu(emb)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + F.chunk(emb, 6, axis=1) + ) + x = self.norm(x) + x = x * (1 + scale_msa[:, None]) + shift_msa[:, None] + return x, gate_msa, shift_mlp, scale_mlp, gate_mlp + + +class AdaLayerNormContinuous(Module): + r"""Adaptive normalization layer with a norm layer (layer_norm or rms_norm). + + Args: + embedding_dim (`int`): Embedding dimension to use during projection. + conditioning_embedding_dim (`int`): Dimension of the input condition. + elementwise_affine (`bool`, defaults to `True`): + Boolean flag to denote if affine transformation should be applied. + eps (`float`, defaults to 1e-5): Epsilon factor. + bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use. + norm_type (`str`, defaults to `"layer_norm"`): + Normalization layer to use. Values supported: "layer_norm", "rms_norm". + """ + + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + # elementwise_affine=True, + eps: float = 1e-5, + bias: bool = True, + norm_type: str = "layer_norm", + dtype: DType = DType.bfloat16, + ): + """Initialize adaptive layer normalization continuous module. + + Args: + embedding_dim: Embedding dimension to use during projection. + conditioning_embedding_dim: Dimension of the input condition. + eps: Epsilon factor for normalization. + bias: Whether to use bias in linear projection. + norm_type: Type of normalization to use ("layer_norm" or "rms_norm"). + dtype: Data type for the module. + """ + super().__init__() + self.silu = F.silu + self.linear = Linear( + conditioning_embedding_dim, + embedding_dim * 2, + bias=bias, + ) + if norm_type == "layer_norm": + self.norm = LayerNorm( + embedding_dim, + eps=eps, + keep_dtype=True, + elementwise_affine=False, + ) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps=eps) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: Tensor, conditioning_embedding: Tensor) -> Tensor: + """Apply adaptive layer normalization with conditioning. + + Args: + x: Input tensor. + conditioning_embedding: Conditioning embedding tensor. + + Returns: + Normalized and conditioned tensor. + """ + # convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT) + emb = self.linear(F.cast(self.silu(conditioning_embedding), x.dtype)) + scale, shift = F.chunk(emb, 2, axis=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x diff --git a/max/python/max/pipelines/architectures/flux1/model.py b/max/python/max/pipelines/architectures/flux1/model.py new file mode 100644 index 00000000000..9c5382b71ed --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/model.py @@ -0,0 +1,59 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.driver import Device +from max.engine import Model +from max.experimental import functional as F +from max.graph.weights import Weights +from max.pipelines.lib import SupportedEncoding +from max.pipelines.lib.interfaces.max_model import MaxModel + +from .flux1 import FluxTransformer2DModel +from .model_config import FluxConfig +from .weight_adapters import convert_safetensor_state_dict + + +class Flux1Model(MaxModel): + config_name = FluxConfig.config_name + + def __init__( + self, + config: dict, + encoding: SupportedEncoding, + devices: list[Device], + weights: Weights, + ) -> None: + super().__init__( + config, + encoding, + devices, + weights, + ) + self.config = FluxConfig.generate( + config, + encoding, + devices, + ) + self.load_model() + + def load_model(self) -> Model: + state_dict = {key: value.data() for key, value in self.weights.items()} + state_dict = convert_safetensor_state_dict(state_dict) + with F.lazy(): + flux = FluxTransformer2DModel(self.config) + flux.to(self.devices[0]) + self.model = flux.compile(*flux.input_types(), weights=state_dict) + return self.model + + def __call__(self, *args, **kwargs): + return self.model(*args, **kwargs) diff --git a/max/python/max/pipelines/architectures/flux1/model_config.py b/max/python/max/pipelines/architectures/flux1/model_config.py new file mode 100644 index 00000000000..c9292030000 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/model_config.py @@ -0,0 +1,59 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from typing import ClassVar + +from max.driver import Device +from max.dtype import DType +from max.graph import DeviceRef +from max.pipelines.lib import MAXModelConfigBase, SupportedEncoding +from pydantic import Field + + +class FluxConfigBase(MAXModelConfigBase): + patch_size: int = 1 + in_channels: int = 64 + out_channels: int | None = None + num_layers: int = 19 + num_single_layers: int = 38 + attention_head_dim: int = 128 + num_attention_heads: int = 24 + joint_attention_dim: int = 4096 + pooled_projection_dim: int = 768 + guidance_embeds: bool = False + axes_dims_rope: tuple[int, int, int] = (16, 56, 56) + dtype: DType = DType.bfloat16 + device: DeviceRef = Field(default_factory=DeviceRef.GPU) + + +class FluxConfig(FluxConfigBase): + config_name: ClassVar[str] = "config.json" + + @staticmethod + def generate( + config_dict: dict, + encoding: SupportedEncoding, + devices: list[Device], + ) -> FluxConfigBase: + init_dict = { + key: value + for key, value in config_dict.items() + if key in FluxConfigBase.__annotations__ + } + init_dict.update( + { + "dtype": encoding.dtype, + "device": DeviceRef.from_device(devices[0]), + } + ) + return FluxConfigBase(**init_dict) diff --git a/max/python/max/pipelines/architectures/flux1/weight_adapters.py b/max/python/max/pipelines/architectures/flux1/weight_adapters.py new file mode 100644 index 00000000000..1d88bb47c18 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/weight_adapters.py @@ -0,0 +1,52 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +import re + +from max.graph.weights import WeightData + + +def convert_safetensor_state_dict( + state_dict: dict[str, WeightData], +) -> dict[str, WeightData]: + keys = list(state_dict.keys()) + for key in keys: + # Remap net.2 to net.1: Diffusers uses [GELU, Dropout, Linear], while MAX uses [GELU, Linear]. + if re.match( + r"transformer_blocks\.\d+\.(ff|ff_context)\.net\.2\.(weight|bias)", + key, + ): + state_dict[key.replace("net.2.", "net.1.")] = state_dict.pop(key) + + # Remap attention output projection: + # Diffusers commonly represents `to_out` as a small Sequential/ModuleList like: + # to_out = [Linear(...), Dropout(...)] + # producing weight names `to_out.0.weight` / `to_out.0.bias`. + # In this MAX port, `to_out` is a single `Linear`, producing `to_out.weight` / `to_out.bias`. + if re.match( + r"transformer_blocks\.\d+\.attn\.to_out\.0\.(weight|bias)", + key, + ): + state_dict[key.replace("to_out.0.", "to_out.")] = state_dict.pop( + key + ) + + # Same pattern for the added/context stream output. + if re.match( + r"transformer_blocks\.\d+\.attn\.to_add_out\.0\.(weight|bias)", + key, + ): + state_dict[key.replace("to_add_out.0.", "to_add_out.")] = ( + state_dict.pop(key) + ) + return state_dict From e5957d346231b12af3ad674b3a5e0b2dc9b8b1ae Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Thu, 22 Jan 2026 01:51:46 +0000 Subject: [PATCH 09/10] feat: add flux1 pipeline --- .../max/pipelines/architectures/__init__.py | 2 + .../pipelines/architectures/flux1/__init__.py | 14 +++++++ .../max/pipelines/architectures/flux1/arch.py | 37 +++++++++++++++++++ .../architectures/flux1/pipeline_flux.py | 4 +- 4 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 max/python/max/pipelines/architectures/flux1/__init__.py create mode 100644 max/python/max/pipelines/architectures/flux1/arch.py diff --git a/max/python/max/pipelines/architectures/__init__.py b/max/python/max/pipelines/architectures/__init__.py index c0b61c1d620..3b370d9f9de 100644 --- a/max/python/max/pipelines/architectures/__init__.py +++ b/max/python/max/pipelines/architectures/__init__.py @@ -29,6 +29,7 @@ def register_all_models() -> None: from .deepseekV32 import deepseekV32_arch from .eagle_llama3 import eagle_llama_arch from .exaone import exaone_arch + from .flux1 import flux1_arch from .gemma3 import gemma3_arch from .gemma3multimodal import gemma3_multimodal_arch from .gpt_oss import gpt_oss_arch @@ -56,6 +57,7 @@ def register_all_models() -> None: deepseekV3_arch, deepseekV32_arch, eagle_llama_arch, + flux1_arch, gemma3_arch, gemma3_multimodal_arch, granite_arch, diff --git a/max/python/max/pipelines/architectures/flux1/__init__.py b/max/python/max/pipelines/architectures/flux1/__init__.py new file mode 100644 index 00000000000..2325700031e --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/__init__.py @@ -0,0 +1,14 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from .arch import flux1_arch diff --git a/max/python/max/pipelines/architectures/flux1/arch.py b/max/python/max/pipelines/architectures/flux1/arch.py new file mode 100644 index 00000000000..0fe73e91514 --- /dev/null +++ b/max/python/max/pipelines/architectures/flux1/arch.py @@ -0,0 +1,37 @@ +# ===----------------------------------------------------------------------=== # +# Copyright (c) 2025, Modular Inc. All rights reserved. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions: +# https://llvm.org/LICENSE.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ===----------------------------------------------------------------------=== # + +from max.graph.weights import WeightsFormat +from max.interfaces import BaseContext, PipelineTask +from max.pipelines.lib import ( + SupportedArchitecture, + SupportedEncoding, + TextTokenizer, +) + +from .pipeline_flux import FluxPipeline + +flux1_arch = SupportedArchitecture( + name="FluxPipeline", + task=PipelineTask.IMAGE_GENERATION, + default_encoding=SupportedEncoding.bfloat16, + supported_encodings={SupportedEncoding.bfloat16: []}, + example_repo_ids=[ + "black-forest-labs/FLUX.1-dev", + "black-forest-labs/FLUX.1-schnell", + ], + pipeline_model=FluxPipeline, + tokenizer=TextTokenizer, + context_type=BaseContext, + default_weights_format=WeightsFormat.safetensors, +) diff --git a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py index 44a450eed2d..7fe6e77746f 100644 --- a/max/python/max/pipelines/architectures/flux1/pipeline_flux.py +++ b/max/python/max/pipelines/architectures/flux1/pipeline_flux.py @@ -40,7 +40,7 @@ T5TokenizerFast, ) -from ..autoencoder_kl import AutoencoderKLModel +from ..autoencoders import AutoencoderKLModel from ..clip import ClipModel from ..t5 import T5Model from .model import Flux1Model @@ -682,7 +682,7 @@ def __call__( ) # NOTE: Convert timesteps to a Max Tensor before denoising loop, - # as in the original implementation, results in a significant slow down. + # as in the original diffusers implementation, results in a significant slow down. # As a workaround, we keep timesteps as a numpy array and convert it # to a Max Tensor here. This might require a more efficient way to handle this. # Converting to a Max module V3 Tensor also results in a significant slow down. From c16451a0dfd49d501b959ced01452283116dd806 Mon Sep 17 00:00:00 2001 From: Minkyu Kim Date: Tue, 3 Feb 2026 13:31:29 +0000 Subject: [PATCH 10/10] chore: add benchmark --- max/python/max/entrypoints/cli/generate.py | 56 ++++++++++++++++++---- max/python/max/entrypoints/pipelines.py | 15 +++++- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/max/python/max/entrypoints/cli/generate.py b/max/python/max/entrypoints/cli/generate.py index 1bcd5aed3a1..aeb05c757a9 100644 --- a/max/python/max/entrypoints/cli/generate.py +++ b/max/python/max/entrypoints/cli/generate.py @@ -18,6 +18,7 @@ import asyncio import dataclasses import logging +import time from collections.abc import Iterable from pathlib import Path from typing import Any @@ -170,18 +171,57 @@ def generate_image( guidance_scale: float, num_images_per_prompt: int, output: Path, + benchmark: bool = False, ) -> None: from ..diffusion import DiffusionPipeline pipeline = DiffusionPipeline(pipeline_config) - result = pipeline( - prompt=prompt, - height=height, - width=width, - num_inference_steps=num_inference_steps, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - ) + + def run_pipeline() -> Any: + return pipeline( + prompt=prompt, + height=height, + width=width, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + num_images_per_prompt=num_images_per_prompt, + ) + + if benchmark: + num_warmups = 0 + num_benchmark_runs = 2 + + print(f"\n{'='*60}") + print("BENCHMARK MODE") + print(f"{'='*60}") + + # Warm-up runs + print(f"\nRunning {num_warmups} warm-up iterations...") + for i in range(num_warmups): + print(f" Warm-up {i + 1}/{num_warmups}...") + run_pipeline() + + # Benchmark runs + print(f"\nRunning {num_benchmark_runs} benchmark iterations...") + times: list[float] = [] + for i in range(num_benchmark_runs): + start = time.perf_counter() + result = run_pipeline() + elapsed = time.perf_counter() - start + times.append(elapsed) + print(f" Run {i + 1}/{num_benchmark_runs}: {elapsed:.4f}s") + + # Report results + avg_time = sum(times) / len(times) + print(f"\n{'='*60}") + print("BENCHMARK RESULTS") + print(f"{'='*60}") + print(f" Individual times: {', '.join(f'{t:.4f}s' for t in times)}") + print(f" Average time: {avg_time:.4f}s") + print(f"{'='*60}\n") + + else: + result = run_pipeline() images = result.images assert images, "Expected at least one generated image." diff --git a/max/python/max/entrypoints/pipelines.py b/max/python/max/entrypoints/pipelines.py index a99c8c6fb18..2276c01564e 100644 --- a/max/python/max/entrypoints/pipelines.py +++ b/max/python/max/entrypoints/pipelines.py @@ -454,6 +454,12 @@ def diffusion_group() -> None: show_default=True, help="Random seed for torch-based latent initialization.", ) +@click.option( + "--benchmark", + is_flag=True, + default=False, + help="Enable benchmarking: 2 warm-ups + 5 timed runs, report average.", +) def diffusion_generate( prompt: str, height: int, @@ -464,15 +470,19 @@ def diffusion_generate( output: Path, use_torch_randn: bool, seed: int, + benchmark: bool, **config_kwargs: Any, ) -> None: - """Generate images using a diffusion pipeline.""" + """ + Generate images using a diffusion pipeline. + Example: ./bazelw run //max/python/max/entrypoints:pipelines_diffusion -- generate --benchmark + """ from max.entrypoints.cli.generate import generate_image from max.experimental.realization_context import set_seed from max.pipelines import PipelineConfig set_seed(seed) - pipeline_config = PipelineConfig(**config_kwargs) + pipeline_config = PipelineConfig(model_path="black-forest-labs/FLUX.1-dev") pipeline_config.log_basic_config() try: @@ -485,6 +495,7 @@ def diffusion_generate( guidance_scale=guidance_scale, num_images_per_prompt=num_images_per_prompt, output=output, + benchmark=benchmark, ) except Exception as exc: logger.exception(