-
Notifications
You must be signed in to change notification settings - Fork 134
feat: add MSC cloud storage support for dcp checkpoints #1709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e509944
09b4725
d6aff83
7e0feb5
ed6f353
c983bd6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,13 +25,22 @@ | |
| import torch.distributed.checkpoint as dcp | ||
| import yaml | ||
|
|
||
| try: | ||
| import multistorageclient as msc | ||
|
|
||
| MSC_AVAILABLE = True | ||
| except ImportError: | ||
| msc = None | ||
| MSC_AVAILABLE = False | ||
|
|
||
| # Safe import of HF_HUB_CACHE from huggingface_hub.constants | ||
| try: | ||
| from huggingface_hub.constants import HF_HUB_CACHE | ||
| except ImportError: | ||
| HF_HUB_CACHE = None | ||
|
|
||
| from packaging.version import parse | ||
| from safetensors.torch import load as safetensors_load | ||
| from safetensors.torch import load_file, save_file | ||
| from torch import nn | ||
| from torch.distributed.device_mesh import DeviceMesh | ||
|
|
@@ -63,6 +72,21 @@ | |
| from transformers.tokenization_utils_base import PreTrainedTokenizerBase | ||
|
|
||
|
|
||
| def is_cloud_path(path: str) -> bool: | ||
| """Check if path is a cloud storage path (MSC).""" | ||
| return path.startswith("msc://") | ||
|
|
||
|
|
||
| def _ensure_msc_available() -> None: | ||
| """Raise an error if MSC is not installed but a cloud path is used.""" | ||
| if not MSC_AVAILABLE: | ||
| raise ImportError( | ||
| "multistorageclient is required for cloud storage paths. " | ||
| "Install it with: pip install multi-storage-client " | ||
| "--index-url https://pypi.nvidia.com" | ||
| ) | ||
|
|
||
|
|
||
| def _is_geq_torch_2_9() -> bool: | ||
| """ | ||
| Check if the current torch version is greater than or equal to 2.9.0. | ||
|
|
@@ -404,7 +428,7 @@ def load_model( | |
| key_mapping: Optional key remapping when reading from HF checkpoints. | ||
| """ | ||
| # Validate checkpoint directory | ||
| if not os.path.exists(model_path): | ||
| if not os.path.exists(model_path) and not is_cloud_path(model_path): | ||
| raise FileNotFoundError(f"Model path {model_path} does not exist") | ||
| model_state = ModelState( | ||
| model, | ||
|
|
@@ -807,8 +831,18 @@ def _do_load( | |
| is_model = True if "/model" in path else False | ||
| # PEFT loading is broadcasted from rank0 so it is a special case | ||
| if self.config.is_peft and is_model and (not is_init_step): | ||
| state_dict = load_file(os.path.join(path, "adapter_model.safetensors")) | ||
| if is_cloud_path(path): | ||
| _ensure_msc_available() | ||
| adapter_path = path.rstrip("/") + "/adapter_model.safetensors" | ||
| with msc.open(adapter_path, "rb") as f: | ||
| data = f.read() | ||
| state_dict = safetensors_load(data) | ||
| else: | ||
| state_dict = load_file(os.path.join(path, "adapter_model.safetensors")) | ||
| else: | ||
| if is_cloud_path(path) and storage_reader is None: | ||
| _ensure_msc_available() | ||
| storage_reader = msc.torch.MultiStorageFileSystemReader(path) | ||
| dcp.load(state_dict, checkpoint_id=path, storage_reader=storage_reader) | ||
|
edjson marked this conversation as resolved.
|
||
| return state_dict | ||
|
|
||
|
|
@@ -834,13 +868,25 @@ def _do_save( | |
| # PEFT saving is done on rank0 so it is a special case | ||
| if self.config.is_peft and is_model: | ||
| if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: | ||
| save_file(state_dict, os.path.join(path, "adapter_model.safetensors")) | ||
| if is_cloud_path(path): | ||
| _ensure_msc_available() | ||
| adapter_path = path.rstrip("/") + "/adapter_model.safetensors" | ||
| with msc.open(adapter_path, "wb") as f: | ||
| save_file(state_dict, f) | ||
| else: | ||
| save_file(state_dict, os.path.join(path, "adapter_model.safetensors")) | ||
| if torch.distributed.is_initialized(): | ||
| torch.distributed.barrier() | ||
| return | ||
|
|
||
| ret = None | ||
| planner = dcp.DefaultSavePlanner(enable_plan_caching=True) | ||
|
|
||
| # Routes to MSC storage write for cloud paths | ||
| if is_cloud_path(path) and storage_writer is None: | ||
| _ensure_msc_available() | ||
| storage_writer = msc.torch.MultiStorageFileSystemWriter(path) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same concern here with the Hugging Face safetensors format
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hello, thank you for your feedback. Yes you are correct, the code was overriding the storage writer for any cloud load paths. I have changed the condition to "if is_cloud_path(path) and storage_writer is None:" to resolve the issue. I have also included extra test cases to address the mentioned combinations. |
||
|
|
||
| if self.config.is_async: | ||
| ctx = self._model_ctx if is_model else self._optim_ctx | ||
| ret = dcp.async_save( | ||
|
|
@@ -1107,8 +1153,14 @@ def save_config(config: dict[str, Any], weights_path: str) -> None: | |
| config: Config to save | ||
| weights_path: Path to save config | ||
| """ | ||
| with open(os.path.join(weights_path, "config.yaml"), "w") as f: | ||
| yaml.dump(config, f, sort_keys=False, default_flow_style=False) | ||
| config_path = os.path.join(weights_path, "config.yaml") | ||
| if is_cloud_path(weights_path): | ||
| _ensure_msc_available() | ||
| with msc.open(config_path, "w") as f: | ||
| yaml.dump(config, f, sort_keys=False, default_flow_style=False) | ||
| else: | ||
| with open(config_path, "w") as f: | ||
| yaml.dump(config, f, sort_keys=False, default_flow_style=False) | ||
|
|
||
|
|
||
| def _ensure_dirs(*dirs: Optional[str]) -> None: | ||
|
|
@@ -1120,7 +1172,8 @@ def _ensure_dirs(*dirs: Optional[str]) -> None: | |
| """ | ||
| for d in dirs: | ||
| if d: | ||
| os.makedirs(d, exist_ok=True) | ||
| if not is_cloud_path(d): | ||
|
adil-a marked this conversation as resolved.
|
||
| os.makedirs(d, exist_ok=True) | ||
| if torch.distributed.is_initialized(): | ||
| torch.distributed.barrier() | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
storage_reader gets overridden here in place of
_HuggingFaceStorageReader. Have you tested this to work with combinations of the following:.safetensors) & PyT DCP format (.distcp)My main worry with this change is that it won't be able to load the SFT + .safetensors path
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello, thank you for your feedback. Yes you are correct, the code was overriding the storage reader for any cloud load paths. I have changed the condition to "if is_cloud_path(path) and storage_reader is None:" to resolve the issue. I have also included extra test cases to address the mentioned combinations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @edjson , apologies for the late response as we were busy pushing for the next Automodel release.
If I understand this correctly, you need
storage_readerto be anmsc.torch.MultiStorageFileSystemReaderinstance for remote save/load. At the same time, we needstorage_readerto be a_HuggingFaceStorageReaderinstance for local save/load. My question then is: is there a way for us to do remote save/load for the .safetensors format? It seems like this is not possible since they both occupy the same argument. If indeed this is not possible, we need to raise an error to the user much earlier on before training starts that safetensors + remote storage is not possible currently. You can put this in the post init of the checkpointing config (https://github.com/edjson/Automodel/blob/c983bd616082b0380dcdb378568454bfbff4431a/nemo_automodel/components/checkpoint/checkpointing.py#L201)