Skip to content

Commit d11185b

Browse files
Merge pull request #3972 from niting:conversion_perf
PiperOrigin-RevId: 921633048
2 parents be19157 + 9812b96 commit d11185b

1 file changed

Lines changed: 16 additions & 10 deletions

File tree

src/maxtext/checkpoint_conversion/to_maxtext.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def __init__(self, model_id, token, revision=None):
116116
self.shard_map = {}
117117
self.current_shard_name = None
118118
self.current_shard_content = {}
119+
# Cache for resolved local shard paths
120+
self._local_shard_paths = {}
119121
# Use a lock to serialize heavy RAM operations, but NOT downloads
120122
self._ram_lock = threading.Lock()
121123
self._initialize_index()
@@ -183,17 +185,21 @@ def get_tensor(self, key: str) -> np.ndarray:
183185
# You might need advanced fuzzy matching here if you encounter errors.
184186
raise ValueError(f"Key {key} not found in HF checkpoint index.")
185187

186-
if self.is_local:
187-
local_path = os.path.join(self.model_id, shard_name)
188+
if shard_name in self._local_shard_paths:
189+
local_path = self._local_shard_paths[shard_name]
188190
else:
189-
# STEP 1: Download outside the lock.
190-
# multiple threads can download different shards at the same time.
191-
local_path = hf_hub_download(
192-
repo_id=self.model_id,
193-
filename=shard_name,
194-
token=self.token,
195-
revision=self.revision,
196-
)
191+
if self.is_local:
192+
local_path = os.path.join(self.model_id, shard_name)
193+
else:
194+
# STEP 1: Download outside the lock.
195+
# multiple threads can download different shards at the same time.
196+
local_path = hf_hub_download(
197+
repo_id=self.model_id,
198+
filename=shard_name,
199+
token=self.token,
200+
revision=self.revision,
201+
)
202+
self._local_shard_paths[shard_name] = local_path
197203

198204
# STEP 2: Lock ONLY the reading into RAM.
199205
# This prevents multiple threads from simultaneously allocating large chunks of RAM.

0 commit comments

Comments
 (0)