Skip to content

Commit 411e4f2

Browse files
Add support for On-The-Fly Dynamic SafeTensors loading.
PiperOrigin-RevId: 916957035
1 parent 0747df9 commit 411e4f2

8 files changed

Lines changed: 663 additions & 15 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@
6767
from maxtext.common.common_types import MODEL_MODE_TRAIN
6868
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
6969
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
70-
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
70+
from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns
71+
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
7172
from maxtext.inference.inference_utils import str2bool
7273
from maxtext.layers import quantizations
7374
from maxtext.models import models
Lines changed: 338 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,338 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Dynamic loading of HuggingFace checkpoints during training/eval workloads directly in the target format."""
16+
17+
import jax
18+
from flax import traverse_util
19+
from flax import nnx
20+
from orbax.checkpoint import v1 as ocp_v1
21+
from orbax.checkpoint._src.arrays import sharding as sharding_utils
22+
23+
from maxtext.utils import max_logging
24+
from maxtext.checkpoint_conversion.utils.tensor_handling import _get_hf_loading_function
25+
from maxtext.checkpoint_conversion.utils import param_mapping
26+
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
27+
import time
28+
29+
30+
def get_hf_config_and_mappings(maxtext_config):
31+
"""Gets HF config and parameter mapping based on the MaxText config."""
32+
model_key = maxtext_config.model_name
33+
if "-Instruct" in model_key:
34+
model_key = model_key.replace("-Instruct", "")
35+
hf_config_obj = HF_MODEL_CONFIGS[model_key]
36+
hf_config_dict = hf_config_obj.to_dict()
37+
38+
param_map_mt_to_hf = param_mapping.PARAM_MAPPING[model_key](
39+
hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers
40+
)
41+
hook_fn_map_mt = param_mapping.HOOK_FNS[model_key](
42+
hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers, saving_to_hf=False
43+
)
44+
return param_map_mt_to_hf, hook_fn_map_mt
45+
46+
47+
def load_sharded_hf_state(path):
48+
"""Loads HF state with maximal sharding across TPU mesh to avoid host OOM."""
49+
t0 = time.time()
50+
context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS)
51+
with context:
52+
metadata = ocp_v1.pytree_metadata(path)
53+
simple_abstract_state = metadata.metadata
54+
55+
# Distributed Sharded Download: Tell JAX to shard the HF Safetensors download
56+
# across the entire TPU mesh to avoid Host OOM.
57+
current_global_devices = jax.devices()
58+
shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state, devices=current_global_devices)
59+
60+
def combine_sharding(sds, single_sharding):
61+
return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=single_sharding)
62+
63+
sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings)
64+
65+
max_logging.log("Reading raw Safetensors into memory (Distributed Sharded GCS Download)...")
66+
hf_state = ocp_v1.load_pytree(path, sharded_abstract_state)
67+
max_logging.log(f"load_sharded_hf_state took {time.time() - t0:.2f}s")
68+
return hf_state
69+
70+
71+
def transform_hf_state_to_mt_state(
72+
hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config
73+
):
74+
"""Transforms HF state into MaxText state by applying param mappings and mathematical hooks."""
75+
t0 = time.time()
76+
def tensor_getter(key):
77+
return hf_state.pop(key)
78+
79+
flat_target = traverse_util.flatten_dict(target_tree, sep=".")
80+
flat_restored = flat_target.copy()
81+
82+
mapped_count = 0
83+
keys_missed = []
84+
max_logging.log("Starting fast in-memory Distributed Transformations...")
85+
86+
for mt_key, hf_source in param_map_mt_to_hf.items():
87+
mt_name = mt_key.replace("params-", "").replace("-", ".")
88+
89+
# Determine the correct key in flat_target
90+
check_name = mt_name
91+
if check_name not in flat_target:
92+
if ("params." + mt_name) in flat_target:
93+
check_name = "params." + mt_name
94+
elif mt_key.replace("-", ".") in flat_target:
95+
check_name = mt_key.replace("-", ".")
96+
97+
if check_name not in flat_target:
98+
keys_missed.append(mt_name)
99+
continue
100+
101+
target_shape = flat_target[check_name].shape
102+
hook_fn = hook_fn_map_mt.get(mt_key)
103+
104+
load_fn = _get_hf_loading_function(
105+
hf_source,
106+
tensor_getter,
107+
hook_fn,
108+
target_shape,
109+
maxtext_config,
110+
)
111+
112+
# Execute transformation and assign to flat_restored
113+
t_layer = time.time()
114+
unsharded_array = load_fn()
115+
116+
# Ensure it's Sharded explicitly matching the JAX model expectations
117+
target_sharding = flat_target[check_name].sharding
118+
flat_restored[check_name] = jax.device_put(unsharded_array, device=target_sharding, donate=True)
119+
120+
max_logging.log(f"Transformed {check_name} from {hf_source} in {time.time() - t_layer:.4f}s")
121+
mapped_count += 1
122+
123+
if mapped_count == 0:
124+
max_logging.log(f"All transformations missed! Sample missed mt_names: {keys_missed[:5]}")
125+
max_logging.log(f"Sample flat_target keys: {list(flat_target.keys())[:5]}")
126+
127+
max_logging.log(f"Successfully mapped {mapped_count} parameters.")
128+
restored_params = traverse_util.unflatten_dict(flat_restored, sep=".")
129+
130+
if "params" in restored_params:
131+
restored_params = restored_params["params"]
132+
133+
max_logging.log(f"transform_hf_state_to_mt_state took {time.time() - t0:.2f}s")
134+
135+
return {"params": restored_params}
136+
137+
138+
def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config):
139+
"""Main entry point to dynamically build and load safetensors into MaxText format.
140+
141+
Splits execution into:
142+
1. Deriving Mappings
143+
2. Loading Sharded arrays directly to TPUs
144+
3. Processing the transformations natively on TPUs
145+
"""
146+
if maxtext_config is None:
147+
raise ValueError("maxtext_config must be provided for safetensors_dynamic loading.")
148+
149+
import os
150+
from maxtext.utils.globals import HF_IDS
151+
152+
model_name = maxtext_config.model_name
153+
if "-Instruct" in model_name:
154+
model_name = model_name.replace("-Instruct", "")
155+
156+
if not path:
157+
if model_name not in HF_IDS:
158+
raise ValueError(f"Unsupported model name for automatic HF repo resolution: {model_name}.")
159+
path = HF_IDS[model_name]
160+
161+
if path.startswith("hf://"):
162+
path = path[5:]
163+
164+
if not path.startswith("gs://") and not os.path.isdir(path):
165+
from huggingface_hub import HfFileSystem
166+
import concurrent.futures
167+
import json
168+
import jax
169+
170+
fs = HfFileSystem(token=maxtext_config.hf_access_token)
171+
repo_id = path
172+
173+
files = fs.glob(f"{repo_id}/*.safetensors")
174+
175+
process_count = max(1, jax.process_count())
176+
host_id = jax.process_index()
177+
HEADER_NUM_BYTES = 8
178+
179+
if maxtext_config.async_checkpointing_dir != "":
180+
gcs_cache_dir = f"{maxtext_config.async_checkpointing_dir}/hf_cache/{repo_id.replace('/', '_')}"
181+
path = gcs_cache_dir
182+
183+
# Only Host 0 downloads to the shared GCS cache
184+
if host_id == 0:
185+
import tensorflow as tf
186+
if not tf.io.gfile.exists(gcs_cache_dir):
187+
tf.io.gfile.makedirs(gcs_cache_dir)
188+
189+
max_logging.log(f"Dynamic HF Hub Fast DL: Host 0 is downloading to shared GCS Cache: {gcs_cache_dir}")
190+
191+
def fetch_shard_gcs(fpath):
192+
time.sleep(random.uniform(0.0, 5.0))
193+
gcs_path = os.path.join(gcs_cache_dir, os.path.basename(fpath))
194+
195+
# Check if it already exists to avoid redundant downloads
196+
if tf.io.gfile.exists(gcs_path):
197+
return
198+
199+
max_retries = 5
200+
for attempt in range(max_retries):
201+
try:
202+
with fs.open(fpath, "rb") as remote_f:
203+
with tf.io.gfile.GFile(gcs_path, "wb") as gcs_f:
204+
buffer_size = 1024 * 1024 * 16
205+
while True:
206+
buf = remote_f.read(buffer_size)
207+
if not buf:
208+
break
209+
gcs_f.write(buf)
210+
break
211+
except Exception as e:
212+
if attempt < max_retries - 1:
213+
max_logging.log(f"Error fetching {fpath} to GCS: {e}. Retrying in 15 seconds... (Attempt {attempt+1}/{max_retries})")
214+
time.sleep(15)
215+
else:
216+
max_logging.log(f"Failed to fetch {fpath} to GCS after {max_retries} attempts.")
217+
raise
218+
219+
import time as time_module
220+
t_gcs_start = time_module.time()
221+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
222+
list(executor.map(fetch_shard_gcs, files))
223+
t_gcs_end = time_module.time()
224+
max_logging.log(f"GCS caching complete in {t_gcs_end - t_gcs_start:.2f}s.")
225+
226+
# Global barrier to ensure all hosts wait for Host 0 to finish downloading to the shared GCS bucket
227+
jax.distributed.barrier_wait()
228+
229+
else:
230+
# Fallback to local /tmp caching across all hosts with distributed downloading
231+
local_dir = f"/tmp/hf_checkpoints/{repo_id.replace('/', '_')}"
232+
os.makedirs(local_dir, exist_ok=True)
233+
234+
max_logging.log(f"Dynamic HF Hub Fast DL: Resolving metadata and partial chunks via HTTP Range Requests for Host {host_id}/{process_count}")
235+
import random
236+
import time
237+
238+
def fetch_shard(fpath):
239+
max_retries = 5
240+
for attempt in range(max_retries):
241+
try:
242+
time.sleep(random.uniform(0.0, 5.0))
243+
local_path = os.path.join(local_dir, os.path.basename(fpath))
244+
245+
if os.path.exists(local_path):
246+
return
247+
248+
with fs.open(fpath, "rb") as remote_f:
249+
header_size_bytes = remote_f.read(HEADER_NUM_BYTES)
250+
header_size = int.from_bytes(header_size_bytes, byteorder="little")
251+
header_bytes = remote_f.read(header_size)
252+
header = json.loads(header_bytes)
253+
254+
data_start_offset = HEADER_NUM_BYTES + header_size
255+
256+
tensors = {k: v for k, v in header.items() if k != "__metadata__"}
257+
sorted_tensors = sorted(tensors.items(), key=lambda item: item[1]["data_offsets"][0])
258+
259+
with open(local_path, "wb") as local_f:
260+
local_f.write(header_size_bytes)
261+
local_f.write(header_bytes)
262+
263+
if not sorted_tensors:
264+
return
265+
266+
total_size = sorted_tensors[-1][1]["data_offsets"][1]
267+
current_bundle = 0
268+
cumulative_size = 0
269+
host_start_offset = None
270+
host_end_offset = None
271+
272+
for name, info in sorted_tensors:
273+
start, end = info["data_offsets"]
274+
tensor_size = end - start
275+
if current_bundle < process_count - 1:
276+
ideal = (current_bundle + 1) * (total_size / process_count)
277+
dist_if_cut = abs(cumulative_size - ideal)
278+
dist_if_keep = abs((cumulative_size + tensor_size) - ideal)
279+
if dist_if_cut < dist_if_keep and cumulative_size > 0:
280+
current_bundle += 1
281+
282+
if current_bundle == host_id:
283+
if host_start_offset is None:
284+
host_start_offset = start
285+
host_end_offset = end
286+
287+
cumulative_size += tensor_size
288+
289+
if host_start_offset is not None:
290+
chunk_size = host_end_offset - host_start_offset
291+
remote_f.seek(data_start_offset + host_start_offset)
292+
local_f.seek(data_start_offset + host_start_offset)
293+
294+
buffer_size = 1024 * 1024 * 16
295+
bytes_remaining = chunk_size
296+
while bytes_remaining > 0:
297+
sz = min(buffer_size, bytes_remaining)
298+
buf = remote_f.read(sz)
299+
if not buf:
300+
break
301+
local_f.write(buf)
302+
bytes_remaining -= len(buf)
303+
break
304+
except Exception as e:
305+
if attempt < max_retries - 1:
306+
max_logging.log(f"Error fetching {fpath}: {e}. Retrying in 15 seconds... (Attempt {attempt+1}/{max_retries})")
307+
time.sleep(15)
308+
else:
309+
max_logging.log(f"Failed to fetch {fpath} after {max_retries} attempts.")
310+
raise
311+
312+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
313+
list(executor.map(fetch_shard, files))
314+
315+
path = local_dir
316+
317+
t_total = time.time()
318+
param_map_mt_to_hf, hook_fn_map_mt = get_hf_config_and_mappings(maxtext_config)
319+
max_logging.log(f"[1/3] Mappings derived in {time.time() - t_total:.2f}s")
320+
321+
target_tree = (
322+
abstract_unboxed_pre_state.to_pure_dict()
323+
if isinstance(abstract_unboxed_pre_state, nnx.State)
324+
else abstract_unboxed_pre_state.params
325+
)
326+
327+
t1 = time.time()
328+
hf_state = load_sharded_hf_state(path)
329+
max_logging.log(f"[2/3] Distributed Sharded GCS load completed in {time.time() - t1:.2f}s")
330+
331+
t2 = time.time()
332+
restored_params = transform_hf_state_to_mt_state(
333+
hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config
334+
)
335+
max_logging.log(f"[3/3] CPU Transformations completed in {time.time() - t2:.2f}s")
336+
max_logging.log(f"Total safetensors_dynamic duration: {time.time() - t_total:.2f}s")
337+
338+
return None, restored_params

0 commit comments

Comments
 (0)