|
| 1 | +# Copyright 2023–2026 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +""" Tests for the common MaxText NNX utilities """ |
| 16 | +import unittest |
| 17 | +from dataclasses import dataclass |
| 18 | +from typing import Any |
| 19 | +import jax |
| 20 | +from flax import nnx |
| 21 | +from jax.sharding import Mesh, NamedSharding, PartitionSpec as P |
| 22 | +from jax.experimental import mesh_utils |
| 23 | + |
| 24 | +from maxtext.utils import maxtext_utils_nnx |
| 25 | + |
| 26 | + |
| 27 | +class TestMaxTextUtilsNNX(unittest.TestCase): |
| 28 | + """Test the functions for MaxText Utils.""" |
| 29 | + |
| 30 | + @dataclass |
| 31 | + class MockConfig: |
| 32 | + """Minimal mock for pyconfig.HyperParameters.""" |
| 33 | + |
| 34 | + init_weights_seed: int = 42 |
| 35 | + |
| 36 | + class TinyModel(nnx.Module): |
| 37 | + """ |
| 38 | + A tiny NNX model with logical annotations. |
| 39 | + Annotations are required to test that sharding extraction logic works. |
| 40 | + """ |
| 41 | + |
| 42 | + def __init__(self, rngs: nnx.Rngs): |
| 43 | + self.linear = nnx.Linear( |
| 44 | + jax.device_count(), |
| 45 | + jax.device_count(), |
| 46 | + kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("data", None)), |
| 47 | + # FIX: Removed () from zeros. zeros is the initializer function itself, |
| 48 | + # not a factory like lecun_normal(). |
| 49 | + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("data",)), |
| 50 | + rngs=rngs, |
| 51 | + ) |
| 52 | + |
| 53 | + def tiny_model_init_fn(self): |
| 54 | + """Factory function for model initialization.""" |
| 55 | + return self.TinyModel(rngs=nnx.Rngs(0)) |
| 56 | + |
| 57 | + def setUp(self): |
| 58 | + # Create a mesh for sharding tests. |
| 59 | + # NamedSharding requires an active Mesh to resolve logical names. |
| 60 | + self.devices = mesh_utils.create_device_mesh((jax.device_count(),)) |
| 61 | + self.mesh = Mesh(self.devices, axis_names=("data",)) |
| 62 | + |
| 63 | + def test_create_nnx_rngs_training(self): |
| 64 | + # Using Any to satisfy static type checkers for the MockConfig |
| 65 | + config: Any = self.MockConfig(init_weights_seed=123) |
| 66 | + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=True) |
| 67 | + |
| 68 | + self.assertIsInstance(rngs, nnx.Rngs) |
| 69 | + # FIX: nnx.Rngs does not have a .streams attribute. |
| 70 | + # Check for stream attributes directly on the object. |
| 71 | + self.assertTrue(hasattr(rngs, "params")) |
| 72 | + self.assertTrue(hasattr(rngs, "dropout")) |
| 73 | + self.assertTrue(hasattr(rngs, "aqt")) |
| 74 | + |
| 75 | + def test_create_nnx_rngs_inference(self): |
| 76 | + config: Any = self.MockConfig(init_weights_seed=123) |
| 77 | + rngs = maxtext_utils_nnx.create_nnx_rngs(config, is_training=False) |
| 78 | + |
| 79 | + self.assertIsInstance(rngs, nnx.Rngs) |
| 80 | + # Check that 'params' exists but 'dropout' and 'aqt' were excluded |
| 81 | + self.assertTrue(hasattr(rngs, "params")) |
| 82 | + self.assertFalse(hasattr(rngs, "dropout")) |
| 83 | + self.assertFalse(hasattr(rngs, "aqt")) |
| 84 | + |
| 85 | + def test_move_memory(self): |
| 86 | + sharding = NamedSharding(self.mesh, P("data")) |
| 87 | + self.assertNotEqual(sharding.memory_kind, "pinned_host") |
| 88 | + |
| 89 | + path = ("layers", "linear", "kernel") |
| 90 | + host_sharding = maxtext_utils_nnx.move_memory_to_host(path, sharding) |
| 91 | + |
| 92 | + self.assertEqual(host_sharding.memory_kind, "pinned_host") |
| 93 | + self.assertEqual(host_sharding.spec, P("data")) |
| 94 | + |
| 95 | + device_sharding = maxtext_utils_nnx.move_memory_to_device(path, sharding) |
| 96 | + |
| 97 | + self.assertEqual(device_sharding.memory_kind, "device") |
| 98 | + self.assertEqual(device_sharding.spec, P("data")) |
| 99 | + |
| 100 | + def test_get_set_named_sharding_nnx(self): |
| 101 | + # 1. Create the abstract state using standard NNX functional API |
| 102 | + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) |
| 103 | + |
| 104 | + # 2. Test extraction |
| 105 | + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) |
| 106 | + |
| 107 | + # Verify kernel and bias match the P("data") annotations from TinyModel |
| 108 | + self.assertEqual(extracted_shardings.linear.kernel.get_value().spec, P("data", None)) |
| 109 | + self.assertEqual(extracted_shardings.linear.bias.get_value().spec, P("data")) |
| 110 | + |
| 111 | + # Target kernel spec update |
| 112 | + new_kernel_spec = P(None, "data") |
| 113 | + |
| 114 | + def update_spec_fn(path, leaf_sharding): |
| 115 | + path_str = jax.tree_util.keystr(path) |
| 116 | + if "linear" in path_str and "kernel" in path_str: |
| 117 | + # Construct a new NamedSharding with the requested logical spec |
| 118 | + return NamedSharding(leaf_sharding.mesh, new_kernel_spec) |
| 119 | + return leaf_sharding |
| 120 | + |
| 121 | + # Apply the spec change to the extracted sharding tree |
| 122 | + extracted_shardings = jax.tree.map_with_path(update_spec_fn, extracted_shardings) |
| 123 | + |
| 124 | + # 3. Test setting new shardings |
| 125 | + # Transform the extracted shardings to host memory |
| 126 | + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) |
| 127 | + updated_abstract = maxtext_utils_nnx.set_named_sharding_nnx(abstract_state, new_shardings) |
| 128 | + |
| 129 | + # Verify the metadata inside the abstract state leaf has updated its sharding |
| 130 | + self.assertEqual(updated_abstract.linear.kernel.sharding.memory_kind, "pinned_host") |
| 131 | + # Also verify the spec was updated successfully |
| 132 | + self.assertEqual(updated_abstract.linear.kernel.sharding.spec, new_kernel_spec) |
| 133 | + |
| 134 | + # 4. Verify named sharding is preserved after NNX merge (update) and split (state) |
| 135 | + model = self.tiny_model_init_fn() |
| 136 | + nnx.update(model, updated_abstract) |
| 137 | + re_extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(nnx.state(model)) |
| 138 | + |
| 139 | + # Verify kernel and bias have expected sharding |
| 140 | + self.assertEqual(re_extracted_shardings.linear.kernel.get_value().spec, new_kernel_spec) |
| 141 | + self.assertEqual(re_extracted_shardings.linear.bias.get_value().spec, P("data")) |
| 142 | + |
| 143 | + def test_create_nnx_sharded_model(self): |
| 144 | + # 1. Create abstract model |
| 145 | + graphdef, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) |
| 146 | + abstract_model = nnx.merge(graphdef, abstract_state) |
| 147 | + |
| 148 | + # 2. Modify shardings to trigger host offloading |
| 149 | + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) |
| 150 | + new_shardings = jax.tree_util.tree_map_with_path(maxtext_utils_nnx.move_memory_to_host, extracted_shardings) |
| 151 | + |
| 152 | + # 3. Run the sharded creation |
| 153 | + # We pass the abstract model and use the custom sharding for instantiation |
| 154 | + sharded_model = maxtext_utils_nnx.create_nnx_sharded_model( |
| 155 | + abstract_model, self.tiny_model_init_fn, mesh=self.mesh, named_sharding=new_shardings |
| 156 | + ) |
| 157 | + |
| 158 | + # 4. Verify the model is concrete (contains Arrays) and sharded on host |
| 159 | + self.assertIsInstance(sharded_model.linear.kernel[...], jax.Array) |
| 160 | + self.assertEqual(sharded_model.linear.kernel[...].sharding.memory_kind, "pinned_host") |
| 161 | + |
| 162 | + def test_get_partition_spec_nnx(self): |
| 163 | + """Verifies extraction of PartitionSpecs from NamedShardings.""" |
| 164 | + # 1. Create abstract state and get sharding |
| 165 | + _, abstract_state = nnx.get_abstract_model(self.tiny_model_init_fn, self.mesh) |
| 166 | + extracted_shardings = maxtext_utils_nnx.get_named_sharding_nnx(abstract_state) |
| 167 | + |
| 168 | + # 2. Execute extraction |
| 169 | + spec = maxtext_utils_nnx.get_partition_spec_nnx(extracted_shardings) |
| 170 | + |
| 171 | + # 3. Verify that the leaves are now raw PartitionSpecs |
| 172 | + # Expected values derived from TinyModel definition |
| 173 | + expected_spec_k = P("data", None) |
| 174 | + expected_spec_b = P("data") |
| 175 | + |
| 176 | + self.assertEqual(spec["linear"]["kernel"], expected_spec_k) |
| 177 | + self.assertEqual(spec["linear"]["bias"], expected_spec_b) |
| 178 | + self.assertNotIsInstance(spec["linear"]["kernel"], NamedSharding) |
| 179 | + |
| 180 | + |
| 181 | +if __name__ == "__main__": |
| 182 | + unittest.main() |
0 commit comments