Skip to content

Commit 1dab6a6

Browse files
committed
fix: address review comments - PEFT cloud path, optional dependencies, and unit test fixes
Signed-off-by: Edison <edisonggacc@gmail.com> git add . git rebase --continue git push origin feat/s3-dcp-support# interactive rebase in progress; onto ec20197 Signed-off-by: Edison <edisonggacc@gmail.com>
1 parent ba9d2f5 commit 1dab6a6

3 files changed

Lines changed: 36 additions & 11 deletions

File tree

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
HF_HUB_CACHE = None
4141

4242
from packaging.version import parse
43-
from safetensors.torch import load_file, save_file
43+
from safetensors.torch import load_file, save_file, load as safetensors_load
4444
from torch import nn
4545
from torch.distributed.device_mesh import DeviceMesh
4646

@@ -76,7 +76,11 @@ def _ensure_msc_available() -> None:
7676
"""Raise an error if MSC is not installed but a cloud path is used."""
7777
if not MSC_AVAILABLE:
7878
raise ImportError(
79+
<<<<<<< HEAD
7980
"multistorageclient is required for cloud storage paths. "
81+
=======
82+
"multistorageclient is required for cloud storage paths."
83+
>>>>>>> 91a223c (fix: address review comments - PEFT cloud path, optional dependencies, and unit test fixes)
8084
"Install it with: pip install multi-storage-client "
8185
"--index-url https://pypi.nvidia.com"
8286
)
@@ -700,7 +704,14 @@ def _do_load(
700704
is_model = True if "/model" in path else False
701705
# PEFT loading is broadcasted from rank0 so it is a special case
702706
if self.config.is_peft and is_model and (not is_init_step):
703-
state_dict = load_file(os.path.join(path, "adapter_model.safetensors"))
707+
if is_cloud_path(path):
708+
_ensure_msc_available()
709+
adapter_path = path.rstrip("/") + "/adapter_model.safetensors"
710+
with msc.open(adapter_path, "rb") as f:
711+
data = f.read()
712+
state_dict = safetensors_load(data)
713+
else:
714+
state_dict = load_file(os.path.join(path, "adapter_model.safetensors"))
704715
else:
705716
if is_cloud_path(path):
706717
_ensure_msc_available()
@@ -730,7 +741,13 @@ def _do_save(
730741
# PEFT saving is done on rank0 so it is a special case
731742
if self.config.is_peft and is_model:
732743
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
733-
save_file(state_dict, os.path.join(path, "adapter_model.safetensors"))
744+
if is_cloud_path(path):
745+
_ensure_msc_available()
746+
adapter_path = path.rstrip("/") + "/adapter_model.safetensors"
747+
with msc.open(adapter_path, "wb") as f:
748+
save_file(state_dict, f)
749+
else:
750+
save_file(state_dict, os.path.join(path, "adapter_model.safetensors"))
734751
if torch.distributed.is_initialized():
735752
torch.distributed.barrier()
736753
return

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ dependencies = [
8787
"torchao",
8888
"mlflow",
8989
"flashoptim>=0.1.3",
90-
"localstack>=2026.3.0",
91-
"multistorageclient[aws]",
9290
]
9391

9492
[project.optional-dependencies]
93+
cloud = ["multi-storage-client",]
94+
dev = ["localstack>=2026.3.0",]
9595
diffusion = [
9696
"diffusers>=0.36.0",
9797
"ftfy",

tests/unit_tests/checkpoint/test_checkpointing.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@ def test_cloud_path_uses_msc_reader(self):
904904
patch("nemo_automodel.components.checkpoint.checkpointing.dcp"):
905905
Checkpointer._do_load(ckptr, state_dict, "msc://bucket/step-100")
906906

907-
mock_msc.torch. MultiStorageFileSystemReader.assert_called_once_with("msc://bucket/step-100")
907+
mock_msc.torch.MultiStorageFileSystemReader.assert_called_once_with("msc://bucket/step-100")
908908

909909
def test_local_path_does_not_use_msc_reader(self, tmp_path):
910910
ckptr = self._make_checkpointer()
@@ -914,17 +914,25 @@ def test_local_path_does_not_use_msc_reader(self, tmp_path):
914914
patch("nemo_automodel.components.checkpoint.checkpointing.dcp"):
915915
Checkpointer._do_load(ckptr, state_dict, str(tmp_path / "step-100"))
916916

917-
mock_msc.torch. MultiStorageFileSystemReader.assert_not_called()
917+
mock_msc.open.assert_not_called()
918918

919919
def test_peft_cloud_load_still_routes_through_msc_reader(self):
920920
ckptr = self._make_checkpointer(is_peft=True)
921921
state_dict = {"weight": torch.zeros(4)}
922+
mock_file = MagicMock()
923+
mock_file.read.return_value= b"fake bytes"
922924

923925
with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \
924-
patch("nemo_automodel.components.checkpoint.checkpointing.dcp"):
925-
Checkpointer._do_load(ckptr, state_dict, "msc://bucket/step-100")
926+
patch("nemo_automodel.components.checkpoint.checkpointing.dcp"), \
927+
patch("nemo_automodel.components.checkpoint.checkpointing.safetensors_load") as mock_load:
928+
mock_msc.open.return_value.__enter__=MagicMock(return_value=mock_file)
929+
mock_msc.open.return_value.__exit__=MagicMock(return_value=False)
930+
mock_load.return_value = state_dict
931+
Checkpointer._do_load(ckptr, state_dict, "msc://bucket/step-100/model")
932+
933+
mock_msc.open.assert_called_once()
926934

927-
mock_msc.torch. MultiStorageFileSystemReader.assert_called_once_with("msc://bucket/step-100")
935+
928936

929937
def test_save_and_load_use_same_path(self):
930938
config = MagicMock()
@@ -941,4 +949,4 @@ def test_save_and_load_use_same_path(self):
941949
Checkpointer._do_load(ckptr, state_dict, path)
942950

943951
mock_msc.torch.MultiStorageFileSystemWriter.assert_called_once_with(path)
944-
mock_msc.torch. MultiStorageFileSystemReader.assert_called_once_with(path)
952+
mock_msc.torch.MultiStorageFileSystemReader.assert_called_once_with(path)

0 commit comments

Comments
 (0)