2121as you would on the target hardware.
2222"""
2323
24+ import datetime
2425import functools
26+ import inspect
2527import os
2628import pickle
29+ import warnings
2730from typing import Sequence
2831
2932from absl import app
@@ -64,11 +67,14 @@ def validate_config(config):
6467
6568def get_topology_mesh (config ):
6669 """Get the target hardware devices, and create configured mesh with them"""
70+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
6771 if config .internal_compile :
6872 topology_devices = get_topology_desc (
6973 platform = "tpu" , topology_name = config .compile_topology , num_slices = config .compile_topology_num_slices
7074 ).devices
75+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
7176 else :
77+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
7278 target_hardware = accelerator_to_spec_map .get_system_characteristics (config .compile_topology )
7379 if target_hardware .platform == "gpu" :
7480 # Disable sharded autotuning. This is an optimization to distribute
@@ -77,6 +83,7 @@ def get_topology_mesh(config):
7783 jax .config .update ("mock_num_gpu_processes" , config .compile_topology_num_slices )
7884 topology_devices = jax .devices ()
7985 else :
86+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
8087 topology_devices = get_topology_desc (
8188 platform = target_hardware .platform ,
8289 topology_name = target_hardware .topology_name ,
@@ -85,11 +92,15 @@ def get_topology_mesh(config):
8592 num_slices = config .compile_topology_num_slices ,
8693 wrap = target_hardware .wrap ,
8794 ).devices
95+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
8896 if config .shard_mode == ShardMode .EXPLICIT :
8997 jax .config .update ("jax_remove_size_one_mesh_axis_from_type" , True )
98+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
9099 topology_device_mesh = maxtext_utils .create_device_mesh (config , topology_devices )
100+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
91101 mesh_axis_type = AxisType .Explicit if config .shard_mode == ShardMode .EXPLICIT else AxisType .Auto
92102 topology_mesh = Mesh (topology_device_mesh , config .mesh_axes , axis_types = (mesh_axis_type ,) * len (config .mesh_axes ))
103+ warnings .warn (f"DEBUG: get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
93104 return topology_mesh
94105
95106
@@ -163,18 +174,32 @@ def jit_and_compile(
163174 # Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax,
164175 # which reads from pxla.thread_resources.env.physical_mesh, can find the mesh.
165176 with jax .set_mesh (mesh ), mesh , logical_axis_rules :
177+ start_time = datetime .datetime .now ()
166178 jitted = jax .jit (
167179 func ,
168180 in_shardings = in_shardings ,
169181 out_shardings = out_shardings ,
170182 static_argnums = static_argnums ,
171183 donate_argnums = donate_argnums ,
172184 )
185+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
186+ print (f"train_compile jit { seconds_elapsed } secs" )
187+ warnings .warn (f"DEBUG: train_compile jit cost { seconds_elapsed } secs" )
173188 maxtext_utils .maybe_dump_jaxpr (config , jitted , func_input_args )
189+
190+ start_time = datetime .datetime .now ()
174191 lowered = jitted .lower (* func_input_args , ** func_input_kwargs )
192+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
193+ print (f"train_compile lower { seconds_elapsed } secs" )
194+ warnings .warn (f"DEBUG: train_compile lower cost { seconds_elapsed } secs" )
175195 # Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
176196 compiler_options = max_utils .parse_libtpu_flags_to_dict (config .compile_xla_flags )
197+
198+ start_time = datetime .datetime .now ()
177199 compiled = lowered .compile (compiler_options = compiler_options )
200+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
201+ print (f"train_compile compile { seconds_elapsed } secs" )
202+ warnings .warn (f"DEBUG: train_compile compile cost { seconds_elapsed } secs" )
178203 return compiled
179204
180205
@@ -257,13 +282,20 @@ def main(argv: Sequence[str]) -> None:
257282 config = pyconfig .initialize (argv )
258283 validate_config (config )
259284
285+ start_time = datetime .datetime .now ()
260286 # Create target mesh
287+ warnings .warn (f"DEBUG: before get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
261288 topology_mesh = get_topology_mesh (config )
289+ warnings .warn (f"DEBUG: after get_topology_mesh: { inspect .currentframe ().f_lineno } datetime: { datetime .datetime .now ()} " )
290+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
291+ print (f"train_compile get_topology_mesh { seconds_elapsed } secs" )
292+ warnings .warn (f"DEBUG: train_compile get_topology_mesh cost { seconds_elapsed } secs" )
262293
263294 # Print system information after building the compile topology to avoid
264295 # prematurely initializing the backend.
265296 max_utils .print_system_information ()
266297
298+ start_time = datetime .datetime .now ()
267299 # Get shaped inputs
268300 (
269301 shaped_train_args ,
@@ -272,7 +304,11 @@ def main(argv: Sequence[str]) -> None:
272304 logical_annotations ,
273305 model ,
274306 ) = get_shaped_inputs (topology_mesh , config )
307+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
308+ print (f"train_compile get_shaped_inputs { seconds_elapsed } secs" )
309+ warnings .warn (f"DEBUG: train_compile get_shaped_inputs cost { seconds_elapsed } secs" )
275310
311+ start_time = datetime .datetime .now ()
276312 # Get data sharding
277313 data_sharding = sharding .get_input_data_sharding (config , topology_mesh )
278314 if config .enable_diloco :
@@ -305,8 +341,12 @@ def main(argv: Sequence[str]) -> None:
305341 ) = maxtext_utils .get_functional_train_with_signature (
306342 train .train_step , data_sharding , state_mesh_shardings , model , config
307343 )
344+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
345+ print (f"train_compile get_functional_train_with_signature { seconds_elapsed } secs" )
346+ warnings .warn (f"DEBUG: train_compile get_functional_train_with_signature cost { seconds_elapsed } secs" )
308347
309348 # print weights sharding info under debug sharding mode
349+ start_time = datetime .datetime .now ()
310350 if config .debug_sharding :
311351 max_utils .print_non_trivial_mesh_axis (topology_mesh )
312352 if config .pure_nnx :
@@ -323,9 +363,13 @@ def main(argv: Sequence[str]) -> None:
323363 topology_mesh ,
324364 logical_annotations .params ,
325365 )
366+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
367+ print (f"train_compile print_shardings_params { seconds_elapsed } secs" )
368+ warnings .warn (f"DEBUG: train_compile print_shardings_params cost { seconds_elapsed } secs" )
326369
327370 # Compile
328371 print ("Jitting and compiling train step..." , flush = True )
372+ start_time = datetime .datetime .now ()
329373 compiled = jit_and_compile (
330374 func_to_compile ,
331375 shaped_train_args ,
@@ -338,9 +382,13 @@ def main(argv: Sequence[str]) -> None:
338382 config ,
339383 nn_partitioning .axis_rules (config .logical_axis_rules ),
340384 )
385+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
386+ print (f"train_compile jit_and_compile { seconds_elapsed } secs" )
387+ warnings .warn (f"DEBUG: train_compile jit_and_compile cost { seconds_elapsed } secs" )
341388 print ("Jitting and compilation complete!" , flush = True )
342389
343390 # Serialize and save the compiled object
391+ start_time = datetime .datetime .now ()
344392 if config .compiled_trainstep_file != "" :
345393 print ("Saving compiled object..." )
346394 save_compiled (compiled , config .compiled_trainstep_file )
@@ -349,6 +397,12 @@ def main(argv: Sequence[str]) -> None:
349397 print (f"Cost analysis: { compiled .cost_analysis ()} " )
350398 print (f"Memory analysis: { compiled .memory_analysis ()} " )
351399
400+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
401+ print (f"train_compile save_compiled { seconds_elapsed } secs" )
402+ warnings .warn (f"DEBUG: train_compile save_compiled cost { seconds_elapsed } secs" )
403+ print ("Jitting and compilation complete!" , flush = True )
404+
405+ start_time = datetime .datetime .now ()
352406 # Dump HLO if requested
353407 if config .dump_hlo :
354408 gcs_utils .upload_dump (
@@ -359,6 +413,10 @@ def main(argv: Sequence[str]) -> None:
359413 all_host_upload = config .dump_hlo_upload_all ,
360414 )
361415
416+ seconds_elapsed = (datetime .datetime .now () - start_time ).total_seconds ()
417+ print (f"train_compile upload dump_hlo { seconds_elapsed } secs" )
418+ warnings .warn (f"DEBUG: train_compile upload dump_hlo cost { seconds_elapsed } secs" )
419+
362420
363421if __name__ == "__main__" :
364422 app .run (main )
0 commit comments