Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions nemo_automodel/components/checkpoint/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,22 @@
import torch.distributed.checkpoint as dcp
import yaml

try:
import multistorageclient as msc

MSC_AVAILABLE = True
except ImportError:
msc = None
MSC_AVAILABLE = False

# Safe import of HF_HUB_CACHE from huggingface_hub.constants
try:
from huggingface_hub.constants import HF_HUB_CACHE
except ImportError:
HF_HUB_CACHE = None

from packaging.version import parse
from safetensors.torch import load as safetensors_load
from safetensors.torch import load_file, save_file
from torch import nn
from torch.distributed.device_mesh import DeviceMesh
Expand Down Expand Up @@ -63,6 +72,21 @@
from transformers.tokenization_utils_base import PreTrainedTokenizerBase


def is_cloud_path(path: str) -> bool:
"""Check if path is a cloud storage path (MSC)."""
return path.startswith("msc://")


def _ensure_msc_available() -> None:
"""Raise an error if MSC is not installed but a cloud path is used."""
if not MSC_AVAILABLE:
raise ImportError(
"multistorageclient is required for cloud storage paths. "
"Install it with: pip install multi-storage-client "
"--index-url https://pypi.nvidia.com"
)


def _is_geq_torch_2_9() -> bool:
"""
Check if the current torch version is greater than or equal to 2.9.0.
Expand Down Expand Up @@ -404,7 +428,7 @@ def load_model(
key_mapping: Optional key remapping when reading from HF checkpoints.
"""
# Validate checkpoint directory
if not os.path.exists(model_path):
if not os.path.exists(model_path) and not is_cloud_path(model_path):
raise FileNotFoundError(f"Model path {model_path} does not exist")
model_state = ModelState(
model,
Expand Down Expand Up @@ -807,8 +831,18 @@ def _do_load(
is_model = True if "/model" in path else False
# PEFT loading is broadcasted from rank0 so it is a special case
if self.config.is_peft and is_model and (not is_init_step):
state_dict = load_file(os.path.join(path, "adapter_model.safetensors"))
if is_cloud_path(path):
_ensure_msc_available()
adapter_path = path.rstrip("/") + "/adapter_model.safetensors"
with msc.open(adapter_path, "rb") as f:
data = f.read()
state_dict = safetensors_load(data)
else:
state_dict = load_file(os.path.join(path, "adapter_model.safetensors"))
else:
if is_cloud_path(path) and storage_reader is None:
_ensure_msc_available()
storage_reader = msc.torch.MultiStorageFileSystemReader(path)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

storage_reader gets overridden here in place of _HuggingFaceStorageReader. Have you tested this to work with combinations of the following:

  • PEFT & full-SFT
  • Hugging Face safetensors format (.safetensors) & PyT DCP format (.distcp)
  • Sync & Async

My main worry with this change is that it won't be able to load the SFT + .safetensors path

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, thank you for your feedback. Yes you are correct, the code was overriding the storage reader for any cloud load paths. I have changed the condition to "if is_cloud_path(path) and storage_reader is None:" to resolve the issue. I have also included extra test cases to address the mentioned combinations.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @edjson , apologies for the late response as we were busy pushing for the next Automodel release.

If I understand this correctly, you need storage_reader to be an msc.torch.MultiStorageFileSystemReader instance for remote save/load. At the same time, we need storage_reader to be a _HuggingFaceStorageReader instance for local save/load. My question then is: is there a way for us to do remote save/load for the .safetensors format? It seems like this is not possible since they both occupy the same argument. If indeed this is not possible, we need to raise an error to the user much earlier on before training starts that safetensors + remote storage is not possible currently. You can put this in the post init of the checkpointing config (https://github.com/edjson/Automodel/blob/c983bd616082b0380dcdb378568454bfbff4431a/nemo_automodel/components/checkpoint/checkpointing.py#L201)

dcp.load(state_dict, checkpoint_id=path, storage_reader=storage_reader)
Comment thread
edjson marked this conversation as resolved.
return state_dict

Expand All @@ -834,13 +868,25 @@ def _do_save(
# PEFT saving is done on rank0 so it is a special case
if self.config.is_peft and is_model:
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
save_file(state_dict, os.path.join(path, "adapter_model.safetensors"))
if is_cloud_path(path):
_ensure_msc_available()
adapter_path = path.rstrip("/") + "/adapter_model.safetensors"
with msc.open(adapter_path, "wb") as f:
save_file(state_dict, f)
else:
save_file(state_dict, os.path.join(path, "adapter_model.safetensors"))
if torch.distributed.is_initialized():
torch.distributed.barrier()
return

ret = None
planner = dcp.DefaultSavePlanner(enable_plan_caching=True)

# Routes to MSC storage write for cloud paths
if is_cloud_path(path) and storage_writer is None:
_ensure_msc_available()
storage_writer = msc.torch.MultiStorageFileSystemWriter(path)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same concern here with the Hugging Face safetensors format

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, thank you for your feedback. Yes you are correct, the code was overriding the storage writer for any cloud load paths. I have changed the condition to "if is_cloud_path(path) and storage_writer is None:" to resolve the issue. I have also included extra test cases to address the mentioned combinations.


if self.config.is_async:
ctx = self._model_ctx if is_model else self._optim_ctx
ret = dcp.async_save(
Expand Down Expand Up @@ -1107,8 +1153,14 @@ def save_config(config: dict[str, Any], weights_path: str) -> None:
config: Config to save
weights_path: Path to save config
"""
with open(os.path.join(weights_path, "config.yaml"), "w") as f:
yaml.dump(config, f, sort_keys=False, default_flow_style=False)
config_path = os.path.join(weights_path, "config.yaml")
if is_cloud_path(weights_path):
_ensure_msc_available()
with msc.open(config_path, "w") as f:
yaml.dump(config, f, sort_keys=False, default_flow_style=False)
else:
with open(config_path, "w") as f:
yaml.dump(config, f, sort_keys=False, default_flow_style=False)


def _ensure_dirs(*dirs: Optional[str]) -> None:
Expand All @@ -1120,7 +1172,8 @@ def _ensure_dirs(*dirs: Optional[str]) -> None:
"""
for d in dirs:
if d:
os.makedirs(d, exist_ok=True)
if not is_cloud_path(d):
Comment thread
adil-a marked this conversation as resolved.
os.makedirs(d, exist_ok=True)
if torch.distributed.is_initialized():
torch.distributed.barrier()

Expand Down
Loading
Loading