Skip to content

Commit f8a4707

Browse files
author
Charles Li
committed
Add more log to measure time
1 parent 4f4c2a3 commit f8a4707

3 files changed

Lines changed: 27 additions & 0 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import datetime
2424
import functools
2525
import os
26+
import warnings
2627

2728
from absl import app
2829

@@ -525,6 +526,7 @@ def train_loop(config, recorder, state=None):
525526
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)
526527

527528
with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
529+
start_time = datetime.datetime.now()
528530
p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
529531
config,
530532
model,
@@ -536,6 +538,9 @@ def train_loop(config, recorder, state=None):
536538
eval_data_iterator,
537539
params_shardings,
538540
)
541+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
542+
print(f"train jit {seconds_elapsed} secs")
543+
warnings.warn(f"DEUBG: train jit cost {seconds_elapsed} secs")
539544
shaped_batch = maxtext_utils.get_shaped_batch(config)
540545
if config.shard_optimizer_over_data:
541546
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)

src/maxtext/trainers/pre_train/train_compile.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
as you would on the target hardware.
2222
"""
2323

24+
import datetime
2425
import functools
2526
import os
2627
import pickle
28+
import warnings
2729
from typing import Sequence
2830

2931
from absl import app
@@ -163,18 +165,32 @@ def jit_and_compile(
163165
# Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax,
164166
# which reads from pxla.thread_resources.env.physical_mesh, can find the mesh.
165167
with jax.set_mesh(mesh), mesh, logical_axis_rules:
168+
start_time = datetime.datetime.now()
166169
jitted = jax.jit(
167170
func,
168171
in_shardings=in_shardings,
169172
out_shardings=out_shardings,
170173
static_argnums=static_argnums,
171174
donate_argnums=donate_argnums,
172175
)
176+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
177+
print(f"train_compile jit {seconds_elapsed} secs")
178+
warnings.warn(f"DEUBG: train_compile jit cost {seconds_elapsed} secs")
173179
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)
180+
181+
start_time = datetime.datetime.now()
174182
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
183+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
184+
print(f"train_compile lower {seconds_elapsed} secs")
185+
warnings.warn(f"DEUBG: train_compile lower cost {seconds_elapsed} secs")
175186
# Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
176187
compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags)
188+
189+
start_time = datetime.datetime.now()
177190
compiled = lowered.compile(compiler_options=compiler_options)
191+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
192+
print(f"train_compile compile {seconds_elapsed} secs")
193+
warnings.warn(f"DEUBG: train_compile compile cost {seconds_elapsed} secs")
178194
return compiled
179195

180196

src/maxtext/utils/train_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
# pylint: disable=bare-except, consider-using-generator
1616
"""Utils that are only interesting for training in MaxText."""
1717

18+
import datetime
1819
import os
20+
import warnings
1921
from functools import partial
2022

2123
import jax
@@ -176,9 +178,13 @@ def jit_train_and_eval_step(
176178
train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings)
177179
train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh)
178180
data_sharding = sharding.get_input_data_sharding(config, mesh)
181+
start_time = datetime.datetime.now()
179182
p_train_step = jit_train_step(
180183
config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh
181184
)
185+
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
186+
print(f"train_util train_step jit {seconds_elapsed} secs")
187+
warnings.warn(f"DEUBG: train_util train_step jit cost {seconds_elapsed} secs")
182188
p_eval_step = None
183189
if eval_data_iterator:
184190
p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step)

0 commit comments

Comments
 (0)