2121as you would on the target hardware.
2222"""
2323
24+ import datetime
2425import functools
26+ import inspect
2527import os
2628import pickle
29+ import warnings
2730from typing import Sequence
2831
2932from absl import app
@@ -64,11 +67,14 @@ def validate_config(config):
6467
6568def get_topology_mesh (config ):
6669 """Get the target hardware devices, and create configured mesh with them"""
70+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
6771 if config .internal_compile :
6872 topology_devices = get_topology_desc (
6973 platform = "tpu" , topology_name = config .compile_topology , num_slices = config .compile_topology_num_slices
7074 ).devices
75+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
7176 else :
77+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
7278 target_hardware = accelerator_to_spec_map .get_system_characteristics (config .compile_topology )
7379 if target_hardware .platform == "gpu" :
7480 # Disable sharded autotuning. This is an optimization to distribute
@@ -77,6 +83,7 @@ def get_topology_mesh(config):
7783 jax .config .update ("mock_num_gpu_processes" , config .compile_topology_num_slices )
7884 topology_devices = jax .devices ()
7985 else :
86+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
8087 topology_devices = get_topology_desc (
8188 platform = target_hardware .platform ,
8289 topology_name = target_hardware .topology_name ,
@@ -85,10 +92,14 @@ def get_topology_mesh(config):
8592 num_slices = config .compile_topology_num_slices ,
8693 wrap = target_hardware .wrap ,
8794 ).devices
95+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
8896 jax .config .update ("jax_remove_size_one_mesh_axis_from_type" , config .remove_size_one_mesh_axis_from_type )
97+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
8998 topology_device_mesh = maxtext_utils .create_device_mesh (config , topology_devices )
99+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
90100 mesh_axis_type = AxisType .Explicit if config .shard_mode == ShardMode .EXPLICIT else AxisType .Auto
91101 topology_mesh = Mesh (topology_device_mesh , config .mesh_axes , axis_types = (mesh_axis_type ,) * len (config .mesh_axes ))
102+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
92103 return topology_mesh
93104
94105
@@ -162,18 +173,32 @@ def jit_and_compile(
162173 # Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax,
163174 # which reads from pxla.thread_resources.env.physical_mesh, can find the mesh.
164175 with jax .set_mesh (mesh ), mesh , logical_axis_rules :
176+ start_time = datetime .datetime .now ()
165177 jitted = jax .jit (
166178 func ,
167179 in_shardings = in_shardings ,
168180 out_shardings = out_shardings ,
169181 static_argnums = static_argnums ,
170182 donate_argnums = donate_argnums ,
171183 )
184+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
185+ print (f"train_compile jit { seconds_elapsed } secs" )
186+ warnings .warn (f"DEBUG: train_compile jit cost { seconds_elapsed } secs" )
172187 maxtext_utils .maybe_dump_jaxpr (config , jitted , func_input_args )
188+
189+ start_time = datetime .datetime .now ()
173190 lowered = jitted .lower (* func_input_args , ** func_input_kwargs )
191+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
192+ print (f"train_compile lower { seconds_elapsed } secs" )
193+ warnings .warn (f"DEBUG: train_compile lower cost { seconds_elapsed } secs" )
174194 # Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
175195 compiler_options = max_utils .parse_libtpu_flags_to_dict (config .compile_xla_flags )
196+
197+ start_time = datetime .datetime .now ()
176198 compiled = lowered .compile (compiler_options = compiler_options )
199+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
200+ print (f"train_compile compile { seconds_elapsed } secs" )
201+ warnings .warn (f"DEBUG: train_compile compile cost { seconds_elapsed } secs" )
177202 return compiled
178203
179204
@@ -256,13 +281,20 @@ def main(argv: Sequence[str]) -> None:
256281 config = pyconfig .initialize (argv )
257282 validate_config (config )
258283
284+ start_time = datetime .datetime .now ()
259285 # Create target mesh
286+ warnings .warn (f"DEBUG: before get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
260287 topology_mesh = get_topology_mesh (config )
288+ warnings .warn (f"DEBUG: after get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
289+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
290+ print (f"train_compile get_topology_mesh { seconds_elapsed } secs" )
291+ warnings .warn (f"DEBUG: train_compile get_topology_mesh cost { seconds_elapsed } secs" )
261292
262293 # Print system information after building the compile topology to avoid
263294 # prematurely initializing the backend.
264295 max_utils .print_system_information ()
265296
297+ start_time = datetime .datetime .now ()
266298 # Get shaped inputs
267299 (
268300 shaped_train_args ,
@@ -271,7 +303,11 @@ def main(argv: Sequence[str]) -> None:
271303 logical_annotations ,
272304 model ,
273305 ) = get_shaped_inputs (topology_mesh , config )
306+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
307+ print (f"train_compile get_shaped_inputs { seconds_elapsed } secs" )
308+ warnings .warn (f"DEBUG: train_compile get_shaped_inputs cost { seconds_elapsed } secs" )
274309
310+ start_time = datetime .datetime .now ()
275311 # Get data sharding
276312 data_sharding = sharding .get_input_data_sharding (config , topology_mesh )
277313 if config .enable_diloco :
@@ -304,8 +340,12 @@ def main(argv: Sequence[str]) -> None:
304340 ) = maxtext_utils .get_functional_train_with_signature (
305341 train .train_step , data_sharding , state_mesh_shardings , model , config
306342 )
343+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
344+ print (f"train_compile get_functional_train_with_signature { seconds_elapsed } secs" )
345+ warnings .warn (f"DEBUG: train_compile get_functional_train_with_signature cost { seconds_elapsed } secs" )
307346
308347 # print weights sharding info under debug sharding mode
348+ start_time = datetime .datetime .now ()
309349 if config .debug_sharding :
310350 max_utils .print_non_trivial_mesh_axis (topology_mesh )
311351 if config .pure_nnx :
@@ -322,9 +362,13 @@ def main(argv: Sequence[str]) -> None:
322362 topology_mesh ,
323363 logical_annotations .params ,
324364 )
365+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
366+ print (f"train_compile print_shardings_params { seconds_elapsed } secs" )
367+ warnings .warn (f"DEBUG: train_compile print_shardings_params cost { seconds_elapsed } secs" )
325368
326369 # Compile
327370 print ("Jitting and compiling train step..." , flush = True )
371+ start_time = datetime .datetime .now ()
328372 compiled = jit_and_compile (
329373 func_to_compile ,
330374 shaped_train_args ,
@@ -337,9 +381,13 @@ def main(argv: Sequence[str]) -> None:
337381 config ,
338382 nn_partitioning .axis_rules (config .logical_axis_rules ),
339383 )
384+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
385+ print (f"train_compile jit_and_compile { seconds_elapsed } secs" )
386+ warnings .warn (f"DEBUG: train_compile jit_and_compile cost { seconds_elapsed } secs" )
340387 print ("Jitting and compilation complete!" , flush = True )
341388
342389 # Serialize and save the compiled object
390+ start_time = datetime .datetime .now ()
343391 if config .compiled_trainstep_file != "" :
344392 print ("Saving compiled object..." )
345393 save_compiled (compiled , config .compiled_trainstep_file )
@@ -348,6 +396,12 @@ def main(argv: Sequence[str]) -> None:
348396 print (f"Cost analysis: { compiled .cost_analysis ()} " )
349397 print (f"Memory analysis: { compiled .memory_analysis ()} " )
350398
399+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
400+ print (f"train_compile save_compiled { seconds_elapsed } secs" )
401+ warnings .warn (f"DEBUG: train_compile save_compiled cost { seconds_elapsed } secs" )
402+ print ("Jitting and compilation complete!" , flush = True )
403+
404+ start_time = datetime .datetime .now ()
351405 # Dump HLO if requested
352406 if config .dump_hlo :
353407 gcs_utils .upload_dump (
@@ -358,6 +412,10 @@ def main(argv: Sequence[str]) -> None:
358412 all_host_upload = config .dump_hlo_upload_all ,
359413 )
360414
415+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
416+ print (f"train_compile upload dump_hlo { seconds_elapsed } secs" )
417+ warnings .warn (f"DEBUG: train_compile upload dump_hlo cost { seconds_elapsed } secs" )
418+
361419
362420if __name__ == "__main__" :
363421 app .run (main )
0 commit comments