Skip to content

Commit d174d4a

Browse files
author
Charles Li
committed
Add more log to measure time
1 parent 45ffccf commit d174d4a

4 files changed

Lines changed: 71 additions & 2 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import datetime
2424
import functools
2525
import os
26+
import warnings
2627

2728
from absl import app
2829

@@ -524,6 +525,7 @@ def train_loop(config, recorder, state=None):
524525
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
525526

526527
with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
528+
start_time = datetime.datetime.now()
527529
p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
528530
config,
529531
model,
@@ -535,6 +537,9 @@ def train_loop(config, recorder, state=None):
535537
eval_data_iterator,
536538
params_shardings,
537539
)
540+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
541+
print(f"train jit {seconds_elapsed} secs")
542+
warnings.warn(f"DEBUG: train jit cost {seconds_elapsed} secs")
538543
shaped_batch = maxtext_utils.get_shaped_batch(config)
539544
if config.shard_optimizer_over_data:
540545
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
as you would on the target hardware.
2222
"""
2323

24+
import datetime
2425
import functools
26+
import inspect
2527
import os
2628
import pickle
29+
import warnings
2730
from typing import Sequence
2831

2932
from absl import app
@@ -64,11 +67,14 @@ def validate_config(config):
6467

6568
def get_topology_mesh(config):
6669
"""Get the target hardware devices, and create configured mesh with them"""
70+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
6771
if config.internal_compile:
6872
topology_devices = get_topology_desc(
6973
platform="tpu", topology_name=config.compile_topology, num_slices=config.compile_topology_num_slices
7074
).devices
75+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
7176
else:
77+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
7278
target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology)
7379
if target_hardware.platform == "gpu":
7480
# Disable sharded autotuning. This is an optimization to distribute
@@ -77,6 +83,7 @@ def get_topology_mesh(config):
7783
jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices)
7884
topology_devices = jax.devices()
7985
else:
86+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
8087
topology_devices = get_topology_desc(
8188
platform=target_hardware.platform,
8289
topology_name=target_hardware.topology_name,
@@ -85,10 +92,14 @@ def get_topology_mesh(config):
8592
num_slices=config.compile_topology_num_slices,
8693
wrap=target_hardware.wrap,
8794
).devices
95+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
8896
jax.config.update("jax_remove_size_one_mesh_axis_from_type", config.remove_size_one_mesh_axis_from_type)
97+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
8998
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
99+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
90100
mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto
91101
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes))
102+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
92103
return topology_mesh
93104

94105

@@ -162,18 +173,32 @@ def jit_and_compile(
162173
# Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax,
163174
# which reads from pxla.thread_resources.env.physical_mesh, can find the mesh.
164175
with jax.set_mesh(mesh), mesh, logical_axis_rules:
176+
start_time = datetime.datetime.now()
165177
jitted = jax.jit(
166178
func,
167179
in_shardings=in_shardings,
168180
out_shardings=out_shardings,
169181
static_argnums=static_argnums,
170182
donate_argnums=donate_argnums,
171183
)
184+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
185+
print(f"train_compile jit {seconds_elapsed} secs")
186+
warnings.warn(f"DEBUG: train_compile jit cost {seconds_elapsed} secs")
172187
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)
188+
189+
start_time = datetime.datetime.now()
173190
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
191+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
192+
print(f"train_compile lower {seconds_elapsed} secs")
193+
warnings.warn(f"DEBUG: train_compile lower cost {seconds_elapsed} secs")
174194
# Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
175195
compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags)
196+
197+
start_time = datetime.datetime.now()
176198
compiled = lowered.compile(compiler_options=compiler_options)
199+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
200+
print(f"train_compile compile {seconds_elapsed} secs")
201+
warnings.warn(f"DEBUG: train_compile compile cost {seconds_elapsed} secs")
177202
return compiled
178203

179204

@@ -256,13 +281,20 @@ def main(argv: Sequence[str]) -> None:
256281
config = pyconfig.initialize(argv)
257282
validate_config(config)
258283

284+
start_time = datetime.datetime.now()
259285
# Create target mesh
286+
warnings.warn(f"DEBUG: before get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
260287
topology_mesh = get_topology_mesh(config)
288+
warnings.warn(f"DEBUG: after get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
289+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
290+
print(f"train_compile get_topology_mesh {seconds_elapsed} secs")
291+
warnings.warn(f"DEBUG: train_compile get_topology_mesh cost {seconds_elapsed} secs")
261292

262293
# Print system information after building the compile topology to avoid
263294
# prematurely initializing the backend.
264295
max_utils.print_system_information()
265296

297+
start_time = datetime.datetime.now()
266298
# Get shaped inputs
267299
(
268300
shaped_train_args,
@@ -271,7 +303,11 @@ def main(argv: Sequence[str]) -> None:
271303
logical_annotations,
272304
model,
273305
) = get_shaped_inputs(topology_mesh, config)
306+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
307+
print(f"train_compile get_shaped_inputs {seconds_elapsed} secs")
308+
warnings.warn(f"DEBUG: train_compile get_shaped_inputs cost {seconds_elapsed} secs")
274309

310+
start_time = datetime.datetime.now()
275311
# Get data sharding
276312
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
277313
if config.enable_diloco:
@@ -304,8 +340,12 @@ def main(argv: Sequence[str]) -> None:
304340
) = maxtext_utils.get_functional_train_with_signature(
305341
train.train_step, data_sharding, state_mesh_shardings, model, config
306342
)
343+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
344+
print(f"train_compile get_functional_train_with_signature {seconds_elapsed} secs")
345+
warnings.warn(f"DEBUG: train_compile get_functional_train_with_signature cost {seconds_elapsed} secs")
307346

308347
# print weights sharding info under debug sharding mode
348+
start_time = datetime.datetime.now()
309349
if config.debug_sharding:
310350
max_utils.print_non_trivial_mesh_axis(topology_mesh)
311351
if config.pure_nnx:
@@ -322,9 +362,13 @@ def main(argv: Sequence[str]) -> None:
322362
topology_mesh,
323363
logical_annotations.params,
324364
)
365+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
366+
print(f"train_compile print_shardings_params {seconds_elapsed} secs")
367+
warnings.warn(f"DEBUG: train_compile print_shardings_params cost {seconds_elapsed} secs")
325368

326369
# Compile
327370
print("Jitting and compiling train step...", flush=True)
371+
start_time = datetime.datetime.now()
328372
compiled = jit_and_compile(
329373
func_to_compile,
330374
shaped_train_args,
@@ -337,9 +381,13 @@ def main(argv: Sequence[str]) -> None:
337381
config,
338382
nn_partitioning.axis_rules(config.logical_axis_rules),
339383
)
384+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
385+
print(f"train_compile jit_and_compile {seconds_elapsed} secs")
386+
warnings.warn(f"DEBUG: train_compile jit_and_compile cost {seconds_elapsed} secs")
340387
print("Jitting and compilation complete!", flush=True)
341388

342389
# Serialize and save the compiled object
390+
start_time = datetime.datetime.now()
343391
if config.compiled_trainstep_file != "":
344392
print("Saving compiled object...")
345393
save_compiled(compiled, config.compiled_trainstep_file)
@@ -348,6 +396,12 @@ def main(argv: Sequence[str]) -> None:
348396
print(f"Cost analysis: {compiled.cost_analysis()}")
349397
print(f"Memory analysis: {compiled.memory_analysis()}")
350398

399+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
400+
print(f"train_compile save_compiled {seconds_elapsed} secs")
401+
warnings.warn(f"DEBUG: train_compile save_compiled cost {seconds_elapsed} secs")
402+
print("Jitting and compilation complete!", flush=True)
403+
404+
start_time = datetime.datetime.now()
351405
# Dump HLO if requested
352406
if config.dump_hlo:
353407
gcs_utils.upload_dump(
@@ -358,6 +412,10 @@ def main(argv: Sequence[str]) -> None:
358412
all_host_upload=config.dump_hlo_upload_all,
359413
)
360414

415+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
416+
print(f"train_compile upload dump_hlo {seconds_elapsed} secs")
417+
warnings.warn(f"DEBUG: train_compile upload dump_hlo cost {seconds_elapsed} secs")
418+
361419

362420
if __name__ == "__main__":
363421
app.run(main)

src/maxtext/utils/train_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
# pylint: disable=bare-except, consider-using-generator
1616
"""Utils that are only interesting for training in MaxText."""
1717

18+
import datetime
1819
import os
20+
import warnings
1921
from functools import partial
2022

2123
import jax
@@ -176,9 +178,13 @@ def jit_train_and_eval_step(
176178
train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings)
177179
train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh)
178180
data_sharding = sharding.get_input_data_sharding(config, mesh)
181+
start_time = datetime.datetime.now()
179182
p_train_step = jit_train_step(
180183
config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh
181184
)
185+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
186+
print(f"train_util train_step jit {seconds_elapsed} secs")
187+
warnings.warn(f"DEBUG: train_util train_step jit cost {seconds_elapsed} secs")
182188
p_eval_step = None
183189
if eval_data_iterator:
184190
p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step)

tests/integration/aot_identical_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
200200
train.main(train_argv)
201201
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
202202
print(f"train cost {seconds_elapsed} secs")
203-
warnings.warn(f"DEUBG: train cost {seconds_elapsed} secs")
203+
warnings.warn(f"DEBUG: train cost {seconds_elapsed} secs")
204204
jax.clear_caches()
205205

206206
# Run train_compile.py and dump jaxpr
@@ -216,7 +216,7 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
216216
train_compile.main(compile_argv)
217217
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
218218
print(f"train_compile cost {seconds_elapsed} secs")
219-
warnings.warn(f"DEUBG: train_compile cost {seconds_elapsed} secs")
219+
warnings.warn(f"DEBUG: train_compile cost {seconds_elapsed} secs")
220220
jax.clear_caches()
221221

222222
# Compare results

0 commit comments

Comments
 (0)