6161def validate_config (config ):
6262 """Validates the config is is setup correctly to compile, returning a useful error message if not."""
6363 assert config .compile_topology != "" , (
64- "You must pass your desired target hardware in compile_topology, e.g."
65- " compile_topology=v5e-256"
64+ "You must pass your desired target hardware in compile_topology, e.g." " compile_topology=v5e-256"
6665 )
67- assert (
68- config .compile_topology_num_slices > 0
69- ), "You must set compile_topology_num_slices to a positive integer"
66+ assert config .compile_topology_num_slices > 0 , "You must set compile_topology_num_slices to a positive integer"
7067
7168
7269def get_topology_mesh (config ):
@@ -78,18 +75,12 @@ def get_topology_mesh(config):
7875 num_slices = config .compile_topology_num_slices ,
7976 ).devices
8077 else :
81- target_hardware = accelerator_to_spec_map .get_system_characteristics (
82- config .compile_topology
83- )
78+ target_hardware = accelerator_to_spec_map .get_system_characteristics (config .compile_topology )
8479 if target_hardware .platform == "gpu" :
8580 # Disable sharded autotuning. This is an optimization to distribute
8681 # autotuning across the fleet, but can cause hangs with AoT compilation.
87- os .environ ["XLA_FLAGS" ] = (
88- os .environ .get ("XLA_FLAGS" , "" ) + " --xla_gpu_shard_autotuning=false"
89- )
90- jax .config .update (
91- "mock_num_gpu_processes" , config .compile_topology_num_slices
92- )
82+ os .environ ["XLA_FLAGS" ] = os .environ .get ("XLA_FLAGS" , "" ) + " --xla_gpu_shard_autotuning=false"
83+ jax .config .update ("mock_num_gpu_processes" , config .compile_topology_num_slices )
9384 topology_devices = jax .devices ()
9485 else :
9586 topology_devices = get_topology_desc (
@@ -104,14 +95,8 @@ def get_topology_mesh(config):
10495 "jax_remove_size_one_mesh_axis_from_type" ,
10596 config .remove_size_one_mesh_axis_from_type ,
10697 )
107- topology_device_mesh = maxtext_utils .create_device_mesh (
108- config , topology_devices
109- )
110- mesh_axis_type = (
111- AxisType .Explicit
112- if config .shard_mode == ShardMode .EXPLICIT
113- else AxisType .Auto
114- )
98+ topology_device_mesh = maxtext_utils .create_device_mesh (config , topology_devices )
99+ mesh_axis_type = AxisType .Explicit if config .shard_mode == ShardMode .EXPLICIT else AxisType .Auto
115100 topology_mesh = Mesh (
116101 topology_device_mesh ,
117102 config .mesh_axes ,
@@ -129,9 +114,7 @@ def _collect_nnx_activation_shardings(create_model_fn, config, mesh):
129114 input_shape = (config .micro_batch_size_to_train_on , config .max_target_length )
130115 abstract_input = jax .ShapeDtypeStruct (input_shape , jnp .int32 )
131116
132- def _nnx_forward (
133- decoder_input_tokens , decoder_positions , decoder_segment_ids
134- ):
117+ def _nnx_forward (decoder_input_tokens , decoder_positions , decoder_segment_ids ):
135118 model_instance = create_model_fn ()
136119 return model_instance (
137120 decoder_input_tokens = decoder_input_tokens ,
@@ -140,9 +123,7 @@ def _nnx_forward(
140123 enable_dropout = False ,
141124 )
142125
143- with jax .set_mesh (mesh ), nn_partitioning .axis_rules (
144- config .logical_axis_rules
145- ):
126+ with jax .set_mesh (mesh ), nn_partitioning .axis_rules (config .logical_axis_rules ):
146127 jax .eval_shape (_nnx_forward , abstract_input , abstract_input , abstract_input )
147128
148129
@@ -151,13 +132,9 @@ def get_shaped_inputs(topology_mesh, config):
151132 # Construct the model and optimizer to get shaped versions of the state
152133 quant = quantizations .configure_quantization (config )
153134 if config .pure_nnx :
154- _create_model_partial , model = (
155- model_creation_utils .create_nnx_abstract_model (config , topology_mesh )
156- )
135+ _create_model_partial , model = model_creation_utils .create_nnx_abstract_model (config , topology_mesh )
157136 else :
158- model = Transformer (
159- config , topology_mesh , quant = quant , model_mode = MODEL_MODE_TRAIN
160- )
137+ model = Transformer (config , topology_mesh , quant = quant , model_mode = MODEL_MODE_TRAIN )
161138 # The learning_rate_schedule is baked into the compiled object.
162139 learning_rate_schedule = maxtext_utils .create_learning_rate_schedule (config )
163140 # pass in model for muon
@@ -176,20 +153,14 @@ def create_train_state_fn():
176153
177154 init_state_fn = create_train_state_fn
178155 else :
179- init_state_fn = functools .partial (
180- maxtext_utils .init_initial_state , model , tx , config , True , example_rng
181- )
156+ init_state_fn = functools .partial (maxtext_utils .init_initial_state , model , tx , config , True , example_rng )
182157
183158 # Shaped state
184- abstract_state , _ , state_mesh_shardings = maxtext_utils .get_abstract_state (
185- config , topology_mesh , init_state_fn , True
186- )
159+ abstract_state , _ , state_mesh_shardings = maxtext_utils .get_abstract_state (config , topology_mesh , init_state_fn , True )
187160
188161 if config .pure_nnx :
189162 # NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings.
190- logical_annotations = maxtext_utils_nnx .get_partition_spec_nnx (
191- state_mesh_shardings
192- )
163+ logical_annotations = maxtext_utils_nnx .get_partition_spec_nnx (state_mesh_shardings )
193164 # For NNX, get_functional_train_with_signature expects the graphdef (static structure),
194165 # not the raw model — mirroring how the training loop does nnx.split(train_state).
195166 with nn_partitioning .axis_rules (config .logical_axis_rules ):
@@ -198,9 +169,7 @@ def create_train_state_fn():
198169 model = graphdef
199170 else :
200171 # unsharded logical annotations
201- logical_annotations = maxtext_utils .get_logical_annotations (
202- config , topology_mesh , init_state_fn
203- )
172+ logical_annotations = maxtext_utils .get_logical_annotations (config , topology_mesh , init_state_fn )
204173
205174 # Shaped batch
206175 shaped_batch = maxtext_utils .get_shaped_batch (config )
@@ -217,9 +186,7 @@ def create_train_state_fn():
217186 # Collect NNX activation shardings via an abstract forward pass (must run
218187 # after get_abstract_state, which only traces __init__).
219188 if config .debug_sharding and config .pure_nnx :
220- _collect_nnx_activation_shardings (
221- _create_model_partial , config , topology_mesh
222- )
189+ _collect_nnx_activation_shardings (_create_model_partial , config , topology_mesh )
223190
224191 return (
225192 shaped_train_args ,
@@ -256,9 +223,7 @@ def jit_and_compile(
256223 maxtext_utils .maybe_dump_jaxpr (config , jitted , func_input_args )
257224 lowered = jitted .lower (* func_input_args , ** func_input_kwargs )
258225 # Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
259- compiler_options = max_utils .parse_libtpu_flags_to_dict (
260- config .compile_xla_flags
261- )
226+ compiler_options = max_utils .parse_libtpu_flags_to_dict (config .compile_xla_flags )
262227 compiled = lowered .compile (compiler_options = compiler_options )
263228 return compiled
264229
@@ -293,20 +258,11 @@ def is_oom(argv: Sequence[str]) -> bool:
293258 ) = get_shaped_inputs (topology_mesh , config )
294259
295260 # Update params_shardings when shard_optimizer_over_data is enabled (Zero-1)
296- params_shardings , state_mesh_shardings = (
297- sharding .maybe_update_params_sharding_with_opt (
298- config , state_mesh_shardings
299- )
300- )
261+ params_shardings , state_mesh_shardings = sharding .maybe_update_params_sharding_with_opt (config , state_mesh_shardings )
301262
302- # When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings
303- # but keep the updated state_mesh_shardings for the optimizer state
304- if config .shard_optimizer_over_data :
305- input_state_mesh_shardings = state_mesh_shardings .replace (
306- params = params_shardings
307- )
308- else :
309- input_state_mesh_shardings = state_mesh_shardings
263+ input_state_mesh_shardings = sharding .build_zero1_input_state_mesh_shardings (
264+ config , state_mesh_shardings , params_shardings
265+ )
310266
311267 # Get data sharding
312268 data_sharding = sharding .get_input_data_sharding (config , topology_mesh )
@@ -355,8 +311,7 @@ def is_oom(argv: Sequence[str]) -> bool:
355311def main (argv : Sequence [str ]) -> None :
356312 jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
357313 os .environ ["LIBTPU_INIT_ARGS" ] = (
358- os .environ .get ("LIBTPU_INIT_ARGS" , "" )
359- + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
314+ os .environ .get ("LIBTPU_INIT_ARGS" , "" ) + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"
360315 )
361316 print ("Starting train_compile.py..." , flush = True )
362317
@@ -381,41 +336,26 @@ def main(argv: Sequence[str]) -> None:
381336 ) = get_shaped_inputs (topology_mesh , config )
382337
383338 # Update params_shardings when shard_optimizer_over_data is enabled (Zero-1)
384- params_shardings , state_mesh_shardings = (
385- sharding .maybe_update_params_sharding_with_opt (
386- config , state_mesh_shardings
387- )
388- )
339+ params_shardings , state_mesh_shardings = sharding .maybe_update_params_sharding_with_opt (config , state_mesh_shardings )
389340
390- # When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings
391- # but keep the updated state_mesh_shardings for the optimizer state
392- if config .shard_optimizer_over_data :
393- input_state_mesh_shardings = state_mesh_shardings .replace (
394- params = params_shardings
395- )
396- else :
397- input_state_mesh_shardings = state_mesh_shardings
341+ input_state_mesh_shardings = sharding .build_zero1_input_state_mesh_shardings (
342+ config , state_mesh_shardings , params_shardings
343+ )
398344
399345 # Get data sharding
400346 data_sharding = sharding .get_input_data_sharding (config , topology_mesh )
401347 if config .enable_diloco :
402348 # Build abstract DiLoCo state and shardings for AOT compilation
403349 abstract_state = shaped_train_args [0 ]
404- diloco_state , state_mesh_shardings , inner_state_shardings = (
405- diloco .build_abstract_diloco_state (
406- config , abstract_state , state_mesh_shardings , topology_mesh
407- )
350+ diloco_state , state_mesh_shardings , inner_state_shardings = diloco .build_abstract_diloco_state (
351+ config , abstract_state , state_mesh_shardings , topology_mesh
408352 )
409353 # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng.
410- shaped_rng_arg = (
411- shaped_train_args [2 ] if len (shaped_train_args ) > 2 else None
412- )
354+ shaped_rng_arg = shaped_train_args [2 ] if len (shaped_train_args ) > 2 else None
413355 shaped_train_args = (diloco_state , shaped_train_args [1 ], shaped_rng_arg )
414356
415357 # Wrap train_step with diloco
416- train_step_partial = functools .partial (
417- train .train_step , model , config , inner_state_shardings , params_shardings
418- )
358+ train_step_partial = functools .partial (train .train_step , model , config , inner_state_shardings , params_shardings )
419359 train_step_fn = diloco .build_diloco_train_step (config , train_step_partial )
420360
421361 # For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng)
@@ -480,10 +420,7 @@ def main(argv: Sequence[str]) -> None:
480420 if config .compiled_trainstep_file != "" :
481421 print ("Saving compiled object..." )
482422 save_compiled (compiled , config .compiled_trainstep_file )
483- print (
484- "Successfully saved compiled object as"
485- f" { config .compiled_trainstep_file } "
486- )
423+ print ("Successfully saved compiled object as" f" { config .compiled_trainstep_file } " )
487424 print ("Finished train_compile.py successfully!" , flush = True )
488425 print (f"Cost analysis: { compiled .cost_analysis ()} " )
489426 print (f"Memory analysis: { compiled .memory_analysis ()} " )
0 commit comments