diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 002934d179..4e61f72d21 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -23,6 +23,7 @@ import datetime import functools import os +import warnings from absl import app @@ -524,6 +525,7 @@ def train_loop(config, recorder, state=None): params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) with jax.set_mesh(mesh), mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + start_time = datetime.datetime.now() p_train_step, p_eval_step = train_utils.jit_train_and_eval_step( config, model, @@ -535,6 +537,9 @@ def train_loop(config, recorder, state=None): eval_data_iterator, params_shardings, ) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train jit {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train jit cost {seconds_elapsed} secs") shaped_batch = maxtext_utils.get_shaped_batch(config) if config.shard_optimizer_over_data: state = sharding.maybe_shard_with_name(state, state_mesh_shardings, config.shard_mode) diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index a2981f67ed..73e63bdcc4 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -21,9 +21,12 @@ as you would on the target hardware. """ +import datetime import functools +import inspect import os import pickle +import warnings from typing import Sequence from absl import app @@ -64,11 +67,14 @@ def validate_config(config): def get_topology_mesh(config): """Get the target hardware devices, and create configured mesh with them""" + warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") if config.internal_compile: topology_devices = get_topology_desc( platform="tpu", topology_name=config.compile_topology, num_slices=config.compile_topology_num_slices ).devices + warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") else: + warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) if target_hardware.platform == "gpu": # Disable sharded autotuning. This is an optimization to distribute @@ -77,6 +83,7 @@ def get_topology_mesh(config): jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices) topology_devices = jax.devices() else: + warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") topology_devices = get_topology_desc( platform=target_hardware.platform, topology_name=target_hardware.topology_name, @@ -85,10 +92,14 @@ def get_topology_mesh(config): num_slices=config.compile_topology_num_slices, wrap=target_hardware.wrap, ).devices + warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") jax.config.update("jax_remove_size_one_mesh_axis_from_type", config.remove_size_one_mesh_axis_from_type) + warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices) + warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto topology_mesh = Mesh(topology_device_mesh, config.mesh_axes, axis_types=(mesh_axis_type,) * len(config.mesh_axes)) + warnings.warn(f"DEBUG: get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") return topology_mesh @@ -162,6 +173,7 @@ def jit_and_compile( # Use both jax.set_mesh (new API) and `with mesh:` (old API) so that drjax, # which reads from pxla.thread_resources.env.physical_mesh, can find the mesh. with jax.set_mesh(mesh), mesh, logical_axis_rules: + start_time = datetime.datetime.now() jitted = jax.jit( func, in_shardings=in_shardings, @@ -169,11 +181,24 @@ def jit_and_compile( static_argnums=static_argnums, donate_argnums=donate_argnums, ) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile jit {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile jit cost {seconds_elapsed} secs") maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args) + + start_time = datetime.datetime.now() lowered = jitted.lower(*func_input_args, **func_input_kwargs) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile lower {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile lower cost {seconds_elapsed} secs") # Import libtpu flags as compiler options. Defaults to empty dict if string is empty. compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags) + + start_time = datetime.datetime.now() compiled = lowered.compile(compiler_options=compiler_options) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile compile {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile compile cost {seconds_elapsed} secs") return compiled @@ -256,13 +281,20 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.initialize(argv) validate_config(config) + start_time = datetime.datetime.now() # Create target mesh + warnings.warn(f"DEBUG: before get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") topology_mesh = get_topology_mesh(config) + warnings.warn(f"DEBUG: after get_topology_mesh: {inspect.currentframe().f_lineno} datetime: {datetime.datetime.now()}") + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile get_topology_mesh {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile get_topology_mesh cost {seconds_elapsed} secs") # Print system information after building the compile topology to avoid # prematurely initializing the backend. max_utils.print_system_information() + start_time = datetime.datetime.now() # Get shaped inputs ( shaped_train_args, @@ -271,7 +303,11 @@ def main(argv: Sequence[str]) -> None: logical_annotations, model, ) = get_shaped_inputs(topology_mesh, config) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile get_shaped_inputs {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile get_shaped_inputs cost {seconds_elapsed} secs") + start_time = datetime.datetime.now() # Get data sharding data_sharding = sharding.get_input_data_sharding(config, topology_mesh) if config.enable_diloco: @@ -304,8 +340,12 @@ def main(argv: Sequence[str]) -> None: ) = maxtext_utils.get_functional_train_with_signature( train.train_step, data_sharding, state_mesh_shardings, model, config ) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile get_functional_train_with_signature {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile get_functional_train_with_signature cost {seconds_elapsed} secs") # print weights sharding info under debug sharding mode + start_time = datetime.datetime.now() if config.debug_sharding: max_utils.print_non_trivial_mesh_axis(topology_mesh) if config.pure_nnx: @@ -322,9 +362,13 @@ def main(argv: Sequence[str]) -> None: topology_mesh, logical_annotations.params, ) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile print_shardings_params {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile print_shardings_params cost {seconds_elapsed} secs") # Compile print("Jitting and compiling train step...", flush=True) + start_time = datetime.datetime.now() compiled = jit_and_compile( func_to_compile, shaped_train_args, @@ -337,9 +381,13 @@ def main(argv: Sequence[str]) -> None: config, nn_partitioning.axis_rules(config.logical_axis_rules), ) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile jit_and_compile {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile jit_and_compile cost {seconds_elapsed} secs") print("Jitting and compilation complete!", flush=True) # Serialize and save the compiled object + start_time = datetime.datetime.now() if config.compiled_trainstep_file != "": print("Saving compiled object...") save_compiled(compiled, config.compiled_trainstep_file) @@ -348,6 +396,12 @@ def main(argv: Sequence[str]) -> None: print(f"Cost analysis: {compiled.cost_analysis()}") print(f"Memory analysis: {compiled.memory_analysis()}") + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile save_compiled {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile save_compiled cost {seconds_elapsed} secs") + print("Jitting and compilation complete!", flush=True) + + start_time = datetime.datetime.now() # Dump HLO if requested if config.dump_hlo: gcs_utils.upload_dump( @@ -358,6 +412,10 @@ def main(argv: Sequence[str]) -> None: all_host_upload=config.dump_hlo_upload_all, ) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile upload dump_hlo {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile upload dump_hlo cost {seconds_elapsed} secs") + if __name__ == "__main__": app.run(main) diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 9e0a00c8e6..f851dbb6af 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -15,7 +15,9 @@ # pylint: disable=bare-except, consider-using-generator """Utils that are only interesting for training in MaxText.""" +import datetime import os +import warnings from functools import partial import jax @@ -176,9 +178,13 @@ def jit_train_and_eval_step( train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh) data_sharding = sharding.get_input_data_sharding(config, mesh) + start_time = datetime.datetime.now() p_train_step = jit_train_step( config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh ) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_util train_step jit {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_util train_step jit cost {seconds_elapsed} secs") p_eval_step = None if eval_data_iterator: p_eval_step = jit_eval_step(config, model, state_mesh_shardings, data_sharding, eval_step) diff --git a/tests/integration/aot_identical_test.py b/tests/integration/aot_identical_test.py index ca95593cf3..357ce3a876 100644 --- a/tests/integration/aot_identical_test.py +++ b/tests/integration/aot_identical_test.py @@ -18,12 +18,14 @@ training run (using train.py). """ +import datetime import tempfile import unittest import pytest import os import shutil import hashlib +import warnings import re import jax from tests.utils.test_helpers import get_test_config_path @@ -133,7 +135,7 @@ def assert_compile_and_real_match_hlo(self, test_name, *extra_args): ) train.main(train_argv) shutil.move(local_landing_dir, train_dump_dir) - jax.clear_caches() + # jax.clear_caches() # Generate train_compile.py HLO os.makedirs(local_landing_dir, exist_ok=True) @@ -142,7 +144,7 @@ def assert_compile_and_real_match_hlo(self, test_name, *extra_args): compile_argv = (None, get_test_config_path()) + tuple(shared_args) + tuple(aot_args) train_compile.main(compile_argv) shutil.move(local_landing_dir, compile_dump_dir) - jax.clear_caches() + # jax.clear_caches() # Compare compile_hlo, real_hlo = self.find_HLO_files(compile_dump_dir, train_dump_dir) @@ -194,7 +196,11 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args): get_test_config_path(), f"dump_jaxpr_local_dir={train_dump_dir}", ) + tuple(shared_args) + start_time = datetime.datetime.now() train.main(train_argv) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train cost {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train cost {seconds_elapsed} secs") jax.clear_caches() # Run train_compile.py and dump jaxpr @@ -206,7 +212,11 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args): f"compile_topology={topology}", "compile_topology_num_slices=1", ) + tuple(shared_args) + start_time = datetime.datetime.now() train_compile.main(compile_argv) + seconds_elapsed = (datetime.datetime.now() - start_time).total_seconds() + print(f"train_compile cost {seconds_elapsed} secs") + warnings.warn(f"DEBUG: train_compile cost {seconds_elapsed} secs") jax.clear_caches() # Compare results @@ -218,5 +228,6 @@ def assert_compile_and_real_match_jaxpr(self, test_name, *extra_args): ) @pytest.mark.tpu_only + @pytest.mark.filterwarnings("always::UserWarning") def test_default_jaxpr_match(self): self.assert_compile_and_real_match_jaxpr("default_run")