Skip to content

Commit f5f6760

Browse files
xibinliuecnal-cienet
authored andcommitted
NNX migration: NNX utils
- Add utils to manipulate the NNX shardings with abstract state of a model - also add unit tests for the utils - Extract mesh creation function to maxtext_utils.get_mesh_from_config() - also add unit tests for this func Note: flax v0.12 has DeprecationWarning in multiple places: - DeprecationWarning: '.value' access is now deprecated. Use variable.get_value() or variable[...] (for [Array]). - DeprecationWarning: 'VariableState' was removed, this is just an alias to 'Variable'. Plase use 'Variable' directly instead. But since the code needs to work with post-training, which currently requires flax v0.11, we didn't change code for these warnings.
1 parent 95cd2b2 commit f5f6760

4 files changed

Lines changed: 289 additions & 57 deletions

File tree

src/maxtext/utils/maxtext_utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import functools
1919
import pickle
2020
import os
21+
from typing import Sequence
2122

2223
from flax import linen as nn
2324
from flax.linen import partitioning as nn_partitioning
@@ -27,6 +28,7 @@
2728

2829
from jax.experimental import mesh_utils
2930
from jax.experimental.serialize_executable import deserialize_and_load
31+
from jax.sharding import AxisType, Mesh
3032

3133
import jax
3234
import jax.numpy as jnp
@@ -36,7 +38,8 @@
3638
import orbax.checkpoint.experimental.emergency.checkpoint_manager as emergency_checkpoint_manager
3739
import orbax.checkpoint.experimental.emergency.replicator_checkpoint_manager as emergency_replicator_checkpoint_manager
3840

39-
from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE
41+
from maxtext.configs import pyconfig
42+
from maxtext.common.common_types import DecoderBlockType, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE, ShardMode
4043
from maxtext.configs import types
4144
from maxtext.inference.page_manager import PageState
4245
from maxtext.common import checkpointing
@@ -1521,3 +1524,27 @@ def maybe_dump_jaxpr(config, p_train_step, train_step_inputs):
15211524
delete_local_after=config.dump_jaxpr_delete_local_after, # Keeping local for debugging
15221525
all_host_upload=False, # Only upload from lead host (Host 0)
15231526
)
1527+
1528+
1529+
def get_mesh_from_config(
1530+
config: pyconfig.HyperParameters,
1531+
devices: Sequence[jax.Device] | None = None,
1532+
) -> Mesh:
1533+
"""
1534+
Geh mesh from the configuration.
1535+
1536+
Args:
1537+
config: the configuration
1538+
devices: the devices
1539+
1540+
Returns:
1541+
the device mesh
1542+
"""
1543+
devices_array = create_device_mesh(config, devices)
1544+
1545+
if config.shard_mode == ShardMode.EXPLICIT:
1546+
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
1547+
else:
1548+
axis_types = tuple([AxisType.Auto] * len(config.mesh_axes))
1549+
1550+
return Mesh(devices_array, config.mesh_axes, axis_types=axis_types)

src/maxtext/utils/model_creation_utils.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,16 @@
1818
from collections.abc import Sequence
1919
from functools import partial
2020
from typing import overload
21-
2221
from etils import epath
2322
from flax import nnx
2423
import flax.linen as nn
2524
import jax
26-
from jax.sharding import AxisType, Mesh
25+
from jax.sharding import Mesh
2726
from maxtext.configs import pyconfig
28-
from maxtext.common.common_types import MODEL_MODE_TRAIN, ShardMode
27+
from maxtext.common.common_types import MODEL_MODE_TRAIN
2928
from maxtext.layers import quantizations
3029
from maxtext.models import models
31-
from maxtext.utils import max_utils
32-
from maxtext.utils import maxtext_utils
30+
from maxtext.utils import max_utils, maxtext_utils, maxtext_utils_nnx
3331
from orbax import checkpoint as ocp
3432

3533

@@ -40,6 +38,7 @@ def from_config(
4038
mesh: Mesh | None = None,
4139
*,
4240
model_mode: str = MODEL_MODE_TRAIN,
41+
rngs: None = None,
4342
) -> nn.Module:
4443
...
4544

@@ -80,15 +79,7 @@ def from_config(
8079
model = from_config(config)
8180
"""
8281
if mesh is None:
83-
devices_array = maxtext_utils.create_device_mesh(config, devices)
84-
85-
if config.shard_mode == ShardMode.EXPLICIT:
86-
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
87-
else:
88-
axis_types = tuple([AxisType.Auto] * len(config.mesh_axes))
89-
90-
mesh = Mesh(devices_array, config.mesh_axes, axis_types=axis_types)
91-
82+
mesh = maxtext_utils.get_mesh_from_config(config, devices)
9283
model = create_model(config, mesh, model_mode=model_mode, rngs=rngs)
9384

9485
# Return only the model
@@ -114,16 +105,10 @@ def create_model(config, mesh, model_mode: str = MODEL_MODE_TRAIN, rngs: nnx.Rng
114105

115106
def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None):
116107
"""Creates a NNX model with sharded parameters, possibly loading from a checkpoint."""
108+
is_training = model_mode == MODEL_MODE_TRAIN
117109

118110
def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None):
119-
if rng_key is None:
120-
rng_key = jax.random.PRNGKey(config.init_weights_seed)
121-
122-
if model_mode == MODEL_MODE_TRAIN:
123-
rngs = nnx.Rngs(params=rng_key, dropout=1)
124-
else:
125-
rngs = nnx.Rngs(params=rng_key) # disable dropout RNG for inference
126-
111+
rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=is_training, rng_key=rng_key)
127112
return from_config(config, devices, mesh, rngs=rngs, model_mode=model_mode)
128113

129114
_create_model_partial = partial(_create_model, mesh=mesh, model_mode=model_mode, rng_key=rng_key)
@@ -136,6 +121,17 @@ def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN,
136121
if mesh is None:
137122
mesh = abstract_model.mesh
138123

124+
# Note for pure_nnx:
125+
# Currently, the NNX model returned has a linen decoder wrapped to NNX. So it is not a pure NNX model and
126+
# we still need to use nn.logical_axis_rules(config.logical_axis_rules) to get the out sharding from the linen
127+
# LogicallyPartitioned structure.
128+
# In the future if the pure NNX model is used, with pure NNX's eager sharding, there will be no LogicallyPartitioned
129+
# structure in the abstract state and we can get the sharded state with the following code:
130+
# graphdef, state = nnx.get_abstract_model(_create_model_partial, mesh)
131+
# abstract_model = nnx.merge(graphdef, state)
132+
# model = maxtext_utils_nnx.create_nnx_sharded_model(abstract_model, _create_model_partial, mesh=mesh)
133+
# sharded_state = nnx.state(model)
134+
139135
# JIT a function that creates the model state with proper sharding from the start.
140136
# By providing out_shardings, we instruct JAX to produce sharded output directly,
141137
# avoiding a large intermediate allocation on a single device.
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
""" Tests for the common MaxText NNX utilities """
16+
import unittest
17+
from dataclasses import dataclass
18+
from typing import Any
19+
import jax
20+
from flax import nnx
21+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
22+
from jax.experimental import mesh_utils
23+
24+
from maxtext.utils import maxtext_utils_nnx
25+
26+
27+
class TestMaxTextUtilsNNX(unittest.TestCase):
28+
"""Test the functions for MaxText Utils."""
29+
30+
@dataclass
31+
class MockConfig:
32+
"""Minimal mock for pyconfig.HyperParameters."""
33+
34+
init_weights_seed: int = 42
35+
36+
class TinyModel(nnx.Module):
37+
"""
38+
A tiny NNX model with logical annotations.
39+
Annotations are required to test that sharding extraction logic works.
40+
"""
41+
42+
def __init__(self, rngs: nnx.Rngs):
43+
self.linear = nnx.Linear(
44+
jax.device_count(),
45+
jax.device_count(),
46+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("data", None)),
47+
# FIX: Removed () from zeros. zeros is the initializer function itself,
48+
# not a factory like lecun_normal().
49+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("data",)),
50+
rngs=rngs,
51+
)
52+
53+
def tiny_model_init_fn(self):
54+
"""Factory function for model initialization."""
55+
return self.TinyModel(rngs=nnx.Rngs(0))
56+
57+
def setUp(self):
58+
# Create a mesh for sharding tests.
59+
# NamedSharding requires an active Mesh to resolve logical names.
60+
self.devices = mesh_utils.create_device_mesh((jax.device_count(),))
61+
self.mesh = Mesh(self.devices, axis_names=("data",))
62+
63+
def test_create_nnx_rngs_training(self):
64+
# Using Any to satisfy static type checkers for the MockConfig
65+
config: Any = self.MockConfig(init_weights_seed=123)
66+
rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=True)
67+
68+
self.assertIsInstance(rngs, nnx.Rngs)
69+
# FIX: nnx.Rngs does not have a .streams attribute.
70+
# Check for stream attributes directly on the object.
71+
self.assertTrue(hasattr(rngs, "params"))
72+
self.assertTrue(hasattr(rngs, "dropout"))
73+
self.assertTrue(hasattr(rngs, "aqt"))
74+
75+
def test_create_nnx_rngs_inference(self):
76+
config: Any = self.MockConfig(init_weights_seed=123)
77+
rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=False)
78+
79+
self.assertIsInstance(rngs, nnx.Rngs)
80+
# Check that 'params' exists but 'dropout' and 'aqt' were excluded
81+
self.assertTrue(hasattr(rngs, "params"))
82+
self.assertFalse(hasattr(rngs, "dropout"))
83+
self.assertFalse(hasattr(rngs, "aqt"))
84+
85+
def test_move_memory(self):
86+
sharding = NamedSharding(self.mesh, P("data"))
87+
self.assertNotEqual(sharding.memory_kind, "pinned_host")
88+
89+
path = ("layers", "linear", "kernel")
90+
host_sharding = maxtext_utils_nnx.move_memory_to_host(path, sharding)
91+
92+
self.assertEqual(host_sharding.memory_kind, "pinned_host")
93+
self.assertEqual(host_sharding.spec, P("data"))
94+
95+
device_sharding = maxtext_utils_nnx.move_memory_to_device(path, sharding)
96+
97+
self.assertEqual(device_sharding.memory_kind, "device")
98+
self.assertEqual(device_sharding.spec, P("data"))
99+
100+
def test_get_set_named_sharding_nnx(self):
101+
# 1. Create the abstract state using standard NNX functional API
102+
_, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh)
103+
104+
# 2. Test extraction
105+
extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state)
106+
107+
# Verify kernel and bias match the P("data") annotations from TinyModel
108+
self.assertEqual(extracted_shardings.linear.kernel.get_value().spec, P("data", None))
109+
self.assertEqual(extracted_shardings.linear.bias.get_value().spec, P("data"))
110+
111+
# Target kernel spec update
112+
new_kernel_spec = P(None, "data")
113+
114+
def update_spec_fn(path, leaf_sharding):
115+
path_str = jax.tree_util.keystr(path)
116+
if "linear" in path_str and "kernel" in path_str:
117+
# Construct a new NamedSharding with the requested logical spec
118+
return NamedSharding(leaf_sharding.mesh, new_kernel_spec)
119+
return leaf_sharding
120+
121+
# Apply the spec change to the extracted sharding tree
122+
extracted_shardings = jax.tree.map_with_path(update_spec_fn, extracted_shardings)
123+
124+
# 3. Test setting new shardings
125+
# Transform the extracted shardings to host memory
126+
new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings)
127+
updated_abstract = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, new_shardings)
128+
129+
# Verify the metadata inside the abstract state leaf has updated its sharding
130+
self.assertEqual(updated_abstract.linear.kernel.sharding.memory_kind, "pinned_host")
131+
# Also verify the spec was updated successfully
132+
self.assertEqual(updated_abstract.linear.kernel.sharding.spec, new_kernel_spec)
133+
134+
# 4. Verify named sharding is preserved after NNX merge (update) and split (state)
135+
model = self.tiny_model_init_fn()
136+
nnx.update(model, updated_abstract)
137+
re_extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(nnx.state(model))
138+
139+
# Verify kernel and bias have expected sharding
140+
self.assertEqual(re_extracted_shardings.linear.kernel.get_value().spec, new_kernel_spec)
141+
self.assertEqual(re_extracted_shardings.linear.bias.get_value().spec, P("data"))
142+
143+
def test_create_nnx_sharded_model(self):
144+
# 1. Create abstract model
145+
graphdef, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh)
146+
abstract_model = nnx.merge(graphdef, abstract_state)
147+
148+
# 2. Modify shardings to trigger host offloading
149+
extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state)
150+
new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings)
151+
152+
# 3. Run the sharded creation
153+
# We pass the abstract model and use the custom sharding for instantiation
154+
sharded_model = maxtext_utils_nnx.create_nnx_sharded_model(
155+
abstract_model, self.tiny_model_init_fn, mesh=self.mesh, named_sharding=new_shardings
156+
)
157+
158+
# 4. Verify the model is concrete (contains Arrays) and sharded on host
159+
self.assertIsInstance(sharded_model.linear.kernel[...], jax.Array)
160+
self.assertEqual(sharded_model.linear.kernel[...].sharding.memory_kind, "pinned_host")
161+
162+
def test_get_partition_spec_nnx(self):
163+
"""Verifies extraction of PartitionSpecs from NamedShardings."""
164+
# 1. Create abstract state and get sharding
165+
_, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh)
166+
extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state)
167+
168+
# 2. Execute extraction
169+
spec = maxtext_utils_nnx.get_partition_spec_nnx(extracted_shardings)
170+
171+
# 3. Verify that the leaves are now raw PartitionSpecs
172+
# Expected values derived from TinyModel definition
173+
expected_spec_k = P("data", None)
174+
expected_spec_b = P("data")
175+
176+
self.assertEqual(spec["linear"]["kernel"], expected_spec_k)
177+
self.assertEqual(spec["linear"]["bias"], expected_spec_b)
178+
self.assertNotIsInstance(spec["linear"]["kernel"], NamedSharding)
179+
180+
181+
if __name__ == "__main__":
182+
unittest.main()

0 commit comments

Comments
 (0)