Skip to content

Commit 5182e3b

Browse files
Merge pull request #3450 from AI-Hypercomputer:anisha-from-pretrained3
PiperOrigin-RevId: 902844587
2 parents 9608068 + ef03866 commit 5182e3b

16 files changed

Lines changed: 399 additions & 204 deletions

File tree

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
package(
16+
default_applicable_licenses = ["//third_party/py/maxtext:license"],
17+
default_visibility = ["//third_party/py/maxtext:__subpackages__"],
18+
)
19+
20+
filegroup(
21+
name = "param_mapping_file",
22+
srcs = ["param_mapping.py"],
23+
visibility = ["//third_party/py/maxtext:__pkg__"],
24+
)
25+
26+
filegroup(
27+
name = "hf_model_configs_file",
28+
srcs = ["hf_model_configs.py"],
29+
visibility = ["//third_party/py/maxtext:__pkg__"],
30+
)

src/maxtext/checkpoint_conversion/utils/hf_model_configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import transformers
2121

2222
if transformers.__version__ >= "5.0.0":
23-
from transformers.configuration_utils import PreTrainedConfig as PTConfig
23+
from transformers.configuration_utils import PreTrainedConfig as PTConfig # pytype: disable=import-error
2424
else:
2525
from transformers.configuration_utils import PretrainedConfig as PTConfig
2626

@@ -151,8 +151,8 @@
151151
gemma4_31b_config = transformers.Gemma4Config(**gemma4_31b_dict)
152152
except AttributeError:
153153
# Graceful fallback to raw dict-based PTConfig if Gemma 4 natively is missing
154-
gemma4_26b_config = PTConfig(**gemma4_26b_dict)
155-
gemma4_31b_config = PTConfig(**gemma4_31b_dict)
154+
gemma4_26b_config = PTConfig(**gemma4_26b_dict) # pytype: disable=wrong-arg-types
155+
gemma4_31b_config = PTConfig(**gemma4_31b_dict) # pytype: disable=wrong-arg-types
156156

157157

158158
gemma3_4b_config = transformers.Gemma3Config(

src/maxtext/configs/post_train/rl.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ checkpoint_storage_use_ocdbt: False # For Pathways
9696
checkpoint_storage_use_zarr3: False # For Pathways
9797
use_pathways: True
9898
log_period: 20
99+
convert_checkpoint_if_possible: True
99100

100101
# ====== Debugging ======
101102
debug:

src/maxtext/configs/pyconfig.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,48 @@ def _module_from_path(path: str) -> str | None:
7777
return None
7878

7979

80-
def _resolve_or_infer_config(argv: list[str]) -> tuple[str, list[str]]:
80+
def _resolve_or_infer_config(argv: list[str] | None = None, **kwargs) -> tuple[str, list[str]]:
8181
"""Resolves or infers config file path from module."""
82+
if argv is None:
83+
argv = [""]
84+
85+
if kwargs.get("base_config"):
86+
logger.info("Using config : %s", kwargs["base_config"])
87+
return resolve_config_path(kwargs["base_config"]), argv[1:]
88+
89+
# if passing at least two arguments via list (no kwargs), then we have to specify
90+
# first one as either "" or python script like train_rl.py or train.py
91+
# the second argument is the yaml file
8292
if len(argv) >= 2 and argv[1].endswith(".yml"):
8393
return resolve_config_path(argv[1]), argv[2:]
84-
module = _module_from_path(argv[0])
94+
module = _module_from_path(argv[0]) if len(argv) > 0 else None
8595
if module not in _CONFIG_FILE_MAPPING:
86-
raise ValueError(f"No config file provided and no default config found for module '{module}'")
87-
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
88-
logger.warning("No config file provided, using default config mapping: %s", config_path)
89-
return config_path, argv[1:]
96+
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, "base.yml")
97+
logger.warning("No config file provided and no default config found for module '%s', using base.yml", module)
98+
else:
99+
config_path = os.path.join(MAXTEXT_CONFIGS_DIR, _CONFIG_FILE_MAPPING[module])
100+
logger.warning("No config file provided, using default config mapping: %s", config_path)
101+
remaining_argv = argv[1:]
102+
103+
return config_path, remaining_argv
104+
105+
106+
def _resolve_or_infer_addl_config(**kwargs):
107+
"""Resolves or infers more configs from module."""
108+
inferred_kwargs = {}
109+
# if base_output_directory key is not seen
110+
if not kwargs.get("base_output_directory"):
111+
max_logging.warning("base_output_directory is not provided; Using local directory called maxtext_output")
112+
base_output_directory = os.path.abspath("maxtext_output")
113+
inferred_kwargs["base_output_directory"] = base_output_directory
114+
115+
# if hf_access_token key is not seen
116+
if not kwargs.get("hf_access_token"):
117+
hf_access_token = os.environ.get("HF_TOKEN")
118+
if hf_access_token:
119+
inferred_kwargs["hf_access_token"] = hf_access_token
120+
121+
return inferred_kwargs
90122

91123

92124
def yaml_key_to_env_key(s: str) -> str:
@@ -289,28 +321,35 @@ def get_keys(self) -> dict[str, Any]:
289321
return self._flat_config
290322

291323

292-
def initialize(argv: list[str], **kwargs) -> HyperParameters:
324+
def initialize(argv: list[str] | None = None, **kwargs) -> HyperParameters:
293325
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides."""
294326
pydantic_config = initialize_pydantic(argv, **kwargs)
295327
config = HyperParameters(pydantic_config)
296328
return config
297329

298330

299-
def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
331+
def initialize_pydantic(argv: list[str] | None = None, **kwargs) -> MaxTextConfig:
300332
"""Initializes the configuration by loading YAML files, and applying CLI, env, and kwarg overrides.
301333
Returns pydantic MaxTextConfig class whereas `initialize` returns the og `HyperParameters`
302334
"""
303335
# 1. Load base and inherited configs from file(s)
304-
config_path, cli_args = _resolve_or_infer_config(argv)
336+
config_path, cli_args = _resolve_or_infer_config(argv, **kwargs)
305337
base_yml_config = _load_config(config_path)
306338

307339
# 2. Get overrides from CLI and kwargs
308340
cli_cfg = omegaconf.OmegaConf.from_cli(cli_args)
309341
kwargs_cfg = omegaconf.OmegaConf.create(kwargs)
310342
overrides_cfg = omegaconf.OmegaConf.merge(cli_cfg, kwargs_cfg)
311343

312-
# 3. Handle model-specific config
344+
temp_cfg1 = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)
345+
# 3.1. infer more configs if possible
346+
temp_cfg1 = _resolve_or_infer_addl_config(**temp_cfg1)
347+
# update overrides_cfg with temp_cfg1
348+
overrides_cfg = omegaconf.OmegaConf.merge(overrides_cfg, temp_cfg1)
313349
temp_cfg = omegaconf.OmegaConf.merge(base_yml_config, overrides_cfg)
350+
351+
# 3.2. Handle model-specific config
352+
314353
model_name = temp_cfg.get("model_name", "default")
315354
# The architecture for -Instruct v/s base models are the same, so for identifying the
316355
# architecture we replace "-Instruct" from the model_name and get the base model name
@@ -437,3 +476,13 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:
437476
# Shim for backward compatibility with pyconfig_deprecated_test.py
438477
validate_and_update_keys = pyconfig_deprecated.validate_and_update_keys
439478
__all__ = ["initialize", "initialize_pydantic"]
479+
480+
481+
class _CallablePyconfigModule(sys.modules[__name__].__class__):
482+
"""Allows calling the module directly as mt.pyconfig()."""
483+
484+
def __call__(self, argv: list[str] | None = None, **kwargs) -> HyperParameters:
485+
return initialize(argv, **kwargs)
486+
487+
488+
sys.modules[__name__].__class__ = _CallablePyconfigModule

src/maxtext/configs/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1950,6 +1950,11 @@ class DerivedValues(BaseModel):
19501950
None,
19511951
description="The full path to the checkpoint directory, derived from `run_name`.",
19521952
)
1953+
convert_checkpoint_if_possible: bool = Field(
1954+
False,
1955+
description="Whether to convert checkpoint on the fly if not provided via\
1956+
load_parameters_path or base_output_directory",
1957+
)
19531958
metrics_dir: None | str = Field(
19541959
None,
19551960
description="The full path to the metrics directory, derived from `run_name`.",

src/maxtext/examples/rl_llama3_demo.ipynb

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -135,27 +135,7 @@
135135
"execution_count": null,
136136
"metadata": {},
137137
"outputs": [],
138-
"source": [
139-
"import datetime\n",
140-
"import os\n",
141-
"import sys\n",
142-
"import subprocess\n",
143-
"from pathlib import Path\n",
144-
"from huggingface_hub import login\n",
145-
"from etils import epath\n",
146-
"import jax\n",
147-
"\n",
148-
"from maxtext.trainers.post_train.rl.train_rl import rl_train, setup_configs_and_devices\n",
149-
"from maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n",
150-
"\n",
151-
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\n",
152-
"os.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n",
153-
"# Suppress vLLM logging with a severity level below ERROR\n",
154-
"os.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n",
155-
"\n",
156-
"\n",
157-
"print(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
158-
]
138+
"source": "import datetime\nimport os\nimport sys\nimport subprocess\nfrom pathlib import Path\nfrom huggingface_hub import login\nfrom etils import epath\nimport jax\n\nfrom maxtext.trainers.post_train.rl.train_rl import rl_train\nfrom maxtext.utils.model_creation_utils import setup_configs_and_devices\nfrom maxtext.utils.globals import MAXTEXT_REPO_ROOT, MAXTEXT_PKG_DIR\n\nos.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"0\"\nos.environ[\"SKIP_JAX_PRECOMPILE\"] = \"1\" # Faster startup for vLLM\n# Suppress vLLM logging with a severity level below ERROR\nos.environ[\"VLLM_LOGGING_LEVEL\"] = \"ERROR\"\n\n\nprint(f\"MaxText installation path: {MAXTEXT_PKG_DIR}\")"
159139
},
160140
{
161141
"cell_type": "code",
@@ -386,4 +366,4 @@
386366
},
387367
"nbformat": 4,
388368
"nbformat_minor": 4
389-
}
369+
}

src/maxtext/inference/vllm_decode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def main(argv: Sequence[str]) -> None:
241241
config = pyconfig.initialize(argv)
242242

243243
if FLAGS.use_tunix:
244-
maxtext_model, mesh = model_creation_utils.create_nnx_model(config)
244+
maxtext_model, mesh = model_creation_utils.from_pretrained(config)
245245
decode_with_tunix(config, model=maxtext_model, mesh=mesh)
246246
else:
247247
decode_with_vllm(config)
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
load(
16+
"//devtools/python/blaze:pytype.bzl",
17+
"pytype_strict_library",
18+
)
19+
20+
package(
21+
default_applicable_licenses = ["//third_party/py/maxtext:license"],
22+
default_visibility = ["//third_party/py/maxtext:__subpackages__"],
23+
)
24+
25+
pytype_strict_library(
26+
name = "weight_mapping",
27+
srcs = [
28+
"weight_mapping/__init__.py",
29+
"weight_mapping/deepseek3.py",
30+
"weight_mapping/gpt_oss.py",
31+
"weight_mapping/llama3.py",
32+
"weight_mapping/qwen2.py",
33+
"weight_mapping/qwen3.py",
34+
],
35+
deps = [
36+
"//third_party/py/jax",
37+
"//third_party/py/numpy",
38+
],
39+
)
40+
41+
pytype_strict_library(
42+
name = "utils",
43+
srcs = ["utils.py"],
44+
deps = [
45+
":weight_mapping",
46+
"//third_party/py/maxtext:checkpoint_conversion_utils_param_mapping",
47+
],
48+
)
49+
50+
pytype_strict_library(
51+
name = "tunix_adapter",
52+
srcs = ["tunix_adapter.py"],
53+
deps = [
54+
":utils",
55+
"//third_party/py/flax/nnx",
56+
"//third_party/py/jax",
57+
"//third_party/py/maxtext:checkpoint_conversion_utils_hf_model_configs",
58+
"//third_party/py/maxtext:layers",
59+
],
60+
)

src/maxtext/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def load_weights(self, rng_key: jax.Array) -> None:
251251
return
252252

253253
with self.mesh, nn.logical_axis_rules(""):
254-
model, _ = model_creation_utils.create_nnx_model(
254+
model = model_creation_utils.from_pretrained(
255255
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
256256
)
257257
self.model = nnx.data(model)

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def get_maxtext_model(config: pyconfig.HyperParameters, mesh: jax.sharding.Mesh)
463463
The loaded MaxText model.
464464
"""
465465
max_logging.log(f"Initializing model: {config.model_name}...")
466-
model, _ = model_creation_utils.create_nnx_model(config, mesh=mesh)
466+
model = model_creation_utils.from_pretrained(config, mesh=mesh)
467467
return model
468468

469469

0 commit comments

Comments
 (0)