|
46 | 46 | from __future__ import annotations |
47 | 47 | from typing import Sequence |
48 | 48 |
|
49 | | -import collections |
50 | 49 | import grain |
51 | 50 | import jax |
52 | 51 | import json |
|
77 | 76 | from maxtext.trainers.post_train.rl.evaluate_rl import evaluate |
78 | 77 | from maxtext.trainers.post_train.rl import utils_rl |
79 | 78 | from maxtext.input_pipeline.instruction_data_processing import load_template_from_file |
80 | | -from maxtext.utils import max_logging, max_utils, maxtext_utils |
| 79 | +from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils |
81 | 80 | import maxtext as mt |
82 | 81 |
|
83 | | -def get_maxtext_model(config, devices=None): |
84 | | - """ |
85 | | - Load MaxText model with Tunix adapter. |
86 | | - # Note: pass the path to your scanned checkpoint for 'load_parameters_path'. |
87 | | - # To create a scanned checkpoint, you can use /maxtext/src/MaxText/checkpoint_conversion/to_maxtext.py and if |
88 | | - # using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags: |
89 | | - # export USE_PATHWAYS=1 |
90 | | - # python src/MaxText/checkpoint_conversion/to_maxtext.py \ |
91 | | - # --model_name="gemma2-2b" \ |
92 | | - # --base_output_directory="/path/to/your/output/directory" \ |
93 | | - # --scan_layers=True \ |
94 | | - # --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \ |
95 | | - # --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS)) |
96 | | - # Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e., |
97 | | - # load_parameters_path=/path/to/your/output/directory/0/items |
98 | | - """ |
99 | | - model, mesh = mt.from_pretrained(config, devices=devices) |
100 | | - with mesh: |
101 | | - use_no_op_mappings = "maxtext_config" in config.vllm_additional_config |
102 | | - tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings) |
103 | | - tunix_model.config = None |
104 | | - return tunix_model, mesh |
105 | | - |
106 | 82 |
|
107 | 83 | def get_dataset( |
108 | 84 | model_tokenizer, tmvp_config, data_dir, split="train", data_files=None, dataset_name=None |
@@ -149,85 +125,6 @@ def get_dataset( |
149 | 125 | return loaded_dataset |
150 | 126 |
|
151 | 127 |
|
152 | | -def setup_configs_and_devices(argv: list[str], kwargs): |
153 | | - """Setup device allocation and configs for training and inference.""" |
154 | | - config = pyconfig.initialize_pydantic(argv, **kwargs) |
155 | | - devices = jax.devices() |
156 | | - if config.num_trainer_slices == -1 and config.num_samplers_slices == -1: |
157 | | - max_logging.log("Running RL on a single slice") |
158 | | - num_vms = len(devices) // config.chips_per_vm |
159 | | - trainer_devices = devices |
160 | | - sampler_devices = devices |
161 | | - if num_vms >= 2 and config.use_pathways: |
162 | | - # Multiple hosts with Pathways - potentially split devices for trainer and sampler |
163 | | - # based on trainer_devices_fraction and sampler_devices_fraction |
164 | | - max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.") |
165 | | - num_devices = len(devices) |
166 | | - num_trainer_devices = int(num_devices * config.trainer_devices_fraction) |
167 | | - num_sampler_devices = int(num_devices * config.sampler_devices_fraction) |
168 | | - trainer_devices = devices[:num_trainer_devices] |
169 | | - sampler_devices = devices[num_devices - num_sampler_devices :] |
170 | | - if config.trainer_devices_fraction != 1.0: |
171 | | - max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices") |
172 | | - if config.sampler_devices_fraction != 1.0: |
173 | | - max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices") |
174 | | - trainer_config = config |
175 | | - sampler_config = config |
176 | | - elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0: |
177 | | - max_logging.log("Running RL with Multislice") |
178 | | - devices_by_slice = collections.defaultdict(list) |
179 | | - for d in devices: |
180 | | - devices_by_slice[d.slice_index].append(d) |
181 | | - slice_indices = sorted(devices_by_slice.keys()) |
182 | | - |
183 | | - if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices: |
184 | | - raise ValueError("Not enough slices for trainer and samplers") |
185 | | - |
186 | | - trainer_devices = [] |
187 | | - for i in range(config.num_trainer_slices): |
188 | | - trainer_devices.extend(devices_by_slice[slice_indices[i]]) |
189 | | - |
190 | | - sampler_devices = [] |
191 | | - for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices): |
192 | | - sampler_devices.extend(devices_by_slice[slice_indices[i]]) |
193 | | - |
194 | | - trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices |
195 | | - trainer_fsdp = trainer_devices_per_slice |
196 | | - tp = config.ici_tensor_parallelism |
197 | | - if tp > 1: |
198 | | - if trainer_devices_per_slice % tp != 0: |
199 | | - raise ValueError( |
200 | | - f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})" |
201 | | - ) |
202 | | - if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice: |
203 | | - raise ValueError( |
204 | | - f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal " |
205 | | - f"devices_per_slice ({trainer_devices_per_slice})" |
206 | | - ) |
207 | | - trainer_fsdp = trainer_devices_per_slice // tp |
208 | | - |
209 | | - trainer_update = { |
210 | | - "num_slices": config.num_trainer_slices, |
211 | | - "ici_fsdp_parallelism": trainer_fsdp, |
212 | | - "ici_tensor_parallelism": tp, |
213 | | - "dcn_data_parallelism": config.num_trainer_slices, |
214 | | - } |
215 | | - |
216 | | - sampler_update = { |
217 | | - "num_slices": config.num_samplers_slices, |
218 | | - "ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices, |
219 | | - "ici_tensor_parallelism": -1, |
220 | | - "dcn_data_parallelism": config.num_samplers_slices, |
221 | | - } |
222 | | - |
223 | | - trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update) |
224 | | - sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update) |
225 | | - |
226 | | - else: |
227 | | - raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive") |
228 | | - |
229 | | - return trainer_config, sampler_config, trainer_devices, sampler_devices |
230 | | - |
231 | 128 |
|
232 | 129 | def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices): |
233 | 130 | """Get rollout kwargs for vLLM rollout when using data parallelism.""" |
@@ -400,27 +297,6 @@ def _filter_long_prompts(x): |
400 | 297 | return train_dataset, test_dataset |
401 | 298 |
|
402 | 299 |
|
403 | | -def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices): |
404 | | - """Create reference and actor models and their respective meshes.""" |
405 | | - max_logging.log("Creating reference model and also meshes for reference and rollout") |
406 | | - reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices) |
407 | | - devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices) |
408 | | - rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes) |
409 | | - |
410 | | - if trainer_config.load_checkpoint_only_once: |
411 | | - max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.") |
412 | | - with reference_mesh: |
413 | | - actor_base_model = nnx.clone(reference_model.base) |
414 | | - use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config |
415 | | - actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings) |
416 | | - actor_model.config = None |
417 | | - actor_mesh = reference_mesh |
418 | | - else: |
419 | | - max_logging.log("Creating policy model with same config as reference model on trainer mesh") |
420 | | - actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices) |
421 | | - |
422 | | - return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh |
423 | | - |
424 | 300 |
|
425 | 301 | def create_rl_components( |
426 | 302 | trainer_config, |
@@ -590,9 +466,9 @@ def rl_train(argv: Sequence[str], kwargs: dict): |
590 | 466 | trainer_devices: JAX devices for the trainer. |
591 | 467 | sampler_devices: JAX devices for the sampler. |
592 | 468 | """ |
593 | | - trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv, kwargs) |
| 469 | + trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(argv, kwargs) |
594 | 470 |
|
595 | | - reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes( |
| 471 | + reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes( |
596 | 472 | trainer_config, sampler_config, trainer_devices, sampler_devices |
597 | 473 | ) |
598 | 474 |
|
|
0 commit comments