@@ -203,6 +203,7 @@ def verify_checkpoint_and_load_strategy(
203203 checkpoint_dir : str ,
204204 sharded_strategy : Union [LoadShardedStrategy , Tuple [str , int ], None ] = None ,
205205 common_strategy : Union [LoadCommonStrategy , Tuple [str , int ], None ] = None ,
206+ cache_metadata : bool = False ,
206207) -> Tuple [LoadShardedStrategy , LoadCommonStrategy ]:
207208 """Verifies if checkpoint metadata exists and matches given strategies.
208209
@@ -216,6 +217,8 @@ def verify_checkpoint_and_load_strategy(
216217 common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified
217218 if compatible with the checkpoint content. If None, the default common load strategy
218219 for the checkpoint backend will be returned.
220+ cache_metadata (bool): if True and checkpoint backend is torch_dist, use a load strategy that caches
221+ metadata (e.g. when ckpt_assume_constant_structure is enabled). Ignored if sharded_strategy is set.
219222 """
220223 isdir = True
221224 if MultiStorageClientFeature .is_enabled ():
@@ -231,11 +234,18 @@ def verify_checkpoint_and_load_strategy(
231234 raise CheckpointingException (f"{ checkpoint_dir } is not a distributed checkpoint" )
232235
233236 if sharded_strategy is None :
234- sharded_strategy = get_default_strategy (
235- StrategyAction .LOAD_SHARDED ,
236- saved_config .sharded_backend ,
237- saved_config .sharded_backend_version ,
238- )
237+ if cache_metadata and saved_config .sharded_backend == 'torch_dist' :
238+ from megatron .core .dist_checkpointing .strategies .torch import (
239+ TorchDistLoadShardedStrategy ,
240+ )
241+
242+ sharded_strategy = TorchDistLoadShardedStrategy (cache_metadata = True )
243+ else :
244+ sharded_strategy = get_default_strategy (
245+ StrategyAction .LOAD_SHARDED ,
246+ saved_config .sharded_backend ,
247+ saved_config .sharded_backend_version ,
248+ )
239249 elif isinstance (sharded_strategy , tuple ):
240250 sharded_strategy = get_default_strategy (StrategyAction .LOAD_SHARDED , * sharded_strategy )
241251
0 commit comments