Skip to content

Commit c217881

Browse files
Add support for On-The-Fly Dynamic SafeTensors loading.
PiperOrigin-RevId: 916957035
1 parent 0747df9 commit c217881

8 files changed

Lines changed: 673 additions & 15 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@
6767
from maxtext.common.common_types import MODEL_MODE_TRAIN
6868
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
6969
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
70-
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
70+
from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns
71+
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
7172
from maxtext.inference.inference_utils import str2bool
7273
from maxtext.layers import quantizations
7374
from maxtext.models import models
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Dynamic loading of HuggingFace checkpoints during training/eval workloads directly in the target format."""
16+
17+
import jax
18+
from flax import traverse_util
19+
from flax import nnx
20+
from orbax.checkpoint import v1 as ocp_v1
21+
from orbax.checkpoint._src.arrays import sharding as sharding_utils
22+
23+
from maxtext.utils import max_logging
24+
from maxtext.checkpoint_conversion.utils.tensor_handling import _get_hf_loading_function
25+
from maxtext.checkpoint_conversion.utils import param_mapping
26+
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
27+
import time
28+
29+
def build_gcs_cache_worker(fpath, gcs_cache_dir, hf_access_token):
30+
import os
31+
import time
32+
import random
33+
import tensorflow as tf
34+
from huggingface_hub import HfFileSystem
35+
from maxtext.utils import max_logging
36+
37+
fs = HfFileSystem(token=hf_access_token)
38+
time.sleep(random.uniform(0.0, 5.0))
39+
gcs_path = os.path.join(gcs_cache_dir, os.path.basename(fpath))
40+
41+
if tf.io.gfile.exists(gcs_path):
42+
return
43+
44+
max_retries = 5
45+
for attempt in range(max_retries):
46+
try:
47+
with fs.open(fpath, "rb") as remote_f:
48+
with tf.io.gfile.GFile(gcs_path, "wb") as gcs_f:
49+
buffer_size = 1024 * 1024 * 16
50+
while True:
51+
buf = remote_f.read(buffer_size)
52+
if not buf:
53+
break
54+
gcs_f.write(buf)
55+
break
56+
except Exception as e:
57+
if attempt < max_retries - 1:
58+
max_logging.log(f"Error fetching {fpath} to GCS: {e}. Retrying in 15 seconds... (Attempt {attempt+1}/{max_retries})")
59+
time.sleep(15)
60+
else:
61+
max_logging.log(f"Failed to fetch {fpath} to GCS after {max_retries} attempts.")
62+
raise
63+
64+
65+
def get_hf_config_and_mappings(maxtext_config):
66+
"""Gets HF config and parameter mapping based on the MaxText config."""
67+
model_key = maxtext_config.model_name
68+
if "-Instruct" in model_key:
69+
model_key = model_key.replace("-Instruct", "")
70+
hf_config_obj = HF_MODEL_CONFIGS[model_key]
71+
hf_config_dict = hf_config_obj.to_dict()
72+
73+
param_map_mt_to_hf = param_mapping.PARAM_MAPPING[model_key](
74+
hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers
75+
)
76+
hook_fn_map_mt = param_mapping.HOOK_FNS[model_key](
77+
hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers, saving_to_hf=False
78+
)
79+
return param_map_mt_to_hf, hook_fn_map_mt
80+
81+
82+
def load_sharded_hf_state(path):
83+
"""Loads HF state with maximal sharding across TPU mesh to avoid host OOM."""
84+
t0 = time.time()
85+
context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS)
86+
with context:
87+
metadata = ocp_v1.pytree_metadata(path)
88+
simple_abstract_state = metadata.metadata
89+
90+
# Distributed Sharded Download: Tell JAX to shard the HF Safetensors download
91+
# across the entire TPU mesh to avoid Host OOM.
92+
current_global_devices = jax.devices()
93+
shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state, devices=current_global_devices)
94+
95+
def combine_sharding(sds, single_sharding):
96+
return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=single_sharding)
97+
98+
sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings)
99+
100+
max_logging.log("Reading raw Safetensors into memory (Distributed Sharded GCS Download)...")
101+
hf_state = ocp_v1.load_pytree(path, sharded_abstract_state)
102+
max_logging.log(f"load_sharded_hf_state took {time.time() - t0:.2f}s")
103+
return hf_state
104+
105+
106+
def transform_hf_state_to_mt_state(
107+
hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config
108+
):
109+
"""Transforms HF state into MaxText state by applying param mappings and mathematical hooks."""
110+
t0 = time.time()
111+
def tensor_getter(key):
112+
return hf_state.pop(key)
113+
114+
flat_target = traverse_util.flatten_dict(target_tree, sep=".")
115+
flat_restored = flat_target.copy()
116+
117+
mapped_count = 0
118+
keys_missed = []
119+
max_logging.log("Starting fast in-memory Distributed Transformations...")
120+
121+
for mt_key, hf_source in param_map_mt_to_hf.items():
122+
mt_name = mt_key.replace("params-", "").replace("-", ".")
123+
124+
# Determine the correct key in flat_target
125+
check_name = mt_name
126+
if check_name not in flat_target:
127+
if ("params." + mt_name) in flat_target:
128+
check_name = "params." + mt_name
129+
elif mt_key.replace("-", ".") in flat_target:
130+
check_name = mt_key.replace("-", ".")
131+
132+
if check_name not in flat_target:
133+
keys_missed.append(mt_name)
134+
continue
135+
136+
target_shape = flat_target[check_name].shape
137+
hook_fn = hook_fn_map_mt.get(mt_key)
138+
139+
load_fn = _get_hf_loading_function(
140+
hf_source,
141+
tensor_getter,
142+
hook_fn,
143+
target_shape,
144+
maxtext_config,
145+
)
146+
147+
# Execute transformation and assign to flat_restored
148+
t_layer = time.time()
149+
unsharded_array = load_fn()
150+
151+
# Ensure it's Sharded explicitly matching the JAX model expectations
152+
target_sharding = flat_target[check_name].sharding
153+
flat_restored[check_name] = jax.device_put(unsharded_array, device=target_sharding, donate=True)
154+
155+
max_logging.log(f"Transformed {check_name} from {hf_source} in {time.time() - t_layer:.4f}s")
156+
mapped_count += 1
157+
158+
if mapped_count == 0:
159+
max_logging.log(f"All transformations missed! Sample missed mt_names: {keys_missed[:5]}")
160+
max_logging.log(f"Sample flat_target keys: {list(flat_target.keys())[:5]}")
161+
162+
max_logging.log(f"Successfully mapped {mapped_count} parameters.")
163+
restored_params = traverse_util.unflatten_dict(flat_restored, sep=".")
164+
165+
if "params" in restored_params:
166+
restored_params = restored_params["params"]
167+
168+
max_logging.log(f"transform_hf_state_to_mt_state took {time.time() - t0:.2f}s")
169+
170+
return {"params": restored_params}
171+
172+
173+
def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config):
174+
"""Main entry point to dynamically build and load safetensors into MaxText format.
175+
176+
Splits execution into:
177+
1. Deriving Mappings
178+
2. Loading Sharded arrays directly to TPUs
179+
3. Processing the transformations natively on TPUs
180+
"""
181+
if maxtext_config is None:
182+
raise ValueError("maxtext_config must be provided for safetensors_dynamic loading.")
183+
184+
import os
185+
from maxtext.utils.globals import HF_IDS
186+
187+
model_name = maxtext_config.model_name
188+
if "-Instruct" in model_name:
189+
model_name = model_name.replace("-Instruct", "")
190+
191+
if not path:
192+
if model_name not in HF_IDS:
193+
raise ValueError(f"Unsupported model name for automatic HF repo resolution: {model_name}.")
194+
path = HF_IDS[model_name]
195+
196+
if path.startswith("hf://"):
197+
path = path[5:]
198+
199+
if not path.startswith("gs://") and not os.path.isdir(path):
200+
from huggingface_hub import HfFileSystem
201+
import concurrent.futures
202+
import json
203+
import jax
204+
205+
fs = HfFileSystem(token=maxtext_config.hf_access_token)
206+
repo_id = path
207+
208+
files = fs.glob(f"{repo_id}/*.safetensors")
209+
210+
process_count = max(1, jax.process_count())
211+
host_id = jax.process_index()
212+
HEADER_NUM_BYTES = 8
213+
214+
if hasattr(maxtext_config, "base_output_directory") and maxtext_config.base_output_directory.startswith("gs://"):
215+
gcs_cache_dir = f"{maxtext_config.base_output_directory}/hf_cache/{repo_id.replace('/', '_')}"
216+
path = gcs_cache_dir
217+
218+
# Only Host 0 downloads to the shared GCS cache
219+
if host_id == 0:
220+
import tensorflow as tf
221+
import time
222+
import random
223+
if not tf.io.gfile.exists(gcs_cache_dir):
224+
tf.io.gfile.makedirs(gcs_cache_dir)
225+
226+
max_logging.log(f"Dynamic HF Hub Fast DL: Host 0 is downloading to shared GCS Cache: {gcs_cache_dir}")
227+
228+
import itertools
229+
import time as time_module
230+
t_gcs_start = time_module.time()
231+
with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor:
232+
list(executor.map(build_gcs_cache_worker, files, itertools.repeat(gcs_cache_dir), itertools.repeat(maxtext_config.hf_access_token)))
233+
t_gcs_end = time_module.time()
234+
max_logging.log(f"GCS caching complete in {t_gcs_end - t_gcs_start:.2f}s.")
235+
236+
# Global barrier to ensure all hosts wait for Host 0 to finish downloading to the shared GCS bucket
237+
jax.distributed.barrier_wait()
238+
239+
else:
240+
# Fallback to local /tmp caching across all hosts with distributed downloading
241+
local_dir = f"/tmp/hf_checkpoints/{repo_id.replace('/', '_')}"
242+
os.makedirs(local_dir, exist_ok=True)
243+
244+
max_logging.log(f"Dynamic HF Hub Fast DL: Resolving metadata and partial chunks via HTTP Range Requests for Host {host_id}/{process_count}")
245+
import random
246+
import time
247+
248+
def fetch_shard(fpath):
249+
max_retries = 5
250+
for attempt in range(max_retries):
251+
try:
252+
time.sleep(random.uniform(0.0, 5.0))
253+
local_path = os.path.join(local_dir, os.path.basename(fpath))
254+
255+
if os.path.exists(local_path):
256+
return
257+
258+
with fs.open(fpath, "rb") as remote_f:
259+
header_size_bytes = remote_f.read(HEADER_NUM_BYTES)
260+
header_size = int.from_bytes(header_size_bytes, byteorder="little")
261+
header_bytes = remote_f.read(header_size)
262+
header = json.loads(header_bytes)
263+
264+
data_start_offset = HEADER_NUM_BYTES + header_size
265+
266+
tensors = {k: v for k, v in header.items() if k != "__metadata__"}
267+
sorted_tensors = sorted(tensors.items(), key=lambda item: item[1]["data_offsets"][0])
268+
269+
with open(local_path, "wb") as local_f:
270+
local_f.write(header_size_bytes)
271+
local_f.write(header_bytes)
272+
273+
if not sorted_tensors:
274+
return
275+
276+
total_size = sorted_tensors[-1][1]["data_offsets"][1]
277+
current_bundle = 0
278+
cumulative_size = 0
279+
host_start_offset = None
280+
host_end_offset = None
281+
282+
for name, info in sorted_tensors:
283+
start, end = info["data_offsets"]
284+
tensor_size = end - start
285+
if current_bundle < process_count - 1:
286+
ideal = (current_bundle + 1) * (total_size / process_count)
287+
dist_if_cut = abs(cumulative_size - ideal)
288+
dist_if_keep = abs((cumulative_size + tensor_size) - ideal)
289+
if dist_if_cut < dist_if_keep and cumulative_size > 0:
290+
current_bundle += 1
291+
292+
if current_bundle == host_id:
293+
if host_start_offset is None:
294+
host_start_offset = start
295+
host_end_offset = end
296+
297+
cumulative_size += tensor_size
298+
299+
if host_start_offset is not None:
300+
chunk_size = host_end_offset - host_start_offset
301+
remote_f.seek(data_start_offset + host_start_offset)
302+
local_f.seek(data_start_offset + host_start_offset)
303+
304+
buffer_size = 1024 * 1024 * 16
305+
bytes_remaining = chunk_size
306+
while bytes_remaining > 0:
307+
sz = min(buffer_size, bytes_remaining)
308+
buf = remote_f.read(sz)
309+
if not buf:
310+
break
311+
local_f.write(buf)
312+
bytes_remaining -= len(buf)
313+
break
314+
except Exception as e:
315+
if attempt < max_retries - 1:
316+
max_logging.log(f"Error fetching {fpath}: {e}. Retrying in 15 seconds... (Attempt {attempt+1}/{max_retries})")
317+
time.sleep(15)
318+
else:
319+
max_logging.log(f"Failed to fetch {fpath} after {max_retries} attempts.")
320+
raise
321+
322+
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
323+
list(executor.map(fetch_shard, files))
324+
325+
path = local_dir
326+
327+
t_total = time.time()
328+
param_map_mt_to_hf, hook_fn_map_mt = get_hf_config_and_mappings(maxtext_config)
329+
max_logging.log(f"[1/3] Mappings derived in {time.time() - t_total:.2f}s")
330+
331+
target_tree = (
332+
abstract_unboxed_pre_state.to_pure_dict()
333+
if isinstance(abstract_unboxed_pre_state, nnx.State)
334+
else abstract_unboxed_pre_state.params
335+
)
336+
337+
t1 = time.time()
338+
hf_state = load_sharded_hf_state(path)
339+
max_logging.log(f"[2/3] Distributed Sharded GCS load completed in {time.time() - t1:.2f}s")
340+
341+
t2 = time.time()
342+
restored_params = transform_hf_state_to_mt_state(
343+
hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config
344+
)
345+
max_logging.log(f"[3/3] CPU Transformations completed in {time.time() - t2:.2f}s")
346+
max_logging.log(f"Total safetensors_dynamic duration: {time.time() - t_total:.2f}s")
347+
348+
return None, restored_params

0 commit comments

Comments
 (0)