Skip to content

Commit 9c4fbdf

Browse files
author
Charles Li
committed
Add more log to measure time
1 parent 4f4c2a3 commit 9c4fbdf

4 files changed

Lines changed: 71 additions & 2 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import datetime
2424
import functools
2525
import os
26+
import warnings
2627

2728
from absl import app
2829

@@ -525,6 +526,7 @@ def train_loop(config, recorder, state=None):
525526
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
526527

527528
with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
529+
start_time = datetime.datetime.now()
528530
p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
529531
config,
530532
model,
@@ -536,6 +538,9 @@ def train_loop(config, recorder, state=None):
536538
eval_data_iterator,
537539
params_shardings,
538540
)
541+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
542+
print(f"train jit {seconds_elapsed} secs")
543+
warnings.warn(f"DEBUG: train jit cost {seconds_elapsed} secs")
539544
shaped_batch = maxtext_utils.get_shaped_batch(config)
540545
if config.shard_optimizer_over_data:
541546
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
as you would on the target hardware.
2222
"""
2323

24+
import datetime
2425
import functools
26+
import inspect
2527
import os
2628
import pickle
29+
import warnings
2730
from typing import Sequence
2831

2932
from absl import app
@@ -64,11 +67,14 @@ def validate_config(config):
6467

6568
def get_topology_mesh(config):
6669
"""Get the target hardware devices, and create configured mesh with them"""
70+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
6771
if config.internal_compile:
6872
topology_devices = get_topology_desc(
6973
platform="tpu", topology_name=config.compile_topology, num_slices=config.compile_topology_num_slices
7074
).devices
75+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
7176
else:
77+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
7278
target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology)
7379
if target_hardware.platform == "gpu":
7480
# Disable sharded autotuning. This is an optimization to distribute
@@ -77,6 +83,7 @@ def get_topology_mesh(config):
7783
jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices)
7884
topology_devices = jax.devices()
7985
else:
86+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
8087
topology_devices = get_topology_desc(
8188
platform=target_hardware.platform,
8289
topology_name=target_hardware.topology_name,
@@ -85,11 +92,15 @@ def get_topology_mesh(config):
8592
num_slices=config.compile_topology_num_slices,
8693
wrap=target_hardware.wrap,
8794
).devices
95+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
8896
if config.shard_mode == ShardMode.EXPLICIT:
8997
jax.config.update("jax_remove_size_one_mesh_axis_from_type", True)
98+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
9099
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
100+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
91101
mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto
92102
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes))
103+
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
93104
return topology_mesh
94105

95106

@@ -163,18 +174,32 @@ def jit_and_compile(
163174
# Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax,
164175
# which reads from pxla.thread_resources.env.physical_mesh, can find the mesh.
165176
with jax.set_mesh(mesh), mesh, logical_axis_rules:
177+
start_time = datetime.datetime.now()
166178
jitted = jax.jit(
167179
func,
168180
in_shardings=in_shardings,
169181
out_shardings=out_shardings,
170182
static_argnums=static_argnums,
171183
donate_argnums=donate_argnums,
172184
)
185+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
186+
print(f"train_compile jit {seconds_elapsed} secs")
187+
warnings.warn(f"DEBUG: train_compile jit cost {seconds_elapsed} secs")
173188
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)
189+
190+
start_time = datetime.datetime.now()
174191
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
192+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
193+
print(f"train_compile lower {seconds_elapsed} secs")
194+
warnings.warn(f"DEBUG: train_compile lower cost {seconds_elapsed} secs")
175195
# Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
176196
compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags)
197+
198+
start_time = datetime.datetime.now()
177199
compiled = lowered.compile(compiler_options=compiler_options)
200+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
201+
print(f"train_compile compile {seconds_elapsed} secs")
202+
warnings.warn(f"DEBUG: train_compile compile cost {seconds_elapsed} secs")
178203
return compiled
179204

180205

@@ -257,13 +282,20 @@ def main(argv: Sequence[str]) -> None:
257282
config = pyconfig.initialize(argv)
258283
validate_config(config)
259284

285+
start_time = datetime.datetime.now()
260286
# Create target mesh
287+
warnings.warn(f"DEBUG: before get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
261288
topology_mesh = get_topology_mesh(config)
289+
warnings.warn(f"DEBUG: after get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
290+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
291+
print(f"train_compile get_topology_mesh {seconds_elapsed} secs")
292+
warnings.warn(f"DEBUG: train_compile get_topology_mesh cost {seconds_elapsed} secs")
262293

263294
# Print system information after building the compile topology to avoid
264295
# prematurely initializing the backend.
265296
max_utils.print_system_information()
266297

298+
start_time = datetime.datetime.now()
267299
# Get shaped inputs
268300
(
269301
shaped_train_args,
@@ -272,7 +304,11 @@ def main(argv: Sequence[str]) -> None:
272304
logical_annotations,
273305
model,
274306
) = get_shaped_inputs(topology_mesh, config)
307+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
308+
print(f"train_compile get_shaped_inputs {seconds_elapsed} secs")
309+
warnings.warn(f"DEBUG: train_compile get_shaped_inputs cost {seconds_elapsed} secs")
275310

311+
start_time = datetime.datetime.now()
276312
# Get data sharding
277313
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
278314
if config.enable_diloco:
@@ -305,8 +341,12 @@ def main(argv: Sequence[str]) -> None:
305341
) = maxtext_utils.get_functional_train_with_signature(
306342
train.train_step, data_sharding, state_mesh_shardings, model, config
307343
)
344+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
345+
print(f"train_compile get_functional_train_with_signature {seconds_elapsed} secs")
346+
warnings.warn(f"DEBUG: train_compile get_functional_train_with_signature cost {seconds_elapsed} secs")
308347

309348
# print weights sharding info under debug sharding mode
349+
start_time = datetime.datetime.now()
310350
if config.debug_sharding:
311351
max_utils.print_non_trivial_mesh_axis(topology_mesh)
312352
if config.pure_nnx:
@@ -323,9 +363,13 @@ def main(argv: Sequence[str]) -> None:
323363
topology_mesh,
324364
logical_annotations.params,
325365
)
366+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
367+
print(f"train_compile print_shardings_params {seconds_elapsed} secs")
368+
warnings.warn(f"DEBUG: train_compile print_shardings_params cost {seconds_elapsed} secs")
326369

327370
# Compile
328371
print("Jitting and compiling train step...", flush=True)
372+
start_time = datetime.datetime.now()
329373
compiled = jit_and_compile(
330374
func_to_compile,
331375
shaped_train_args,
@@ -338,9 +382,13 @@ def main(argv: Sequence[str]) -> None:
338382
config,
339383
nn_partitioning.axis_rules(config.logical_axis_rules),
340384
)
385+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
386+
print(f"train_compile jit_and_compile {seconds_elapsed} secs")
387+
warnings.warn(f"DEBUG: train_compile jit_and_compile cost {seconds_elapsed} secs")
341388
print("Jitting and compilation complete!", flush=True)
342389

343390
# Serialize and save the compiled object
391+
start_time = datetime.datetime.now()
344392
if config.compiled_trainstep_file != "":
345393
print("Saving compiled object...")
346394
save_compiled(compiled, config.compiled_trainstep_file)
@@ -349,6 +397,12 @@ def main(argv: Sequence[str]) -> None:
349397
print(f"Cost analysis: {compiled.cost_analysis()}")
350398
print(f"Memory analysis: {compiled.memory_analysis()}")
351399

400+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
401+
print(f"train_compile save_compiled {seconds_elapsed} secs")
402+
warnings.warn(f"DEBUG: train_compile save_compiled cost {seconds_elapsed} secs")
403+
print("Jitting and compilation complete!", flush=True)
404+
405+
start_time = datetime.datetime.now()
352406
# Dump HLO if requested
353407
if config.dump_hlo:
354408
gcs_utils.upload_dump(
@@ -359,6 +413,10 @@ def main(argv: Sequence[str]) -> None:
359413
all_host_upload=config.dump_hlo_upload_all,
360414
)
361415

416+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
417+
print(f"train_compile upload dump_hlo {seconds_elapsed} secs")
418+
warnings.warn(f"DEBUG: train_compile upload dump_hlo cost {seconds_elapsed} secs")
419+
362420

363421
if __name__ == "__main__":
364422
app.run(main)

src/maxtext/utils/train_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
# pylint: disable=bare-except, consider-using-generator
1616
"""Utils that are only interesting for training in MaxText."""
1717

18+
import datetime
1819
import os
20+
import warnings
1921
from functools import partial
2022

2123
import jax
@@ -176,9 +178,13 @@ def jit_train_and_eval_step(
176178
train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings)
177179
train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh)
178180
data_sharding = sharding.get_input_data_sharding(config, mesh)
181+
start_time = datetime.datetime.now()
179182
p_train_step = jit_train_step(
180183
config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh
181184
)
185+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
186+
print(f"train_util train_step jit {seconds_elapsed} secs")
187+
warnings.warn(f"DEBUG: train_util train_step jit cost {seconds_elapsed} secs")
182188
p_eval_step = None
183189
if eval_data_iterator:
184190
p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step)

tests/integration/aot_identical_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
200200
train.main(train_argv)
201201
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
202202
print(f"train cost {seconds_elapsed} secs")
203-
warnings.warn(f"DEUBG: train cost {seconds_elapsed} secs")
203+
warnings.warn(f"DEBUG: train cost {seconds_elapsed} secs")
204204
jax.clear_caches()
205205

206206
# Run train_compile.py and dump jaxpr
@@ -216,7 +216,7 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
216216
train_compile.main(compile_argv)
217217
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
218218
print(f"train_compile cost {seconds_elapsed} secs")
219-
warnings.warn(f"DEUBG: train_compile cost {seconds_elapsed} secs")
219+
warnings.warn(f"DEBUG: train_compile cost {seconds_elapsed} secs")
220220
jax.clear_caches()
221221

222222
# Compare results

0 commit comments

Comments
 (0)