Skip to content

Commit 6252c71

Browse files
committed
Merge branch 'add/flux1-pipeline/models' into add-flux1-pipeline
2 parents 0213cbd + 2f99887 commit 6252c71

31 files changed

Lines changed: 4888 additions & 15 deletions

max/python/max/dtype/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
# limitations under the License.
1212
# ===----------------------------------------------------------------------=== #
1313

14+
from . import dtype_extension
1415
from .dtype import DType
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2025, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
"""Extension for max.dtype to support additional attributes."""
15+
16+
from numpy import finfo as np_finfo
17+
18+
from .dtype import DType
19+
20+
21+
class finfo:
22+
"""A numerical properties of a floating point max.dtype.DType.
23+
24+
This class mimics torch.finfo behavior without torch dependency,
25+
including support for bfloat16.
26+
27+
NOTE: Currently, it's applied through patching.
28+
This extension is better to be implemented in dtype library itself.
29+
"""
30+
31+
def __init__(self, dtype: DType):
32+
"""Initialize finfo for a given max.dtype.DType.
33+
34+
Args:
35+
dtype: The data type to get limits for.
36+
"""
37+
if dtype == DType.bfloat16:
38+
self.min = -3.38953e38
39+
self.max = 3.38953e38
40+
self.bits = 16
41+
self.eps = 0.0078125
42+
self.resolution = 0.01
43+
self.tiny = 1.17549e-38
44+
self.dtype = "bfloat16"
45+
else:
46+
np_finfo_obj = np_finfo(dtype.to_numpy())
47+
self.min = float(np_finfo_obj.min)
48+
self.max = float(np_finfo_obj.max)
49+
self.bits = np_finfo_obj.bits
50+
self.eps = float(np_finfo_obj.eps)
51+
self.resolution = float(np_finfo_obj.resolution)
52+
self.tiny = float(np_finfo_obj.tiny)
53+
self.dtype = str(np_finfo_obj.dtype)
54+
55+
56+
DType.finfo = finfo

max/python/max/nn/norm/group_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(
4545
eps: float = 1e-5,
4646
affine: bool = True,
4747
device: DeviceRef = DeviceRef.GPU(),
48+
dtype: DType = DType.float32,
4849
) -> None:
4950
super().__init__()
5051
self.num_groups = num_groups
@@ -65,13 +66,13 @@ def __init__(
6566
self.weight = Weight(
6667
name="weight",
6768
shape=(self.num_channels,),
68-
dtype=DType.float32,
69+
dtype=dtype,
6970
device=device,
7071
)
7172
self.bias = Weight(
7273
name="bias",
7374
shape=(self.num_channels,),
74-
dtype=DType.float32,
75+
dtype=dtype,
7576
device=device,
7677
)
7778

max/python/max/nn/norm/layer_norm.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -36,37 +36,56 @@ def __init__(
3636
dtype: DType,
3737
eps: float = 1e-5,
3838
use_bias: bool = True,
39+
keep_dtype: bool = False,
40+
elementwise_affine: bool = True,
3941
) -> None:
4042
super().__init__()
4143
self.devices = devices
42-
self.weight = Weight("weight", dtype, (dims,), device=self.devices[0])
43-
self.bias = (
44-
Weight("bias", dtype, (dims,), device=self.devices[0])
45-
if use_bias
46-
else None
47-
)
44+
if elementwise_affine:
45+
self.weight = Weight(
46+
"weight", dtype, (dims,), device=self.devices[0]
47+
)
48+
self.bias = (
49+
Weight("bias", dtype, (dims,), device=self.devices[0])
50+
if use_bias
51+
else None
52+
)
53+
else:
54+
self.weight = None
55+
self.bias = None
4856
self.eps = eps
4957
self.dim = dims
5058
self.dtype = dtype
59+
self.keep_dtype = keep_dtype
5160
self._sharding_strategy: ShardingStrategy | None = None
5261

5362
def __call__(self, input: TensorValue):
5463
# TODO: AIPIPE-95 Replace with a broadcasting rmo.layer_norm
5564
bias = (
56-
ops.cast(self.bias, DType.float32)
65+
self.bias
5766
if self.bias
5867
# If bias wasn't passed then use bias-less layer norm (beta = 0).
5968
else ops.broadcast_to(
60-
ops.constant(0.0, DType.float32, self.weight.device),
69+
ops.constant(0.0, self.dtype, input.device),
70+
shape=(input.shape[-1],),
71+
)
72+
)
73+
gamma = (
74+
self.weight
75+
if self.weight
76+
else ops.broadcast_to(
77+
ops.constant(1.0, self.dtype, input.device),
6178
shape=(input.shape[-1],),
6279
)
6380
)
64-
return ops.layer_norm(
65-
input.cast(DType.float32),
66-
gamma=ops.cast(self.weight, DType.float32),
67-
beta=bias,
81+
82+
output = ops.layer_norm(
83+
input=input if self.keep_dtype else input.cast(DType.float32),
84+
gamma=gamma if self.keep_dtype else ops.cast(gamma, DType.float32),
85+
beta=bias if self.keep_dtype else ops.cast(bias, DType.float32),
6886
epsilon=self.eps,
69-
).cast(input.dtype)
87+
)
88+
return output if self.keep_dtype else output.cast(input.dtype)
7089

7190
@property
7291
def sharding_strategy(self) -> ShardingStrategy | None:

max/python/max/pipelines/architectures/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def register_all_models() -> None:
2828
from .deepseekV3 import deepseekV3_arch
2929
from .eagle_llama3 import eagle_llama_arch
3030
from .exaone import exaone_arch
31+
from .flux1 import flux1_arch
3132
from .gemma3 import gemma3_arch
3233
from .gemma3multimodal import gemma3_multimodal_arch
3334
from .gpt_oss import gpt_oss_arch
@@ -54,6 +55,7 @@ def register_all_models() -> None:
5455
deepseekV2_arch,
5556
deepseekV3_arch,
5657
eagle_llama_arch,
58+
flux1_arch,
5759
gemma3_arch,
5860
gemma3_multimodal_arch,
5961
granite_arch,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Copyright (c) 2025, Modular Inc. All rights reserved.
3+
#
4+
# Licensed under the Apache License v2.0 with LLVM Exceptions:
5+
# https://llvm.org/LICENSE.txt
6+
#
7+
# Unless required by applicable law or agreed to in writing, software
8+
# distributed under the License is distributed on an "AS IS" BASIS,
9+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
# See the License for the specific language governing permissions and
11+
# limitations under the License.
12+
# ===----------------------------------------------------------------------=== #
13+
14+
from .model import AutoencoderKLModel

0 commit comments

Comments
 (0)