diff --git a/src/maxtext/configs/post_train/rl.yml b/src/maxtext/configs/post_train/rl.yml index c31af3f770..f5ce74ced0 100644 --- a/src/maxtext/configs/post_train/rl.yml +++ b/src/maxtext/configs/post_train/rl.yml @@ -96,6 +96,7 @@ checkpoint_storage_use_ocdbt: False # For Pathways checkpoint_storage_use_zarr3: False # For Pathways use_pathways: True log_period: 20 +convert_checkpoint_if_possible: True # ====== Debugging ====== debug: diff --git a/src/maxtext/configs/pyconfig.py b/src/maxtext/configs/pyconfig.py index ad8142748a..c562a05357 100644 --- a/src/maxtext/configs/pyconfig.py +++ b/src/maxtext/configs/pyconfig.py @@ -77,16 +77,48 @@ def _module_from_path(path: str) -> str | None: return None -def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]: +def _resolve_or_infer_config(argv: list[str] | None = None, **kwargs) -> tuple[str, list[str]]: """Resolves or infers config file path from module.""" + if argv is None: + argv = [""] + + if kwargs.get("base_config"): + logger.info("Using config : %s", kwargs["base_config"]) + return resolve_config_path(kwargs["base_config"]), argv[1:] + + # if passing at least two arguments via list (no kwargs), then we have to specify + # first one as either "" or python script like train_rl.py or train.py + # the second argument is the yaml file if len(argv) >= 2 and argv[1].endswith(".yml"): return resolve_config_path(argv[1]), argv[2:] - module = _module_from_path(argv[0]) + module = _module_from_path(argv[0]) if len(argv) > 0 else None if module not in _CONFIG_FILE_MAPPING: - raise ValueError(f"No config file provided and no default config found for module '{module}'") - config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) - logger.warning("No config file provided, using default config mapping: %s", config_path) - return config_path, argv[1:] + config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml") + logger.warning("No config file provided and no default config found for module '%s', using base.yml", module) + else: + config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module]) + logger.warning("No config file provided, using default config mapping: %s", config_path) + remaining_argv = argv[1:] + + return config_path, remaining_argv + + +def _resolve_or_infer_addl_config(**kwargs): + """Resolves or infers more configs from module.""" + inferred_kwargs = {} + # if base_output_directory key is not seen + if not kwargs.get("base_output_directory"): + max_logging.warning("base_output_directory is not provided; Using local directory called maxtext_output") + base_output_directory = os.path.abspath("maxtext_output") + inferred_kwargs["base_output_directory"] = base_output_directory + + # if hf_access_token key is not seen + if not kwargs.get("hf_access_token"): + hf_access_token = os.environ.get("HF_TOKEN") + if hf_access_token: + inferred_kwargs["hf_access_token"] = hf_access_token + + return inferred_kwargs def yaml_key_to_env_key(s: str) -> str: @@ -289,19 +321,19 @@ def get_keys(self) -> dict[str, Any]: return self._flat_config -def initialize(argv: list[str], **kwargs) -> HyperParameters: +def initialize(argv: list[str] | None = None, **kwargs) -> HyperParameters: """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.""" pydantic_config = initialize_pydantic(argv, **kwargs) config = HyperParameters(pydantic_config) return config -def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: +def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfig: """Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides. Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters` """ # 1. Load base and inherited configs from file(s) - config_path, cli_args = _resolve_or_infer_config(argv) + config_path, cli_args = _resolve_or_infer_config(argv, **kwargs) base_yml_config = _load_config(config_path) # 2. Get overrides from CLI and kwargs @@ -309,8 +341,15 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: kwargs_cfg = omegaconf.OmegaConf.create(kwargs) overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg) - # 3. Handle model-specific config + temp_cfg1 = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg) + # 3.1. infer more configs if possible + temp_cfg1 = _resolve_or_infer_addl_config(**temp_cfg1) + # update overrides_cfg with temp_cfg1 + overrides_cfg = omegaconf.OmegaConf.merge(overrides_cfg, temp_cfg1) temp_cfg = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg) + + # 3.2. Handle model-specific config + model_name = temp_cfg.get("model_name", "default") # The architecture for -Instruct v/s base models are the same, so for identifying the # architecture we replace "-Instruct" from the model_name and get the base model name @@ -437,3 +476,13 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig: # Shim for backward compatibility with pyconfig_deprecated_test.py validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys __all__ = ["initialize", "initialize_pydantic"] + + +class _CallablePyconfigModule(sys.modules[__name__].__class__): + """Allows calling the module directly as mt.pyconfig().""" + + def __call__(self, argv: list[str] | None = None, **kwargs) -> HyperParameters: + return initialize(argv, **kwargs) + + +sys.modules[__name__].__class__ = _CallablePyconfigModule diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index c734ac2f87..7c26c8f23b 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -1931,6 +1931,11 @@ class DerivedValues(BaseModel): None, description="The full path to the checkpoint directory, derived from `run_name`.", ) + convert_checkpoint_if_possible: bool = Field( + False, + description="Whether to convert checkpoint on the fly if not provided via\ + load_parameters_path or base_output_directory", + ) metrics_dir: None | str = Field( None, description="The full path to the metrics directory, derived from `run_name`.", diff --git a/src/maxtext/examples/rl_llama3_demo.ipynb b/src/maxtext/examples/rl_llama3_demo.ipynb index ef5cdb9fc2..f0d18b12d0 100644 --- a/src/maxtext/examples/rl_llama3_demo.ipynb +++ b/src/maxtext/examples/rl_llama3_demo.ipynb @@ -135,27 +135,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "import datetime\n", - "import os\n", - "import sys\n", - "import subprocess\n", - "from pathlib import Path\n", - "from huggingface_hub import login\n", - "from etils import epath\n", - "import jax\n", - "\n", - "from maxtext.trainers.post_train.rl.train_rl import rl_train, setup_configs_and_devices\n", - "from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n", - "\n", - "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n", - "os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n", - "# Suppress vLLM logging with a severity level below ERROR\n", - "os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n", - "\n", - "\n", - "print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" - ] + "source": "import datetime\nimport os\nimport sys\nimport subprocess\nfrom pathlib import Path\nfrom huggingface_hub import login\nfrom etils import epath\nimport jax\n\nfrom maxtext.trainers.post_train.rl.train_rl import rl_train\nfrom maxtext.utils.model_creation_utils import setup_configs_and_devices\nfrom maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n\nos.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\nos.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n# Suppress vLLM logging with a severity level below ERROR\nos.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n\n\nprint(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")" }, { "cell_type": "code", @@ -386,4 +366,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/src/maxtext/inference/vllm_decode.py b/src/maxtext/inference/vllm_decode.py index 66af92e209..f3628d054b 100644 --- a/src/maxtext/inference/vllm_decode.py +++ b/src/maxtext/inference/vllm_decode.py @@ -241,7 +241,7 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.initialize(argv) if FLAGS.use_tunix: - maxtext_model, mesh = model_creation_utils.create_nnx_model(config) + maxtext_model, mesh = model_creation_utils.from_pretrained(config) decode_with_tunix(config, model=maxtext_model, mesh=mesh) else: decode_with_vllm(config) diff --git a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py index a4cd924672..bd38bb0650 100644 --- a/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py +++ b/src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py @@ -251,7 +251,7 @@ def load_weights(self, rng_key: jax.Array) -> None: return with self.mesh, nn.logical_axis_rules(""): - model, _ = model_creation_utils.create_nnx_model( + model = model_creation_utils.from_pretrained( self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key ) self.model = nnx.data(model) diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 5681119d86..93aafa2bd1 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -458,7 +458,7 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh) The loaded MaxText model. """ max_logging.log(f"Initializing model: {config.model_name}...") - model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh) + model = model_creation_utils.from_pretrained(config, mesh=mesh) return model diff --git a/src/maxtext/trainers/post_train/rl/train_rl.py b/src/maxtext/trainers/post_train/rl/train_rl.py index 8663add55f..cccfa7439b 100644 --- a/src/maxtext/trainers/post_train/rl/train_rl.py +++ b/src/maxtext/trainers/post_train/rl/train_rl.py @@ -47,7 +47,6 @@ from functools import wraps from typing import Sequence -import collections import grain import jax import json @@ -60,7 +59,6 @@ from absl import logging as absl_logging from etils import epath from flax import nnx -from jax.sharding import Mesh from orbax import checkpoint as ocp from pprint import pprint from transformers import AutoTokenizer @@ -74,35 +72,10 @@ from maxtext.configs import pyconfig from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR -from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from maxtext.trainers.post_train.rl.evaluate_rl import evaluate from maxtext.trainers.post_train.rl import utils_rl from maxtext.input_pipeline.instruction_data_processing import load_template_from_file -from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils - - -def get_maxtext_model(config, devices=None): - """ - Load MaxText model with Tunix adapter. - # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. - # To create a scanned checkpoint, you can use /maxtext/src/maxtext/checkpoint_conversion/to_maxtext.py and if - # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags: - # export USE_PATHWAYS=1 - # python src/maxtext/checkpoint_conversion/to_maxtext.py \ - # --model_name="gemma2-2b" \ - # --base_output_directory="/path/to/your/output/directory" \ - # --scan_layers=True \ - # --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \ - # --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) - # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., - # load_parameters_path=/path/to/your/output/directory/0/items - """ - model, mesh = model_creation_utils.create_nnx_model(config, devices=devices) - with mesh: - use_no_op_mappings = "maxtext_config" in config.vllm_additional_config - tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) - tunix_model.config = None - return tunix_model, mesh +from maxtext.utils import max_logging, max_utils, model_creation_utils def get_dataset( @@ -150,86 +123,6 @@ def get_dataset( return loaded_dataset -def setup_configs_and_devices(argv: list[str]): - """Setup device allocation and configs for training and inference.""" - config = pyconfig.initialize_pydantic(argv) - devices = jax.devices() - if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: - max_logging.log("Running RL on a single slice") - num_vms = len(devices) // config.chips_per_vm - trainer_devices = devices - sampler_devices = devices - if num_vms >= 2 and config.use_pathways: - # Multiple hosts with Pathways - potentially split devices for trainer and sampler - # based on trainer_devices_fraction and sampler_devices_fraction - max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") - num_devices = len(devices) - num_trainer_devices = int(num_devices * config.trainer_devices_fraction) - num_sampler_devices = int(num_devices * config.sampler_devices_fraction) - trainer_devices = devices[:num_trainer_devices] - sampler_devices = devices[num_devices - num_sampler_devices :] - if config.trainer_devices_fraction != 1.0: - max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") - if config.sampler_devices_fraction != 1.0: - max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") - trainer_config = config - sampler_config = config - elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0: - max_logging.log("Running RL with Multislice") - devices_by_slice = collections.defaultdict(list) - for d in devices: - devices_by_slice[d.slice_index].append(d) - slice_indices = sorted(devices_by_slice.keys()) - - if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices: - raise ValueError("Not enough slices for trainer and samplers") - - trainer_devices = [] - for i in range(config.num_trainer_slices): - trainer_devices.extend(devices_by_slice[slice_indices[i]]) - - sampler_devices = [] - for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices): - sampler_devices.extend(devices_by_slice[slice_indices[i]]) - - trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices - trainer_fsdp = trainer_devices_per_slice - tp = config.ici_tensor_parallelism - if tp > 1: - if trainer_devices_per_slice % tp != 0: - raise ValueError( - f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})" - ) - if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice: - raise ValueError( - f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal " - f"devices_per_slice ({trainer_devices_per_slice})" - ) - trainer_fsdp = trainer_devices_per_slice // tp - - trainer_update = { - "num_slices": config.num_trainer_slices, - "ici_fsdp_parallelism": trainer_fsdp, - "ici_tensor_parallelism": tp, - "dcn_data_parallelism": config.num_trainer_slices, - } - - sampler_update = { - "num_slices": config.num_samplers_slices, - "ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices, - "ici_tensor_parallelism": -1, - "dcn_data_parallelism": config.num_samplers_slices, - } - - trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update) - sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update) - - else: - raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") - - return trainer_config, sampler_config, trainer_devices, sampler_devices - - def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices): """Get rollout kwargs for vLLM rollout when using data parallelism.""" dp = sampler_config.rollout_data_parallelism @@ -411,28 +304,6 @@ def _use_raw_prompt(x): return train_dataset, test_dataset -def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices): - """Create reference and actor models and their respective meshes.""" - max_logging.log("Creating reference model and also meshes for reference and rollout") - reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices) - devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices) - rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes) - - if trainer_config.load_checkpoint_only_once: - max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.") - with reference_mesh: - actor_base_model = nnx.clone(reference_model.base) - use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config - actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) - actor_model.config = None - actor_mesh = reference_mesh - else: - max_logging.log("Creating policy model with same config as reference model on trainer mesh") - actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) - - return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh - - def create_rl_components( trainer_config, sampler_config, @@ -652,7 +523,7 @@ def _reward_fn(**kwargs): return rl_cluster, rl_trainer, optimizer -def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): +def rl_train(argv: Sequence[str], kwargs: dict): """ Run RL training with the provided configuration. @@ -662,13 +533,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): trainer_devices: JAX devices for the trainer. sampler_devices: JAX devices for the sampler. """ + trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices( + argv, kwargs + ) + + reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes( + trainer_config, sampler_config, trainer_devices, sampler_devices + ) + if not trainer_config.debug.rl: # Apply filter to suppress noisy logs noise_filter = max_logging.NoisyLogFilter() logging.getLogger().addFilter(noise_filter) absl_logging.get_absl_logger().addFilter(noise_filter) + os.environ["VLLM_LOGGING_LEVEL"] = "ERROR" - max_logging.log("Starting RL Training") if not epath.Path(trainer_config.tensorboard_dir).exists(): epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True) @@ -692,10 +571,6 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): break pprint(ele) - reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes( - trainer_config, sampler_config, trainer_devices, sampler_devices - ) - if trainer_config.debug.rl: max_logging.log("Reference Model initialized successfully") nnx.display(reference_model) @@ -703,6 +578,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): max_logging.log("Policy Model initialized successfully") nnx.display(actor_model) max_logging.log(f"Policy mesh shape: {actor_mesh.shape}") + max_logging.log(f"Rollout_mesh shape: {rollout_mesh.shape}") rl_cluster, rl_trainer, _ = create_rl_components( trainer_config, @@ -759,18 +635,18 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices): max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%") -def main(argv: Sequence[str]) -> None: +def main(argv: Sequence[str], kwargs: dict = None) -> None: """Main function to run RL training. Args: argv: Command-line arguments. """ + kwargs = kwargs or {} pathwaysutils.initialize() os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" max_utils.print_system_information() - trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv) - rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices) + rl_train(argv, kwargs) if __name__ == "__main__": diff --git a/src/maxtext/trainers/post_train/sft/train_sft.py b/src/maxtext/trainers/post_train/sft/train_sft.py index 90595a05fd..e2b9407a63 100644 --- a/src/maxtext/trainers/post_train/sft/train_sft.py +++ b/src/maxtext/trainers/post_train/sft/train_sft.py @@ -146,7 +146,7 @@ def setup_trainer_state(mt_config, goodput_recorder=None): tunix_config = get_tunix_config(mt_config) with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT): - model, mesh = model_creation_utils.create_nnx_model(mt_config) + model, mesh = model_creation_utils.from_pretrained(mt_config) learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config) # pass in model for muon optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model) diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index 49fb9d3490..8570514b2c 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -16,8 +16,12 @@ """ Utils that are only interesting for creating a model in MaxText. """ import dataclasses +import collections from collections.abc import Sequence from functools import partial +import os +import subprocess +import sys from typing import overload from etils import epath @@ -33,6 +37,7 @@ from maxtext.utils import max_logging from maxtext.utils import max_utils from maxtext.utils import maxtext_utils +from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter from orbax import checkpoint as ocp try: @@ -258,8 +263,187 @@ def _create_model(rng_key=None): return _create_model_partial, abstract_model -def create_nnx_model(config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None): +def setup_configs_and_devices(argv: list[str] | None = None, kwargs: dict | None = None, **extra_kwargs): + """Setup device allocation and configs for training and inference. + This API is particularly useful for Reinforcement Learning where we might split the available + devices into separate mesh for trainer and sampler + """ + if argv is None: + argv = [""] + + combined_kwargs = dict(kwargs) if kwargs else {} + combined_kwargs.update(extra_kwargs) + config = pyconfig.initialize_pydantic(argv, **combined_kwargs) + devices = jax.devices() + if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: + max_logging.log("Running on a single slice") + num_vms = len(devices) // config.chips_per_vm + trainer_devices = devices + sampler_devices = devices + if num_vms >= 2 and config.use_pathways: + # Multiple hosts with Pathways - potentially split devices for trainer and sampler + # based on trainer_devices_fraction and sampler_devices_fraction + max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") + num_devices = len(devices) + num_trainer_devices = int(num_devices * config.trainer_devices_fraction) + num_sampler_devices = int(num_devices * config.sampler_devices_fraction) + trainer_devices = devices[:num_trainer_devices] + sampler_devices = devices[num_devices - num_sampler_devices :] + if config.trainer_devices_fraction != 1.0: + max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") + if config.sampler_devices_fraction != 1.0: + max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") + trainer_config = config + sampler_config = config + elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0: + max_logging.log("Running with Multislice") + devices_by_slice = collections.defaultdict(list) + for d in devices: + devices_by_slice[d.slice_index].append(d) + slice_indices = sorted(devices_by_slice.keys()) + + if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices: + raise ValueError("Not enough slices for trainer and samplers") + + trainer_devices = [] + for i in range(config.num_trainer_slices): + trainer_devices.extend(devices_by_slice[slice_indices[i]]) + + sampler_devices = [] + for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices): + sampler_devices.extend(devices_by_slice[slice_indices[i]]) + + trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices + trainer_fsdp = trainer_devices_per_slice + tp = config.ici_tensor_parallelism + if tp > 1: + if trainer_devices_per_slice % tp != 0: + raise ValueError( + f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})" + ) + if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice: + raise ValueError( + f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal " + f"devices_per_slice ({trainer_devices_per_slice})" + ) + trainer_fsdp = trainer_devices_per_slice // tp + + trainer_kwargs = dict(combined_kwargs) + trainer_kwargs.update( + { + "num_slices": config.num_trainer_slices, + "ici_fsdp_parallelism": trainer_fsdp, + "ici_tensor_parallelism": tp, + "dcn_data_parallelism": config.num_trainer_slices, + } + ) + + sampler_kwargs = dict(combined_kwargs) + sampler_kwargs.update( + { + "num_slices": config.num_samplers_slices, + "ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices, + "ici_tensor_parallelism": -1, + "dcn_data_parallelism": config.num_samplers_slices, + } + ) + + trainer_config = pyconfig.initialize_pydantic(argv, **trainer_kwargs) + sampler_config = pyconfig.initialize_pydantic(argv, **sampler_kwargs) + + else: + raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") + + return trainer_config, sampler_config, trainer_devices, sampler_devices + + +def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices): + """Create reference and actor models and their respective meshes. + This API is particularly useful for Reinforcement Learning (RL) where we need 2 models (wrapped in TunixMaxTextAdapter + so that they are compatible with default Tunix APIs) and meshes for reference, actor and rollout (which can be disjoint + in case of disaggreggated RL training). + """ + max_logging.log("Creating reference model and also meshes for reference and rollout") + reference_model, reference_mesh = from_pretrained(trainer_config, devices=trainer_devices, wrap_with_tunix_adapter=True) + devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices) + rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes) + + if trainer_config.load_checkpoint_only_once: + max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.") + with reference_mesh: + actor_base_model = nnx.clone(reference_model.base) + use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config + # TunixMaxTextAdapter wraps MaxText models to be compatible with Tunix's default APIs + # The weight mappings for vllm (which is interfaced to from MaxText via Tunix) are model specific. + # The mappings are defined inside src/maxtext/integration/tunix/weight_mapping + actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) + actor_model.config = None + actor_mesh = reference_mesh + else: + max_logging.log("Creating policy model with same config as reference model on trainer mesh") + actor_model, actor_mesh = from_pretrained(trainer_config, devices=trainer_devices, wrap_with_tunix_adapter=True) + + return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh + + +def from_pretrained( + config, mesh=None, devices=None, model_mode=MODEL_MODE_TRAIN, rng_key=None, wrap_with_tunix_adapter=False +): """Creates a NNX model with sharded parameters, possibly loading from a checkpoint.""" + original_mesh = mesh + if config.convert_checkpoint_if_possible and not config.load_parameters_path: + if not (epath.Path(config.base_output_directory) / "0" / "items").exists(): + # Try to convert checkpoint on the fly + if not config.hf_access_token: + raise ValueError("hf_access_token must be provided when not providing a pre-existing checkpoint") + + # Only process 0 performs the conversion; other processes wait at the barrier below. + # Otherwise every host would race to download from HF and concurrently write the same + # GCS checkpoint, wasting work and risking corruption. + if jax.process_index() == 0: + max_logging.warning("Checkpoint path is not provided, converting checkpoint to orbax format for MaxText") + + # This is an empirically derived value. This simulated devices is needed such that orbax creates multiple + # shards of the checkpoint. Without simulating multiple devices, when running on CPU orbax created a single + # giant checkpoint file, which could lead to OOM on TPU generations with smaller memory. + simulated_cpu_devices_count = 16 + + # Run the conversion in a completely isolated subprocess so its CPU + # JAX/XLA requirements do not interfere with the parent's Pathways TPU mesh. + conversion_env = os.environ.copy() + conversion_env["JAX_PLATFORMS"] = "cpu" + # conversion_env["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={simulated_cpu_devices_count}" + + to_maxtext_cmd = [ + sys.executable, + "-m", + "maxtext.checkpoint_conversion.to_maxtext", + ] + [ + f"model_name={config.model_name}", + f"base_output_directory={config.base_output_directory}", + f"scan_layers={config.scan_layers}", + f"hf_access_token={config.hf_access_token}", + "use_multimodal=false", + "skip_jax_distributed_system=True", + "--lazy_load_tensors=True", + f"--simulated_cpu_devices_count={simulated_cpu_devices_count}", + ] + + try: + subprocess.run(to_maxtext_cmd, env=conversion_env, check=True) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Checkpoint conversion failed with exit code {e.returncode}") from e + + jax.experimental.multihost_utils.sync_global_devices("from_pretrained_convert_checkpoint") + load_parameters_path = epath.Path(config.base_output_directory) / "0" / "items" + # Create a copied Pydantic model with the updated values + pydantic_config = getattr(config, "_pydantic_config", config) + new_config = pydantic_config.model_copy( + update={ + "load_parameters_path": load_parameters_path, + } + ) + config = pyconfig.HyperParameters(new_config) def _create_model(mesh: Mesh | None = None, model_mode: str = MODEL_MODE_TRAIN, rng_key: jax.Array | None = None): if rng_key is None: @@ -406,4 +590,13 @@ def create_sharded_state(): except Exception as e: raise ValueError(f"Checkpoint loading failed: {e}") from e - return model, mesh + if wrap_with_tunix_adapter: + with mesh: + use_no_op_mappings = "maxtext_config" in config.vllm_additional_config + model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) + model.config = None + + if original_mesh: + return model + else: + return model, mesh diff --git a/tests/post_training/unit/train_rl_test.py b/tests/post_training/unit/train_rl_test.py index ed2018d657..8f07b01433 100644 --- a/tests/post_training/unit/train_rl_test.py +++ b/tests/post_training/unit/train_rl_test.py @@ -29,6 +29,7 @@ "maxtext.trainers.post_train.rl.train_rl", reason="Tunix is not installed on the GPU image", ) +from maxtext.utils import model_creation_utils def _get_mock_devices(devices_per_slice, num_slices=1): @@ -62,9 +63,9 @@ def test_setup_configs_and_devices_pathways_split(self): # Following the pattern in distillation_checkpointing_test.py for mocking jax objects with ( mock.patch.object(jax, "devices", return_value=mock_devices), - mock.patch("maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", return_value=mock_config), + mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config), ): - trainer_config, sampler_config, trainer_devices, sampler_devices = train_rl.setup_configs_and_devices( + trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices( ["dummy", "dummy"] ) @@ -91,9 +92,9 @@ def test_setup_configs_and_devices_pathways_fractional_split(self): with ( mock.patch.object(jax, "devices", return_value=mock_devices), - mock.patch("maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", return_value=mock_config), + mock.patch("maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", return_value=mock_config), ): - _, _, trainer_devices, sampler_devices = train_rl.setup_configs_and_devices(["dummy", "dummy"]) + _, _, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(["dummy", "dummy"]) self.assertEqual(len(trainer_devices), 2) self.assertEqual(len(sampler_devices), 6) @@ -118,12 +119,12 @@ def side_effect(argv, **kwargs): with ( mock.patch.object(jax, "devices", return_value=mock_devices), mock.patch( - "maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", + "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", side_effect=side_effect, ), ): with self.assertRaisesRegex(ValueError, "Not enough slices for trainer and samplers"): - train_rl.setup_configs_and_devices(["dummy", "dummy"]) + model_creation_utils.setup_configs_and_devices(["dummy", "dummy"]) @pytest.mark.cpu_only def test_setup_configs_and_devices_multislice_invalid_tp(self): @@ -145,12 +146,12 @@ def side_effect(argv, **kwargs): with ( mock.patch.object(jax, "devices", return_value=mock_devices), mock.patch( - "maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", + "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", side_effect=side_effect, ), ): with self.assertRaisesRegex(ValueError, "must be divisible by tensor parallelism"): - train_rl.setup_configs_and_devices(["dummy", "dummy"]) + model_creation_utils.setup_configs_and_devices(["dummy", "dummy"]) @pytest.mark.cpu_only def test_setup_configs_and_devices_multislice_invalid_tp_fsdp(self): @@ -172,12 +173,12 @@ def side_effect(argv, **kwargs): with ( mock.patch.object(jax, "devices", return_value=mock_devices), mock.patch( - "maxtext.trainers.post_train.rl.train_rl.pyconfig.initialize_pydantic", + "maxtext.utils.model_creation_utils.pyconfig.initialize_pydantic", side_effect=side_effect, ), ): with self.assertRaisesRegex(ValueError, "must equal devices_per_slice"): - train_rl.setup_configs_and_devices(["dummy", "dummy"]) + model_creation_utils.setup_configs_and_devices(["dummy", "dummy"]) @pytest.mark.cpu_only def test_get_rollout_kwargs_no_dp(self): diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py index bed2e699fa..f785fb15fa 100644 --- a/tests/unit/model_creation_utils_test.py +++ b/tests/unit/model_creation_utils_test.py @@ -292,39 +292,38 @@ def test_prefill_model_mode(self): class TestCreateNnxModel(unittest.TestCase): - """Tests for create_nnx_model().""" + """Tests for from_pretrained().""" def setUp(self): self.config = _make_config() self.mesh = _make_mesh(self.config) def test_no_checkpoint_returns_model_and_mesh(self): - """Without load_parameters_path, should return (model, mesh) cleanly.""" - model, mesh = model_creation_utils.create_nnx_model(self.config, self.mesh) + """Without load_parameters_path, should return the model cleanly.""" + model = model_creation_utils.from_pretrained(self.config, self.mesh) self.assertIsInstance(model, models.Transformer) - self.assertIsInstance(mesh, Mesh) def test_mesh_none_uses_abstract_model_mesh(self): """When mesh=None is passed, the function resolves it from the abstract model.""" - model, mesh = model_creation_utils.create_nnx_model(self.config, mesh=None) + model, mesh = model_creation_utils.from_pretrained(self.config, mesh=None) self.assertIsInstance(model, models.Transformer) self.assertIsInstance(mesh, Mesh) def test_explicit_rng_key(self): """An explicit rng_key should be accepted without error.""" rng_key = jax.random.PRNGKey(99) - model, _ = model_creation_utils.create_nnx_model(self.config, self.mesh, rng_key=rng_key) + model = model_creation_utils.from_pretrained(self.config, self.mesh, rng_key=rng_key) self.assertIsInstance(model, models.Transformer) def test_inference_mode_disables_dropout_rng(self): """MODEL_MODE_PREFILL should create rngs without a dropout key.""" - model, _ = model_creation_utils.create_nnx_model(self.config, self.mesh, model_mode=MODEL_MODE_PREFILL) + model = model_creation_utils.from_pretrained(self.config, self.mesh, model_mode=MODEL_MODE_PREFILL) self.assertIsInstance(model, models.Transformer) def test_debug_sharding_flag(self): """debug_sharding=True should execute the sharding-print path without error.""" cfg = _make_config(debug_sharding=True) - model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) + model = model_creation_utils.from_pretrained(cfg, self.mesh) self.assertIsInstance(model, models.Transformer) # ---- checkpoint loading: mocked paths ---- @@ -367,7 +366,7 @@ def test_load_nnx_checkpoint(self, mock_ocp): mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/nnx_ckpt") - model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) + model = model_creation_utils.from_pretrained(cfg, self.mesh) self.assertIsInstance(model, models.Transformer) @patch("maxtext.utils.model_creation_utils.ocp") @@ -395,7 +394,7 @@ def test_load_linen_checkpoint(self, mock_ocp): mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/linen_ckpt") - model, _ = model_creation_utils.create_nnx_model(cfg, self.mesh) + model = model_creation_utils.from_pretrained(cfg, self.mesh) self.assertIsInstance(model, models.Transformer) @patch("maxtext.utils.model_creation_utils.ocp") @@ -408,7 +407,7 @@ def test_checkpoint_load_error_raises_value_error(self, mock_ocp): cfg = _make_config(enable_checkpointing=True, load_parameters_path="gs://fake/bad_ckpt") with self.assertRaises(ValueError): - model_creation_utils.create_nnx_model(cfg, self.mesh) + model_creation_utils.from_pretrained(cfg, self.mesh) if __name__ == "__main__": diff --git a/tests/unit/pyconfig_test.py b/tests/unit/pyconfig_test.py index a7b1379066..26a6bb1fcc 100644 --- a/tests/unit/pyconfig_test.py +++ b/tests/unit/pyconfig_test.py @@ -149,9 +149,10 @@ def test_hlo_dump_module_names_none_coercion(self): self.assertEqual(config.dump_hlo_local_module_name, "") self.assertEqual(config.dump_hlo_module_name, "") - def test_unknown_module_raises(self): - with self.assertRaises(ValueError): - pyconfig.initialize_pydantic(["/custom_rl/module.py", "run_name=test"]) + def test_unknown_module_falls_back_to_base_yml(self): + """An unknown module should fall back to base.yml with a warning (not raise).""" + config = pyconfig.initialize_pydantic(["/custom_rl/module.py", "run_name=test", "skip_jax_distributed_system=True"]) + self.assertEqual(config.run_name, "test") if __name__ == "__main__":