@@ -116,6 +116,8 @@ def __init__(self, model_id, token, revision=None):
116116 self .shard_map = {}
117117 self .current_shard_name = None
118118 self .current_shard_content = {}
119+ # Cache for resolved local shard paths
120+ self ._local_shard_paths = {}
119121 # Use a lock to serialize heavy RAM operations, but NOT downloads
120122 self ._ram_lock = threading .Lock ()
121123 self ._initialize_index ()
@@ -183,17 +185,21 @@ def get_tensor(self, key: str) -> np.ndarray:
183185 # You might need advanced fuzzy matching here if you encounter errors.
184186 raise ValueError (f"Key { key } not found in HF checkpoint index." )
185187
186- if self .is_local :
187- local_path = os . path . join ( self .model_id , shard_name )
188+ if shard_name in self ._local_shard_paths :
189+ local_path = self ._local_shard_paths [ shard_name ]
188190 else :
189- # STEP 1: Download outside the lock.
190- # multiple threads can download different shards at the same time.
191- local_path = hf_hub_download (
192- repo_id = self .model_id ,
193- filename = shard_name ,
194- token = self .token ,
195- revision = self .revision ,
196- )
191+ if self .is_local :
192+ local_path = os .path .join (self .model_id , shard_name )
193+ else :
194+ # STEP 1: Download outside the lock.
195+ # multiple threads can download different shards at the same time.
196+ local_path = hf_hub_download (
197+ repo_id = self .model_id ,
198+ filename = shard_name ,
199+ token = self .token ,
200+ revision = self .revision ,
201+ )
202+ self ._local_shard_paths [shard_name ] = local_path
197203
198204 # STEP 2: Lock ONLY the reading into RAM.
199205 # This prevents multiple threads from simultaneously allocating large chunks of RAM.
0 commit comments