diff --git a/nemo_automodel/components/checkpoint/checkpointing.py b/nemo_automodel/components/checkpoint/checkpointing.py index a1d1f1ebe1..72795c35ef 100644 --- a/nemo_automodel/components/checkpoint/checkpointing.py +++ b/nemo_automodel/components/checkpoint/checkpointing.py @@ -25,6 +25,14 @@ import torch.distributed.checkpoint as dcp import yaml +try: + import multistorageclient as msc + + MSC_AVAILABLE = True +except ImportError: + msc = None + MSC_AVAILABLE = False + # Safe import of HF_HUB_CACHE from huggingface_hub.constants try: from huggingface_hub.constants import HF_HUB_CACHE @@ -32,6 +40,7 @@ HF_HUB_CACHE = None from packaging.version import parse +from safetensors.torch import load as safetensors_load from safetensors.torch import load_file, save_file from torch import nn from torch.distributed.device_mesh import DeviceMesh @@ -63,6 +72,21 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase +def is_cloud_path(path: str) -> bool: + """Check if path is a cloud storage path (MSC).""" + return path.startswith("msc://") + + +def _ensure_msc_available() -> None: + """Raise an error if MSC is not installed but a cloud path is used.""" + if not MSC_AVAILABLE: + raise ImportError( + "multistorageclient is required for cloud storage paths. " + "Install it with: pip install multi-storage-client " + "--index-url https://pypi.nvidia.com" + ) + + def _is_geq_torch_2_9() -> bool: """ Check if the current torch version is greater than or equal to 2.9.0. @@ -404,7 +428,7 @@ def load_model( key_mapping: Optional key remapping when reading from HF checkpoints. """ # Validate checkpoint directory - if not os.path.exists(model_path): + if not os.path.exists(model_path) and not is_cloud_path(model_path): raise FileNotFoundError(f"Model path {model_path} does not exist") model_state = ModelState( model, @@ -807,8 +831,18 @@ def _do_load( is_model = True if "/model" in path else False # PEFT loading is broadcasted from rank0 so it is a special case if self.config.is_peft and is_model and (not is_init_step): - state_dict = load_file(os.path.join(path, "adapter_model.safetensors")) + if is_cloud_path(path): + _ensure_msc_available() + adapter_path = path.rstrip("/") + "/adapter_model.safetensors" + with msc.open(adapter_path, "rb") as f: + data = f.read() + state_dict = safetensors_load(data) + else: + state_dict = load_file(os.path.join(path, "adapter_model.safetensors")) else: + if is_cloud_path(path) and storage_reader is None: + _ensure_msc_available() + storage_reader = msc.torch.MultiStorageFileSystemReader(path) dcp.load(state_dict, checkpoint_id=path, storage_reader=storage_reader) return state_dict @@ -834,13 +868,25 @@ def _do_save( # PEFT saving is done on rank0 so it is a special case if self.config.is_peft and is_model: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: - save_file(state_dict, os.path.join(path, "adapter_model.safetensors")) + if is_cloud_path(path): + _ensure_msc_available() + adapter_path = path.rstrip("/") + "/adapter_model.safetensors" + with msc.open(adapter_path, "wb") as f: + save_file(state_dict, f) + else: + save_file(state_dict, os.path.join(path, "adapter_model.safetensors")) if torch.distributed.is_initialized(): torch.distributed.barrier() return ret = None planner = dcp.DefaultSavePlanner(enable_plan_caching=True) + + # Routes to MSC storage write for cloud paths + if is_cloud_path(path) and storage_writer is None: + _ensure_msc_available() + storage_writer = msc.torch.MultiStorageFileSystemWriter(path) + if self.config.is_async: ctx = self._model_ctx if is_model else self._optim_ctx ret = dcp.async_save( @@ -1107,8 +1153,14 @@ def save_config(config: dict[str, Any], weights_path: str) -> None: config: Config to save weights_path: Path to save config """ - with open(os.path.join(weights_path, "config.yaml"), "w") as f: - yaml.dump(config, f, sort_keys=False, default_flow_style=False) + config_path = os.path.join(weights_path, "config.yaml") + if is_cloud_path(weights_path): + _ensure_msc_available() + with msc.open(config_path, "w") as f: + yaml.dump(config, f, sort_keys=False, default_flow_style=False) + else: + with open(config_path, "w") as f: + yaml.dump(config, f, sort_keys=False, default_flow_style=False) def _ensure_dirs(*dirs: Optional[str]) -> None: @@ -1120,7 +1172,8 @@ def _ensure_dirs(*dirs: Optional[str]) -> None: """ for d in dirs: if d: - os.makedirs(d, exist_ok=True) + if not is_cloud_path(d): + os.makedirs(d, exist_ok=True) if torch.distributed.is_initialized(): torch.distributed.barrier() diff --git a/tests/unit_tests/checkpoint/test_checkpointing.py b/tests/unit_tests/checkpoint/test_checkpointing.py index d21a3168e8..e88aae3960 100644 --- a/tests/unit_tests/checkpoint/test_checkpointing.py +++ b/tests/unit_tests/checkpoint/test_checkpointing.py @@ -14,12 +14,13 @@ import json import os -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - import pytest import torch +import yaml +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from contextlib import ExitStack from nemo_automodel.components.checkpoint._backports.hf_storage import _DIFFUSERS_INDEX_FN from nemo_automodel.components.checkpoint.checkpointing import ( Checkpointer, @@ -28,6 +29,10 @@ _is_custom_model, _model_has_dtensors, _reinit_non_persistent_buffers, + is_cloud_path, + _ensure_msc_available, + _ensure_dirs, + save_config, _summarize_state_dict_key_diff, ) from nemo_automodel.components.checkpoint.stateful_wrappers import ModelState, _get_lm_head_weight_and_name @@ -36,6 +41,10 @@ materialize_missing_tied_lm_head, ) +CLOUD_PATH_MODEL = "msc://bucket/step-100/model" +CLOUD_PATH_OPTIM = "msc://bucket/step-100/optim" +LOCAL_PATH_MODEL = "/ckpts/step-100/model" + def _make_keys(count: int) -> list[str]: return [f"layer.{i}" for i in range(count)] @@ -868,3 +877,797 @@ def _fake_consolidate(**kwargs): assert (consolidated_dir / "model.safetensors.index.json").exists() assert not (consolidated_dir / _DIFFUSERS_INDEX_FN).exists() + + +# ============================================================================= +# Tests for cloud storage path support (MSC integration) +# ============================================================================= + + +@pytest.mark.parametrize("path,expected", [ + ("msc://my-bucket/checkpoints", True), + ("msc://", True), + ("/local/path/checkpoints", False), + ("", False), + ("s3://my-bucket/checkpoints", False), + ("msc:/missing-slash", False), + ("/msc://tricky", False), +]) + +def test_is_cloud_path(path, expected): + """Returns True if path starts with 'msc://', False for all other paths. Only msc:// is supported.""" + assert is_cloud_path(path) is expected + +def _make_ckptr(is_peft=False, is_async=False): + """Returns a minimal mock Checkpointer for testing _do_save and _do_load without a real config or distributed setup.""" + config = MagicMock() + config.is_peft = is_peft + config.is_async = is_async + ckptr = MagicMock(spec=Checkpointer) + ckptr.config = config + ckptr._model_ctx = MagicMock(staging_active=False) + ckptr._optim_ctx = MagicMock(staging_active=False) + return ckptr + +def _cloud_patches(extra_patches=()): + """Returns an ExitStack that patches MSC_AVAILABLE=True and stubs AsyncCheckpointerType for cloud path tests.""" + stack = ExitStack() + stack.enter_context(patch("nemo_automodel.components.checkpoint.checkpointing.MSC_AVAILABLE", True)) + stack.enter_context(patch( "nemo_automodel.components.checkpoint.checkpointing.AsyncCheckpointerType", MagicMock(), create=True,)) + for i in extra_patches: + stack.enter_context(i) + return stack + + +class TestEnsureDirs: + """Ensures that _ensure_dirs creates local directories and skips cloud path creation.""" + + def test_creates_nested_local_dirs(self, tmp_path): + """Calling _ensure_dirs called on a non-existing path creates it will all intermediate directories.""" + target = str(tmp_path / "a" / "b" / "c") + assert not os.path.exists(target) + _ensure_dirs(target) + assert os.path.isdir(target) + + def test_existing_dir_does_not_raise(self, tmp_path): + """Calling _ensure_dirs on a pre-existing directory does not raise error.""" + _ensure_dirs(str(tmp_path)) + + def test_cloud_path_never_touches_filesystem(self): + """For a msc:// path, os.makedirs is never called.""" + with patch("os.makedirs") as mock_makedirs: + _ensure_dirs("msc://bucket/some/deep/path") + mock_makedirs.assert_not_called() + + def test_local_path_passes_exist_ok_true(self, tmp_path): + """os.makedirs is called exactly, use exist_ok=True to avoid errors on existing directories.""" + target = str(tmp_path / "new") + with patch("os.makedirs") as mock_makedirs: + _ensure_dirs(target) + mock_makedirs.assert_called_once_with(target, exist_ok=True) + + +class TestSaveConfig: + """Ensures that save_config writes valid YAML to local paths and uses msc.open for cloud paths.""" + + def test_local_path_writes_valid_yaml(self, tmp_path): + """Writes a config dict to a local path and verifies the file exist and contains the correct values when loaded back.""" + config = {"model": "llama3", "lr": 3e-4, "steps": 1000} + save_config(config, str(tmp_path)) + cfg_file = tmp_path / "config.yaml" + assert cfg_file.exists() + loaded = yaml.safe_load(cfg_file.read_text()) + assert loaded["lr"] == pytest.approx(3e-4) + assert loaded["steps"] == 1000 + + def test_cloud_path_uses_msc_open_not_builtin(self): + """Verifies that for an msc:// path, msc.open is used instead of python's open.""" + config = {"model": "llama3", "lr": 3e-4} + mock_file = MagicMock() + mock_ctx = MagicMock( + __enter__=MagicMock(return_value=mock_file), + __exit__=MagicMock(return_value=False), + ) + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.MSC_AVAILABLE", True), \ + patch("builtins.open") as mock_builtin_open: + mock_msc.open.return_value = mock_ctx + save_config(config, "msc://bucket/checkpoints") + + mock_msc.open.assert_called_once() + mock_builtin_open.assert_not_called() + + def test_config_written_inside_checkpoint_dir(self): + """Confirms the config file lands inside the checkpoint directory""" + config = {"x": 1} + mock_file = MagicMock() + mock_ctx = MagicMock( + __enter__=MagicMock(return_value=mock_file), + __exit__=MagicMock(return_value=False), + ) + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.MSC_AVAILABLE", True): + mock_msc.open.return_value = mock_ctx + save_config(config, "msc://bucket/run42") + + opened_path = mock_msc.open.call_args[0][0] + assert opened_path.startswith("msc://bucket/run42") + + +class TestDoLoad: + """Tests that _do_load routes to the correct storage writer based on path and format.""" + + def _make_checkpointer(self, is_peft=False): + config = MagicMock() + config.is_peft = is_peft + ckptr = MagicMock(spec=Checkpointer) + ckptr.config = config + return ckptr + + def test_cloud_path_uses_msc_reader(self): + """Cloud path: MSC writer is injected and used for saving.""" + ckptr = self._make_checkpointer() + state_dict = {"weight": torch.zeros(4)} + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.MSC_AVAILABLE", True), \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp"): + Checkpointer._do_load(ckptr, state_dict, "msc://bucket/step-100") + + mock_msc.torch.MultiStorageFileSystemReader.assert_called_once_with("msc://bucket/step-100") + + def test_local_path_does_not_use_msc_reader(self, tmp_path): + """Local path: MSC writer is never used.""" + ckptr = self._make_checkpointer() + state_dict = {"weight": torch.zeros(4)} + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp"): + Checkpointer._do_load(ckptr, state_dict, str(tmp_path / "step-100")) + + mock_msc.open.assert_not_called() + + def test_peft_cloud_load_still_routes_through_msc_reader(self): + """MSC writer is called with the exact checkpoint path, not a modified subpath.""" + ckptr = self._make_checkpointer(is_peft=True) + state_dict = {"weight": torch.zeros(4)} + mock_file = MagicMock() + mock_file.read.return_value= b"fake bytes" + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.MSC_AVAILABLE", True), \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp"), \ + patch("nemo_automodel.components.checkpoint.checkpointing.safetensors_load") as mock_load: + mock_msc.open.return_value.__enter__=MagicMock(return_value=mock_file) + mock_msc.open.return_value.__exit__=MagicMock(return_value=False) + mock_load.return_value = state_dict + Checkpointer._do_load(ckptr, state_dict, "msc://bucket/step-100/model") + + mock_msc.open.assert_called_once() + + def test_save_and_load_use_same_path(self): + """Async mode: MSC writer is still injected for cloud paths.""" + config = MagicMock() + config.is_peft = False + config.is_async = False + ckptr = MagicMock(spec=Checkpointer) + ckptr.config = config + state_dict = {"weight": torch.ones(4)} + path = "msc://bucket/step-300" + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.MSC_AVAILABLE", True), \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp"): + Checkpointer._do_save(ckptr, state_dict, path) + Checkpointer._do_load(ckptr, state_dict, path) + + mock_msc.torch.MultiStorageFileSystemWriter.assert_called_once_with(path) + mock_msc.torch.MultiStorageFileSystemReader.assert_called_once_with(path) + +class TestDoSaveFullSFT: + """Tests that _do_save correctly routes full-SFT saves for DCP and safetensors formats on cloud and local paths.""" + + def test_dcp_cloud_sync_uses_msc_writer(self): + """DCP + cloud + sync: MSC writer injected, and dcp.save is called""" + ckptr = _make_ckptr(is_peft=False, is_async=False) + sd = {"w": torch.ones(4)} + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, CLOUD_PATH_OPTIM, storage_writer=None) + + mock_msc.torch.MultiStorageFileSystemWriter.assert_called_once_with(CLOUD_PATH_OPTIM) + mock_dcp.save.assert_called_once() + + def test_safetensors_cloud_sync_does_not_override_hf_writer(self): + """Safetensors + cloud + sync: existing HF writer NOT replaced by MSC writer.""" + ckptr = _make_ckptr(is_peft=False, is_async=False) + sd = {"w": torch.ones(4)} + hf_writer = MagicMock(name="HFStorageWriter") + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, CLOUD_PATH_MODEL, storage_writer=hf_writer) + + mock_msc.torch.MultiStorageFileSystemWriter.assert_not_called() + mock_dcp.save.assert_called_once() + _, kwargs = mock_dcp.save.call_args + assert kwargs["storage_writer"] is hf_writer + + def test_safetensors_cloud_async_does_not_override_hf_writer(self): + """Safetensors + cloud + async: existing HF writer NOT replaced by MSC writer.""" + ckptr = _make_ckptr(is_peft=False, is_async=True) + sd = {"w": torch.ones(4)} + hf_writer = MagicMock(name="HFStorageWriter") + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, CLOUD_PATH_MODEL, storage_writer=hf_writer) + + mock_msc.torch.MultiStorageFileSystemWriter.assert_not_called() + mock_dcp.async_save.assert_called_once() + _, kwargs = mock_dcp.async_save.call_args + assert kwargs["storage_writer"] is hf_writer + + def test_local_dcp_sync_no_msc(self): + """Local + DCP + sync: MSC writer never used.""" + + ckptr = _make_ckptr(is_peft=False, is_async=False) + sd = {"w": torch.ones(4)} + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp"): + Checkpointer._do_save(ckptr, sd, LOCAL_PATH_MODEL, storage_writer=None) + + mock_msc.torch.MultiStorageFileSystemWriter.assert_not_called() + +class TestDoSavePEFT: + """Tests that _do_save correctly handles PEFT adapter saves using msc.open for cloud paths and save_file for local paths.""" + + def test_peft_cloud_sync_uses_msc_open(self): + """PEFT + cloud + sync: msc.open used for adapter file, dcp never called.""" + ckptr = _make_ckptr(is_peft=True, is_async=False) + sd = {"lora.weight": torch.ones(4)} + mock_file = MagicMock() + mock_ctx = MagicMock(__enter__=MagicMock(return_value=mock_file), + __exit__=MagicMock(return_value=False)) + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp, \ + patch("nemo_automodel.components.checkpoint.checkpointing.save_file"), \ + patch("torch.distributed.is_initialized", return_value=False): + mock_msc.open.return_value = mock_ctx + Checkpointer._do_save(ckptr, sd, CLOUD_PATH_MODEL) + + mock_msc.open.assert_called_once() + mock_dcp.save.assert_not_called() + mock_dcp.async_save.assert_not_called() + + def test_peft_cloud_async_still_uses_msc_open_not_dcp(self): + """PEFT + cloud + async: adapter written sync via msc.open, dcp never called.""" + ckptr = _make_ckptr(is_peft=True, is_async=True) + sd = {"lora.weight": torch.ones(4)} + mock_file = MagicMock() + mock_ctx = MagicMock(__enter__=MagicMock(return_value=mock_file), + __exit__=MagicMock(return_value=False)) + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp, \ + patch("nemo_automodel.components.checkpoint.checkpointing.save_file"), \ + patch("torch.distributed.is_initialized", return_value=False): + mock_msc.open.return_value = mock_ctx + Checkpointer._do_save(ckptr, sd, CLOUD_PATH_MODEL) + + mock_msc.open.assert_called_once() + mock_dcp.async_save.assert_not_called() + mock_dcp.save.assert_not_called() + + def test_peft_local_sync_uses_save_file_not_msc(self): + """PEFT + local + sync: save_file called, msc.open NOT called.""" + ckptr = _make_ckptr(is_peft=True, is_async=False) + sd = {"lora.weight": torch.ones(4)} + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.save_file") as mock_sf, \ + patch("torch.distributed.is_initialized", return_value=False): + Checkpointer._do_save(ckptr, sd, LOCAL_PATH_MODEL) + + mock_msc.open.assert_not_called() + mock_sf.assert_called_once() + + def test_peft_adapter_path_appended_correctly(self): + """PEFT cloud save opens exactly '/adapter_model.safetensors'.""" + ckptr = _make_ckptr(is_peft=True) + sd = {"lora.weight": torch.ones(4)} + path = "msc://mybucket/run7/step-500/model" + mock_file = MagicMock() + mock_ctx = MagicMock(__enter__=MagicMock(return_value=mock_file), + __exit__=MagicMock(return_value=False)) + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.save_file"), \ + patch("torch.distributed.is_initialized", return_value=False): + mock_msc.open.return_value = mock_ctx + Checkpointer._do_save(ckptr, sd, path) + + opened_path = mock_msc.open.call_args[0][0] + assert opened_path == "msc://mybucket/run7/step-500/model/adapter_model.safetensors" + +class TestDoLoadFullSFT: + """Tests that _do_load correctly routes full-SFT loads for DCP and safetensors formats on cloud and local paths.""" + + def test_dcp_cloud_uses_msc_reader(self): + """DCP + cloud: MSC reader injected when no reader provided.""" + ckptr = _make_ckptr(is_peft=False) + sd = {"w": torch.zeros(4)} + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_load(ckptr, sd, CLOUD_PATH_OPTIM, storage_reader=None) + + mock_msc.torch.MultiStorageFileSystemReader.assert_called_once_with(CLOUD_PATH_OPTIM) + mock_dcp.load.assert_called_once() + + def test_safetensors_cloud_does_not_override_hf_reader(self): + """Safetensors + cloud: existing HF reader NOT replaced by MSC reader.""" + ckptr = _make_ckptr(is_peft=False) + sd = {"w": torch.zeros(4)} + hf_reader = MagicMock(name="HFStorageReader") + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_load(ckptr, sd, CLOUD_PATH_MODEL, storage_reader=hf_reader) + + mock_msc.torch.MultiStorageFileSystemReader.assert_not_called() + mock_dcp.load.assert_called_once() + _, kwargs = mock_dcp.load.call_args + assert kwargs["storage_reader"] is hf_reader + + def test_local_dcp_no_msc(self): + """Local + DCP: MSC reader never used.""" + ckptr = _make_ckptr(is_peft=False) + sd = {"w": torch.zeros(4)} + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp"): + Checkpointer._do_load(ckptr, sd, LOCAL_PATH_MODEL, storage_reader=None) + + mock_msc.torch.MultiStorageFileSystemReader.assert_not_called() + + def test_safetensors_local_does_not_use_msc(self): + """Safetensors + local: MSC reader never used.""" + ckptr = _make_ckptr(is_peft=False) + sd = {"w": torch.zeros(4)} + hf_reader = MagicMock(name="HFStorageReader") + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp"): + Checkpointer._do_load(ckptr, sd, LOCAL_PATH_MODEL, storage_reader=hf_reader) + + mock_msc.torch.MultiStorageFileSystemReader.assert_not_called() + +class TestDoLoadPEFT: + """Tests that _do_load correctly handles PEFT adapter loads using msc.open for cloud paths and load_file for local paths.""" + + def test_peft_cloud_uses_msc_open_not_dcp(self): + """PEFT + cloud: msc.open used for adapter, dcp.load NOT called.""" + ckptr = _make_ckptr(is_peft=True) + sd = {"lora.weight": torch.zeros(4)} + mock_file = MagicMock() + mock_file.read.return_value = b"fake_safetensors_bytes" + mock_ctx = MagicMock(__enter__=MagicMock(return_value=mock_file), + __exit__=MagicMock(return_value=False)) + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp, \ + patch("nemo_automodel.components.checkpoint.checkpointing.safetensors_load", return_value=sd): + mock_msc.open.return_value = mock_ctx + Checkpointer._do_load(ckptr, sd, CLOUD_PATH_MODEL) + + mock_msc.open.assert_called_once() + mock_dcp.load.assert_not_called() + + def test_peft_cloud_adapter_path_correct(self): + """PEFT + cloud: opens exactly '/adapter_model.safetensors'.""" + ckptr = _make_ckptr(is_peft=True) + sd = {} + path = "msc://bucket/run3/step-200/model" + mock_file = MagicMock() + mock_file.read.return_value = b"bytes" + mock_ctx = MagicMock(__enter__=MagicMock(return_value=mock_file), + __exit__=MagicMock(return_value=False)) + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.safetensors_load", return_value=sd): + mock_msc.open.return_value = mock_ctx + Checkpointer._do_load(ckptr, sd, path) + + opened_path = mock_msc.open.call_args[0][0] + assert opened_path == "msc://bucket/run3/step-200/model/adapter_model.safetensors" + + def test_peft_local_uses_load_file_not_msc(self): + """PEFT + local: load_file called, msc.open NOT called.""" + ckptr = _make_ckptr(is_peft=True) + sd = {} + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.load_file", return_value=sd) as mock_lf: + Checkpointer._do_load(ckptr, sd, LOCAL_PATH_MODEL) + + mock_msc.open.assert_not_called() + mock_lf.assert_called_once() + + def test_peft_load_at_init_step_skips_peft_branch_uses_dcp(self): + """PEFT + cloud + is_init_step=True: DCP path used, not PEFT adapter path.""" + ckptr = _make_ckptr(is_peft=True) + sd = {"w": torch.zeros(4)} + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_load(ckptr, sd, CLOUD_PATH_MODEL, is_init_step=True) + + mock_msc.torch.MultiStorageFileSystemReader.assert_called_once_with(CLOUD_PATH_MODEL) + mock_dcp.load.assert_called_once() + mock_msc.open.assert_not_called() + + +class TestFormatSave: + """Tests that _get_storage_writer returns the correct writer for each format, and that _do_save routes correctly based on whether a writer is provided.""" + + def _make_checkpointer(self, model_save_format, is_peft=False): + with patch("torch.distributed.is_initialized", return_value=False): + config = CheckpointingConfig( + enabled=True, + checkpoint_dir="/tmp/test", + model_save_format=model_save_format, + model_cache_dir="/tmp/cache", + model_repo_id="test/model", + save_consolidated=False, + is_peft=is_peft, + ) + return Checkpointer(config, dp_rank=0, tp_rank=0, pp_rank=0) + + def test_safetensors_format_produces_hf_writer(self): + """safetensors format: _get_storage_writer returns _HuggingFaceStorageWriter.""" + ckptr = self._make_checkpointer("safetensors") + writer = ckptr._get_storage_writer( + consolidated_output_path=None, + fqn_to_index_mapping={"w": 1}, + model_path="/tmp/model", + ) + from nemo_automodel.components.checkpoint._backports.hf_storage import _HuggingFaceStorageWriter + assert isinstance(writer, _HuggingFaceStorageWriter) + + def test_dcp_format_produces_no_writer(self): + """torch_save (DCP) format: _get_storage_writer returns None.""" + ckptr = self._make_checkpointer("torch_save") + writer = ckptr._get_storage_writer( + consolidated_output_path=None, + fqn_to_index_mapping=None, + model_path="/tmp/model", + ) + assert writer is None + + def test_safetensors_cloud_save_uses_hf_writer_not_msc(self): + """safetensors + cloud: HF writer passed to dcp.save, MSC writer never created.""" + ckptr = self._make_checkpointer("safetensors") + sd = {"w": torch.ones(4)} + hf_writer = MagicMock(name="HFStorageWriter") + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + ckptr._do_save(sd, "msc://bucket/step-100/model", storage_writer=hf_writer) + + mock_msc.torch.MultiStorageFileSystemWriter.assert_not_called() + mock_dcp.save.assert_called_once() + _, kwargs = mock_dcp.save.call_args + assert kwargs["storage_writer"] is hf_writer + + def test_dcp_cloud_save_uses_msc_writer(self): + """torch_save (DCP) + cloud: no HF writer provided, so MSC writer injected.""" + ckptr = self._make_checkpointer("torch_save") + sd = {"w": torch.ones(4)} + writer = ckptr._get_storage_writer( + consolidated_output_path=None, + fqn_to_index_mapping=None, + model_path="msc://bucket/step-100/optim", + ) + assert writer is None + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + ckptr._do_save(sd, "msc://bucket/step-100/optim", storage_writer=writer) + + mock_msc.torch.MultiStorageFileSystemWriter.assert_called_once_with("msc://bucket/step-100/optim") + mock_dcp.save.assert_called_once() + + def test_safetensors_local_save_uses_hf_writer(self): + """safetensors + local: HF writer used, MSC never involved.""" + ckptr = self._make_checkpointer("safetensors") + sd = {"w": torch.ones(4)} + hf_writer = ckptr._get_storage_writer( + consolidated_output_path=None, + fqn_to_index_mapping={"w": 1}, + model_path="/tmp/step-100/model", + ) + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + ckptr._do_save(sd, "/tmp/step-100/model", storage_writer=hf_writer) + + mock_msc.torch.MultiStorageFileSystemWriter.assert_not_called() + mock_dcp.save.assert_called_once() + + def test_dcp_local_save_no_writer_no_msc(self): + """torch_save (DCP) + local: no writer, no MSC, plain dcp.save.""" + ckptr = self._make_checkpointer("torch_save") + sd = {"w": torch.ones(4)} + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + ckptr._do_save(sd, "/tmp/step-100/optim", storage_writer=None) + + mock_msc.torch.MultiStorageFileSystemWriter.assert_not_called() + mock_dcp.save.assert_called_once() + + +class TestFormatLoad: + """Tests that _get_storage_reader returns the correct reader for each format, and that _do_load routes correctly based on whether a reader is provided.""" + + def _make_checkpointer(self, model_save_format, is_peft=False): + with patch("torch.distributed.is_initialized", return_value=False): + config = CheckpointingConfig( + enabled=True, + checkpoint_dir="/tmp/test", + model_save_format=model_save_format, + model_cache_dir="/tmp/cache", + model_repo_id="test/model", + save_consolidated=False, + is_peft=is_peft, + ) + return Checkpointer(config, dp_rank=0, tp_rank=0, pp_rank=0) + + def test_safetensors_format_produces_hf_reader(self): + """safetensors format: _get_storage_reader returns an HF reader.""" + ckptr = self._make_checkpointer("safetensors") + reader = ckptr._get_storage_reader("/tmp/model", key_mapping=None) + assert reader is not None + + def test_dcp_format_produces_no_reader(self): + """torch_save (DCP) format: _get_storage_reader returns None.""" + ckptr = self._make_checkpointer("torch_save") + reader = ckptr._get_storage_reader("/tmp/model", key_mapping=None) + assert reader is None + + def test_safetensors_cloud_load_uses_hf_reader_not_msc(self): + """safetensors + cloud: HF reader passed to dcp.load, MSC reader never created.""" + ckptr = self._make_checkpointer("safetensors") + sd = {"w": torch.zeros(4)} + hf_reader = ckptr._get_storage_reader("/tmp/model", key_mapping=None) + assert hf_reader is not None + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + ckptr._do_load(sd, "msc://bucket/step-100/model", storage_reader=hf_reader) + + mock_msc.torch.MultiStorageFileSystemReader.assert_not_called() + mock_dcp.load.assert_called_once() + _, kwargs = mock_dcp.load.call_args + assert kwargs["storage_reader"] is hf_reader + + def test_dcp_cloud_load_uses_msc_reader(self): + """torch_save (DCP) + cloud: no HF reader, so MSC reader injected.""" + ckptr = self._make_checkpointer("torch_save") + sd = {"w": torch.zeros(4)} + reader = ckptr._get_storage_reader("/tmp/model", key_mapping=None) + assert reader is None + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + ckptr._do_load(sd, "msc://bucket/step-100/optim", storage_reader=reader) + + mock_msc.torch.MultiStorageFileSystemReader.assert_called_once_with("msc://bucket/step-100/optim") + mock_dcp.load.assert_called_once() + + def test_safetensors_local_load_uses_hf_reader(self): + """safetensors + local: HF reader used, MSC never involved.""" + ckptr = self._make_checkpointer("safetensors") + sd = {"w": torch.zeros(4)} + hf_reader = ckptr._get_storage_reader("/tmp/model", key_mapping=None) + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + ckptr._do_load(sd, "/tmp/step-100/model", storage_reader=hf_reader) + + mock_msc.torch.MultiStorageFileSystemReader.assert_not_called() + mock_dcp.load.assert_called_once() + + def test_dcp_local_load_no_reader_no_msc(self): + """torch_save (DCP) + local: no reader, no MSC, plain dcp.load.""" + ckptr = self._make_checkpointer("torch_save") + sd = {"w": torch.zeros(4)} + + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + ckptr._do_load(sd, "/tmp/step-100/optim", storage_reader=None) + + mock_msc.torch.MultiStorageFileSystemReader.assert_not_called() + mock_dcp.load.assert_called_once() + + +class TestSyncAsyncSave: + """Tests that _do_save calls dcp.save for sync and dcp.async_save for async, across DCP, safetensors, and PEFT formats on both cloud and local paths.""" + + def _make_ckptr(self, is_async, is_peft=False): + config = MagicMock() + config.is_peft = is_peft + config.is_async = is_async + ckptr = MagicMock(spec=Checkpointer) + ckptr.config = config + ckptr._model_ctx = MagicMock(staging_active=False) + ckptr._optim_ctx = MagicMock(staging_active=False) + return ckptr + + def test_dcp_cloud_sync_calls_dcp_save(self): + """DCP + cloud + sync: dcp.save called, dcp.async_save NOT called.""" + ckptr = self._make_ckptr(is_async=False) + sd = {"w": torch.ones(4)} + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc"), \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, "msc://bucket/step-100/optim") + + mock_dcp.save.assert_called_once() + mock_dcp.async_save.assert_not_called() + + def test_dcp_cloud_async_calls_dcp_async_save(self): + """DCP + cloud + async: dcp.async_save called, dcp.save NOT called.""" + ckptr = self._make_ckptr(is_async=True) + sd = {"w": torch.ones(4)} + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc"), \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, "msc://bucket/step-100/optim") + + mock_dcp.async_save.assert_called_once() + mock_dcp.save.assert_not_called() + + def test_dcp_cloud_async_msc_writer_passed_to_async_save(self): + """DCP + cloud + async: MSC writer is passed as storage_writer to dcp.async_save.""" + ckptr = self._make_ckptr(is_async=True) + sd = {"w": torch.ones(4)} + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, "msc://bucket/step-100/optim") + + msc_writer = mock_msc.torch.MultiStorageFileSystemWriter.return_value + _, kwargs = mock_dcp.async_save.call_args + assert kwargs["storage_writer"] is msc_writer + + def test_dcp_cloud_sync_msc_writer_passed_to_save(self): + """DCP + cloud + sync: MSC writer is passed as storage_writer to dcp.save.""" + ckptr = self._make_ckptr(is_async=False) + sd = {"w": torch.ones(4)} + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, "msc://bucket/step-100/optim") + + msc_writer = mock_msc.torch.MultiStorageFileSystemWriter.return_value + _, kwargs = mock_dcp.save.call_args + assert kwargs["storage_writer"] is msc_writer + + def test_safetensors_cloud_sync_calls_dcp_save(self): + """safetensors + cloud + sync: dcp.save called with HF writer, dcp.async_save NOT called.""" + ckptr = self._make_ckptr(is_async=False) + sd = {"w": torch.ones(4)} + hf_writer = MagicMock(name="HFStorageWriter") + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, "msc://bucket/step-100/model", storage_writer=hf_writer) + + mock_dcp.save.assert_called_once() + mock_dcp.async_save.assert_not_called() + mock_msc.torch.MultiStorageFileSystemWriter.assert_not_called() + + def test_safetensors_cloud_async_calls_dcp_async_save(self): + """safetensors + cloud + async: dcp.async_save called with HF writer, dcp.save NOT called.""" + ckptr = self._make_ckptr(is_async=True) + sd = {"w": torch.ones(4)} + hf_writer = MagicMock(name="HFStorageWriter") + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, "msc://bucket/step-100/model", storage_writer=hf_writer) + + mock_dcp.async_save.assert_called_once() + mock_dcp.save.assert_not_called() + mock_msc.torch.MultiStorageFileSystemWriter.assert_not_called() + _, kwargs = mock_dcp.async_save.call_args + assert kwargs["storage_writer"] is hf_writer + + def test_peft_cloud_sync_uses_msc_open_not_dcp(self): + """PEFT + cloud + sync: adapter written via msc.open, dcp never called.""" + ckptr = self._make_ckptr(is_async=False, is_peft=True) + sd = {"lora.weight": torch.ones(4)} + mock_file = MagicMock() + mock_ctx = MagicMock(__enter__=MagicMock(return_value=mock_file), + __exit__=MagicMock(return_value=False)) + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp, \ + patch("nemo_automodel.components.checkpoint.checkpointing.save_file"), \ + patch("torch.distributed.is_initialized", return_value=False): + mock_msc.open.return_value = mock_ctx + Checkpointer._do_save(ckptr, sd, "msc://bucket/step-100/model") + + mock_msc.open.assert_called_once() + mock_dcp.save.assert_not_called() + mock_dcp.async_save.assert_not_called() + + def test_peft_cloud_async_still_uses_msc_open_not_dcp(self): + """PEFT + cloud + async: adapter still written sync via msc.open, dcp never called.""" + ckptr = self._make_ckptr(is_async=True, is_peft=True) + sd = {"lora.weight": torch.ones(4)} + mock_file = MagicMock() + mock_ctx = MagicMock(__enter__=MagicMock(return_value=mock_file), + __exit__=MagicMock(return_value=False)) + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.msc") as mock_msc, \ + patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp, \ + patch("nemo_automodel.components.checkpoint.checkpointing.save_file"), \ + patch("torch.distributed.is_initialized", return_value=False): + mock_msc.open.return_value = mock_ctx + Checkpointer._do_save(ckptr, sd, "msc://bucket/step-100/model") + + mock_msc.open.assert_called_once() + mock_dcp.async_save.assert_not_called() + mock_dcp.save.assert_not_called() + + def test_local_sync_calls_dcp_save(self): + """Local + sync: dcp.save called, dcp.async_save NOT called.""" + ckptr = self._make_ckptr(is_async=False) + sd = {"w": torch.ones(4)} + + with patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, "/tmp/step-100/optim") + + mock_dcp.save.assert_called_once() + mock_dcp.async_save.assert_not_called() + + def test_local_async_calls_dcp_async_save(self): + """Local + async: dcp.async_save called, dcp.save NOT called.""" + ckptr = self._make_ckptr(is_async=True) + sd = {"w": torch.ones(4)} + + with _cloud_patches(): + with patch("nemo_automodel.components.checkpoint.checkpointing.dcp") as mock_dcp: + Checkpointer._do_save(ckptr, sd, "/tmp/step-100/optim") + + mock_dcp.async_save.assert_called_once() + mock_dcp.save.assert_not_called() \ No newline at end of file