2525import torch .distributed .checkpoint as dcp
2626import yaml
2727
28+ try :
29+ import multistorageclient as msc
30+
31+ MSC_AVAILABLE = True
32+ except ImportError :
33+ msc = None
34+ MSC_AVAILABLE = False
35+
2836# Safe import of HF_HUB_CACHE from huggingface_hub.constants
2937try :
3038 from huggingface_hub .constants import HF_HUB_CACHE
3139except ImportError :
3240 HF_HUB_CACHE = None
3341
3442from packaging .version import parse
43+ from safetensors .torch import load as safetensors_load
3544from safetensors .torch import load_file , save_file
3645from torch import nn
3746from torch .distributed .device_mesh import DeviceMesh
5968 from transformers .tokenization_utils_base import PreTrainedTokenizerBase
6069
6170
71+ def is_cloud_path (path : str ) -> bool :
72+ """Check if path is a cloud storage path (MSC)."""
73+ return path .startswith ("msc://" )
74+
75+
76+ def _ensure_msc_available () -> None :
77+ """Raise an error if MSC is not installed but a cloud path is used."""
78+ if not MSC_AVAILABLE :
79+ raise ImportError (
80+ "multistorageclient is required for cloud storage paths. "
81+ "Install it with: pip install multi-storage-client "
82+ "--index-url https://pypi.nvidia.com"
83+ )
84+
85+
6286def _is_geq_torch_2_9 () -> bool :
6387 """
6488 Check if the current torch version is greater than or equal to 2.9.0.
@@ -267,7 +291,11 @@ def save_model(
267291
268292 # Convert to HF format if using custom model implementations
269293 state_dict = _maybe_adapt_state_dict_to_hf (
270- model_state .model [0 ], state_dict , quantization = False , device_mesh = self .moe_mesh
294+ model_state .model [0 ],
295+ state_dict ,
296+ quantization = False ,
297+ device_mesh = self .moe_mesh ,
298+ v4_compatible = self .config .v4_compatible ,
271299 )
272300 # Build the consolidated model.safetensors.index.json if needed
273301 fqn_to_file_index_mapping = self ._maybe_build_consolidated_index (model_state , state_dict )
@@ -369,7 +397,7 @@ def load_model(
369397 key_mapping: Optional key remapping when reading from HF checkpoints.
370398 """
371399 # Validate checkpoint directory
372- if not os .path .exists (model_path ):
400+ if not os .path .exists (model_path ) and not is_cloud_path ( model_path ) :
373401 raise FileNotFoundError (f"Model path { model_path } does not exist" )
374402 model_state = ModelState (
375403 model ,
@@ -481,9 +509,15 @@ def initialize_model_weights(
481509 device: Target device for materialized parameters.
482510 peft_init_method: Initialization method for PEFT adapters (e.g. "xavier").
483511 """
484- to_empty_parameters_only (model , device = device )
512+ # Only materialize parameters that are actually on the meta device.
513+ # When the caller sets is_meta_device=True but the model was already
514+ # constructed on a real device (e.g. ContextManagers was patched to
515+ # a no-op), calling to_empty_parameters_only would replace valid
516+ # weights with uninitialized CUDA memory.
517+ has_meta_params = any (p .device .type == "meta" for p in model .parameters ())
518+ if has_meta_params :
519+ to_empty_parameters_only (model , device = device )
485520
486- # to_empty_parameters_only only materializes parameters, not buffers.
487521 # Buffers (e.g. RoPE inv_freq) may still be on meta device. Move them
488522 # to *device* with uninitialized storage so that the subsequent
489523 # initialize_weights() call can overwrite them with proper values
@@ -521,6 +555,17 @@ def initialize_model_weights(
521555 and getattr (model .config , "n_routed_experts" , None ) # is Nemotron V3
522556 and hasattr (model , "backbone" ) # is HF remote code
523557 )
558+ # HF's _init_weights calls init.zeros_(weight[padding_idx]) on
559+ # nn.Embedding layers. When the weight is a DTensor (TP-sharded),
560+ # the integer index triggers a redistribute that fails. Temporarily
561+ # clear padding_idx so the zeroing is skipped, then restore it and
562+ # zero the row via local-tensor ops instead.
563+ has_padding_idx = any (
564+ isinstance (mod , nn .Embedding )
565+ and type (mod .weight ).__name__ == "DTensor"
566+ and getattr (mod , "padding_idx" , None ) is not None
567+ for mod in model .modules ()
568+ )
524569 skip_initialize_weights = (
525570 model_class
526571 in [
@@ -529,6 +574,7 @@ def initialize_model_weights(
529574 ]
530575 or is_nemotron_v2
531576 or is_nemotron_v3_hf
577+ or has_padding_idx
532578 )
533579 if not skip_initialize_weights :
534580 for _ , module in model .named_modules ():
@@ -539,7 +585,8 @@ def initialize_model_weights(
539585 model .initialize_weights ()
540586 else :
541587 logging .warning (
542- "Warning: Model does not have initialize_weights method. Requires custom initialization to be implemented."
588+ "Warning: Model does not have initialize_weights method."
589+ " Requires custom initialization to be implemented."
543590 )
544591
545592 if peft_init_method is not None :
@@ -563,10 +610,11 @@ def load_base_model(
563610 model_name: Name of the model or an absolute path to a snapshot
564611 load_base_model: If True, restore from HF base checkpoint
565612 """
613+ model_type = getattr (getattr (model , "config" , None ), "model_type" , None )
614+
566615 if load_base_model :
567616 assert model_name is not None , "model_name is required when loading base model"
568617 # Get combined key mapping from model attribute and model-type specific conversions
569- model_type = getattr (getattr (model , "config" , None ), "model_type" , None )
570618 model_key_mapping = getattr (model , "_checkpoint_conversion_mapping" , None )
571619 key_mapping = get_combined_key_mapping (model_type , model_key_mapping )
572620 # NemotronH remote code (trust_remote_code) uses backbone.* params matching checkpoint keys
@@ -582,7 +630,7 @@ def load_base_model(
582630 key_mapping = key_mapping ,
583631 )
584632
585- _reinit_rope_buffers (model , device )
633+ _reinit_non_persistent_buffers (model , device , model_type = model_type )
586634
587635 is_tied_lm_head = is_tied_word_embeddings (model )
588636 self .config .original_model_root_dir = root_dir
@@ -677,8 +725,18 @@ def _do_load(
677725 is_model = True if "/model" in path else False
678726 # PEFT loading is broadcasted from rank0 so it is a special case
679727 if self .config .is_peft and is_model and (not is_init_step ):
680- state_dict = load_file (os .path .join (path , "adapter_model.safetensors" ))
728+ if is_cloud_path (path ):
729+ _ensure_msc_available ()
730+ adapter_path = path .rstrip ("/" ) + "/adapter_model.safetensors"
731+ with msc .open (adapter_path , "rb" ) as f :
732+ data = f .read ()
733+ state_dict = safetensors_load (data )
734+ else :
735+ state_dict = load_file (os .path .join (path , "adapter_model.safetensors" ))
681736 else :
737+ if is_cloud_path (path ) and storage_reader is None :
738+ _ensure_msc_available ()
739+ storage_reader = msc .torch .MultiStorageFileSystemReader (path )
682740 dcp .load (state_dict , checkpoint_id = path , storage_reader = storage_reader )
683741 return state_dict
684742
@@ -704,13 +762,25 @@ def _do_save(
704762 # PEFT saving is done on rank0 so it is a special case
705763 if self .config .is_peft and is_model :
706764 if not torch .distributed .is_initialized () or torch .distributed .get_rank () == 0 :
707- save_file (state_dict , os .path .join (path , "adapter_model.safetensors" ))
765+ if is_cloud_path (path ):
766+ _ensure_msc_available ()
767+ adapter_path = path .rstrip ("/" ) + "/adapter_model.safetensors"
768+ with msc .open (adapter_path , "wb" ) as f :
769+ save_file (state_dict , f )
770+ else :
771+ save_file (state_dict , os .path .join (path , "adapter_model.safetensors" ))
708772 if torch .distributed .is_initialized ():
709773 torch .distributed .barrier ()
710774 return
711775
712776 ret = None
713777 planner = dcp .DefaultSavePlanner (enable_plan_caching = True )
778+
779+ # Routes to MSC storage write for cloud paths
780+ if is_cloud_path (path ) and storage_writer is None :
781+ _ensure_msc_available ()
782+ storage_writer = msc .torch .MultiStorageFileSystemWriter (path )
783+
714784 if self .config .is_async :
715785 ctx = self ._model_ctx if is_model else self ._optim_ctx
716786 ret = dcp .async_save (
@@ -974,8 +1044,14 @@ def save_config(config: dict[str, Any], weights_path: str) -> None:
9741044 config: Config to save
9751045 weights_path: Path to save config
9761046 """
977- with open (os .path .join (weights_path , "config.yaml" ), "w" ) as f :
978- yaml .dump (config , f , sort_keys = False , default_flow_style = False )
1047+ config_path = os .path .join (weights_path , "config.yaml" )
1048+ if is_cloud_path (weights_path ):
1049+ _ensure_msc_available ()
1050+ with msc .open (config_path , "w" ) as f :
1051+ yaml .dump (config , f , sort_keys = False , default_flow_style = False )
1052+ else :
1053+ with open (config_path , "w" ) as f :
1054+ yaml .dump (config , f , sort_keys = False , default_flow_style = False )
9791055
9801056
9811057def _ensure_dirs (* dirs : Optional [str ]) -> None :
@@ -987,7 +1063,8 @@ def _ensure_dirs(*dirs: Optional[str]) -> None:
9871063 """
9881064 for d in dirs :
9891065 if d :
990- os .makedirs (d , exist_ok = True )
1066+ if not is_cloud_path (d ):
1067+ os .makedirs (d , exist_ok = True )
9911068 if torch .distributed .is_initialized ():
9921069 torch .distributed .barrier ()
9931070
@@ -1008,18 +1085,48 @@ def _init_peft_adapters(model: nn.Module, peft_init_method: str) -> None:
10081085 logging .warning (f"Failed to initialize weights for PEFT adapter `{ module .__class__ .__name__ } `: { e } " )
10091086
10101087
1011- def _reinit_rope_buffers (model : nn .Module , device : torch .device ) -> None :
1088+ _MODELS_REQUIRING_BUFFER_REINIT : frozenset [str ] = frozenset (
1089+ {
1090+ "gemma3" ,
1091+ "nemotron-nas" ,
1092+ }
1093+ )
1094+
1095+
1096+ def _reinit_non_persistent_buffers (model : nn .Module , device : torch .device , model_type : str | None = None ) -> None :
10121097 """
1013- Recompute non-persistent RoPE ``inv_freq`` buffers for Nemotron-NAS models.
1098+ Recompute non-persistent buffers that are not saved in checkpoints.
1099+
1100+ Non-persistent buffers are not saved in checkpoints, so after meta-device
1101+ materialization they contain uninitialized CUDA memory. When
1102+ ``initialize_weights()`` is skipped (e.g. for Gemma3 to avoid DTensor
1103+ issues), these buffers must be recomputed explicitly.
1104+
1105+ Only runs for models listed in ``_MODELS_REQUIRING_BUFFER_REINIT`` to
1106+ avoid unexpected side-effects on arbitrary HF Hub models.
1107+
1108+ Handles four patterns:
1109+
1110+ 1. **Standard RoPE** — single ``inv_freq`` buffer with ``rope_init_fn`` +
1111+ ``rope_kwargs`` (e.g. Nemotron-NAS).
1112+ 2. **Per-layer-type RoPE** — ``{layer_type}_inv_freq`` buffers via
1113+ ``compute_default_rope_parameters`` (e.g. Gemma3RotaryEmbedding).
1114+ 3. **Scaled embedding** — ``embed_scale`` buffer on ``ScaledWordEmbedding``
1115+ modules (Gemma family), recomputed from ``scalar_embed_scale``.
1116+ 4. **Vision position IDs** — ``position_ids`` buffer on vision embedding
1117+ modules (SigLIP), recomputed from ``num_positions``.
1118+
10141119 Args:
1015- model: Model to reinitialize RoPE buffers for.
1120+ model: Model to reinitialize non-persistent buffers for.
10161121 device: Device to create the new buffers on.
1122+ model_type: The ``config.model_type`` string. If not in
1123+ ``_MODELS_REQUIRING_BUFFER_REINIT`` the function is a no-op.
10171124 """
1018- model_type = getattr (getattr (model , "config" , None ), "model_type" , None )
1019- if model_type not in ("nemotron-nas" ,):
1125+ if model_type not in _MODELS_REQUIRING_BUFFER_REINIT :
10201126 return
10211127
10221128 for name , module in model .named_modules ():
1129+ # Pattern 1: standard RoPE with rope_init_fn + rope_kwargs (Nemotron-NAS)
10231130 if hasattr (module , "rope_init_fn" ) and hasattr (module , "inv_freq" ) and hasattr (module , "rope_kwargs" ):
10241131 try :
10251132 inv_freq , _ = module .rope_init_fn (module .config , device , ** module .rope_kwargs )
@@ -1030,6 +1137,51 @@ def _reinit_rope_buffers(model: nn.Module, device: torch.device) -> None:
10301137 except Exception as e :
10311138 logging .warning (f"Failed to reinitialize RoPE inv_freq for { name } : { e } " )
10321139
1140+ # Pattern 2: per-layer-type RoPE (Gemma3RotaryEmbedding and similar)
1141+ elif hasattr (module , "layer_types" ) and hasattr (module , "rope_type" ) and hasattr (module , "config" ):
1142+ rope_config = getattr (module , "config" , None )
1143+ rope_parameters = getattr (rope_config , "rope_parameters" , None )
1144+ if rope_parameters is None :
1145+ continue
1146+ for layer_type in getattr (module , "layer_types" , []):
1147+ inv_freq_attr = f"{ layer_type } _inv_freq"
1148+ if not hasattr (module , inv_freq_attr ):
1149+ continue
1150+ try :
1151+ rope_init_fn = getattr (module , "compute_default_rope_parameters" , None )
1152+ if rope_init_fn is None :
1153+ continue
1154+ rope_type = module .rope_type .get (layer_type , "default" )
1155+ if rope_type != "default" :
1156+ from transformers .modeling_rope_utils import ROPE_INIT_FUNCTIONS
1157+
1158+ rope_init_fn = ROPE_INIT_FUNCTIONS [rope_type ]
1159+ curr_inv_freq , curr_attention_scaling = rope_init_fn (rope_config , device , layer_type = layer_type )
1160+ setattr (module , inv_freq_attr , curr_inv_freq )
1161+ orig_attr = f"{ layer_type } _original_inv_freq"
1162+ if hasattr (module , orig_attr ):
1163+ setattr (module , orig_attr , curr_inv_freq .clone ())
1164+ setattr (module , f"{ layer_type } _attention_scaling" , curr_attention_scaling )
1165+ logging .debug (f"Reinitialized RoPE { inv_freq_attr } for { name } on device { device } " )
1166+ except Exception as e :
1167+ logging .warning (f"Failed to reinitialize RoPE { inv_freq_attr } for { name } : { e } " )
1168+
1169+ # Pattern 3: ScaledWordEmbedding embed_scale (Gemma family)
1170+ if hasattr (module , "scalar_embed_scale" ) and "embed_scale" in getattr (module , "_buffers" , {}):
1171+ try :
1172+ module .embed_scale = torch .tensor (module .scalar_embed_scale , device = device )
1173+ logging .debug (f"Reinitialized embed_scale={ module .scalar_embed_scale } for { name } on device { device } " )
1174+ except Exception as e :
1175+ logging .warning (f"Failed to reinitialize embed_scale for { name } : { e } " )
1176+
1177+ # Pattern 4: Vision embedding position_ids (SigLIP and similar)
1178+ if hasattr (module , "num_positions" ) and "position_ids" in getattr (module , "_buffers" , {}):
1179+ try :
1180+ module .position_ids = torch .arange (module .num_positions , device = device ).expand ((1 , - 1 ))
1181+ logging .debug (f"Reinitialized position_ids (num_positions={ module .num_positions } ) for { name } " )
1182+ except Exception as e :
1183+ logging .warning (f"Failed to reinitialize position_ids for { name } : { e } " )
1184+
10331185
10341186def _apply (module , fn , recurse = True ) -> nn .Module :
10351187 """
0 commit comments