Skip to content

Commit d47d478

Browse files
committed
add from_pretrained as simple API
1 parent b83ff02 commit d47d478

9 files changed

Lines changed: 258 additions & 155 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ checkpoint_storage_use_ocdbt: False # For Pathways
9696
checkpoint_storage_use_zarr3: False # For Pathways
9797
use_pathways: True
9898
log_period: 20
99+
convert_checkpoint_if_possible: True
99100

100101
# ====== Debugging ======
101102
debug:

src/maxtext/configs/pyconfig.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,45 @@ def _module_from_path(path: str) -> str | None:
7777
return None
7878

7979

80-
def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]:
80+
def _resolve_or_infer_config(argv: list[str] | None = None, **kwargs) -> tuple[str, list[str]]:
8181
"""Resolves or infers config file path from module."""
82+
if argv is None:
83+
argv = [""]
84+
85+
if kwargs.get("base_config"):
86+
logger.info("Using config : %s", kwargs["base_config"])
87+
return resolve_config_path(kwargs["base_config"]), argv[1:] if len(argv) > 1 else []
88+
8289
if len(argv) >= 2 and argv[1].endswith(".yml"):
8390
return resolve_config_path(argv[1]), argv[2:]
84-
module = _module_from_path(argv[0])
91+
module = _module_from_path(argv[0]) if len(argv) > 0 else None
8592
if module not in _CONFIG_FILE_MAPPING:
86-
raise ValueError(f"No config file provided and no default config found for module '{module}'")
87-
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
88-
logger.warning("No config file provided, using default config mapping: %s", config_path)
89-
return config_path, argv[1:]
93+
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")
94+
logger.warning("No config file provided and no default config found for module '%s', using base.yml", module)
95+
else:
96+
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
97+
logger.warning("No config file provided, using default config mapping: %s", config_path)
98+
remaining_argv = argv[1:] if len(argv) > 1 else []
99+
100+
return config_path, remaining_argv
101+
102+
103+
def _resolve_or_infer_addl_config(**kwargs):
104+
"""Resolves or infers more configs from module."""
105+
inferred_kwargs = {}
106+
# if base_output_directory key is not seen
107+
if not kwargs.get("base_output_directory"):
108+
max_logging.warning("base_output_directory is not provided; Using local directory called maxtext_output")
109+
base_output_directory = os.path.abspath("maxtext_output")
110+
inferred_kwargs["base_output_directory"] = base_output_directory
111+
112+
# if hf_access_token key is not seen
113+
if not kwargs.get("hf_access_token"):
114+
hf_access_token = os.environ.get("HF_TOKEN")
115+
if hf_access_token:
116+
inferred_kwargs["hf_access_token"] = hf_access_token
117+
118+
return inferred_kwargs
90119

91120

92121
def yaml_key_to_env_key(s: str) -> str:
@@ -280,28 +309,35 @@ def get_keys(self) -> dict[str, Any]:
280309
return self._flat_config
281310

282311

283-
def initialize(argv: list[str], **kwargs) -> HyperParameters:
312+
def initialize(argv: list[str] | None = None, **kwargs) -> HyperParameters:
284313
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides."""
285314
pydantic_config = initialize_pydantic(argv, **kwargs)
286315
config = HyperParameters(pydantic_config)
287316
return config
288317

289318

290-
def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
319+
def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfig:
291320
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.
292321
Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
293322
"""
294323
# 1. Load base and inherited configs from file(s)
295-
config_path, cli_args = _resolve_or_infer_config(argv)
324+
config_path, cli_args = _resolve_or_infer_config(argv, **kwargs)
296325
base_yml_config = _load_config(config_path)
297326

298327
# 2. Get overrides from CLI and kwargs
299328
cli_cfg = omegaconf.OmegaConf.from_cli(cli_args)
300329
kwargs_cfg = omegaconf.OmegaConf.create(kwargs)
301330
overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg)
302331

303-
# 3. Handle model-specific config
332+
temp_cfg1 = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)
333+
# 3.1. infer more configs if possible
334+
temp_cfg1 = _resolve_or_infer_addl_config(**temp_cfg1)
335+
# update overrides_cfg with temp_cfg1
336+
overrides_cfg = omegaconf.OmegaConf.merge(overrides_cfg, temp_cfg1)
304337
temp_cfg = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)
338+
339+
# 3.2. Handle model-specific config
340+
305341
model_name = temp_cfg.get("model_name", "default")
306342
# The architecture for -Instruct v/s base models are the same, so for identifying the
307343
# architecture we replace "-Instruct" from the model_name and get the base model name
@@ -419,3 +455,13 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
419455
# Shim for backward compatibility with pyconfig_deprecated_test.py
420456
validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys
421457
__all__ = ["initialize", "initialize_pydantic"]
458+
459+
460+
class _CallablePyconfigModule(sys.modules[__name__].__class__):
461+
"""Allows calling the module directly as mt.pyconfig()."""
462+
463+
def __call__(self, argv: list[str] | None = None, **kwargs) -> HyperParameters:
464+
return initialize(argv, **kwargs)
465+
466+
467+
sys.modules[__name__].__class__ = _CallablePyconfigModule

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1925,6 +1925,11 @@ class DerivedValues(BaseModel):
19251925
None,
19261926
description="The full path to the checkpoint directory, derived from `run_name`.",
19271927
)
1928+
convert_checkpoint_if_possible: bool = Field(
1929+
False,
1930+
description="Whether to convert checkpoint on the fly if not provided via\
1931+
load_parameters_path or base_output_directory",
1932+
)
19281933
metrics_dir: None | str = Field(
19291934
None,
19301935
description="The full path to the metrics directory, derived from `run_name`.",

src/maxtext/inference/vllm_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def main(argv: Sequence[str]) -> None:
241241
config = pyconfig.initialize(argv)
242242

243243
if FLAGS.use_tunix:
244-
maxtext_model, mesh = model_creation_utils.create_nnx_model(config)
244+
maxtext_model, mesh = model_creation_utils.from_pretrained(config)
245245
decode_with_tunix(config, model=maxtext_model, mesh=mesh)
246246
else:
247247
decode_with_vllm(config)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
251251
return
252252

253253
with self.mesh, nn.logical_axis_rules(""):
254-
model, _ = model_creation_utils.create_nnx_model(
254+
model, _ = model_creation_utils.from_pretrained(
255255
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
256256
)
257257
self.model = nnx.data(model)

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
458458
The loaded MaxText model.
459459
"""
460460
max_logging.log(f"Initializing model: {config.model_name}...")
461-
model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
461+
model, _ = model_creation_utils.from_pretrained(config, mesh=mesh)
462462
return model
463463

464464

src/maxtext/trainers/post_train/rl/train_rl.py

Lines changed: 15 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
from functools import wraps
4848
from typing import Sequence
4949

50-
import collections
5150
import grain
5251
import jax
5352
import json
@@ -60,7 +59,6 @@
6059
from absl import logging as absl_logging
6160
from etils import epath
6261
from flax import nnx
63-
from jax.sharding import Mesh
6462
from orbax import checkpoint as ocp
6563
from pprint import pprint
6664
from transformers import AutoTokenizer
@@ -74,35 +72,10 @@
7472

7573
from maxtext.configs import pyconfig
7674
from maxtext.utils.globals import MAXTEXT_CONFIGS_DIR
77-
from maxtext.integration.tunix.tunix_adapter import TunixMaxTextAdapter
7875
from maxtext.trainers.post_train.rl.evaluate_rl import evaluate
7976
from maxtext.trainers.post_train.rl import utils_rl
8077
from maxtext.input_pipeline.instruction_data_processing import load_template_from_file
81-
from maxtext.utils import max_logging, max_utils, maxtext_utils, model_creation_utils
82-
83-
84-
def get_maxtext_model(config, devices=None):
85-
"""
86-
Load MaxText model with Tunix adapter.
87-
# Note: pass the path to your scanned checkpoint for 'load_parameters_path'.
88-
# To create a scanned checkpoint, you can use /maxtext/src/maxtext/checkpoint_conversion/to_maxtext.py and if
89-
# using Pathways, please set `USE_PATHWAYS=1` and use `$((1 - USE_PATHWAYS))` for storage flags:
90-
# export USE_PATHWAYS=1
91-
# python src/maxtext/checkpoint_conversion/to_maxtext.py \
92-
# --model_name="gemma2-2b" \
93-
# --base_output_directory="/path/to/your/output/directory" \
94-
# --scan_layers=True \
95-
# --checkpoint_storage_use_ocdbt=$((1 - USE_PATHWAYS)) \
96-
# --checkpoint_storage_use_zarr3=$((1 - USE_PATHWAYS))
97-
# Please ensure that you pass the full path ending in `/0/items` for load_parameters_path to train_rl.py i.e.,
98-
# load_parameters_path=/path/to/your/output/directory/0/items
99-
"""
100-
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
101-
with mesh:
102-
use_no_op_mappings = "maxtext_config" in config.vllm_additional_config
103-
tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings)
104-
tunix_model.config = None
105-
return tunix_model, mesh
78+
from maxtext.utils import max_logging, max_utils, model_creation_utils
10679

10780

10881
def get_dataset(
@@ -150,86 +123,6 @@ def get_dataset(
150123
return loaded_dataset
151124

152125

153-
def setup_configs_and_devices(argv: list[str]):
154-
"""Setup device allocation and configs for training and inference."""
155-
config = pyconfig.initialize_pydantic(argv)
156-
devices = jax.devices()
157-
if config.num_trainer_slices == -1 and config.num_samplers_slices == -1:
158-
max_logging.log("Running RL on a single slice")
159-
num_vms = len(devices) // config.chips_per_vm
160-
trainer_devices = devices
161-
sampler_devices = devices
162-
if num_vms >= 2 and config.use_pathways:
163-
# Multiple hosts with Pathways - potentially split devices for trainer and sampler
164-
# based on trainer_devices_fraction and sampler_devices_fraction
165-
max_logging.log(f"{num_vms} VMs detected, allocating trainer and sampler devices, and using Pathways.")
166-
num_devices = len(devices)
167-
num_trainer_devices = int(num_devices * config.trainer_devices_fraction)
168-
num_sampler_devices = int(num_devices * config.sampler_devices_fraction)
169-
trainer_devices = devices[:num_trainer_devices]
170-
sampler_devices = devices[num_devices - num_sampler_devices :]
171-
if config.trainer_devices_fraction != 1.0:
172-
max_logging.log(f"Using first {len(trainer_devices)} devices as Trainer devices")
173-
if config.sampler_devices_fraction != 1.0:
174-
max_logging.log(f"Using last {len(sampler_devices)} devices as Sampler devices")
175-
trainer_config = config
176-
sampler_config = config
177-
elif config.num_trainer_slices > 0 and config.num_samplers_slices > 0:
178-
max_logging.log("Running RL with Multislice")
179-
devices_by_slice = collections.defaultdict(list)
180-
for d in devices:
181-
devices_by_slice[d.slice_index].append(d)
182-
slice_indices = sorted(devices_by_slice.keys())
183-
184-
if len(slice_indices) < config.num_trainer_slices + config.num_samplers_slices:
185-
raise ValueError("Not enough slices for trainer and samplers")
186-
187-
trainer_devices = []
188-
for i in range(config.num_trainer_slices):
189-
trainer_devices.extend(devices_by_slice[slice_indices[i]])
190-
191-
sampler_devices = []
192-
for i in range(config.num_trainer_slices, config.num_trainer_slices + config.num_samplers_slices):
193-
sampler_devices.extend(devices_by_slice[slice_indices[i]])
194-
195-
trainer_devices_per_slice = len(trainer_devices) // config.num_trainer_slices
196-
trainer_fsdp = trainer_devices_per_slice
197-
tp = config.ici_tensor_parallelism
198-
if tp > 1:
199-
if trainer_devices_per_slice % tp != 0:
200-
raise ValueError(
201-
f"trainer_devices_per_slice ({trainer_devices_per_slice}) must be divisible by tensor parallelism ({tp})"
202-
)
203-
if config.ici_fsdp_parallelism != -1 and config.ici_fsdp_parallelism * tp != trainer_devices_per_slice:
204-
raise ValueError(
205-
f"ici_fsdp_parallelism ({config.ici_fsdp_parallelism}) * ici_tensor_parallelism ({tp}) must equal "
206-
f"devices_per_slice ({trainer_devices_per_slice})"
207-
)
208-
trainer_fsdp = trainer_devices_per_slice // tp
209-
210-
trainer_update = {
211-
"num_slices": config.num_trainer_slices,
212-
"ici_fsdp_parallelism": trainer_fsdp,
213-
"ici_tensor_parallelism": tp,
214-
"dcn_data_parallelism": config.num_trainer_slices,
215-
}
216-
217-
sampler_update = {
218-
"num_slices": config.num_samplers_slices,
219-
"ici_fsdp_parallelism": len(sampler_devices) // config.num_samplers_slices,
220-
"ici_tensor_parallelism": -1,
221-
"dcn_data_parallelism": config.num_samplers_slices,
222-
}
223-
224-
trainer_config = pyconfig.initialize_pydantic(argv, **trainer_update)
225-
sampler_config = pyconfig.initialize_pydantic(argv, **sampler_update)
226-
227-
else:
228-
raise ValueError("num_trainer_slices and num_samplers_slices should be both -1 or positive")
229-
230-
return trainer_config, sampler_config, trainer_devices, sampler_devices
231-
232-
233126
def get_rollout_kwargs_for_parallelism(sampler_config, num_sampler_devices):
234127
"""Get rollout kwargs for vLLM rollout when using data parallelism."""
235128
dp = sampler_config.rollout_data_parallelism
@@ -411,28 +304,6 @@ def _use_raw_prompt(x):
411304
return train_dataset, test_dataset
412305

413306

414-
def create_models_and_meshes(trainer_config, sampler_config, trainer_devices, sampler_devices):
415-
"""Create reference and actor models and their respective meshes."""
416-
max_logging.log("Creating reference model and also meshes for reference and rollout")
417-
reference_model, reference_mesh = get_maxtext_model(trainer_config, trainer_devices)
418-
devices_array = maxtext_utils.create_device_mesh(sampler_config, sampler_devices)
419-
rollout_mesh = Mesh(devices_array, sampler_config.mesh_axes)
420-
421-
if trainer_config.load_checkpoint_only_once:
422-
max_logging.log("Creating policy model by copying reference model instead of restoring from checkpoint again.")
423-
with reference_mesh:
424-
actor_base_model = nnx.clone(reference_model.base)
425-
use_no_op_mappings = "maxtext_config" in trainer_config.vllm_additional_config
426-
actor_model = TunixMaxTextAdapter(base_model=actor_base_model, use_no_op_mappings=use_no_op_mappings)
427-
actor_model.config = None
428-
actor_mesh = reference_mesh
429-
else:
430-
max_logging.log("Creating policy model with same config as reference model on trainer mesh")
431-
actor_model, actor_mesh = get_maxtext_model(trainer_config, trainer_devices)
432-
433-
return reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh
434-
435-
436307
def create_rl_components(
437308
trainer_config,
438309
sampler_config,
@@ -652,7 +523,7 @@ def _reward_fn(**kwargs):
652523
return rl_cluster, rl_trainer, optimizer
653524

654525

655-
def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
526+
def rl_train(argv: Sequence[str], kwargs: dict):
656527
"""
657528
Run RL training with the provided configuration.
658529
@@ -662,13 +533,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
662533
trainer_devices: JAX devices for the trainer.
663534
sampler_devices: JAX devices for the sampler.
664535
"""
536+
trainer_config, sampler_config, trainer_devices, sampler_devices = model_creation_utils.setup_configs_and_devices(
537+
argv, kwargs
538+
)
539+
540+
reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = model_creation_utils.create_models_and_meshes(
541+
trainer_config, sampler_config, trainer_devices, sampler_devices
542+
)
543+
665544
if not trainer_config.debug.rl:
666545
# Apply filter to suppress noisy logs
667546
noise_filter = max_logging.NoisyLogFilter()
668547
logging.getLogger().addFilter(noise_filter)
669548
absl_logging.get_absl_logger().addFilter(noise_filter)
549+
os.environ["VLLM_LOGGING_LEVEL"] = "ERROR"
670550

671-
max_logging.log("Starting RL Training")
672551
if not epath.Path(trainer_config.tensorboard_dir).exists():
673552
epath.Path(trainer_config.tensorboard_dir).mkdir(parents=True, exist_ok=True)
674553

@@ -692,17 +571,14 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
692571
break
693572
pprint(ele)
694573

695-
reference_model, reference_mesh, actor_model, actor_mesh, rollout_mesh = create_models_and_meshes(
696-
trainer_config, sampler_config, trainer_devices, sampler_devices
697-
)
698-
699574
if trainer_config.debug.rl:
700575
max_logging.log("Reference Model initialized successfully")
701576
nnx.display(reference_model)
702577
max_logging.log(f"Reference mesh shape: {reference_mesh.shape}")
703578
max_logging.log("Policy Model initialized successfully")
704579
nnx.display(actor_model)
705580
max_logging.log(f"Policy mesh shape: {actor_mesh.shape}")
581+
max_logging.log(f"Rollout_mesh shape: {rollout_mesh.shape}")
706582

707583
rl_cluster, rl_trainer, _ = create_rl_components(
708584
trainer_config,
@@ -759,18 +635,18 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
759635
max_logging.warning(f"Post RL Training: {corr=}, {total=}, {accuracy=}%, {partial_accuracy=}%," f" {format_accuracy=}%")
760636

761637

762-
def main(argv: Sequence[str]) -> None:
638+
def main(argv: Sequence[str], kwargs: dict = None) -> None:
763639
"""Main function to run RL training.
764640
765641
Args:
766642
argv: Command-line arguments.
767643
"""
644+
kwargs = kwargs or {}
768645
pathwaysutils.initialize()
769646
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
770647

771648
max_utils.print_system_information()
772-
trainer_config, sampler_config, trainer_devices, sampler_devices = setup_configs_and_devices(argv)
773-
rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices)
649+
rl_train(argv, kwargs)
774650

775651

776652
if __name__ == "__main__":

0 commit comments

Comments
 (0)