Skip to content

Commit 1a51440

Browse files
committed
single host run works
1 parent 2c9349f commit 1a51440

5 files changed

Lines changed: 186 additions & 130 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ checkpoint_storage_use_ocdbt: False # For Pathways
8181
checkpoint_storage_use_zarr3: False # For Pathways
8282
use_pathways: True
8383
log_period: 20
84+
convert_checkpoint_if_possible: True
8485

8586
# ====== Debugging ======
8687
debug:

src/maxtext/configs/pyconfig.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,6 @@ def _resolve_or_infer_addl_config(**kwargs):
106106
hf_access_token = os.environ.get("HF_TOKEN")
107107
if hf_access_token:
108108
inferred_kwargs["hf_access_token"] = hf_access_token
109-
breakpoint()
110109

111110

112111
return inferred_kwargs

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,6 +1793,11 @@ class DerivedValues(BaseModel):
17931793
None,
17941794
description="The full path to the checkpoint directory, derived from `run_name`.",
17951795
)
1796+
convert_checkpoint_if_possible:bool = Field(
1797+
False,
1798+
description="Whether to convert checkpoint on the fly if not provided via\
1799+
load_parameters_path or base_output_directory"
1800+
)
17961801
metrics_dir: None | str = Field(
17971802
None,
17981803
description="The full path to the metrics directory, derived from `run_name`.",

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 3 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from __future__ import annotations
4747
from typing import Sequence
4848

49-
import collections
5049
import grain
5150
import jax
5251
import json
@@ -77,32 +76,9 @@
7776
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
7877
from maxtext.trainers.post_train.rl import utils_rl
7978
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
80-
from maxtext.utils import max_logging, max_utils, maxtext_utils
79+
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils
8180
import maxtext as mt
8281

83-
def get_maxtext_model(config, devices=None):
84-
"""
85-
Load MaxText model with Tunix adapter.
86-
# Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
87-
# To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if
88-
# using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags:
89-
# export USE_PATHWAYS=1
90-
# python src/MaxText/checkpoint_conversion/to_maxtext.py \
91-
# --model_name="gemma2-2b" \
92-
# --base_output_directory="/path/to/your/output/directory" \
93-
# --scan_layers=True \
94-
# --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
95-
# --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
96-
# Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
97-
# load_parameters_path=/path/to/your/output/directory/0/items
98-
"""
99-
model, mesh = mt.from_pretrained(config, devices=devices)
100-
with mesh:
101-
use_no_op_mappings = "maxtext_config" in config.vllm_additional_config
102-
tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings)
103-
tunix_model.config = None
104-
return tunix_model, mesh
105-
10682

10783
def get_dataset(
10884
model_tokenizer, tmvp_config, data_dir, split="train", data_files=None, dataset_name=None
@@ -149,85 +125,6 @@ def get_dataset(
149125
return loaded_dataset
150126

151127

152-
def setup_configs_and_devices(argv: list[str], kwargs):
153-
"""Setup device allocation and configs for training and inference."""
154-
config = pyconfig.initialize_pydantic(argv, **kwargs)
155-
devices = jax.devices()
156-
if config.num_trainer_slices == -1 and config.num_samplers_slices == -1:
157-
max_logging.log("Running RL on a single slice")
158-
num_vms = len(devices) // config.chips_per_vm
159-
trainer_devices = devices
160-
sampler_devices = devices
161-
if num_vms >= 2 and config.use_pathways:
162-
# Multiple hosts with Pathways - potentially split devices for trainer and sampler
163-
# based on trainer_devices_fraction and sampler_devices_fraction
164-
max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.")
165-
num_devices = len(devices)
166-
num_trainer_devices = int(num_devices * config.trainer_devices_fraction)
167-
num_sampler_devices = int(num_devices * config.sampler_devices_fraction)
168-
trainer_devices = devices[:num_trainer_devices]
169-
sampler_devices = devices[num_devices - num_sampler_devices :]
170-
if config.trainer_devices_fraction != 1.0:
171-
max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices")
172-
if config.sampler_devices_fraction != 1.0:
173-
max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices")
174-
trainer_config = config
175-
sampler_config = config
176-
elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0:
177-
max_logging.log("Running RL with Multislice")
178-
devices_by_slice = collections.defaultdict(list)
179-
for d in devices:
180-
devices_by_slice[d.slice_index].append(d)
181-
slice_indices = sorted(devices_by_slice.keys())
182-
183-
if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices:
184-
raise ValueError("Not enough slices for trainer and samplers")
185-
186-
trainer_devices = []
187-
for i in range(config.num_trainer_slices):
188-
trainer_devices.extend(devices_by_slice[slice_indices[i]])
189-
190-
sampler_devices = []
191-
for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices):
192-
sampler_devices.extend(devices_by_slice[slice_indices[i]])
193-
194-
trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices
195-
trainer_fsdp = trainer_devices_per_slice
196-
tp = config.ici_tensor_parallelism
197-
if tp > 1:
198-
if trainer_devices_per_slice % tp != 0:
199-
raise ValueError(
200-
f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})"
201-
)
202-
if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice:
203-
raise ValueError(
204-
f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal "
205-
f"devices_per_slice ({trainer_devices_per_slice})"
206-
)
207-
trainer_fsdp = trainer_devices_per_slice // tp
208-
209-
trainer_update = {
210-
"num_slices": config.num_trainer_slices,
211-
"ici_fsdp_parallelism": trainer_fsdp,
212-
"ici_tensor_parallelism": tp,
213-
"dcn_data_parallelism": config.num_trainer_slices,
214-
}
215-
216-
sampler_update = {
217-
"num_slices": config.num_samplers_slices,
218-
"ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices,
219-
"ici_tensor_parallelism": -1,
220-
"dcn_data_parallelism": config.num_samplers_slices,
221-
}
222-
223-
trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update)
224-
sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update)
225-
226-
else:
227-
raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive")
228-
229-
return trainer_config, sampler_config, trainer_devices, sampler_devices
230-
231128

232129
def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
233130
"""Get rollout kwargs for vLLM rollout when using data parallelism."""
@@ -400,27 +297,6 @@ def _filter_long_prompts(x):
400297
return train_dataset, test_dataset
401298

402299

403-
def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices):
404-
"""Create reference and actor models and their respective meshes."""
405-
max_logging.log("Creating reference model and also meshes for reference and rollout")
406-
reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices)
407-
devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices)
408-
rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes)
409-
410-
if trainer_config.load_checkpoint_only_once:
411-
max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.")
412-
with reference_mesh:
413-
actor_base_model = nnx.clone(reference_model.base)
414-
use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config
415-
actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings)
416-
actor_model.config = None
417-
actor_mesh = reference_mesh
418-
else:
419-
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
420-
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
421-
422-
return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh
423-
424300

425301
def create_rl_components(
426302
trainer_config,
@@ -590,9 +466,9 @@ def rl_train(argv: Sequence[str], kwargs: dict):
590466
trainer_devices: JAX devices for the trainer.
591467
sampler_devices: JAX devices for the sampler.
592468
"""
593-
trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv, kwargs)
469+
trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(argv, kwargs)
594470

595-
reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes(
471+
reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes(
596472
trainer_config, sampler_config, trainer_devices, sampler_devices
597473
)
598474

0 commit comments

Comments
 (0)