Skip to content

Commit 9812b96

Browse files
committed
Cache already downloaded HuggingFace shards.
Currently, shards seem to be redownloaded every time they are required causing slowdowns in conversion. Tried running the script with the changes and there's significant improvements. Benchmark: 2-Layer Qwen3 MoE Checkpoint Conversion (Lazy Loading Enabled) | Metric | Baseline (Cached) | Optimized (Phase 1 Only) | Speedup | |------------------------------|-------------------|--------------------------|----------| | Sharding (Materialization) | 81.6s (1.36 min) | 16.2s (0.27 min) | **5.0x** | | Overall Elapse | 83.4s (1.39 min) | 17.4s (0.29 min) | **4.8x** | Integration Tests (tests/integration/checkpoint_conversion_test.py): - Baseline: 148.73s (2:28) - Optimized: 77.33s (1:17) -> **1.9x speedup overall** (includes model download)
1 parent a8be563 commit 9812b96

1 file changed

Lines changed: 16 additions & 10 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def __init__(self, model_id, token, revision=None):
116116
self.shard_map = {}
117117
self.current_shard_name = None
118118
self.current_shard_content = {}
119+
# Cache for resolved local shard paths
120+
self._local_shard_paths = {}
119121
# Use a lock to serialize heavy RAM operations, but NOT downloads
120122
self._ram_lock = threading.Lock()
121123
self._initialize_index()
@@ -183,17 +185,21 @@ def get_tensor(self, key: str) -> np.ndarray:
183185
# You might need advanced fuzzy matching here if you encounter errors.
184186
raise ValueError(f"Key {key} not found in HF checkpoint index.")
185187

186-
if self.is_local:
187-
local_path = os.path.join(self.model_id, shard_name)
188+
if shard_name in self._local_shard_paths:
189+
local_path = self._local_shard_paths[shard_name]
188190
else:
189-
# STEP 1: Download outside the lock.
190-
# multiple threads can download different shards at the same time.
191-
local_path = hf_hub_download(
192-
repo_id=self.model_id,
193-
filename=shard_name,
194-
token=self.token,
195-
revision=self.revision,
196-
)
191+
if self.is_local:
192+
local_path = os.path.join(self.model_id, shard_name)
193+
else:
194+
# STEP 1: Download outside the lock.
195+
# multiple threads can download different shards at the same time.
196+
local_path = hf_hub_download(
197+
repo_id=self.model_id,
198+
filename=shard_name,
199+
token=self.token,
200+
revision=self.revision,
201+
)
202+
self._local_shard_paths[shard_name] = local_path
197203

198204
# STEP 2: Lock ONLY the reading into RAM.
199205
# This prevents multiple threads from simultaneously allocating large chunks of RAM.

0 commit comments

Comments
 (0)