Skip to content

Commit e509944

Browse files
committed
feat: add MSC cloud storage support for dcp checkpoints
Signed-off-by: Edison <edisonggacc@gmail.com>
1 parent b9a2154 commit e509944

5 files changed

Lines changed: 496 additions & 17 deletions

File tree

nemo_automodel/components/checkpoint/checkpointing.py

Lines changed: 169 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,22 @@
2525
import torch.distributed.checkpoint as dcp
2626
import yaml
2727

28+
try:
29+
import multistorageclient as msc
30+
31+
MSC_AVAILABLE = True
32+
except ImportError:
33+
msc = None
34+
MSC_AVAILABLE = False
35+
2836
# Safe import of HF_HUB_CACHE from huggingface_hub.constants
2937
try:
3038
from huggingface_hub.constants import HF_HUB_CACHE
3139
except ImportError:
3240
HF_HUB_CACHE = None
3341

3442
from packaging.version import parse
43+
from safetensors.torch import load as safetensors_load
3544
from safetensors.torch import load_file, save_file
3645
from torch import nn
3746
from torch.distributed.device_mesh import DeviceMesh
@@ -59,6 +68,21 @@
5968
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
6069

6170

71+
def is_cloud_path(path: str) -> bool:
72+
"""Check if path is a cloud storage path (MSC)."""
73+
return path.startswith("msc://")
74+
75+
76+
def _ensure_msc_available() -> None:
77+
"""Raise an error if MSC is not installed but a cloud path is used."""
78+
if not MSC_AVAILABLE:
79+
raise ImportError(
80+
"multistorageclient is required for cloud storage paths. "
81+
"Install it with: pip install multi-storage-client "
82+
"--index-url https://pypi.nvidia.com"
83+
)
84+
85+
6286
def _is_geq_torch_2_9() -> bool:
6387
"""
6488
Check if the current torch version is greater than or equal to 2.9.0.
@@ -267,7 +291,11 @@ def save_model(
267291

268292
# Convert to HF format if using custom model implementations
269293
state_dict = _maybe_adapt_state_dict_to_hf(
270-
model_state.model[0], state_dict, quantization=False, device_mesh=self.moe_mesh
294+
model_state.model[0],
295+
state_dict,
296+
quantization=False,
297+
device_mesh=self.moe_mesh,
298+
v4_compatible=self.config.v4_compatible,
271299
)
272300
# Build the consolidated model.safetensors.index.json if needed
273301
fqn_to_file_index_mapping = self._maybe_build_consolidated_index(model_state, state_dict)
@@ -369,7 +397,7 @@ def load_model(
369397
key_mapping: Optional key remapping when reading from HF checkpoints.
370398
"""
371399
# Validate checkpoint directory
372-
if not os.path.exists(model_path):
400+
if not os.path.exists(model_path) and not is_cloud_path(model_path):
373401
raise FileNotFoundError(f"Model path {model_path} does not exist")
374402
model_state = ModelState(
375403
model,
@@ -481,9 +509,15 @@ def initialize_model_weights(
481509
device: Target device for materialized parameters.
482510
peft_init_method: Initialization method for PEFT adapters (e.g. "xavier").
483511
"""
484-
to_empty_parameters_only(model, device=device)
512+
# Only materialize parameters that are actually on the meta device.
513+
# When the caller sets is_meta_device=True but the model was already
514+
# constructed on a real device (e.g. ContextManagers was patched to
515+
# a no-op), calling to_empty_parameters_only would replace valid
516+
# weights with uninitialized CUDA memory.
517+
has_meta_params = any(p.device.type == "meta" for p in model.parameters())
518+
if has_meta_params:
519+
to_empty_parameters_only(model, device=device)
485520

486-
# to_empty_parameters_only only materializes parameters, not buffers.
487521
# Buffers (e.g. RoPE inv_freq) may still be on meta device. Move them
488522
# to *device* with uninitialized storage so that the subsequent
489523
# initialize_weights() call can overwrite them with proper values
@@ -521,6 +555,17 @@ def initialize_model_weights(
521555
and getattr(model.config, "n_routed_experts", None) # is Nemotron V3
522556
and hasattr(model, "backbone") # is HF remote code
523557
)
558+
# HF's _init_weights calls init.zeros_(weight[padding_idx]) on
559+
# nn.Embedding layers. When the weight is a DTensor (TP-sharded),
560+
# the integer index triggers a redistribute that fails. Temporarily
561+
# clear padding_idx so the zeroing is skipped, then restore it and
562+
# zero the row via local-tensor ops instead.
563+
has_padding_idx = any(
564+
isinstance(mod, nn.Embedding)
565+
and type(mod.weight).__name__ == "DTensor"
566+
and getattr(mod, "padding_idx", None) is not None
567+
for mod in model.modules()
568+
)
524569
skip_initialize_weights = (
525570
model_class
526571
in [
@@ -529,6 +574,7 @@ def initialize_model_weights(
529574
]
530575
or is_nemotron_v2
531576
or is_nemotron_v3_hf
577+
or has_padding_idx
532578
)
533579
if not skip_initialize_weights:
534580
for _, module in model.named_modules():
@@ -539,7 +585,8 @@ def initialize_model_weights(
539585
model.initialize_weights()
540586
else:
541587
logging.warning(
542-
"Warning: Model does not have initialize_weights method. Requires custom initialization to be implemented."
588+
"Warning: Model does not have initialize_weights method."
589+
" Requires custom initialization to be implemented."
543590
)
544591

545592
if peft_init_method is not None:
@@ -563,10 +610,11 @@ def load_base_model(
563610
model_name: Name of the model or an absolute path to a snapshot
564611
load_base_model: If True, restore from HF base checkpoint
565612
"""
613+
model_type = getattr(getattr(model, "config", None), "model_type", None)
614+
566615
if load_base_model:
567616
assert model_name is not None, "model_name is required when loading base model"
568617
# Get combined key mapping from model attribute and model-type specific conversions
569-
model_type = getattr(getattr(model, "config", None), "model_type", None)
570618
model_key_mapping = getattr(model, "_checkpoint_conversion_mapping", None)
571619
key_mapping = get_combined_key_mapping(model_type, model_key_mapping)
572620
# NemotronH remote code (trust_remote_code) uses backbone.* params matching checkpoint keys
@@ -582,7 +630,7 @@ def load_base_model(
582630
key_mapping=key_mapping,
583631
)
584632

585-
_reinit_rope_buffers(model, device)
633+
_reinit_non_persistent_buffers(model, device, model_type=model_type)
586634

587635
is_tied_lm_head = is_tied_word_embeddings(model)
588636
self.config.original_model_root_dir = root_dir
@@ -677,8 +725,18 @@ def _do_load(
677725
is_model = True if "/model" in path else False
678726
# PEFT loading is broadcasted from rank0 so it is a special case
679727
if self.config.is_peft and is_model and (not is_init_step):
680-
state_dict = load_file(os.path.join(path, "adapter_model.safetensors"))
728+
if is_cloud_path(path):
729+
_ensure_msc_available()
730+
adapter_path = path.rstrip("/") + "/adapter_model.safetensors"
731+
with msc.open(adapter_path, "rb") as f:
732+
data = f.read()
733+
state_dict = safetensors_load(data)
734+
else:
735+
state_dict = load_file(os.path.join(path, "adapter_model.safetensors"))
681736
else:
737+
if is_cloud_path(path) and storage_reader is None:
738+
_ensure_msc_available()
739+
storage_reader = msc.torch.MultiStorageFileSystemReader(path)
682740
dcp.load(state_dict, checkpoint_id=path, storage_reader=storage_reader)
683741
return state_dict
684742

@@ -704,13 +762,25 @@ def _do_save(
704762
# PEFT saving is done on rank0 so it is a special case
705763
if self.config.is_peft and is_model:
706764
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
707-
save_file(state_dict, os.path.join(path, "adapter_model.safetensors"))
765+
if is_cloud_path(path):
766+
_ensure_msc_available()
767+
adapter_path = path.rstrip("/") + "/adapter_model.safetensors"
768+
with msc.open(adapter_path, "wb") as f:
769+
save_file(state_dict, f)
770+
else:
771+
save_file(state_dict, os.path.join(path, "adapter_model.safetensors"))
708772
if torch.distributed.is_initialized():
709773
torch.distributed.barrier()
710774
return
711775

712776
ret = None
713777
planner = dcp.DefaultSavePlanner(enable_plan_caching=True)
778+
779+
# Routes to MSC storage write for cloud paths
780+
if is_cloud_path(path) and storage_writer is None:
781+
_ensure_msc_available()
782+
storage_writer = msc.torch.MultiStorageFileSystemWriter(path)
783+
714784
if self.config.is_async:
715785
ctx = self._model_ctx if is_model else self._optim_ctx
716786
ret = dcp.async_save(
@@ -974,8 +1044,14 @@ def save_config(config: dict[str, Any], weights_path: str) -> None:
9741044
config: Config to save
9751045
weights_path: Path to save config
9761046
"""
977-
with open(os.path.join(weights_path, "config.yaml"), "w") as f:
978-
yaml.dump(config, f, sort_keys=False, default_flow_style=False)
1047+
config_path = os.path.join(weights_path, "config.yaml")
1048+
if is_cloud_path(weights_path):
1049+
_ensure_msc_available()
1050+
with msc.open(config_path, "w") as f:
1051+
yaml.dump(config, f, sort_keys=False, default_flow_style=False)
1052+
else:
1053+
with open(config_path, "w") as f:
1054+
yaml.dump(config, f, sort_keys=False, default_flow_style=False)
9791055

9801056

9811057
def _ensure_dirs(*dirs: Optional[str]) -> None:
@@ -987,7 +1063,8 @@ def _ensure_dirs(*dirs: Optional[str]) -> None:
9871063
"""
9881064
for d in dirs:
9891065
if d:
990-
os.makedirs(d, exist_ok=True)
1066+
if not is_cloud_path(d):
1067+
os.makedirs(d, exist_ok=True)
9911068
if torch.distributed.is_initialized():
9921069
torch.distributed.barrier()
9931070

@@ -1008,18 +1085,48 @@ def _init_peft_adapters(model: nn.Module, peft_init_method: str) -> None:
10081085
logging.warning(f"Failed to initialize weights for PEFT adapter `{module.__class__.__name__}`: {e}")
10091086

10101087

1011-
def _reinit_rope_buffers(model: nn.Module, device: torch.device) -> None:
1088+
_MODELS_REQUIRING_BUFFER_REINIT: frozenset[str] = frozenset(
1089+
{
1090+
"gemma3",
1091+
"nemotron-nas",
1092+
}
1093+
)
1094+
1095+
1096+
def _reinit_non_persistent_buffers(model: nn.Module, device: torch.device, model_type: str | None = None) -> None:
10121097
"""
1013-
Recompute non-persistent RoPE ``inv_freq`` buffers for Nemotron-NAS models.
1098+
Recompute non-persistent buffers that are not saved in checkpoints.
1099+
1100+
Non-persistent buffers are not saved in checkpoints, so after meta-device
1101+
materialization they contain uninitialized CUDA memory. When
1102+
``initialize_weights()`` is skipped (e.g. for Gemma3 to avoid DTensor
1103+
issues), these buffers must be recomputed explicitly.
1104+
1105+
Only runs for models listed in ``_MODELS_REQUIRING_BUFFER_REINIT`` to
1106+
avoid unexpected side-effects on arbitrary HF Hub models.
1107+
1108+
Handles four patterns:
1109+
1110+
1. **Standard RoPE** — single ``inv_freq`` buffer with ``rope_init_fn`` +
1111+
``rope_kwargs`` (e.g. Nemotron-NAS).
1112+
2. **Per-layer-type RoPE** — ``{layer_type}_inv_freq`` buffers via
1113+
``compute_default_rope_parameters`` (e.g. Gemma3RotaryEmbedding).
1114+
3. **Scaled embedding** — ``embed_scale`` buffer on ``ScaledWordEmbedding``
1115+
modules (Gemma family), recomputed from ``scalar_embed_scale``.
1116+
4. **Vision position IDs** — ``position_ids`` buffer on vision embedding
1117+
modules (SigLIP), recomputed from ``num_positions``.
1118+
10141119
Args:
1015-
model: Model to reinitialize RoPE buffers for.
1120+
model: Model to reinitialize non-persistent buffers for.
10161121
device: Device to create the new buffers on.
1122+
model_type: The ``config.model_type`` string. If not in
1123+
``_MODELS_REQUIRING_BUFFER_REINIT`` the function is a no-op.
10171124
"""
1018-
model_type = getattr(getattr(model, "config", None), "model_type", None)
1019-
if model_type not in ("nemotron-nas",):
1125+
if model_type not in _MODELS_REQUIRING_BUFFER_REINIT:
10201126
return
10211127

10221128
for name, module in model.named_modules():
1129+
# Pattern 1: standard RoPE with rope_init_fn + rope_kwargs (Nemotron-NAS)
10231130
if hasattr(module, "rope_init_fn") and hasattr(module, "inv_freq") and hasattr(module, "rope_kwargs"):
10241131
try:
10251132
inv_freq, _ = module.rope_init_fn(module.config, device, **module.rope_kwargs)
@@ -1030,6 +1137,51 @@ def _reinit_rope_buffers(model: nn.Module, device: torch.device) -> None:
10301137
except Exception as e:
10311138
logging.warning(f"Failed to reinitialize RoPE inv_freq for {name}: {e}")
10321139

1140+
# Pattern 2: per-layer-type RoPE (Gemma3RotaryEmbedding and similar)
1141+
elif hasattr(module, "layer_types") and hasattr(module, "rope_type") and hasattr(module, "config"):
1142+
rope_config = getattr(module, "config", None)
1143+
rope_parameters = getattr(rope_config, "rope_parameters", None)
1144+
if rope_parameters is None:
1145+
continue
1146+
for layer_type in getattr(module, "layer_types", []):
1147+
inv_freq_attr = f"{layer_type}_inv_freq"
1148+
if not hasattr(module, inv_freq_attr):
1149+
continue
1150+
try:
1151+
rope_init_fn = getattr(module, "compute_default_rope_parameters", None)
1152+
if rope_init_fn is None:
1153+
continue
1154+
rope_type = module.rope_type.get(layer_type, "default")
1155+
if rope_type != "default":
1156+
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
1157+
1158+
rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type]
1159+
curr_inv_freq, curr_attention_scaling = rope_init_fn(rope_config, device, layer_type=layer_type)
1160+
setattr(module, inv_freq_attr, curr_inv_freq)
1161+
orig_attr = f"{layer_type}_original_inv_freq"
1162+
if hasattr(module, orig_attr):
1163+
setattr(module, orig_attr, curr_inv_freq.clone())
1164+
setattr(module, f"{layer_type}_attention_scaling", curr_attention_scaling)
1165+
logging.debug(f"Reinitialized RoPE {inv_freq_attr} for {name} on device {device}")
1166+
except Exception as e:
1167+
logging.warning(f"Failed to reinitialize RoPE {inv_freq_attr} for {name}: {e}")
1168+
1169+
# Pattern 3: ScaledWordEmbedding embed_scale (Gemma family)
1170+
if hasattr(module, "scalar_embed_scale") and "embed_scale" in getattr(module, "_buffers", {}):
1171+
try:
1172+
module.embed_scale = torch.tensor(module.scalar_embed_scale, device=device)
1173+
logging.debug(f"Reinitialized embed_scale={module.scalar_embed_scale} for {name} on device {device}")
1174+
except Exception as e:
1175+
logging.warning(f"Failed to reinitialize embed_scale for {name}: {e}")
1176+
1177+
# Pattern 4: Vision embedding position_ids (SigLIP and similar)
1178+
if hasattr(module, "num_positions") and "position_ids" in getattr(module, "_buffers", {}):
1179+
try:
1180+
module.position_ids = torch.arange(module.num_positions, device=device).expand((1, -1))
1181+
logging.debug(f"Reinitialized position_ids (num_positions={module.num_positions}) for {name}")
1182+
except Exception as e:
1183+
logging.warning(f"Failed to reinitialize position_ids for {name}: {e}")
1184+
10331185

10341186
def _apply(module, fn, recurse=True) -> nn.Module:
10351187
"""

0 commit comments

Comments
 (0)