Skip to content

Commit ff01066

Browse files
committed
add from_pretrained as simple API
1 parent 4909a0a commit ff01066

12 files changed

Lines changed: 296 additions & 198 deletions

File tree

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ checkpoint_storage_use_ocdbt: False # For Pathways
9696
checkpoint_storage_use_zarr3: False # For Pathways
9797
use_pathways: True
9898
log_period: 20
99+
convert_checkpoint_if_possible: True
99100

100101
# ====== Debugging ======
101102
debug:

src/maxtext/configs/pyconfig.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,48 @@ def _module_from_path(path: str) -> str | None:
7777
return None
7878

7979

80-
def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]:
80+
def _resolve_or_infer_config(argv: list[str] | None = None, **kwargs) -> tuple[str, list[str]]:
8181
"""Resolves or infers config file path from module."""
82+
if argv is None:
83+
argv = [""]
84+
85+
if kwargs.get("base_config"):
86+
logger.info("Using config : %s", kwargs["base_config"])
87+
return resolve_config_path(kwargs["base_config"]), argv[1:] if len(argv) > 1 else []
88+
89+
# if passing at least two arguments via list (no kwargs), then we have to specify
90+
# first one as either "" or python script like train_rl.py or train.py
91+
# the second argument is the yaml file
8292
if len(argv) >= 2 and argv[1].endswith(".yml"):
8393
return resolve_config_path(argv[1]), argv[2:]
84-
module = _module_from_path(argv[0])
94+
module = _module_from_path(argv[0]) if len(argv) > 0 else None
8595
if module not in _CONFIG_FILE_MAPPING:
86-
raise ValueError(f"No config file provided and no default config found for module '{module}'")
87-
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
88-
logger.warning("No config file provided, using default config mapping: %s", config_path)
89-
return config_path, argv[1:]
96+
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")
97+
logger.warning("No config file provided and no default config found for module '%s', using base.yml", module)
98+
else:
99+
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
100+
logger.warning("No config file provided, using default config mapping: %s", config_path)
101+
remaining_argv = argv[1:] if len(argv) > 1 else []
102+
103+
return config_path, remaining_argv
104+
105+
106+
def _resolve_or_infer_addl_config(**kwargs):
107+
"""Resolves or infers more configs from module."""
108+
inferred_kwargs = {}
109+
# if base_output_directory key is not seen
110+
if not kwargs.get("base_output_directory"):
111+
max_logging.warning("base_output_directory is not provided; Using local directory called maxtext_output")
112+
base_output_directory = os.path.abspath("maxtext_output")
113+
inferred_kwargs["base_output_directory"] = base_output_directory
114+
115+
# if hf_access_token key is not seen
116+
if not kwargs.get("hf_access_token"):
117+
hf_access_token = os.environ.get("HF_TOKEN")
118+
if hf_access_token:
119+
inferred_kwargs["hf_access_token"] = hf_access_token
120+
121+
return inferred_kwargs
90122

91123

92124
def yaml_key_to_env_key(s: str) -> str:
@@ -289,28 +321,35 @@ def get_keys(self) -> dict[str, Any]:
289321
return self._flat_config
290322

291323

292-
def initialize(argv: list[str], **kwargs) -> HyperParameters:
324+
def initialize(argv: list[str] | None = None, **kwargs) -> HyperParameters:
293325
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides."""
294326
pydantic_config = initialize_pydantic(argv, **kwargs)
295327
config = HyperParameters(pydantic_config)
296328
return config
297329

298330

299-
def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
331+
def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfig:
300332
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.
301333
Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
302334
"""
303335
# 1. Load base and inherited configs from file(s)
304-
config_path, cli_args = _resolve_or_infer_config(argv)
336+
config_path, cli_args = _resolve_or_infer_config(argv, **kwargs)
305337
base_yml_config = _load_config(config_path)
306338

307339
# 2. Get overrides from CLI and kwargs
308340
cli_cfg = omegaconf.OmegaConf.from_cli(cli_args)
309341
kwargs_cfg = omegaconf.OmegaConf.create(kwargs)
310342
overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg)
311343

312-
# 3. Handle model-specific config
344+
temp_cfg1 = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)
345+
# 3.1. infer more configs if possible
346+
temp_cfg1 = _resolve_or_infer_addl_config(**temp_cfg1)
347+
# update overrides_cfg with temp_cfg1
348+
overrides_cfg = omegaconf.OmegaConf.merge(overrides_cfg, temp_cfg1)
313349
temp_cfg = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)
350+
351+
# 3.2. Handle model-specific config
352+
314353
model_name = temp_cfg.get("model_name", "default")
315354
# The architecture for -Instruct v/s base models are the same, so for identifying the
316355
# architecture we replace "-Instruct" from the model_name and get the base model name
@@ -437,3 +476,13 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
437476
# Shim for backward compatibility with pyconfig_deprecated_test.py
438477
validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys
439478
__all__ = ["initialize", "initialize_pydantic"]
479+
480+
481+
class _CallablePyconfigModule(sys.modules[__name__].__class__):
482+
"""Allows calling the module directly as mt.pyconfig()."""
483+
484+
def __call__(self, argv: list[str] | None = None, **kwargs) -> HyperParameters:
485+
return initialize(argv, **kwargs)
486+
487+
488+
sys.modules[__name__].__class__ = _CallablePyconfigModule

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,6 +1931,11 @@ class DerivedValues(BaseModel):
19311931
None,
19321932
description="The full path to the checkpoint directory, derived from `run_name`.",
19331933
)
1934+
convert_checkpoint_if_possible: bool = Field(
1935+
False,
1936+
description="Whether to convert checkpoint on the fly if not provided via\
1937+
load_parameters_path or base_output_directory",
1938+
)
19341939
metrics_dir: None | str = Field(
19351940
None,
19361941
description="The full path to the metrics directory, derived from `run_name`.",

src/maxtext/examples/rl_llama3_demo.ipynb

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -135,27 +135,7 @@
135135
"execution_count": null,
136136
"metadata": {},
137137
"outputs": [],
138-
"source": [
139-
"import datetime\n",
140-
"import os\n",
141-
"import sys\n",
142-
"import subprocess\n",
143-
"from pathlib import Path\n",
144-
"from huggingface_hub import login\n",
145-
"from etils import epath\n",
146-
"import jax\n",
147-
"\n",
148-
"from maxtext.trainers.post_train.rl.train_rl import rl_train, setup_configs_and_devices\n",
149-
"from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n",
150-
"\n",
151-
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n",
152-
"os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n",
153-
"# Suppress vLLM logging with a severity level below ERROR\n",
154-
"os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n",
155-
"\n",
156-
"\n",
157-
"print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
158-
]
138+
"source": "import datetime\nimport os\nimport sys\nimport subprocess\nfrom pathlib import Path\nfrom huggingface_hub import login\nfrom etils import epath\nimport jax\n\nfrom maxtext.trainers.post_train.rl.train_rl import rl_train\nfrom maxtext.utils.model_creation_utils import setup_configs_and_devices\nfrom maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n\nos.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\nos.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n# Suppress vLLM logging with a severity level below ERROR\nos.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n\n\nprint(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
159139
},
160140
{
161141
"cell_type": "code",
@@ -386,4 +366,4 @@
386366
},
387367
"nbformat": 4,
388368
"nbformat_minor": 4
389-
}
369+
}

src/maxtext/inference/vllm_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def main(argv: Sequence[str]) -> None:
241241
config = pyconfig.initialize(argv)
242242

243243
if FLAGS.use_tunix:
244-
maxtext_model, mesh = model_creation_utils.create_nnx_model(config)
244+
maxtext_model, mesh = model_creation_utils.from_pretrained(config)
245245
decode_with_tunix(config, model=maxtext_model, mesh=mesh)
246246
else:
247247
decode_with_vllm(config)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
251251
return
252252

253253
with self.mesh, nn.logical_axis_rules(""):
254-
model, _ = model_creation_utils.create_nnx_model(
254+
model = model_creation_utils.from_pretrained(
255255
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
256256
)
257257
self.model = nnx.data(model)

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
458458
The loaded MaxText model.
459459
"""
460460
max_logging.log(f"Initializing model: {config.model_name}...")
461-
model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
461+
model = model_creation_utils.from_pretrained(config, mesh=mesh)
462462
return model
463463

464464

0 commit comments

Comments
 (0)