Skip to content

Commit fb04bc7

Browse files
Merge pull request #3034 from CIeNET-International:user/sharony/exp_sharding_dump
PiperOrigin-RevId: 867643814
2 parents 95ef3e1 + c0a6b81 commit fb04bc7

64 files changed

Lines changed: 81565 additions & 42716 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

tests/unit/maxtext_utils_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,5 +827,37 @@ def test_wsd_schedule(self):
827827
self.assertIn("wsd_decay_steps_fraction", str(cm.exception))
828828

829829

830+
class TestGetAbstractState(unittest.TestCase):
831+
"""Test class for get_abstract_state."""
832+
833+
def setUp(self):
834+
self.config = pyconfig.initialize(
835+
[None, get_test_config_path()],
836+
enable_checkpointing=False,
837+
model_name="llama3.1-8b",
838+
per_device_batch_size=1,
839+
max_target_length=16,
840+
)
841+
devices_array = maxtext_utils.create_device_mesh(self.config)
842+
self.mesh = Mesh(devices_array, self.config.mesh_axes)
843+
quant = quantizations.configure_quantization(self.config)
844+
self.model = Transformer(self.config, mesh=self.mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
845+
self.rng = jax.random.PRNGKey(0)
846+
self.tx = optax.adam(learning_rate=0.001)
847+
848+
def test_get_abstract_state(self):
849+
"""Tests that get_abstract_state returns abstract arrays."""
850+
# get_abstract_state returns a tuple, the first element is the abstract state.
851+
abstract_state, _, _ = maxtext_utils.get_abstract_state(self.model, self.tx, self.config, self.rng, self.mesh, None)
852+
853+
# Check that params are abstract
854+
param_leaves = jax.tree_util.tree_leaves(abstract_state.params)
855+
self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in param_leaves))
856+
857+
# Check that opt_state is abstract
858+
opt_state_leaves = jax.tree_util.tree_leaves(abstract_state.opt_state)
859+
self.assertTrue(all(isinstance(leaf, jax.ShapeDtypeStruct) for leaf in opt_state_leaves))
860+
861+
830862
if __name__ == "__main__":
831863
unittest.main()
Lines changed: 183 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -18,12 +18,21 @@
1818
import json
1919
import os
2020
import pytest
21+
import jax
22+
import jax.numpy as jnp
23+
# import optax
2124

25+
from MaxText.globals import MAXTEXT_PKG_DIR
2226
from MaxText.train_compile import get_shaped_inputs, get_topology_mesh, validate_config
2327
from MaxText import pyconfig
28+
from MaxText import maxtext_utils
29+
from MaxText.layers import models
30+
from MaxText.layers import quantizations
31+
from MaxText import optimizers
2432

25-
from tests.utils.sharding_dump import named_shardings_to_json, load_named_sharding_json, TEST_CASES
26-
from tests.utils.test_helpers import get_test_config_path
33+
from tests.utils.sharding_dump import load_json, TEST_CASES, named_shardings_to_json, partition_specs_to_json
34+
35+
Transformer = models.transformer_as_linen
2736

2837

2938
def compute_checksum(d: dict) -> str:
@@ -37,7 +46,7 @@ def compute_checksum(d: dict) -> str:
3746
return checksum
3847

3948

40-
def compare_named_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_name: str) -> bool:
49+
def compare_sharding_jsons(json1: dict, model1_name: str, json2: dict, model2_name: str) -> bool:
4150
"""Compare two json files and print the differences if any."""
4251
keys1 = set(json1.keys())
4352
keys2 = set(json2.keys())
@@ -46,66 +55,210 @@ def compare_named_sharding_jsons(json1: dict, model1_name: str, json2: dict, mod
4655
only_in_2 = keys2 - keys1
4756
shared_keys = keys1 & keys2
4857

58+
has_diff = False
59+
4960
if only_in_1:
5061
print(f"Keys only in {model1_name}:")
5162
for k in sorted(only_in_1):
5263
print(f" {k}")
64+
has_diff = True
5365

5466
if only_in_2:
5567
print(f"Keys only in {model2_name}:")
5668
for k in sorted(only_in_2):
5769
print(f" {k}")
70+
has_diff = True
5871

5972
for key in sorted(shared_keys):
6073
entry1 = json1[key]
6174
entry2 = json2[key]
6275

63-
mesh1 = entry1.get("mesh", {})
64-
mesh2 = entry2.get("mesh", {})
65-
spec1 = entry1.get("partition_spec", [])
66-
spec2 = entry2.get("partition_spec", [])
76+
if isinstance(entry1, dict) and isinstance(entry2, dict):
77+
mesh1 = entry1.get("mesh", {})
78+
mesh2 = entry2.get("mesh", {})
79+
80+
spec1 = entry1.get("partition_spec", [])
81+
spec2 = entry2.get("partition_spec", [])
82+
83+
shape1 = entry1.get("shape")
84+
shape2 = entry2.get("shape")
85+
86+
if mesh1 != mesh2:
87+
print(f"\nMesh mismatch at '{key}':")
88+
print(f" {model1_name}: {mesh1}")
89+
print(f" {model2_name}: {mesh2}")
90+
has_diff = True
91+
92+
if spec1 != spec2:
93+
print(f"\nPartitionSpec mismatch at '{key}':")
94+
print(f" {model1_name}: {spec1}")
95+
print(f" {model2_name}: {spec2}")
96+
has_diff = True
6797

68-
if mesh1 != mesh2:
69-
print(f"\nMesh mismatch at '{key}':")
70-
print(f" mesh1: {mesh1}")
71-
print(f" mesh2: {mesh2}")
98+
if shape1 != shape2:
99+
print(f"\nShape mismatch at '{key}':")
100+
print(f" {model1_name}: {shape1}")
101+
print(f" {model2_name}: {shape2}")
102+
has_diff = True
72103

73-
if spec1 != spec2:
74-
print(f"\nPartitionSpec mismatch at '{key}':")
75-
print(f" spec1: {spec1}")
76-
print(f" spec2: {spec2}")
104+
else:
105+
print(f"\nFormat mismatch at '{key}':")
106+
print(f" {model1_name} type: {type(entry1)}")
107+
print(f" {model2_name} type: {type(entry2)}")
108+
has_diff = True
77109

78-
return not only_in_1 and not only_in_2 and all(json1[k] == json2[k] for k in shared_keys)
110+
return has_diff
79111

80112

81113
@pytest.mark.parametrize("model_name, topology, num_slice", TEST_CASES)
82114
def test_sharding_dump_for_model(model_name: str, topology: str, num_slice: str) -> None:
83-
"""Test if the sharding of new model implementation is as expected."""
115+
"""
116+
Test sharding configurations from train_compile.get_shaped_inputs.
117+
This test verifies that the sharding configurations for various models and topologies remain consistent with golden files.
118+
"""
84119
params = [
85120
"/deps/MaxText/tests/unit/sharding_compare_test",
86-
get_test_config_path(),
121+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
87122
f"compile_topology={topology}",
88123
f"compile_topology_num_slices={num_slice}",
89124
f"model_name={model_name}",
90125
]
91126

92-
json_path = f"sharding_info/" f"{model_name}/" f"{topology}/" f"slice_{num_slice}/named_shardings.json"
93-
if not os.path.exists(json_path):
127+
root_dir = "tests/utils/sharding_info"
128+
base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}")
129+
130+
named_json_path = os.path.join(base_path, "named_shardings.json")
131+
logical_json_path = os.path.join(base_path, "logical_shardings.json")
132+
133+
if not os.path.exists(named_json_path):
134+
pytest.skip(f"Missing named_shardings.json for {model_name} {topology} slice {num_slice}")
135+
return
136+
if not os.path.exists(logical_json_path):
137+
pytest.skip(f"Missing logical_shardings.json for {model_name} {topology} slice {num_slice}")
94138
return
95139

96140
config = pyconfig.initialize(params)
97141
validate_config(config)
98142

99143
topology_mesh = get_topology_mesh(config)
100-
_, _, state_mesh_shardings, _, _ = get_shaped_inputs(topology_mesh, config)
101-
actual_json = named_shardings_to_json(state_mesh_shardings)
102-
expected_json = load_named_sharding_json(json_path)
144+
shaped_train_args, _, state_mesh_shardings, logical_shardings, _ = get_shaped_inputs(topology_mesh, config)
145+
146+
error_messages = []
147+
148+
# 1. Compare Named Shardings
149+
actual_named = named_shardings_to_json(state_mesh_shardings, shaped_train_args[0])
150+
expected_named = load_json(named_json_path)
151+
# calculate checksum
152+
actual_named_sum = compute_checksum(actual_named)
153+
expected_named_sum = compute_checksum(expected_named)
154+
named_match = actual_named_sum == expected_named_sum
155+
156+
if not named_match:
157+
print(f"\n[FAIL] Physical Sharding Mismatch: {model_name} {topology} slice {num_slice}", flush=True)
158+
compare_sharding_jsons(expected_named, "Expected (Physical)", actual_named, "Actual (Physical)")
159+
error_messages.append(f" Physical sharding mismatch for {model_name} on {topology} slice {num_slice}")
160+
161+
# 2. Compare Logical Shardings
162+
actual_logical = partition_specs_to_json(logical_shardings, shaped_train_args[0])
163+
expected_logical = load_json(logical_json_path)
164+
# calculate checksum
165+
actual_logical_sum = compute_checksum(actual_logical)
166+
expected_logical_sum = compute_checksum(expected_logical)
167+
logical_match = actual_logical_sum == expected_logical_sum
168+
169+
if not logical_match:
170+
print(f"\n[FAIL] Logical Sharding Mismatch: {model_name} {topology} slice {num_slice}", flush=True)
171+
compare_sharding_jsons(expected_logical, "Expected (Logical)", actual_logical, "Actual (Logical)")
172+
error_messages.append(f"Logical sharding mismatch for {model_name} on {topology} slice {num_slice}")
173+
174+
assert not error_messages, "\n".join(error_messages)
175+
176+
177+
@pytest.fixture(
178+
scope="module",
179+
params=[pytest.param(case, id=f"{case[0]}-{case[1]}-{case[2]}") for case in TEST_CASES],
180+
)
181+
def abstract_state_and_shardings(request):
182+
"""Pytest fixture to set up model, config, and generate abstract state once per test case."""
183+
model_name, topology, num_slice = request.param
184+
print(f"Testing model: {model_name}, topology: {topology}, num_slices: {num_slice}", flush=True)
185+
params = [
186+
"/deps/MaxText/tests/unit/sharding_compare_test",
187+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
188+
f"compile_topology={topology}",
189+
f"compile_topology_num_slices={num_slice}",
190+
f"model_name={model_name}",
191+
"weight_dtype=float32",
192+
]
193+
config = pyconfig.initialize(params)
194+
validate_config(config)
195+
196+
topology_mesh = get_topology_mesh(config)
197+
quant = quantizations.configure_quantization(config)
198+
model = Transformer(config, mesh=topology_mesh, quant=quant)
199+
200+
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
201+
# tx = optax.adam(learning_rate=learning_rate_schedule)
202+
tx = optimizers.get_optimizer(config, learning_rate_schedule)
203+
rng = jax.random.PRNGKey(0)
204+
205+
# Get abstract state and physical shardings from maxtext_utils
206+
abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(
207+
model, tx, config, rng, topology_mesh, is_training=True
208+
)
209+
210+
# Get logical shardings from maxtext_utils
211+
logical_shardings = maxtext_utils.get_logical_annotations(model, tx, config, rng, topology_mesh, is_training=True)
212+
213+
return model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings
214+
215+
216+
class TestGetAbstractState:
217+
"""Test class for get_abstract_state function and sharding comparison."""
218+
219+
def test_get_abstract_state_sharding(self, abstract_state_and_shardings): # pylint: disable=redefined-outer-name
220+
"""Tests that get_abstract_state returns a state with the correct abstract structure and compares sharding."""
221+
222+
model_name, topology, num_slice, abstract_state, state_mesh_shardings, logical_shardings = (
223+
abstract_state_and_shardings
224+
)
225+
226+
assert hasattr(abstract_state, "params")
227+
assert hasattr(abstract_state, "opt_state")
228+
param_leaf = jax.tree_util.tree_leaves(abstract_state.params)[0]
229+
assert isinstance(param_leaf, jax.ShapeDtypeStruct)
230+
assert param_leaf.dtype == jnp.float32
231+
232+
root_dir = "tests/utils/sharding_info" # Or your target directory
233+
base_path = os.path.join(root_dir, model_name, topology, f"slice_{num_slice}")
234+
os.makedirs(base_path, exist_ok=True) # Ensure directory exists for saving actual
235+
236+
error_messages = []
237+
238+
# 1. Compare Physical/Named Shardings
239+
named_json_path = os.path.join(base_path, "named_shardings.json")
240+
if not os.path.exists(named_json_path):
241+
pytest.skip(f"Missing named_shardings.json for {model_name} {topology} slice {num_slice}")
242+
return
243+
244+
# Use state_mesh_shardings from the fixture
245+
actual_named = named_shardings_to_json(state_mesh_shardings, abstract_state)
246+
expected_named = load_json(named_json_path)
247+
248+
if compare_sharding_jsons(expected_named, "Expected (Physical)", actual_named, "Actual (Physical)"):
249+
error_messages.append(f"Physical sharding mismatch for {model_name} on {topology} slice {num_slice}")
250+
251+
# 2. Compare Logical Shardings
252+
logical_json_path = os.path.join(base_path, "logical_shardings.json")
253+
if not os.path.exists(logical_json_path):
254+
pytest.skip(f"Missing logical_shardings.json for {model_name} {topology} slice {num_slice}")
255+
return
103256

104-
actual_checksum = compute_checksum(actual_json)
105-
expected_checksum2 = compute_checksum(expected_json)
106-
result = actual_checksum == expected_checksum2
257+
# Use logical_shardings from the fixture
258+
actual_logical = partition_specs_to_json(logical_shardings, abstract_state)
259+
expected_logical = load_json(logical_json_path)
107260

108-
if not result:
109-
compare_named_sharding_jsons(expected_json, f"expected_{model_name}", actual_json, f"actual_{model_name}")
261+
if compare_sharding_jsons(expected_logical, "Expected (Logical)", actual_logical, "Actual (Logical)"):
262+
error_messages.append(f"Logical sharding mismatch for {model_name} on {topology} slice {num_slice}")
110263

111-
assert result is True
264+
assert not error_messages, "\n".join(error_messages)

0 commit comments

Comments
 (0)