|
| 1 | +# Copyright 2023–2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# https://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Dynamic loading of HuggingFace checkpoints during training/eval workloads directly in the target format.""" |
| 16 | + |
| 17 | +import jax |
| 18 | +from flax import traverse_util |
| 19 | +from flax import nnx |
| 20 | +from orbax.checkpoint import v1 as ocp_v1 |
| 21 | +from orbax.checkpoint._src.arrays import sharding as sharding_utils |
| 22 | + |
| 23 | +from maxtext.utils import max_logging |
| 24 | +from maxtext.checkpoint_conversion.utils.tensor_handling import _get_hf_loading_function |
| 25 | +from maxtext.checkpoint_conversion.utils import param_mapping |
| 26 | +from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS |
| 27 | +import time |
| 28 | + |
| 29 | +def build_gcs_cache_worker(fpath, gcs_cache_dir, hf_access_token): |
| 30 | + import os |
| 31 | + import time |
| 32 | + import random |
| 33 | + import tensorflow as tf |
| 34 | + from huggingface_hub import HfFileSystem |
| 35 | + from maxtext.utils import max_logging |
| 36 | + |
| 37 | + fs = HfFileSystem(token=hf_access_token) |
| 38 | + time.sleep(random.uniform(0.0, 5.0)) |
| 39 | + gcs_path = os.path.join(gcs_cache_dir, os.path.basename(fpath)) |
| 40 | + |
| 41 | + if tf.io.gfile.exists(gcs_path): |
| 42 | + return |
| 43 | + |
| 44 | + max_retries = 5 |
| 45 | + for attempt in range(max_retries): |
| 46 | + try: |
| 47 | + with fs.open(fpath, "rb") as remote_f: |
| 48 | + with tf.io.gfile.GFile(gcs_path, "wb") as gcs_f: |
| 49 | + buffer_size = 1024 * 1024 * 16 |
| 50 | + while True: |
| 51 | + buf = remote_f.read(buffer_size) |
| 52 | + if not buf: |
| 53 | + break |
| 54 | + gcs_f.write(buf) |
| 55 | + break |
| 56 | + except Exception as e: |
| 57 | + if attempt < max_retries - 1: |
| 58 | + max_logging.log(f"Error fetching {fpath} to GCS: {e}. Retrying in 15 seconds... (Attempt {attempt+1}/{max_retries})") |
| 59 | + time.sleep(15) |
| 60 | + else: |
| 61 | + max_logging.log(f"Failed to fetch {fpath} to GCS after {max_retries} attempts.") |
| 62 | + raise |
| 63 | + |
| 64 | + |
| 65 | +def get_hf_config_and_mappings(maxtext_config): |
| 66 | + """Gets HF config and parameter mapping based on the MaxText config.""" |
| 67 | + model_key = maxtext_config.model_name |
| 68 | + if "-Instruct" in model_key: |
| 69 | + model_key = model_key.replace("-Instruct", "") |
| 70 | + hf_config_obj = HF_MODEL_CONFIGS[model_key] |
| 71 | + hf_config_dict = hf_config_obj.to_dict() |
| 72 | + |
| 73 | + param_map_mt_to_hf = param_mapping.PARAM_MAPPING[model_key]( |
| 74 | + hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers |
| 75 | + ) |
| 76 | + hook_fn_map_mt = param_mapping.HOOK_FNS[model_key]( |
| 77 | + hf_config_dict, maxtext_config, scan_layers=maxtext_config.scan_layers, saving_to_hf=False |
| 78 | + ) |
| 79 | + return param_map_mt_to_hf, hook_fn_map_mt |
| 80 | + |
| 81 | + |
| 82 | +def load_sharded_hf_state(path): |
| 83 | + """Loads HF state with maximal sharding across TPU mesh to avoid host OOM.""" |
| 84 | + t0 = time.time() |
| 85 | + context = ocp_v1.Context(checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS) |
| 86 | + with context: |
| 87 | + metadata = ocp_v1.pytree_metadata(path) |
| 88 | + simple_abstract_state = metadata.metadata |
| 89 | + |
| 90 | + # Distributed Sharded Download: Tell JAX to shard the HF Safetensors download |
| 91 | + # across the entire TPU mesh to avoid Host OOM. |
| 92 | + current_global_devices = jax.devices() |
| 93 | + shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state, devices=current_global_devices) |
| 94 | + |
| 95 | + def combine_sharding(sds, single_sharding): |
| 96 | + return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=single_sharding) |
| 97 | + |
| 98 | + sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings) |
| 99 | + |
| 100 | + max_logging.log("Reading raw Safetensors into memory (Distributed Sharded GCS Download)...") |
| 101 | + hf_state = ocp_v1.load_pytree(path, sharded_abstract_state) |
| 102 | + max_logging.log(f"load_sharded_hf_state took {time.time() - t0:.2f}s") |
| 103 | + return hf_state |
| 104 | + |
| 105 | + |
| 106 | +def transform_hf_state_to_mt_state( |
| 107 | + hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config |
| 108 | +): |
| 109 | + """Transforms HF state into MaxText state by applying param mappings and mathematical hooks.""" |
| 110 | + t0 = time.time() |
| 111 | + def tensor_getter(key): |
| 112 | + return hf_state.pop(key) |
| 113 | + |
| 114 | + flat_target = traverse_util.flatten_dict(target_tree, sep=".") |
| 115 | + flat_restored = flat_target.copy() |
| 116 | + |
| 117 | + mapped_count = 0 |
| 118 | + keys_missed = [] |
| 119 | + max_logging.log("Starting fast in-memory Distributed Transformations...") |
| 120 | + |
| 121 | + for mt_key, hf_source in param_map_mt_to_hf.items(): |
| 122 | + mt_name = mt_key.replace("params-", "").replace("-", ".") |
| 123 | + |
| 124 | + # Determine the correct key in flat_target |
| 125 | + check_name = mt_name |
| 126 | + if check_name not in flat_target: |
| 127 | + if ("params." + mt_name) in flat_target: |
| 128 | + check_name = "params." + mt_name |
| 129 | + elif mt_key.replace("-", ".") in flat_target: |
| 130 | + check_name = mt_key.replace("-", ".") |
| 131 | + |
| 132 | + if check_name not in flat_target: |
| 133 | + keys_missed.append(mt_name) |
| 134 | + continue |
| 135 | + |
| 136 | + target_shape = flat_target[check_name].shape |
| 137 | + hook_fn = hook_fn_map_mt.get(mt_key) |
| 138 | + |
| 139 | + load_fn = _get_hf_loading_function( |
| 140 | + hf_source, |
| 141 | + tensor_getter, |
| 142 | + hook_fn, |
| 143 | + target_shape, |
| 144 | + maxtext_config, |
| 145 | + ) |
| 146 | + |
| 147 | + # Execute transformation and assign to flat_restored |
| 148 | + t_layer = time.time() |
| 149 | + unsharded_array = load_fn() |
| 150 | + |
| 151 | + # Ensure it's Sharded explicitly matching the JAX model expectations |
| 152 | + target_sharding = flat_target[check_name].sharding |
| 153 | + flat_restored[check_name] = jax.device_put(unsharded_array, device=target_sharding, donate=True) |
| 154 | + |
| 155 | + max_logging.log(f"Transformed {check_name} from {hf_source} in {time.time() - t_layer:.4f}s") |
| 156 | + mapped_count += 1 |
| 157 | + |
| 158 | + if mapped_count == 0: |
| 159 | + max_logging.log(f"All transformations missed! Sample missed mt_names: {keys_missed[:5]}") |
| 160 | + max_logging.log(f"Sample flat_target keys: {list(flat_target.keys())[:5]}") |
| 161 | + |
| 162 | + max_logging.log(f"Successfully mapped {mapped_count} parameters.") |
| 163 | + restored_params = traverse_util.unflatten_dict(flat_restored, sep=".") |
| 164 | + |
| 165 | + if "params" in restored_params: |
| 166 | + restored_params = restored_params["params"] |
| 167 | + |
| 168 | + max_logging.log(f"transform_hf_state_to_mt_state took {time.time() - t0:.2f}s") |
| 169 | + |
| 170 | + return {"params": restored_params} |
| 171 | + |
| 172 | + |
| 173 | +def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, maxtext_config): |
| 174 | + """Main entry point to dynamically build and load safetensors into MaxText format. |
| 175 | + |
| 176 | + Splits execution into: |
| 177 | + 1. Deriving Mappings |
| 178 | + 2. Loading Sharded arrays directly to TPUs |
| 179 | + 3. Processing the transformations natively on TPUs |
| 180 | + """ |
| 181 | + if maxtext_config is None: |
| 182 | + raise ValueError("maxtext_config must be provided for safetensors_dynamic loading.") |
| 183 | + |
| 184 | + import os |
| 185 | + from maxtext.utils.globals import HF_IDS |
| 186 | + |
| 187 | + model_name = maxtext_config.model_name |
| 188 | + if "-Instruct" in model_name: |
| 189 | + model_name = model_name.replace("-Instruct", "") |
| 190 | + |
| 191 | + if not path: |
| 192 | + if model_name not in HF_IDS: |
| 193 | + raise ValueError(f"Unsupported model name for automatic HF repo resolution: {model_name}.") |
| 194 | + path = HF_IDS[model_name] |
| 195 | + |
| 196 | + if path.startswith("hf://"): |
| 197 | + path = path[5:] |
| 198 | + |
| 199 | + if not path.startswith("gs://") and not os.path.isdir(path): |
| 200 | + from huggingface_hub import HfFileSystem |
| 201 | + import concurrent.futures |
| 202 | + import json |
| 203 | + import jax |
| 204 | + |
| 205 | + fs = HfFileSystem(token=maxtext_config.hf_access_token) |
| 206 | + repo_id = path |
| 207 | + |
| 208 | + files = fs.glob(f"{repo_id}/*.safetensors") |
| 209 | + |
| 210 | + process_count = max(1, jax.process_count()) |
| 211 | + host_id = jax.process_index() |
| 212 | + HEADER_NUM_BYTES = 8 |
| 213 | + |
| 214 | + if hasattr(maxtext_config, "base_output_directory") and maxtext_config.base_output_directory.startswith("gs://"): |
| 215 | + gcs_cache_dir = f"{maxtext_config.base_output_directory}/hf_cache/{repo_id.replace('/', '_')}" |
| 216 | + path = gcs_cache_dir |
| 217 | + |
| 218 | + # Only Host 0 downloads to the shared GCS cache |
| 219 | + if host_id == 0: |
| 220 | + import tensorflow as tf |
| 221 | + import time |
| 222 | + import random |
| 223 | + if not tf.io.gfile.exists(gcs_cache_dir): |
| 224 | + tf.io.gfile.makedirs(gcs_cache_dir) |
| 225 | + |
| 226 | + max_logging.log(f"Dynamic HF Hub Fast DL: Host 0 is downloading to shared GCS Cache: {gcs_cache_dir}") |
| 227 | + |
| 228 | + import itertools |
| 229 | + import time as time_module |
| 230 | + t_gcs_start = time_module.time() |
| 231 | + with concurrent.futures.ProcessPoolExecutor(max_workers=16) as executor: |
| 232 | + list(executor.map(build_gcs_cache_worker, files, itertools.repeat(gcs_cache_dir), itertools.repeat(maxtext_config.hf_access_token))) |
| 233 | + t_gcs_end = time_module.time() |
| 234 | + max_logging.log(f"GCS caching complete in {t_gcs_end - t_gcs_start:.2f}s.") |
| 235 | + |
| 236 | + # Global barrier to ensure all hosts wait for Host 0 to finish downloading to the shared GCS bucket |
| 237 | + import jax.experimental.multihost_utils as multihost_utils |
| 238 | + multihost_utils.sync_global_devices("GCS cache stream barrier") |
| 239 | + |
| 240 | + else: |
| 241 | + # Fallback to local /tmp caching across all hosts with distributed downloading |
| 242 | + local_dir = f"/tmp/hf_checkpoints/{repo_id.replace('/', '_')}" |
| 243 | + os.makedirs(local_dir, exist_ok=True) |
| 244 | + |
| 245 | + max_logging.log(f"Dynamic HF Hub Fast DL: Resolving metadata and partial chunks via HTTP Range Requests for Host {host_id}/{process_count}") |
| 246 | + import random |
| 247 | + import time |
| 248 | + |
| 249 | + def fetch_shard(fpath): |
| 250 | + max_retries = 5 |
| 251 | + for attempt in range(max_retries): |
| 252 | + try: |
| 253 | + time.sleep(random.uniform(0.0, 5.0)) |
| 254 | + local_path = os.path.join(local_dir, os.path.basename(fpath)) |
| 255 | + |
| 256 | + if os.path.exists(local_path): |
| 257 | + return |
| 258 | + |
| 259 | + with fs.open(fpath, "rb") as remote_f: |
| 260 | + header_size_bytes = remote_f.read(HEADER_NUM_BYTES) |
| 261 | + header_size = int.from_bytes(header_size_bytes, byteorder="little") |
| 262 | + header_bytes = remote_f.read(header_size) |
| 263 | + header = json.loads(header_bytes) |
| 264 | + |
| 265 | + data_start_offset = HEADER_NUM_BYTES + header_size |
| 266 | + |
| 267 | + tensors = {k: v for k, v in header.items() if k != "__metadata__"} |
| 268 | + sorted_tensors = sorted(tensors.items(), key=lambda item: item[1]["data_offsets"][0]) |
| 269 | + |
| 270 | + with open(local_path, "wb") as local_f: |
| 271 | + local_f.write(header_size_bytes) |
| 272 | + local_f.write(header_bytes) |
| 273 | + |
| 274 | + if not sorted_tensors: |
| 275 | + return |
| 276 | + |
| 277 | + total_size = sorted_tensors[-1][1]["data_offsets"][1] |
| 278 | + current_bundle = 0 |
| 279 | + cumulative_size = 0 |
| 280 | + host_start_offset = None |
| 281 | + host_end_offset = None |
| 282 | + |
| 283 | + for name, info in sorted_tensors: |
| 284 | + start, end = info["data_offsets"] |
| 285 | + tensor_size = end - start |
| 286 | + if current_bundle < process_count - 1: |
| 287 | + ideal = (current_bundle + 1) * (total_size / process_count) |
| 288 | + dist_if_cut = abs(cumulative_size - ideal) |
| 289 | + dist_if_keep = abs((cumulative_size + tensor_size) - ideal) |
| 290 | + if dist_if_cut < dist_if_keep and cumulative_size > 0: |
| 291 | + current_bundle += 1 |
| 292 | + |
| 293 | + if current_bundle == host_id: |
| 294 | + if host_start_offset is None: |
| 295 | + host_start_offset = start |
| 296 | + host_end_offset = end |
| 297 | + |
| 298 | + cumulative_size += tensor_size |
| 299 | + |
| 300 | + if host_start_offset is not None: |
| 301 | + chunk_size = host_end_offset - host_start_offset |
| 302 | + remote_f.seek(data_start_offset + host_start_offset) |
| 303 | + local_f.seek(data_start_offset + host_start_offset) |
| 304 | + |
| 305 | + buffer_size = 1024 * 1024 * 16 |
| 306 | + bytes_remaining = chunk_size |
| 307 | + while bytes_remaining > 0: |
| 308 | + sz = min(buffer_size, bytes_remaining) |
| 309 | + buf = remote_f.read(sz) |
| 310 | + if not buf: |
| 311 | + break |
| 312 | + local_f.write(buf) |
| 313 | + bytes_remaining -= len(buf) |
| 314 | + break |
| 315 | + except Exception as e: |
| 316 | + if attempt < max_retries - 1: |
| 317 | + max_logging.log(f"Error fetching {fpath}: {e}. Retrying in 15 seconds... (Attempt {attempt+1}/{max_retries})") |
| 318 | + time.sleep(15) |
| 319 | + else: |
| 320 | + max_logging.log(f"Failed to fetch {fpath} after {max_retries} attempts.") |
| 321 | + raise |
| 322 | + |
| 323 | + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: |
| 324 | + list(executor.map(fetch_shard, files)) |
| 325 | + |
| 326 | + path = local_dir |
| 327 | + |
| 328 | + t_total = time.time() |
| 329 | + param_map_mt_to_hf, hook_fn_map_mt = get_hf_config_and_mappings(maxtext_config) |
| 330 | + max_logging.log(f"[1/3] Mappings derived in {time.time() - t_total:.2f}s") |
| 331 | + |
| 332 | + target_tree = ( |
| 333 | + abstract_unboxed_pre_state.to_pure_dict() |
| 334 | + if isinstance(abstract_unboxed_pre_state, nnx.State) |
| 335 | + else abstract_unboxed_pre_state.params |
| 336 | + ) |
| 337 | + |
| 338 | + t1 = time.time() |
| 339 | + hf_state = load_sharded_hf_state(path) |
| 340 | + max_logging.log(f"[2/3] Distributed Sharded GCS load completed in {time.time() - t1:.2f}s") |
| 341 | + |
| 342 | + t2 = time.time() |
| 343 | + restored_params = transform_hf_state_to_mt_state( |
| 344 | + hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, maxtext_config |
| 345 | + ) |
| 346 | + max_logging.log(f"[3/3] CPU Transformations completed in {time.time() - t2:.2f}s") |
| 347 | + max_logging.log(f"Total safetensors_dynamic duration: {time.time() - t_total:.2f}s") |
| 348 | + |
| 349 | + return None, restored_params |
0 commit comments