Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions max/examples/diffusion/offline_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

import argparse
from pathlib import Path

from max.entrypoints.diffusion import DiffusionPipeline
from max.experimental.realization_context import set_seed
from max.pipelines import PipelineConfig


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-path", type=str, default="black-forest-labs/FLUX.1-dev"
)
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

model_path = args.model_path
set_seed(args.seed)
pipeline_config = PipelineConfig(model_path=model_path)
pipe = DiffusionPipeline(pipeline_config)

prompt = "A cat holding a sign that says hello world"
print(f"Prompt: {prompt}")

result = pipe(
prompt=prompt,
height=1024,
width=1024,
num_inference_steps=50,
guidance_scale=3.5,
)

images = result.images

output_path = Path("output.png")
output_path.parent.mkdir(parents=True, exist_ok=True)
images[0].save(output_path)
print(f"Image saved to: {output_path}")


if __name__ == "__main__":
main()
22 changes: 19 additions & 3 deletions max/python/max/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

import argparse
import enum
import json
import logging
import os
import types
from abc import abstractmethod
from collections.abc import Mapping
Expand Down Expand Up @@ -506,6 +508,7 @@ def _extract_max_config_data(
config_dict: The loaded YAML configuration dictionary.
config_class: The config class we're extracting data for.
section_name: Optional specific section name to look for.
config_file_path: Path to the config file for resolving inheritance.

Returns:
Configuration data for the specific config class.
Expand Down Expand Up @@ -854,9 +857,9 @@ def _add_field_as_argument(
):
# For enums, use the string value as default but we'll need to convert back
arg_kwargs = {
"default": field_value.value
if field_value
else field_obj.default
"default": (
field_value.value if field_value else field_obj.default
)
}
else:
arg_kwargs = {"default": field_value}
Expand Down Expand Up @@ -1071,6 +1074,19 @@ def parse_args( # type: ignore[override] # noqa: ANN202
return MAXConfigArgumentParser(parser, self)


def load_config(config_path: str | os.PathLike) -> dict:
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file not found: {config_path}")
try:
with open(config_path, encoding="utf-8") as f:
config_dict = json.loads(f.read())
except Exception as e:
raise ValueError(
f"Failed to load configuration from {config_path}: {e}"
) from e
Comment on lines +1083 to +1086
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a generic Exception is too broad and can hide unexpected errors. It's better to catch more specific exceptions that you expect to handle, such as json.JSONDecodeError for parsing issues and IOError for file reading problems.

Suggested change
except Exception as e:
raise ValueError(
f"Failed to load configuration from {config_path}: {e}"
) from e
except (json.JSONDecodeError, IOError) as e:
raise ValueError(
f"Failed to load configuration from {config_path}: {e}"
) from e

return config_dict


all = [
"MAXBaseModel",
"ConfigFileModel",
Expand Down
1 change: 1 addition & 0 deletions max/python/max/dtype/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from . import dtype_extension
from .dtype import DType
56 changes: 56 additions & 0 deletions max/python/max/dtype/dtype_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

"""Extension for max.dtype to support additional attributes."""

from numpy import finfo as np_finfo

from .dtype import DType


class finfo:
"""A numerical properties of a floating point max.dtype.DType.

This class mimics torch.finfo behavior without torch dependency,
including support for bfloat16.

NOTE: Currently, it's applied through patching.
This extension is better to be implemented in dtype library itself.
"""

def __init__(self, dtype: DType):
"""Initialize finfo for a given max.dtype.DType.

Args:
dtype: The data type to get limits for.
"""
if dtype == DType.bfloat16:
self.min = -3.38953e38
self.max = 3.38953e38
self.bits = 16
self.eps = 0.0078125
self.resolution = 0.01
self.tiny = 1.17549e-38
self.dtype = "bfloat16"
else:
np_finfo_obj = np_finfo(dtype.to_numpy())
self.min = float(np_finfo_obj.min)
self.max = float(np_finfo_obj.max)
self.bits = np_finfo_obj.bits
self.eps = float(np_finfo_obj.eps)
self.resolution = float(np_finfo_obj.resolution)
self.tiny = float(np_finfo_obj.tiny)
self.dtype = str(np_finfo_obj.dtype)


DType.finfo = finfo
36 changes: 36 additions & 0 deletions max/python/max/entrypoints/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,42 @@ modular_py_binary(
],
)

modular_py_binary(
name = "pipelines_diffusion",
srcs = [
"pipelines_diffusion.py",
],
data = [
"@nvshmem_prebuilt//:host",
],
env = {
"OTEL_EXPORTER_OTLP_METRICS_DEFAULT_HISTOGRAM_AGGREGATION": "base2_exponential_bucket_histogram",
"MODULAR_SHMEM_LIB_DIR": "../+http_archive+nvshmem_prebuilt",
},
mojo_deps = select({
"//:emit_mojo_enabled": PROD_MOJOPKGS,
"//conditions:default": [],
}),
deps = [
# Provides the `max.entrypoints.pipelines` module for the wrapper to import.
":_pipelines",
":entrypoints",
"//max/python/max:_core",
"//max/python/max/benchmark:benchmark_serving_lib",
"//max/python/max/interfaces",
"//max/python/max/pipelines",
"//max/python/max/serve:config",
"//max/python/max/serve/telemetry",
requirement("typing-extensions"),
requirement("click"),
] + select({
"//:nvidia_gpu": [
requirement("torch"),
],
"//conditions:default": [],
}),
)

modular_py_binary(
name = "replay_recording",
srcs = ["replay_recording.py"],
Expand Down
79 changes: 79 additions & 0 deletions max/python/max/entrypoints/cli/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import asyncio
import dataclasses
import logging
import time
from collections.abc import Iterable
from pathlib import Path
from typing import Any

import requests
Expand Down Expand Up @@ -158,3 +160,80 @@ def generate_text_for_pipeline(
print_tokens=True,
)
)


def generate_image(
pipeline_config: PipelineConfig,
prompt: str,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: float,
num_images_per_prompt: int,
output: Path,
benchmark: bool = False,
) -> None:
from ..diffusion import DiffusionPipeline

pipeline = DiffusionPipeline(pipeline_config)

def run_pipeline() -> Any:
return pipeline(
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
)

if benchmark:
num_warmups = 0
num_benchmark_runs = 2

print(f"\n{'='*60}")
print("BENCHMARK MODE")
print(f"{'='*60}")

# Warm-up runs
print(f"\nRunning {num_warmups} warm-up iterations...")
for i in range(num_warmups):
print(f" Warm-up {i + 1}/{num_warmups}...")
run_pipeline()

# Benchmark runs
print(f"\nRunning {num_benchmark_runs} benchmark iterations...")
times: list[float] = []
for i in range(num_benchmark_runs):
start = time.perf_counter()
result = run_pipeline()
elapsed = time.perf_counter() - start
times.append(elapsed)
print(f" Run {i + 1}/{num_benchmark_runs}: {elapsed:.4f}s")

# Report results
avg_time = sum(times) / len(times)
print(f"\n{'='*60}")
print("BENCHMARK RESULTS")
print(f"{'='*60}")
print(f" Individual times: {', '.join(f'{t:.4f}s' for t in times)}")
print(f" Average time: {avg_time:.4f}s")
print(f"{'='*60}\n")

else:
result = run_pipeline()

images = result.images
assert images, "Expected at least one generated image."

output.parent.mkdir(parents=True, exist_ok=True)
if num_images_per_prompt == 1:
images[0].save(output)
logger.info(f"Image saved to: {output}")
else:
stem = output.stem
suffix = output.suffix
for i, image in enumerate(images):
numbered_path = output.parent / f"{stem}_{i + 1}{suffix}"
image.save(numbered_path)
logger.info(f"{len(images)} images saved to: {output.parent}")
61 changes: 61 additions & 0 deletions max/python/max/entrypoints/diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #

from max.interfaces import (
ImageGenerationInputs,
ImageGenerationOutput,
PipelineTask,
)
from max.pipelines.lib import PIPELINE_REGISTRY, PipelineConfig


class DiffusionPipeline:
"""Entrypoint for image-generation diffusion pipelines."""

def __init__(self, pipeline_config: PipelineConfig) -> None:
# NOTE: Currently, this entrypoint is implemented minimally
# for offline image generation.
# It will be developed further to support serving as well.
self.pipeline_config = pipeline_config
_, model_factory = PIPELINE_REGISTRY.retrieve_factory(
pipeline_config,
task=PipelineTask.IMAGE_GENERATION,
)
self.pipeline = model_factory()

def __call__(
self,
prompt: str,
negative_prompt: str | None = None,
true_cfg_scale: float = 1.0,
height: int = 1024,
width: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 3.5,
num_images_per_prompt: int = 1,
) -> ImageGenerationOutput:
"""Generate images from a prompt with the configured pipeline."""
# TODO: consider all possible diffusion tasks,
# e.g. T2I, I2I, T2V, I2V, V2V.
inputs = ImageGenerationInputs(
prompt=prompt,
negative_prompt=negative_prompt,
true_cfg_scale=true_cfg_scale,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
)
pipeline_output: ImageGenerationOutput = self.pipeline.execute(inputs)
return pipeline_output
Loading