Skip to content

Commit 6d9b928

Browse files
Add support for On-The-Fly Dynamic SafeTensors loading.
PiperOrigin-RevId: 916957035
1 parent 0747df9 commit 6d9b928

8 files changed

Lines changed: 715 additions & 15 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@
6767
from maxtext.common.common_types import MODEL_MODE_TRAIN
6868
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
6969
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
70-
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
70+
from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns
71+
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
7172
from maxtext.inference.inference_utils import str2bool
7273
from maxtext.layers import quantizations
7374
from maxtext.models import models

src/maxtext/checkpoint_conversion/utils/load_dynamic.py

Lines changed: 390 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tensor handling utility functions for checkpoint conversion."""
16+
17+
from functools import partial
18+
from typing import Any, Callable, List
19+
import jax.numpy as np
20+
21+
22+
def apply_hook_fns(weight, target_shape, hook_fns):
23+
"""Apply hook functions, essential for to_maxtext and to_huggingface"""
24+
# If hook is unsepecified, use identity
25+
if hook_fns is None:
26+
return weight
27+
if not isinstance(hook_fns, list):
28+
hook_fns = [hook_fns]
29+
# Apply a list of hooks, be careful of order
30+
for hook_fn in hook_fns:
31+
weight = hook_fn(weight, target_shape)
32+
return weight
33+
34+
35+
def _build_multi_axis_stacked_tensor(
36+
hf_source_keys: List[List[str]],
37+
tensor_getter_fn: Callable[[str], np.ndarray],
38+
hook_fns: Any,
39+
target_shape: tuple,
40+
config,
41+
) -> np.ndarray:
42+
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers).
43+
44+
This function handles the complex case for scanned MoE layers, producing a tensor
45+
with the shape (num_experts, num_layers, ...).
46+
47+
Args:
48+
hf_source_keys: A nested (2D) list of Hugging Face parameter names.
49+
Outer list iterates experts, inner list iterates layers.
50+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
51+
hook_fns: The hook function(s) to apply to each individual weight.
52+
target_shape: The final shape of the target MaxText tensor.
53+
config: The MaxText pyconfig object.
54+
55+
Returns:
56+
The final, assembled NumPy array for the MaxText parameter.
57+
"""
58+
all_expert_tensors = []
59+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
60+
# For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:]
61+
mt_slice_shape = target_shape[2:]
62+
63+
# Outer loop iterates through experts
64+
for layer_keys_for_expert in hf_source_keys:
65+
layer_tensors_for_expert = []
66+
# Inner loop iterates through layers for the current expert
67+
for hf_key_single in layer_keys_for_expert:
68+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
69+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
70+
layer_tensors_for_expert.append(processed_hf_tensor)
71+
all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0))
72+
return np.stack(all_expert_tensors, axis=0)
73+
74+
75+
def _build_single_axis_stacked_tensor(
76+
hf_source_keys: List[str],
77+
tensor_getter_fn: Callable[[str], np.ndarray],
78+
hook_fns: Any,
79+
target_shape: tuple,
80+
config,
81+
) -> np.ndarray:
82+
"""Builds a MaxText tensor by stacking HF weights along a single axis.
83+
84+
This function handles both standard scanned layers (e.g., attention) and
85+
unscanned MoE layers (which are stacked along the expert axis).
86+
87+
Args:
88+
hf_source_keys: A 1D list of Hugging Face parameter names.
89+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
90+
hook_fns: The hook function(s) to apply to each individual weight.
91+
target_shape: The final shape of the target MaxText tensor.
92+
config: The MaxText pyconfig object.
93+
94+
Returns:
95+
The final, assembled NumPy array for the MaxText parameter.
96+
"""
97+
tensors_to_stack = []
98+
99+
if config.scan_layers:
100+
# If it's a standard scanned layer, we use the configured param_scan_axis.
101+
axis_to_stack = config.param_scan_axis
102+
else:
103+
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
104+
axis_to_stack = 0
105+
106+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
107+
# We calculate it by removing the stacking dimension from the final target shape.
108+
mt_slice_shape_list = list(target_shape)
109+
del mt_slice_shape_list[axis_to_stack]
110+
mt_slice_shape = tuple(mt_slice_shape_list)
111+
112+
for hf_key_single in hf_source_keys:
113+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
114+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
115+
tensors_to_stack.append(processed_hf_tensor)
116+
117+
# Stack all processed tensors along the determined axis.
118+
return np.stack(tensors_to_stack, axis=axis_to_stack)
119+
120+
121+
def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config):
122+
"""Determine the loading function for HF keys.
123+
HF keys can take four forms:
124+
Case 1: Unscanned (single string)
125+
Case 2: Scanned (list of strings)
126+
Case 3: Unscanned with expert stacking (list of strings)
127+
Case 4: Scanned with expert stacking (nested list of strings)
128+
"""
129+
load_fn = None
130+
if not isinstance(hf_source_keys_or_key, list):
131+
# Case 1: Single hf key (str)
132+
def _loader(getter, key, shape, hook):
133+
return apply_hook_fns(getter(key), shape, hook)
134+
135+
load_fn = partial(
136+
_loader,
137+
tensor_getter,
138+
hf_source_keys_or_key,
139+
mt_target_shape_or_shapes,
140+
hook_fn,
141+
)
142+
# Stacked mapping
143+
elif not isinstance(hf_source_keys_or_key[0], list):
144+
# Case 2 or 3: Single-Axis Stacked hf keys (un-nested list)
145+
load_fn = partial(
146+
_build_single_axis_stacked_tensor,
147+
hf_source_keys_or_key,
148+
tensor_getter,
149+
hook_fn,
150+
mt_target_shape_or_shapes,
151+
config,
152+
)
153+
else:
154+
# isinstance(hf_source_keys_or_key[0], list)
155+
# Case 4: Multi-Axis Stacked hf keys (nested list)
156+
load_fn = partial(
157+
_build_multi_axis_stacked_tensor,
158+
hf_source_keys_or_key,
159+
tensor_getter,
160+
hook_fn,
161+
mt_target_shape_or_shapes,
162+
config,
163+
)
164+
return load_fn

src/maxtext/checkpoint_conversion/utils/utils.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import time
2424
import json
2525
from concurrent.futures import ThreadPoolExecutor
26-
from typing import Any
26+
from functools import partial
27+
from typing import Any, Callable, List
2728
from tqdm import tqdm
2829
import resource
2930
import numpy as np
@@ -1188,3 +1189,135 @@ def save_weights_to_checkpoint(
11881189
checkpoint_manager.wait_until_finished()
11891190

11901191
max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min")
1192+
1193+
1194+
def _build_multi_axis_stacked_tensor(
1195+
hf_source_keys: List[List[str]],
1196+
tensor_getter_fn: Callable[[str], np.ndarray],
1197+
hook_fns: Any,
1198+
target_shape: tuple,
1199+
config,
1200+
) -> np.ndarray:
1201+
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers).
1202+
1203+
This function handles the complex case for scanned MoE layers, producing a tensor
1204+
with the shape (num_experts, num_layers, ...).
1205+
1206+
Args:
1207+
hf_source_keys: A nested (2D) list of Hugging Face parameter names.
1208+
Outer list iterates experts, inner list iterates layers.
1209+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
1210+
hook_fns: The hook function(s) to apply to each individual weight.
1211+
target_shape: The final shape of the target MaxText tensor.
1212+
config: The MaxText pyconfig object.
1213+
1214+
Returns:
1215+
The final, assembled NumPy array for the MaxText parameter.
1216+
"""
1217+
all_expert_tensors = []
1218+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
1219+
# For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:]
1220+
mt_slice_shape = target_shape[2:]
1221+
1222+
# Outer loop iterates through experts
1223+
for layer_keys_for_expert in hf_source_keys:
1224+
layer_tensors_for_expert = []
1225+
# Inner loop iterates through layers for the current expert
1226+
for hf_key_single in layer_keys_for_expert:
1227+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
1228+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
1229+
layer_tensors_for_expert.append(processed_hf_tensor)
1230+
all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0))
1231+
return np.stack(all_expert_tensors, axis=0)
1232+
1233+
1234+
def _build_single_axis_stacked_tensor(
1235+
hf_source_keys: List[str],
1236+
tensor_getter_fn: Callable[[str], np.ndarray],
1237+
hook_fns: Any,
1238+
target_shape: tuple,
1239+
config,
1240+
) -> np.ndarray:
1241+
"""Builds a MaxText tensor by stacking HF weights along a single axis.
1242+
1243+
This function handles both standard scanned layers (e.g., attention) and
1244+
unscanned MoE layers (which are stacked along the expert axis).
1245+
1246+
Args:
1247+
hf_source_keys: A 1D list of Hugging Face parameter names.
1248+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
1249+
hook_fns: The hook function(s) to apply to each individual weight.
1250+
target_shape: The final shape of the target MaxText tensor.
1251+
config: The MaxText pyconfig object.
1252+
1253+
Returns:
1254+
The final, assembled NumPy array for the MaxText parameter.
1255+
"""
1256+
tensors_to_stack = []
1257+
1258+
if config.scan_layers:
1259+
# If it's a standard scanned layer, we use the configured param_scan_axis.
1260+
axis_to_stack = config.param_scan_axis
1261+
else:
1262+
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
1263+
axis_to_stack = 0
1264+
1265+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
1266+
# We calculate it by removing the stacking dimension from the final target shape.
1267+
mt_slice_shape_list = list(target_shape)
1268+
del mt_slice_shape_list[axis_to_stack]
1269+
mt_slice_shape = tuple(mt_slice_shape_list)
1270+
1271+
for hf_key_single in hf_source_keys:
1272+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
1273+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
1274+
tensors_to_stack.append(processed_hf_tensor)
1275+
1276+
# Stack all processed tensors along the determined axis.
1277+
return np.stack(tensors_to_stack, axis=axis_to_stack)
1278+
1279+
1280+
def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config):
1281+
"""Determine the loading function for HF keys.
1282+
HF keys can take four forms:
1283+
Case 1: Unscanned (single string)
1284+
Case 2: Scanned (list of strings)
1285+
Case 3: Unscanned with expert stacking (list of strings)
1286+
Case 4: Scanned with expert stacking (nested list of strings)
1287+
"""
1288+
load_fn = None
1289+
if not isinstance(hf_source_keys_or_key, list):
1290+
# Case 1: Single hf key (str)
1291+
def _loader(getter, key, shape, hook):
1292+
return apply_hook_fns(getter(key), shape, hook)
1293+
1294+
load_fn = partial(
1295+
_loader,
1296+
tensor_getter,
1297+
hf_source_keys_or_key,
1298+
mt_target_shape_or_shapes,
1299+
hook_fn,
1300+
)
1301+
# Stacked mapping
1302+
elif not isinstance(hf_source_keys_or_key[0], list):
1303+
# Case 2 or 3: Single-Axis Stacked hf keys (un-nested list)
1304+
load_fn = partial(
1305+
_build_single_axis_stacked_tensor,
1306+
hf_source_keys_or_key,
1307+
tensor_getter,
1308+
hook_fn,
1309+
mt_target_shape_or_shapes,
1310+
config,
1311+
)
1312+
else:
1313+
# isinstance(hf_source_keys_or_key[0], list)
1314+
# Case 4: Multi-Axis Stacked hf keys (nested list)
1315+
load_fn = partial(
1316+
_build_multi_axis_stacked_tensor,
1317+
hf_source_keys_or_key,
1318+
tensor_getter,
1319+
hook_fn,
1320+
mt_target_shape_or_shapes,
1321+
config,
1322+
)
1323+
return load_fn

src/maxtext/common/checkpointing.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -734,6 +734,7 @@ def load_state_if_possible(
734734
checkpoint_conversion_fn=None,
735735
source_checkpoint_layout="orbax",
736736
expansion_factor_real_data: int = -1,
737+
maxtext_config: Any | None = None,
737738
):
738739
"""Loads TrainState as possible from the inputs.
739740
@@ -856,7 +857,16 @@ def map_to_pspec(data):
856857
case _:
857858
return (checkpoint_manager.restore(step, args=Composite(items=checkpoint_args)), None)
858859

859-
if load_parameters_from_path != "":
860+
if source_checkpoint_layout == "safetensors_dynamic":
861+
path = load_parameters_from_path or load_full_state_from_path
862+
max_logging.log(f"Dynamic On-the-Fly Formatting: Loading SafeTensors from {path}")
863+
864+
from maxtext.checkpoint_conversion.utils.load_dynamic import load_safetensors_dynamic_state
865+
866+
return load_safetensors_dynamic_state(
867+
path, abstract_unboxed_pre_state, maxtext_config
868+
)
869+
elif load_parameters_from_path != "":
860870
if isinstance(abstract_unboxed_pre_state, nnx.State):
861871
_, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...)
862872
else:
@@ -868,6 +878,9 @@ def map_to_pspec(data):
868878
checkpoint_storage_concurrent_gb,
869879
use_ocdbt=use_ocdbt,
870880
use_zarr3=use_zarr3,
881+
enable_orbax_v1=enable_orbax_v1,
882+
source_checkpoint_layout=source_checkpoint_layout,
883+
checkpoint_conversion_fn=checkpoint_conversion_fn,
871884
)
872885
return None, restored_params
873886
elif load_full_state_from_path != "":
@@ -908,22 +921,19 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-
908921

909922

910923
def load_params_from_path(
911-
load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True
924+
load_parameters_from_path,
925+
abstract_unboxed_params,
926+
checkpoint_storage_concurrent_gb,
927+
use_ocdbt=True,
928+
use_zarr3=True,
929+
enable_orbax_v1=False,
930+
source_checkpoint_layout="orbax",
931+
checkpoint_conversion_fn=None,
912932
):
913933
"""Load decode params from checkpoint at specified path."""
914934
assert load_parameters_from_path, "load_parameters_from_path is not defined."
915935
max_logging.log(f"restoring params from {load_parameters_from_path}")
916936

917-
# NNX target: the on-disk checkpoint is in Linen layout; reshape it into the NNX params state.
918-
if isinstance(abstract_unboxed_params, nnx.State):
919-
return _load_linen_params_into_nnx(
920-
load_parameters_from_path,
921-
abstract_unboxed_params,
922-
checkpoint_storage_concurrent_gb,
923-
use_ocdbt,
924-
use_zarr3,
925-
)
926-
927937
# *_concurrent_gb should be set for large models, the default is 96.
928938
max_logging.log(f"Creating checkpoint manager with ocdbt={use_ocdbt} and zarr3={use_zarr3}")
929939
ckptr = ocp.Checkpointer(

src/maxtext/configs/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ class Checkpointing(BaseModel):
344344
save_quantized_params_path: PathStr = Field("", description="Path to save params quantized on the fly.")
345345
enable_orbax_v1: bool = Field(False, description="Bool flag for enabling Orbax v1.")
346346
checkpoint_conversion_fn: None | str = Field(None, description="Function for processing loaded checkpoint dict.")
347-
source_checkpoint_layout: Literal["orbax", "safetensors"] = Field(
347+
source_checkpoint_layout: Literal["orbax", "safetensors", "safetensors_dynamic"] = Field(
348348
"orbax", description="The layout of the source checkpoint to load."
349349
)
350350
save_checkpoint_on_completion: bool = Field(

0 commit comments

Comments
 (0)