Skip to content

Commit 385ac64

Browse files
authored
Merge branch 'main' into cye/te-dcp-test
2 parents 0d68b94 + 0810e63 commit 385ac64

7 files changed

Lines changed: 71 additions & 24 deletions

File tree

megatron/core/dist_checkpointing/serialization.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,14 @@ def get_default_save_common_strategy(
453453
return get_default_strategy(StrategyAction.SAVE_COMMON, backend, version)
454454

455455

456-
def get_default_load_sharded_strategy(checkpoint_dir: str) -> LoadShardedStrategy:
457-
"""Get default load sharded strategy."""
458-
return verify_checkpoint_and_load_strategy(checkpoint_dir)[0]
456+
def get_default_load_sharded_strategy(
457+
checkpoint_dir: str, cache_metadata: bool = False
458+
) -> LoadShardedStrategy:
459+
"""Get default load sharded strategy.
460+
461+
Args:
462+
checkpoint_dir: Path to the checkpoint directory.
463+
cache_metadata: If True and checkpoint format is torch_dist, use a strategy that caches
464+
metadata (e.g. when ckpt_assume_constant_structure is enabled).
465+
"""
466+
return verify_checkpoint_and_load_strategy(checkpoint_dir, cache_metadata=cache_metadata)[0]

megatron/core/dist_checkpointing/strategies/cached_metadata_filesystem_reader.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
""" FS Reader with metadata cached support. """
44

55
import os
6-
from typing import Union
6+
from typing import Dict, Union
77

88
from torch.distributed.checkpoint import FileSystemReader, Metadata
99

@@ -12,27 +12,42 @@ class CachedMetadataFileSystemReader(FileSystemReader):
1212
"""
1313
Extends FileSystemReader to cache metadata for improved performance.
1414
15+
Metadata is shared across all reader instances that use the same checkpoint
16+
directory (same path), since the loaded metadata is identical.
17+
1518
Attributes:
16-
_cached_metadata (Metadata or None): Cached metadata from the file system.
19+
_metadata_cache (Dict[str, Metadata]): Class-level cache keyed by checkpoint path.
1720
"""
1821

19-
def __init__(self, path: Union[str, os.PathLike]) -> None:
22+
_metadata_cache: Dict[str, Metadata] = {}
23+
24+
def __init__(self, path: Union[str, os.PathLike], cache_metadata: bool = True) -> None:
2025
"""
2126
Initialize with file system path.
2227
2328
Args:
2429
path (Union[str, os.PathLike]): Path to the checkpoint directory or file.
2530
"""
2631
super().__init__(path=path)
27-
self._cached_metadata = None
32+
self._cache_key = os.path.abspath(os.fspath(path)) if cache_metadata else None
2833

2934
def read_metadata(self) -> Metadata:
3035
"""
3136
Read metadata from file system, caching for subsequent calls.
37+
Shared across instances when the checkpoint directory is the same.
3238
3339
Returns:
3440
Metadata: Checkpoint metadata.
3541
"""
36-
if self._cached_metadata is None:
37-
self._cached_metadata = super().read_metadata()
38-
return self._cached_metadata
42+
if self._cache_key not in CachedMetadataFileSystemReader._metadata_cache:
43+
CachedMetadataFileSystemReader._metadata_cache[self._cache_key] = (
44+
super().read_metadata()
45+
)
46+
return CachedMetadataFileSystemReader._metadata_cache[self._cache_key]
47+
48+
@classmethod
49+
def clear_metadata_cache(cls):
50+
"""
51+
Clear the metadata cache.
52+
"""
53+
cls._metadata_cache.clear()

megatron/core/dist_checkpointing/strategies/torch.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -763,16 +763,17 @@ def _get_filesystem_reader(
763763
return msc.torch.MultiStorageFileSystemReader(checkpoint_dir, thread_count=2)
764764

765765
if cache_metadata:
766-
return CachedMetadataFileSystemReader(checkpoint_dir)
766+
return CachedMetadataFileSystemReader(checkpoint_dir, cache_metadata=cache_metadata)
767767

768768
return FileSystemReader(checkpoint_dir)
769769

770770

771771
class TorchDistLoadShardedStrategy(LoadShardedStrategy):
772772
"""Basic load strategy for the PyT Distributed format."""
773773

774-
def __init__(self):
774+
def __init__(self, cache_metadata: bool = False):
775775
self.cached_global_metadata: Optional[Metadata] = None
776+
self.cache_metadata = cache_metadata
776777
super().__init__()
777778

778779
def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> StateDict:
@@ -803,7 +804,7 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
803804
)
804805
pyt_state_dict = mcore_to_pyt_state_dict(sharded_state_dict, True)
805806
# Load PyT Distributed format
806-
fsr = _get_filesystem_reader(checkpoint_dir, cache_metadata=True)
807+
fsr = _get_filesystem_reader(checkpoint_dir, cache_metadata=self.cache_metadata)
807808
checkpoint.load_state_dict(
808809
pyt_state_dict,
809810
fsr,
@@ -815,9 +816,10 @@ def load(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Path) -> St
815816
),
816817
)
817818

818-
self.cached_global_metadata = (
819-
fsr.read_metadata()
820-
) # no storage interaction thanks to caching
819+
if self.cache_metadata:
820+
self.cached_global_metadata = (
821+
fsr.read_metadata()
822+
) # no storage interaction thanks to caching
821823

822824
pyt_state_dict = cast(
823825
Dict[str, Union[TorchShardedTensor, List[io.BytesIO]]], pyt_state_dict

megatron/core/dist_checkpointing/validation.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def verify_checkpoint_and_load_strategy(
203203
checkpoint_dir: str,
204204
sharded_strategy: Union[LoadShardedStrategy, Tuple[str, int], None] = None,
205205
common_strategy: Union[LoadCommonStrategy, Tuple[str, int], None] = None,
206+
cache_metadata: bool = False,
206207
) -> Tuple[LoadShardedStrategy, LoadCommonStrategy]:
207208
"""Verifies if checkpoint metadata exists and matches given strategies.
208209
@@ -216,6 +217,8 @@ def verify_checkpoint_and_load_strategy(
216217
common_strategy (LoadCommonStrategy, Tuple[str, int], optional): common load strategy to be verified
217218
if compatible with the checkpoint content. If None, the default common load strategy
218219
for the checkpoint backend will be returned.
220+
cache_metadata (bool): if True and checkpoint backend is torch_dist, use a load strategy that caches
221+
metadata (e.g. when ckpt_assume_constant_structure is enabled). Ignored if sharded_strategy is set.
219222
"""
220223
isdir = True
221224
if MultiStorageClientFeature.is_enabled():
@@ -231,11 +234,18 @@ def verify_checkpoint_and_load_strategy(
231234
raise CheckpointingException(f"{checkpoint_dir} is not a distributed checkpoint")
232235

233236
if sharded_strategy is None:
234-
sharded_strategy = get_default_strategy(
235-
StrategyAction.LOAD_SHARDED,
236-
saved_config.sharded_backend,
237-
saved_config.sharded_backend_version,
238-
)
237+
if cache_metadata and saved_config.sharded_backend == 'torch_dist':
238+
from megatron.core.dist_checkpointing.strategies.torch import (
239+
TorchDistLoadShardedStrategy,
240+
)
241+
242+
sharded_strategy = TorchDistLoadShardedStrategy(cache_metadata=True)
243+
else:
244+
sharded_strategy = get_default_strategy(
245+
StrategyAction.LOAD_SHARDED,
246+
saved_config.sharded_backend,
247+
saved_config.sharded_backend_version,
248+
)
239249
elif isinstance(sharded_strategy, tuple):
240250
sharded_strategy = get_default_strategy(StrategyAction.LOAD_SHARDED, *sharded_strategy)
241251

megatron/training/async_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import logging
88

99
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest
10+
from megatron.core.dist_checkpointing.strategies.cached_metadata_filesystem_reader import (
11+
CachedMetadataFileSystemReader,
12+
)
1013
from megatron.core.dist_checkpointing.strategies.filesystem_async import _results_queue
1114
from megatron.training import get_args
1215
from megatron.training.utils import print_rank_0
@@ -76,3 +79,4 @@ def reset_persistent_async_worker():
7679
del _results_queue
7780
_results_queue = None
7881
_async_calls_queue = None
82+
CachedMetadataFileSystemReader.clear_metadata_cache()

megatron/training/checkpointing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1130,7 +1130,9 @@ def _load_global_dist_base_checkpoint(
11301130
)
11311131

11321132
checkpoint_name = get_checkpoint_name(load_dir, iteration, release, return_base_dir=True)
1133-
load_strategy = get_default_load_sharded_strategy(checkpoint_name)
1133+
load_strategy = get_default_load_sharded_strategy(
1134+
checkpoint_name, cache_metadata=args.ckpt_assume_constant_structure
1135+
)
11341136
# NOTE: `args.ckpt_fully_parallel_load` applies to both persistent and non-persistent checkpoints.
11351137
if args.ckpt_fully_parallel_load:
11361138
if args.ckpt_fully_parallel_load_process_group == 'dp':

tests/unit_tests/dist_checkpointing/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
import torch
88

9+
from megatron.core.dist_checkpointing.strategies.cached_metadata_filesystem_reader import (
10+
CachedMetadataFileSystemReader,
11+
)
912
from megatron.core.models.gpt import GPTModel
1013
from megatron.core.models.gpt.gpt_layer_specs import (
1114
get_gpt_layer_local_spec,
@@ -167,6 +170,9 @@ def init_checkpointing_mock_args(args, ckpt_dir, fully_parallel=False):
167170
args.dist_ckpt_optim_fully_reshardable = False
168171
args.distrib_optim_fully_reshardable_mem_efficient = False
169172
args.phase_transition_iterations = None
173+
# Clear the metadata cache to avoid contamination between tests
174+
175+
CachedMetadataFileSystemReader.clear_metadata_cache()
170176

171177

172178
def setup_model_and_optimizer(
@@ -224,7 +230,7 @@ def setup_model_and_optimizer(
224230
opt.init_state_fn(opt)
225231

226232
optimizer.reload_model_params()
227-
233+
CachedMetadataFileSystemReader.clear_metadata_cache()
228234
return unwrap_model(model), optimizer
229235

230236

@@ -322,5 +328,5 @@ def setup_moe_model_and_optimizer(
322328
opt.init_state_fn(opt)
323329

324330
optimizer.reload_model_params()
325-
331+
CachedMetadataFileSystemReader.clear_metadata_cache()
326332
return unwrap_model(model), optimizer

0 commit comments

Comments
 (0)