Skip to content

Commit 24191ed

Browse files
Refactor tensor handling and add dynamic SafeTensors loading.
PiperOrigin-RevId: 916957035
1 parent 3f9789f commit 24191ed

7 files changed

Lines changed: 458 additions & 27 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@
6767
from maxtext.common.common_types import MODEL_MODE_TRAIN
6868
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
6969
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
70-
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
70+
from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns
71+
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
7172
from maxtext.inference.inference_utils import str2bool
7273
from maxtext.layers import quantizations
7374
from maxtext.models import models
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tensor handling utility functions for checkpoint conversion."""
16+
17+
from functools import partial
18+
from typing import Any, Callable, List
19+
import numpy as np
20+
21+
22+
def apply_hook_fns(weight, target_shape, hook_fns):
23+
"""Apply hook functions, essential for to_maxtext and to_huggingface"""
24+
# If hook is unsepecified, use identity
25+
if hook_fns is None:
26+
return weight
27+
if not isinstance(hook_fns, list):
28+
hook_fns = [hook_fns]
29+
# Apply a list of hooks, be careful of order
30+
for hook_fn in hook_fns:
31+
weight = hook_fn(weight, target_shape)
32+
return weight
33+
34+
35+
def _build_multi_axis_stacked_tensor(
36+
hf_source_keys: List[List[str]],
37+
tensor_getter_fn: Callable[[str], np.ndarray],
38+
hook_fns: Any,
39+
target_shape: tuple,
40+
config,
41+
) -> np.ndarray:
42+
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers).
43+
44+
This function handles the complex case for scanned MoE layers, producing a tensor
45+
with the shape (num_experts, num_layers, ...).
46+
47+
Args:
48+
hf_source_keys: A nested (2D) list of Hugging Face parameter names.
49+
Outer list iterates experts, inner list iterates layers.
50+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
51+
hook_fns: The hook function(s) to apply to each individual weight.
52+
target_shape: The final shape of the target MaxText tensor.
53+
config: The MaxText pyconfig object.
54+
55+
Returns:
56+
The final, assembled NumPy array for the MaxText parameter.
57+
"""
58+
all_expert_tensors = []
59+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
60+
# For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:]
61+
mt_slice_shape = target_shape[2:]
62+
63+
# Outer loop iterates through experts
64+
for layer_keys_for_expert in hf_source_keys:
65+
layer_tensors_for_expert = []
66+
# Inner loop iterates through layers for the current expert
67+
for hf_key_single in layer_keys_for_expert:
68+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
69+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
70+
layer_tensors_for_expert.append(processed_hf_tensor)
71+
all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0))
72+
return np.stack(all_expert_tensors, axis=0)
73+
74+
75+
def _build_single_axis_stacked_tensor(
76+
hf_source_keys: List[str],
77+
tensor_getter_fn: Callable[[str], np.ndarray],
78+
hook_fns: Any,
79+
target_shape: tuple,
80+
config,
81+
) -> np.ndarray:
82+
"""Builds a MaxText tensor by stacking HF weights along a single axis.
83+
84+
This function handles both standard scanned layers (e.g., attention) and
85+
unscanned MoE layers (which are stacked along the expert axis).
86+
87+
Args:
88+
hf_source_keys: A 1D list of Hugging Face parameter names.
89+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
90+
hook_fns: The hook function(s) to apply to each individual weight.
91+
target_shape: The final shape of the target MaxText tensor.
92+
config: The MaxText pyconfig object.
93+
94+
Returns:
95+
The final, assembled NumPy array for the MaxText parameter.
96+
"""
97+
tensors_to_stack = []
98+
99+
if config.scan_layers:
100+
# If it's a standard scanned layer, we use the configured param_scan_axis.
101+
axis_to_stack = config.param_scan_axis
102+
else:
103+
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
104+
axis_to_stack = 0
105+
106+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
107+
# We calculate it by removing the stacking dimension from the final target shape.
108+
mt_slice_shape_list = list(target_shape)
109+
del mt_slice_shape_list[axis_to_stack]
110+
mt_slice_shape = tuple(mt_slice_shape_list)
111+
112+
for hf_key_single in hf_source_keys:
113+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
114+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
115+
tensors_to_stack.append(processed_hf_tensor)
116+
117+
# Stack all processed tensors along the determined axis.
118+
return np.stack(tensors_to_stack, axis=axis_to_stack)
119+
120+
121+
def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config):
122+
"""Determine the loading function for HF keys.
123+
HF keys can take four forms:
124+
Case 1: Unscanned (single string)
125+
Case 2: Scanned (list of strings)
126+
Case 3: Unscanned with expert stacking (list of strings)
127+
Case 4: Scanned with expert stacking (nested list of strings)
128+
"""
129+
load_fn = None
130+
if not isinstance(hf_source_keys_or_key, list):
131+
# Case 1: Single hf key (str)
132+
def _loader(getter, key, shape, hook):
133+
return apply_hook_fns(getter(key), shape, hook)
134+
135+
load_fn = partial(
136+
_loader,
137+
tensor_getter,
138+
hf_source_keys_or_key,
139+
mt_target_shape_or_shapes,
140+
hook_fn,
141+
)
142+
# Stacked mapping
143+
elif not isinstance(hf_source_keys_or_key[0], list):
144+
# Case 2 or 3: Single-Axis Stacked hf keys (un-nested list)
145+
load_fn = partial(
146+
_build_single_axis_stacked_tensor,
147+
hf_source_keys_or_key,
148+
tensor_getter,
149+
hook_fn,
150+
mt_target_shape_or_shapes,
151+
config,
152+
)
153+
else:
154+
# isinstance(hf_source_keys_or_key[0], list)
155+
# Case 4: Multi-Axis Stacked hf keys (nested list)
156+
load_fn = partial(
157+
_build_multi_axis_stacked_tensor,
158+
hf_source_keys_or_key,
159+
tensor_getter,
160+
hook_fn,
161+
mt_target_shape_or_shapes,
162+
config,
163+
)
164+
return load_fn

src/maxtext/checkpoint_conversion/utils/utils.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import time
2424
import json
2525
from concurrent.futures import ThreadPoolExecutor
26-
from typing import Any
26+
from functools import partial
27+
from typing import Any, Callable, List
2728
from tqdm import tqdm
2829
import resource
2930
import numpy as np
@@ -1165,3 +1166,135 @@ def save_weights_to_checkpoint(
11651166
checkpoint_manager.wait_until_finished()
11661167

11671168
max_logging.log(f"Elapse for checkpoint save: {(time.time() - start) / 60:.2f} min")
1169+
1170+
1171+
def _build_multi_axis_stacked_tensor(
1172+
hf_source_keys: List[List[str]],
1173+
tensor_getter_fn: Callable[[str], np.ndarray],
1174+
hook_fns: Any,
1175+
target_shape: tuple,
1176+
config,
1177+
) -> np.ndarray:
1178+
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers).
1179+
1180+
This function handles the complex case for scanned MoE layers, producing a tensor
1181+
with the shape (num_experts, num_layers, ...).
1182+
1183+
Args:
1184+
hf_source_keys: A nested (2D) list of Hugging Face parameter names.
1185+
Outer list iterates experts, inner list iterates layers.
1186+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
1187+
hook_fns: The hook function(s) to apply to each individual weight.
1188+
target_shape: The final shape of the target MaxText tensor.
1189+
config: The MaxText pyconfig object.
1190+
1191+
Returns:
1192+
The final, assembled NumPy array for the MaxText parameter.
1193+
"""
1194+
all_expert_tensors = []
1195+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
1196+
# For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:]
1197+
mt_slice_shape = target_shape[2:]
1198+
1199+
# Outer loop iterates through experts
1200+
for layer_keys_for_expert in hf_source_keys:
1201+
layer_tensors_for_expert = []
1202+
# Inner loop iterates through layers for the current expert
1203+
for hf_key_single in layer_keys_for_expert:
1204+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
1205+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
1206+
layer_tensors_for_expert.append(processed_hf_tensor)
1207+
all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0))
1208+
return np.stack(all_expert_tensors, axis=0)
1209+
1210+
1211+
def _build_single_axis_stacked_tensor(
1212+
hf_source_keys: List[str],
1213+
tensor_getter_fn: Callable[[str], np.ndarray],
1214+
hook_fns: Any,
1215+
target_shape: tuple,
1216+
config,
1217+
) -> np.ndarray:
1218+
"""Builds a MaxText tensor by stacking HF weights along a single axis.
1219+
1220+
This function handles both standard scanned layers (e.g., attention) and
1221+
unscanned MoE layers (which are stacked along the expert axis).
1222+
1223+
Args:
1224+
hf_source_keys: A 1D list of Hugging Face parameter names.
1225+
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
1226+
hook_fns: The hook function(s) to apply to each individual weight.
1227+
target_shape: The final shape of the target MaxText tensor.
1228+
config: The MaxText pyconfig object.
1229+
1230+
Returns:
1231+
The final, assembled NumPy array for the MaxText parameter.
1232+
"""
1233+
tensors_to_stack = []
1234+
1235+
if config.scan_layers:
1236+
# If it's a standard scanned layer, we use the configured param_scan_axis.
1237+
axis_to_stack = config.param_scan_axis
1238+
else:
1239+
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
1240+
axis_to_stack = 0
1241+
1242+
# The hook function needs the shape of an individual slice, not the full stacked tensor.
1243+
# We calculate it by removing the stacking dimension from the final target shape.
1244+
mt_slice_shape_list = list(target_shape)
1245+
del mt_slice_shape_list[axis_to_stack]
1246+
mt_slice_shape = tuple(mt_slice_shape_list)
1247+
1248+
for hf_key_single in hf_source_keys:
1249+
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
1250+
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
1251+
tensors_to_stack.append(processed_hf_tensor)
1252+
1253+
# Stack all processed tensors along the determined axis.
1254+
return np.stack(tensors_to_stack, axis=axis_to_stack)
1255+
1256+
1257+
def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config):
1258+
"""Determine the loading function for HF keys.
1259+
HF keys can take four forms:
1260+
Case 1: Unscanned (single string)
1261+
Case 2: Scanned (list of strings)
1262+
Case 3: Unscanned with expert stacking (list of strings)
1263+
Case 4: Scanned with expert stacking (nested list of strings)
1264+
"""
1265+
load_fn = None
1266+
if not isinstance(hf_source_keys_or_key, list):
1267+
# Case 1: Single hf key (str)
1268+
def _loader(getter, key, shape, hook):
1269+
return apply_hook_fns(getter(key), shape, hook)
1270+
1271+
load_fn = partial(
1272+
_loader,
1273+
tensor_getter,
1274+
hf_source_keys_or_key,
1275+
mt_target_shape_or_shapes,
1276+
hook_fn,
1277+
)
1278+
# Stacked mapping
1279+
elif not isinstance(hf_source_keys_or_key[0], list):
1280+
# Case 2 or 3: Single-Axis Stacked hf keys (un-nested list)
1281+
load_fn = partial(
1282+
_build_single_axis_stacked_tensor,
1283+
hf_source_keys_or_key,
1284+
tensor_getter,
1285+
hook_fn,
1286+
mt_target_shape_or_shapes,
1287+
config,
1288+
)
1289+
else:
1290+
# isinstance(hf_source_keys_or_key[0], list)
1291+
# Case 4: Multi-Axis Stacked hf keys (nested list)
1292+
load_fn = partial(
1293+
_build_multi_axis_stacked_tensor,
1294+
hf_source_keys_or_key,
1295+
tensor_getter,
1296+
hook_fn,
1297+
mt_target_shape_or_shapes,
1298+
config,
1299+
)
1300+
return load_fn

0 commit comments

Comments
 (0)