Skip to content

Commit 2fbe229

Browse files
Add support for On-The-Fly Dynamic SafeTensors loading.
PiperOrigin-RevId: 916957035
1 parent be4fd71 commit 2fbe229

8 files changed

Lines changed: 496 additions & 15 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@
6767
from maxtext.common.common_types import MODEL_MODE_TRAIN
6868
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
6969
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
70-
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
70+
from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns
71+
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
7172
from maxtext.inference.inference_utils import str2bool
7273
from maxtext.layers import quantizations
7374
from maxtext.models import models
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Dynamic loading of HuggingFace checkpoints during training/eval workloads directly in the target format."""
16+
17+
import jax
18+
from flax import traverse_util
19+
from flax import nnx
20+
from orbax.checkpoint import v1 as ocp_v1
21+
from orbax.checkpoint._src.arrays import sharding as sharding_utils
22+
23+
from maxtext.utils import max_logging
24+
from maxtext.checkpoint_conversion.utils.tensor_handling import _get_hf_loading_function
25+
from maxtext.checkpoint_conversion.utils import param_mapping
26+
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
27+
import time
28+
29+
30+
def get_hf_config_and_mappings(maxtext_config):
31+
"""Gets HF config and parameter mapping based on the MaxText config."""
32+
model_key = maxtext_config.model_name
33+
if "-Instruct" in model_key:
34+
model_key = model_key.replace("-Instruct", "")
35+
hf_config_obj = HF_MODEL_CONFIGS[model_key]
36+
hf_config_dict = hf_config_obj.to_dict()
37+
38+
param_map_mt_to_hf = param_mapping.PARAM_MAPPING[model_key](
39+
hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers
40+
)
41+
hook_fn_map_mt = param_mapping.HOOK_FNS[model_key](
42+
hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers, saving_to_hf=False
43+
)
44+
return param_map_mt_to_hf, hook_fn_map_mt
45+
46+
47+
def load_sharded_hf_state(path):
48+
"""Loads HF state with maximal sharding across TPU mesh to avoid host OOM."""
49+
t0 = time.time()
50+
context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS)
51+
with context:
52+
metadata = ocp_v1.pytree_metadata(path)
53+
simple_abstract_state = metadata.metadata
54+
55+
# Distributed Sharded Download: Tell JAX to shard the HF Safetensors download
56+
# across the current TPU slice mesh to avoid Host OOM without cross-slice DCN overhead.
57+
my_slice_id = getattr(jax.local_devices()[0], 'slice_index', 0)
58+
current_slice_devices = [d for d in jax.devices() if getattr(d, 'slice_index', 0) == my_slice_id]
59+
shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state, devices=current_slice_devices)
60+
61+
def combine_sharding(sds, single_sharding):
62+
return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=single_sharding)
63+
64+
sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings)
65+
66+
max_logging.log("Reading raw Safetensors into memory (Distributed Sharded GCS Download)...")
67+
hf_state = ocp_v1.load_pytree(path, sharded_abstract_state)
68+
max_logging.log(f"load_sharded_hf_state took {time.time() - t0:.2f}s")
69+
return hf_state
70+
71+
72+
def transform_hf_state_to_mt_state(
73+
hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config
74+
):
75+
"""Transforms HF state into MaxText state by applying param mappings and mathematical hooks."""
76+
t0 = time.time()
77+
def tensor_getter(key):
78+
return hf_state.pop(key)
79+
80+
flat_target = traverse_util.flatten_dict(target_tree, sep=".")
81+
flat_restored = flat_target.copy()
82+
83+
mapped_count = 0
84+
keys_missed = []
85+
max_logging.log("Starting fast in-memory Distributed Transformations...")
86+
87+
for mt_key, hf_source in param_map_mt_to_hf.items():
88+
mt_name = mt_key.replace("params-", "").replace("-", ".")
89+
90+
# Determine the correct key in flat_target
91+
check_name = mt_name
92+
if check_name not in flat_target:
93+
if ("params." + mt_name) in flat_target:
94+
check_name = "params." + mt_name
95+
elif mt_key.replace("-", ".") in flat_target:
96+
check_name = mt_key.replace("-", ".")
97+
98+
if check_name not in flat_target:
99+
keys_missed.append(mt_name)
100+
continue
101+
102+
target_shape = flat_target[check_name].shape
103+
hook_fn = hook_fn_map_mt.get(mt_key)
104+
105+
load_fn = _get_hf_loading_function(
106+
hf_source,
107+
tensor_getter,
108+
hook_fn,
109+
target_shape,
110+
maxtext_config,
111+
)
112+
113+
# Execute transformation and assign to flat_restored
114+
t_layer = time.time()
115+
unsharded_array = load_fn()
116+
117+
# Ensure it's Sharded explicitly matching the JAX model expectations
118+
target_sharding = flat_target[check_name].sharding
119+
flat_restored[check_name] = jax.device_put(unsharded_array, device=target_sharding, donate=True)
120+
121+
max_logging.log(f"Transformed {check_name} from {hf_source} in {time.time() - t_layer:.4f}s")
122+
mapped_count += 1
123+
124+
if mapped_count == 0:
125+
max_logging.log(f"All transformations missed! Sample missed mt_names: {keys_missed[:5]}")
126+
max_logging.log(f"Sample flat_target keys: {list(flat_target.keys())[:5]}")
127+
128+
max_logging.log(f"Successfully mapped {mapped_count} parameters.")
129+
restored_params = traverse_util.unflatten_dict(flat_restored, sep=".")
130+
131+
if "params" in restored_params:
132+
restored_params = restored_params["params"]
133+
134+
max_logging.log(f"transform_hf_state_to_mt_state took {time.time() - t0:.2f}s")
135+
136+
return {"params": restored_params}
137+
138+
139+
def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config):
140+
"""Main entry point to dynamically build and load safetensors into MaxText format.
141+
142+
Splits execution into:
143+
1. Deriving Mappings
144+
2. Loading Sharded arrays directly to TPUs
145+
3. Processing the transformations natively on TPUs
146+
"""
147+
if maxtext_config is None:
148+
raise ValueError("maxtext_config must be provided for safetensors_dynamic loading.")
149+
150+
t_total = time.time()
151+
param_map_mt_to_hf, hook_fn_map_mt = get_hf_config_and_mappings(maxtext_config)
152+
max_logging.log(f"[1/3] Mappings derived in {time.time() - t_total:.2f}s")
153+
154+
target_tree = (
155+
abstract_unboxed_pre_state.to_pure_dict()
156+
if isinstance(abstract_unboxed_pre_state, nnx.State)
157+
else abstract_unboxed_pre_state.params
158+
)
159+
160+
t1 = time.time()
161+
hf_state = load_sharded_hf_state(path)
162+
max_logging.log(f"[2/3] Distributed Sharded GCS load completed in {time.time() - t1:.2f}s")
163+
164+
t2 = time.time()
165+
restored_params = transform_hf_state_to_mt_state(
166+
hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config
167+
)
168+
max_logging.log(f"[3/3] CPU Transformations completed in {time.time() - t2:.2f}s")
169+
max_logging.log(f"Total safetensors_dynamic duration: {time.time() - t_total:.2f}s")
170+
171+
return None, restored_params
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tensor handling utility functions for checkpoint conversion."""
16+
17+
from functools import partial
18+
from typing import Any, Callable, List
19+
import jax.numpy as np
20+
21+
22+
def apply_hook_fns(weight, target_shape, hook_fns):
23+
"""Apply hook functions, essential for to_maxtext and to_huggingface"""
24+
# If hook is unsepecified, use identity
25+
if hook_fns is None:
26+
return weight
27+
if not isinstance(hook_fns, list):
28+
hook_fns = [hook_fns]
29+
# Apply a list of hooks, be careful of order
30+
for hook_fn in hook_fns:
31+
weight = hook_fn(weight, target_shape)
32+
return weight
33+
34+
35+
def _build_multi_axis_stacked_tensor(
36+
hf_source_keys: List[List[str]],
37+
tensor_getter_fn: Callable[[str], np.ndarray],
38+
hook_fns: Any,
39+
target_shape: tuple,
40+
config,
41+
) -> np.ndarray:
42+
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers).
43+
44+
This function handles the complex case for scanned MoE layers, producing a tensor
45+
with the shape (num_experts, num_layers, ...).
46+
47+
Args:
48+
hf_source_keys: A nested (2D) list of Hugging Face parameter names.
49+
Outer list iterates experts, inner list iterates layers.
50+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
51+
hook_fns: The hook function(s) to apply to each individual weight.
52+
target_shape: The final shape of the target MaxText tensor.
53+
config: The MaxText pyconfig object.
54+
55+
Returns:
56+
The final, assembled NumPy array for the MaxText parameter.
57+
"""
58+
all_expert_tensors = []
59+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
60+
# For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:]
61+
mt_slice_shape = target_shape[2:]
62+
63+
# Outer loop iterates through experts
64+
for layer_keys_for_expert in hf_source_keys:
65+
layer_tensors_for_expert = []
66+
# Inner loop iterates through layers for the current expert
67+
for hf_key_single in layer_keys_for_expert:
68+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
69+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
70+
layer_tensors_for_expert.append(processed_hf_tensor)
71+
all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0))
72+
return np.stack(all_expert_tensors, axis=0)
73+
74+
75+
def _build_single_axis_stacked_tensor(
76+
hf_source_keys: List[str],
77+
tensor_getter_fn: Callable[[str], np.ndarray],
78+
hook_fns: Any,
79+
target_shape: tuple,
80+
config,
81+
) -> np.ndarray:
82+
"""Builds a MaxText tensor by stacking HF weights along a single axis.
83+
84+
This function handles both standard scanned layers (e.g., attention) and
85+
unscanned MoE layers (which are stacked along the expert axis).
86+
87+
Args:
88+
hf_source_keys: A 1D list of Hugging Face parameter names.
89+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
90+
hook_fns: The hook function(s) to apply to each individual weight.
91+
target_shape: The final shape of the target MaxText tensor.
92+
config: The MaxText pyconfig object.
93+
94+
Returns:
95+
The final, assembled NumPy array for the MaxText parameter.
96+
"""
97+
tensors_to_stack = []
98+
99+
if config.scan_layers:
100+
# If it's a standard scanned layer, we use the configured param_scan_axis.
101+
axis_to_stack = config.param_scan_axis
102+
else:
103+
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
104+
axis_to_stack = 0
105+
106+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
107+
# We calculate it by removing the stacking dimension from the final target shape.
108+
mt_slice_shape_list = list(target_shape)
109+
del mt_slice_shape_list[axis_to_stack]
110+
mt_slice_shape = tuple(mt_slice_shape_list)
111+
112+
for hf_key_single in hf_source_keys:
113+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
114+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
115+
tensors_to_stack.append(processed_hf_tensor)
116+
117+
# Stack all processed tensors along the determined axis.
118+
return np.stack(tensors_to_stack, axis=axis_to_stack)
119+
120+
121+
def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config):
122+
"""Determine the loading function for HF keys.
123+
HF keys can take four forms:
124+
Case 1: Unscanned (single string)
125+
Case 2: Scanned (list of strings)
126+
Case 3: Unscanned with expert stacking (list of strings)
127+
Case 4: Scanned with expert stacking (nested list of strings)
128+
"""
129+
load_fn = None
130+
if not isinstance(hf_source_keys_or_key, list):
131+
# Case 1: Single hf key (str)
132+
def _loader(getter, key, shape, hook):
133+
return apply_hook_fns(getter(key), shape, hook)
134+
135+
load_fn = partial(
136+
_loader,
137+
tensor_getter,
138+
hf_source_keys_or_key,
139+
mt_target_shape_or_shapes,
140+
hook_fn,
141+
)
142+
# Stacked mapping
143+
elif not isinstance(hf_source_keys_or_key[0], list):
144+
# Case 2 or 3: Single-Axis Stacked hf keys (un-nested list)
145+
load_fn = partial(
146+
_build_single_axis_stacked_tensor,
147+
hf_source_keys_or_key,
148+
tensor_getter,
149+
hook_fn,
150+
mt_target_shape_or_shapes,
151+
config,
152+
)
153+
else:
154+
# isinstance(hf_source_keys_or_key[0], list)
155+
# Case 4: Multi-Axis Stacked hf keys (nested list)
156+
load_fn = partial(
157+
_build_multi_axis_stacked_tensor,
158+
hf_source_keys_or_key,
159+
tensor_getter,
160+
hook_fn,
161+
mt_target_shape_or_shapes,
162+
config,
163+
)
164+
return load_fn

0 commit comments

Comments
 (0)