Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/maxtext/trainers/pre_train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import datetime
import functools
import os
import warnings

from absl import app

Expand Down Expand Up @@ -524,6 +525,7 @@ def train_loop(config, recorder, state=None):
params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings)

with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
start_time = datetime.datetime.now()
p_train_step, p_eval_step = train_utils.jit_train_and_eval_step(
config,
model,
Expand All @@ -535,6 +537,9 @@ def train_loop(config, recorder, state=None):
eval_data_iterator,
params_shardings,
)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train jit {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train jit cost {seconds_elapsed} secs")
shaped_batch = maxtext_utils.get_shaped_batch(config)
if config.shard_optimizer_over_data:
state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode)
Expand Down
58 changes: 58 additions & 0 deletions src/maxtext/trainers/pre_train/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
as you would on the target hardware.
"""

import datetime
import functools
import inspect
import os
import pickle
import warnings
from typing import Sequence

from absl import app
Expand Down Expand Up @@ -64,11 +67,14 @@ def validate_config(config):

def get_topology_mesh(config):
"""Get the target hardware devices, and create configured mesh with them"""
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
if config.internal_compile:
topology_devices = get_topology_desc(
platform="tpu", topology_name=config.compile_topology, num_slices=config.compile_topology_num_slices
).devices
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
else:
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology)
if target_hardware.platform == "gpu":
# Disable sharded autotuning. This is an optimization to distribute
Expand All @@ -77,6 +83,7 @@ def get_topology_mesh(config):
jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices)
topology_devices = jax.devices()
else:
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
topology_devices = get_topology_desc(
platform=target_hardware.platform,
topology_name=target_hardware.topology_name,
Expand All @@ -85,10 +92,14 @@ def get_topology_mesh(config):
num_slices=config.compile_topology_num_slices,
wrap=target_hardware.wrap,
).devices
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
jax.config.update("jax_remove_size_one_mesh_axis_from_type", config.remove_size_one_mesh_axis_from_type)
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices)
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto
topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes))
warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
return topology_mesh


Expand Down Expand Up @@ -162,18 +173,32 @@ def jit_and_compile(
# Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax,
# which reads from pxla.thread_resources.env.physical_mesh, can find the mesh.
with jax.set_mesh(mesh), mesh, logical_axis_rules:
start_time = datetime.datetime.now()
jitted = jax.jit(
func,
in_shardings=in_shardings,
out_shardings=out_shardings,
static_argnums=static_argnums,
donate_argnums=donate_argnums,
)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile jit {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile jit cost {seconds_elapsed} secs")
maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args)

start_time = datetime.datetime.now()
lowered = jitted.lower(*func_input_args, **func_input_kwargs)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile lower {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile lower cost {seconds_elapsed} secs")
# Import libtpu flags as compiler options. Defaults to empty dict if string is empty.
compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags)

start_time = datetime.datetime.now()
compiled = lowered.compile(compiler_options=compiler_options)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile compile {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile compile cost {seconds_elapsed} secs")
return compiled


Expand Down Expand Up @@ -256,13 +281,20 @@ def main(argv: Sequence[str]) -> None:
config = pyconfig.initialize(argv)
validate_config(config)

start_time = datetime.datetime.now()
# Create target mesh
warnings.warn(f"DEBUG: before get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
topology_mesh = get_topology_mesh(config)
warnings.warn(f"DEBUG: after get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}")
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile get_topology_mesh {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile get_topology_mesh cost {seconds_elapsed} secs")

# Print system information after building the compile topology to avoid
# prematurely initializing the backend.
max_utils.print_system_information()

start_time = datetime.datetime.now()
# Get shaped inputs
(
shaped_train_args,
Expand All @@ -271,7 +303,11 @@ def main(argv: Sequence[str]) -> None:
logical_annotations,
model,
) = get_shaped_inputs(topology_mesh, config)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile get_shaped_inputs {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile get_shaped_inputs cost {seconds_elapsed} secs")

start_time = datetime.datetime.now()
# Get data sharding
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
if config.enable_diloco:
Expand Down Expand Up @@ -304,8 +340,12 @@ def main(argv: Sequence[str]) -> None:
) = maxtext_utils.get_functional_train_with_signature(
train.train_step, data_sharding, state_mesh_shardings, model, config
)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile get_functional_train_with_signature {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile get_functional_train_with_signature cost {seconds_elapsed} secs")

# print weights sharding info under debug sharding mode
start_time = datetime.datetime.now()
if config.debug_sharding:
max_utils.print_non_trivial_mesh_axis(topology_mesh)
if config.pure_nnx:
Expand All @@ -322,9 +362,13 @@ def main(argv: Sequence[str]) -> None:
topology_mesh,
logical_annotations.params,
)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile print_shardings_params {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile print_shardings_params cost {seconds_elapsed} secs")

# Compile
print("Jitting and compiling train step...", flush=True)
start_time = datetime.datetime.now()
compiled = jit_and_compile(
func_to_compile,
shaped_train_args,
Expand All @@ -337,9 +381,13 @@ def main(argv: Sequence[str]) -> None:
config,
nn_partitioning.axis_rules(config.logical_axis_rules),
)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile jit_and_compile {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile jit_and_compile cost {seconds_elapsed} secs")
print("Jitting and compilation complete!", flush=True)

# Serialize and save the compiled object
start_time = datetime.datetime.now()
if config.compiled_trainstep_file != "":
print("Saving compiled object...")
save_compiled(compiled, config.compiled_trainstep_file)
Expand All @@ -348,6 +396,12 @@ def main(argv: Sequence[str]) -> None:
print(f"Cost analysis: {compiled.cost_analysis()}")
print(f"Memory analysis: {compiled.memory_analysis()}")

seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile save_compiled {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile save_compiled cost {seconds_elapsed} secs")
print("Jitting and compilation complete!", flush=True)

start_time = datetime.datetime.now()
# Dump HLO if requested
if config.dump_hlo:
gcs_utils.upload_dump(
Expand All @@ -358,6 +412,10 @@ def main(argv: Sequence[str]) -> None:
all_host_upload=config.dump_hlo_upload_all,
)

seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile upload dump_hlo {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile upload dump_hlo cost {seconds_elapsed} secs")


if __name__ == "__main__":
app.run(main)
6 changes: 6 additions & 0 deletions src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# pylint: disable=bare-except, consider-using-generator
"""Utils that are only interesting for training in MaxText."""

import datetime
import os
import warnings
from functools import partial

import jax
Expand Down Expand Up @@ -176,9 +178,13 @@ def jit_train_and_eval_step(
train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings)
train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh)
data_sharding = sharding.get_input_data_sharding(config, mesh)
start_time = datetime.datetime.now()
p_train_step = jit_train_step(
config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh
)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_util train_step jit {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_util train_step jit cost {seconds_elapsed} secs")
p_eval_step = None
if eval_data_iterator:
p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step)
Expand Down
15 changes: 13 additions & 2 deletions tests/integration/aot_identical_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
training run (using train.py).
"""

import datetime
import tempfile
import unittest
import pytest
import os
import shutil
import hashlib
import warnings
import re
import jax
from tests.utils.test_helpers import get_test_config_path
Expand Down Expand Up @@ -133,7 +135,7 @@ def assert_compile_and_real_match_hlo(self, test_name, *extra_args):
)
train.main(train_argv)
shutil.move(local_landing_dir, train_dump_dir)
jax.clear_caches()
# jax.clear_caches()

# Generate train_compile.py HLO
os.makedirs(local_landing_dir, exist_ok=True)
Expand All @@ -142,7 +144,7 @@ def assert_compile_and_real_match_hlo(self, test_name, *extra_args):
compile_argv = (None, get_test_config_path()) + tuple(shared_args) + tuple(aot_args)
train_compile.main(compile_argv)
shutil.move(local_landing_dir, compile_dump_dir)
jax.clear_caches()
# jax.clear_caches()

# Compare
compile_hlo, real_hlo = self.find_HLO_files(compile_dump_dir, train_dump_dir)
Expand Down Expand Up @@ -194,7 +196,11 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
get_test_config_path(),
f"dump_jaxpr_local_dir={train_dump_dir}",
) + tuple(shared_args)
start_time = datetime.datetime.now()
train.main(train_argv)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train cost {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train cost {seconds_elapsed} secs")
jax.clear_caches()

# Run train_compile.py and dump jaxpr
Expand All @@ -206,7 +212,11 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
f"compile_topology={topology}",
"compile_topology_num_slices=1",
) + tuple(shared_args)
start_time = datetime.datetime.now()
train_compile.main(compile_argv)
seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds()
print(f"train_compile cost {seconds_elapsed} secs")
warnings.warn(f"DEBUG: train_compile cost {seconds_elapsed} secs")
jax.clear_caches()

# Compare results
Expand All @@ -218,5 +228,6 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args):
)

@pytest.mark.tpu_only
@pytest.mark.filterwarnings("always::UserWarning")
def test_default_jaxpr_match(self):
self.assert_compile_and_real_match_jaxpr("default_run")
Loading