1- # Copyright 2023–2025 Google LLC
1+ # Copyright 2023–2026 Google LLC
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
1818import json
1919import os
2020import pytest
21+ import jax
22+ import jax .numpy as jnp
23+ # import optax
2124
25+ from MaxText .globals import MAXTEXT_PKG_DIR
2226from MaxText .train_compile import get_shaped_inputs , get_topology_mesh , validate_config
2327from MaxText import pyconfig
28+ from MaxText import maxtext_utils
29+ from MaxText .layers import models
30+ from MaxText .layers import quantizations
31+ from MaxText import optimizers
2432
25- from tests .utils .sharding_dump import named_shardings_to_json , load_named_sharding_json , TEST_CASES
26- from tests .utils .test_helpers import get_test_config_path
33+ from tests .utils .sharding_dump import load_json , TEST_CASES , named_shardings_to_json , partition_specs_to_json
34+
35+ Transformer = models .transformer_as_linen
2736
2837
2938def compute_checksum (d : dict ) -> str :
@@ -37,7 +46,7 @@ def compute_checksum(d: dict) -> str:
3746 return checksum
3847
3948
40- def compare_named_sharding_jsons (json1 : dict , model1_name : str , json2 : dict , model2_name : str ) -> bool :
49+ def compare_sharding_jsons (json1 : dict , model1_name : str , json2 : dict , model2_name : str ) -> bool :
4150 """Compare two json files and print the differences if any."""
4251 keys1 = set (json1 .keys ())
4352 keys2 = set (json2 .keys ())
@@ -46,66 +55,210 @@ def compare_named_sharding_jsons(json1: dict, model1_name: str, json2: dict, mod
4655 only_in_2 = keys2 - keys1
4756 shared_keys = keys1 & keys2
4857
58+ has_diff = False
59+
4960 if only_in_1 :
5061 print (f"Keys only in { model1_name } :" )
5162 for k in sorted (only_in_1 ):
5263 print (f" { k } " )
64+ has_diff = True
5365
5466 if only_in_2 :
5567 print (f"Keys only in { model2_name } :" )
5668 for k in sorted (only_in_2 ):
5769 print (f" { k } " )
70+ has_diff = True
5871
5972 for key in sorted (shared_keys ):
6073 entry1 = json1 [key ]
6174 entry2 = json2 [key ]
6275
63- mesh1 = entry1 .get ("mesh" , {})
64- mesh2 = entry2 .get ("mesh" , {})
65- spec1 = entry1 .get ("partition_spec" , [])
66- spec2 = entry2 .get ("partition_spec" , [])
76+ if isinstance (entry1 , dict ) and isinstance (entry2 , dict ):
77+ mesh1 = entry1 .get ("mesh" , {})
78+ mesh2 = entry2 .get ("mesh" , {})
79+
80+ spec1 = entry1 .get ("partition_spec" , [])
81+ spec2 = entry2 .get ("partition_spec" , [])
82+
83+ shape1 = entry1 .get ("shape" )
84+ shape2 = entry2 .get ("shape" )
85+
86+ if mesh1 != mesh2 :
87+ print (f"\n Mesh mismatch at '{ key } ':" )
88+ print (f" { model1_name } : { mesh1 } " )
89+ print (f" { model2_name } : { mesh2 } " )
90+ has_diff = True
91+
92+ if spec1 != spec2 :
93+ print (f"\n PartitionSpec mismatch at '{ key } ':" )
94+ print (f" { model1_name } : { spec1 } " )
95+ print (f" { model2_name } : { spec2 } " )
96+ has_diff = True
6797
68- if mesh1 != mesh2 :
69- print (f"\n Mesh mismatch at '{ key } ':" )
70- print (f" mesh1: { mesh1 } " )
71- print (f" mesh2: { mesh2 } " )
98+ if shape1 != shape2 :
99+ print (f"\n Shape mismatch at '{ key } ':" )
100+ print (f" { model1_name } : { shape1 } " )
101+ print (f" { model2_name } : { shape2 } " )
102+ has_diff = True
72103
73- if spec1 != spec2 :
74- print (f"\n PartitionSpec mismatch at '{ key } ':" )
75- print (f" spec1: { spec1 } " )
76- print (f" spec2: { spec2 } " )
104+ else :
105+ print (f"\n Format mismatch at '{ key } ':" )
106+ print (f" { model1_name } type: { type (entry1 )} " )
107+ print (f" { model2_name } type: { type (entry2 )} " )
108+ has_diff = True
77109
78- return not only_in_1 and not only_in_2 and all ( json1 [ k ] == json2 [ k ] for k in shared_keys )
110+ return has_diff
79111
80112
81113@pytest .mark .parametrize ("model_name, topology, num_slice" , TEST_CASES )
82114def test_sharding_dump_for_model (model_name : str , topology : str , num_slice : str ) -> None :
83- """Test if the sharding of new model implementation is as expected."""
115+ """
116+ Test sharding configurations from train_compile.get_shaped_inputs.
117+ This test verifies that the sharding configurations for various models and topologies remain consistent with golden files.
118+ """
84119 params = [
85120 "/deps/MaxText/tests/unit/sharding_compare_test" ,
86- get_test_config_path ( ),
121+ os . path . join ( MAXTEXT_PKG_DIR , "configs" , "base.yml" ),
87122 f"compile_topology={ topology } " ,
88123 f"compile_topology_num_slices={ num_slice } " ,
89124 f"model_name={ model_name } " ,
90125 ]
91126
92- json_path = f"sharding_info/" f"{ model_name } /" f"{ topology } /" f"slice_{ num_slice } /named_shardings.json"
93- if not os .path .exists (json_path ):
127+ root_dir = "tests/utils/sharding_info"
128+ base_path = os .path .join (root_dir , model_name , topology , f"slice_{ num_slice } " )
129+
130+ named_json_path = os .path .join (base_path , "named_shardings.json" )
131+ logical_json_path = os .path .join (base_path , "logical_shardings.json" )
132+
133+ if not os .path .exists (named_json_path ):
134+ pytest .skip (f"Missing named_shardings.json for { model_name } { topology } slice { num_slice } " )
135+ return
136+ if not os .path .exists (logical_json_path ):
137+ pytest .skip (f"Missing logical_shardings.json for { model_name } { topology } slice { num_slice } " )
94138 return
95139
96140 config = pyconfig .initialize (params )
97141 validate_config (config )
98142
99143 topology_mesh = get_topology_mesh (config )
100- _ , _ , state_mesh_shardings , _ , _ = get_shaped_inputs (topology_mesh , config )
101- actual_json = named_shardings_to_json (state_mesh_shardings )
102- expected_json = load_named_sharding_json (json_path )
144+ shaped_train_args , _ , state_mesh_shardings , logical_shardings , _ = get_shaped_inputs (topology_mesh , config )
145+
146+ error_messages = []
147+
148+ # 1. Compare Named Shardings
149+ actual_named = named_shardings_to_json (state_mesh_shardings , shaped_train_args [0 ])
150+ expected_named = load_json (named_json_path )
151+ # calculate checksum
152+ actual_named_sum = compute_checksum (actual_named )
153+ expected_named_sum = compute_checksum (expected_named )
154+ named_match = actual_named_sum == expected_named_sum
155+
156+ if not named_match :
157+ print (f"\n [FAIL] Physical Sharding Mismatch: { model_name } { topology } slice { num_slice } " , flush = True )
158+ compare_sharding_jsons (expected_named , "Expected (Physical)" , actual_named , "Actual (Physical)" )
159+ error_messages .append (f" Physical sharding mismatch for { model_name } on { topology } slice { num_slice } " )
160+
161+ # 2. Compare Logical Shardings
162+ actual_logical = partition_specs_to_json (logical_shardings , shaped_train_args [0 ])
163+ expected_logical = load_json (logical_json_path )
164+ # calculate checksum
165+ actual_logical_sum = compute_checksum (actual_logical )
166+ expected_logical_sum = compute_checksum (expected_logical )
167+ logical_match = actual_logical_sum == expected_logical_sum
168+
169+ if not logical_match :
170+ print (f"\n [FAIL] Logical Sharding Mismatch: { model_name } { topology } slice { num_slice } " , flush = True )
171+ compare_sharding_jsons (expected_logical , "Expected (Logical)" , actual_logical , "Actual (Logical)" )
172+ error_messages .append (f"Logical sharding mismatch for { model_name } on { topology } slice { num_slice } " )
173+
174+ assert not error_messages , "\n " .join (error_messages )
175+
176+
177+ @pytest .fixture (
178+ scope = "module" ,
179+ params = [pytest .param (case , id = f"{ case [0 ]} -{ case [1 ]} -{ case [2 ]} " ) for case in TEST_CASES ],
180+ )
181+ def abstract_state_and_shardings (request ):
182+ """Pytest fixture to set up model, config, and generate abstract state once per test case."""
183+ model_name , topology , num_slice = request .param
184+ print (f"Testing model: { model_name } , topology: { topology } , num_slices: { num_slice } " , flush = True )
185+ params = [
186+ "/deps/MaxText/tests/unit/sharding_compare_test" ,
187+ os .path .join (MAXTEXT_PKG_DIR , "configs" , "base.yml" ),
188+ f"compile_topology={ topology } " ,
189+ f"compile_topology_num_slices={ num_slice } " ,
190+ f"model_name={ model_name } " ,
191+ "weight_dtype=float32" ,
192+ ]
193+ config = pyconfig .initialize (params )
194+ validate_config (config )
195+
196+ topology_mesh = get_topology_mesh (config )
197+ quant = quantizations .configure_quantization (config )
198+ model = Transformer (config , mesh = topology_mesh , quant = quant )
199+
200+ learning_rate_schedule = maxtext_utils .create_learning_rate_schedule (config )
201+ # tx = optax.adam(learning_rate=learning_rate_schedule)
202+ tx = optimizers .get_optimizer (config , learning_rate_schedule )
203+ rng = jax .random .PRNGKey (0 )
204+
205+ # Get abstract state and physical shardings from maxtext_utils
206+ abstract_state , _ , state_mesh_shardings = maxtext_utils .get_abstract_state (
207+ model , tx , config , rng , topology_mesh , is_training = True
208+ )
209+
210+ # Get logical shardings from maxtext_utils
211+ logical_shardings = maxtext_utils .get_logical_annotations (model , tx , config , rng , topology_mesh , is_training = True )
212+
213+ return model_name , topology , num_slice , abstract_state , state_mesh_shardings , logical_shardings
214+
215+
216+ class TestGetAbstractState :
217+ """Test class for get_abstract_state function and sharding comparison."""
218+
219+ def test_get_abstract_state_sharding (self , abstract_state_and_shardings ): # pylint: disable=redefined-outer-name
220+ """Tests that get_abstract_state returns a state with the correct abstract structure and compares sharding."""
221+
222+ model_name , topology , num_slice , abstract_state , state_mesh_shardings , logical_shardings = (
223+ abstract_state_and_shardings
224+ )
225+
226+ assert hasattr (abstract_state , "params" )
227+ assert hasattr (abstract_state , "opt_state" )
228+ param_leaf = jax .tree_util .tree_leaves (abstract_state .params )[0 ]
229+ assert isinstance (param_leaf , jax .ShapeDtypeStruct )
230+ assert param_leaf .dtype == jnp .float32
231+
232+ root_dir = "tests/utils/sharding_info" # Or your target directory
233+ base_path = os .path .join (root_dir , model_name , topology , f"slice_{ num_slice } " )
234+ os .makedirs (base_path , exist_ok = True ) # Ensure directory exists for saving actual
235+
236+ error_messages = []
237+
238+ # 1. Compare Physical/Named Shardings
239+ named_json_path = os .path .join (base_path , "named_shardings.json" )
240+ if not os .path .exists (named_json_path ):
241+ pytest .skip (f"Missing named_shardings.json for { model_name } { topology } slice { num_slice } " )
242+ return
243+
244+ # Use state_mesh_shardings from the fixture
245+ actual_named = named_shardings_to_json (state_mesh_shardings , abstract_state )
246+ expected_named = load_json (named_json_path )
247+
248+ if compare_sharding_jsons (expected_named , "Expected (Physical)" , actual_named , "Actual (Physical)" ):
249+ error_messages .append (f"Physical sharding mismatch for { model_name } on { topology } slice { num_slice } " )
250+
251+ # 2. Compare Logical Shardings
252+ logical_json_path = os .path .join (base_path , "logical_shardings.json" )
253+ if not os .path .exists (logical_json_path ):
254+ pytest .skip (f"Missing logical_shardings.json for { model_name } { topology } slice { num_slice } " )
255+ return
103256
104- actual_checksum = compute_checksum ( actual_json )
105- expected_checksum2 = compute_checksum ( expected_json )
106- result = actual_checksum == expected_checksum2
257+ # Use logical_shardings from the fixture
258+ actual_logical = partition_specs_to_json ( logical_shardings , abstract_state )
259+ expected_logical = load_json ( logical_json_path )
107260
108- if not result :
109- compare_named_sharding_jsons ( expected_json , f"expected_ { model_name } " , actual_json , f"actual_ { model_name } " )
261+ if compare_sharding_jsons ( expected_logical , "Expected (Logical)" , actual_logical , "Actual (Logical)" ) :
262+ error_messages . append ( f"Logical sharding mismatch for { model_name } on { topology } slice { num_slice } " )
110263
111- assert result is True
264+ assert not error_messages , " \n " . join ( error_messages )
0 commit comments