6464 # since PyTorch 2.3 the path has changed
6565 from torch .amp .grad_scaler import _refresh_per_optimizer_state
6666
67+ from concurrent .futures import ThreadPoolExecutor
68+
69+ import multistorageclient as msc
70+ from multistorageclient .types import MSC_PROTOCOL
71+
6772from nemo .collections .nlp .modules .common .megatron .module import Float16Module
6873from nemo .collections .nlp .modules .common .megatron .transformer import AutocastTransformerLayer , ParallelTransformerLayer
6974from nemo .collections .nlp .parts import utils_funcs
7378from nemo .utils import AppState , logging
7479from nemo .utils .model_utils import ckpt_to_dir , inject_model_parallel_rank , uninject_model_parallel_rank
7580
76- from concurrent .futures import ThreadPoolExecutor
77- from multistorageclient .types import MSC_PROTOCOL
78- import multistorageclient as msc
79-
8081try :
8182
8283 from nemo .core .optim .distributed_adam import MegatronDistributedFusedAdam
@@ -1042,25 +1043,26 @@ def msc_download_dir(url: str, local_path: str):
10421043 if not msc .os .path .exists (url ):
10431044 raise Exception (f"Download Path doesn't exist: { url } " )
10441045
1045- base_name = os .path .basename (url ) # url = "msc://my-profile/path/to/data", base_name = "data"
1046+ base_name = os .path .basename (url ) # url = "msc://my-profile/path/to/data", base_name = "data"
10461047 files = msc .list (url )
10471048
10481049 def download_file (item ):
10491050 """Helper function to download a single file."""
1050- file_name = item .key #item.key = "msc://profile/path/to/data/file1.txt"
1051- base_name_idx = file_name .find (base_name ) # base_name_idx = 23
1052- local_file_path = f"{ local_path } /{ file_name [base_name_idx :]} " #local_file_path = f"{local_path}/data/file1.txt"
1051+ file_name = item .key # item.key = "msc://profile/path/to/data/file1.txt"
1052+ base_name_idx = file_name .find (base_name ) # base_name_idx = 23
1053+ local_file_path = (
1054+ f"{ local_path } /{ file_name [base_name_idx :]} " # local_file_path = f"{local_path}/data/file1.txt"
1055+ )
10531056 os .makedirs (os .path .dirname (local_file_path ), exist_ok = True )
10541057 msc .download_file (item , local_file_path )
1055- #msc.download_file(f"{MSC_PROTOCOL}{get_profile()}/{file_name}", local_file_path)
1058+ # msc.download_file(f"{MSC_PROTOCOL}{get_profile()}/{file_name}", local_file_path)
10561059
10571060 # Use ThreadPoolExecutor for par allel downloads
10581061 with ThreadPoolExecutor (max_workers = 32 ) as executor : # Adjust max_workers as needed
10591062 executor .map (download_file , files )
10601063
10611064 logging .warning (f"msc_download_dir completed rank { torch .distributed .get_rank ()} " )
1062-
1063-
1065+
10641066
10651067class NLPSaveRestoreConnector (SaveRestoreConnector ):
10661068 """Custom connector to support saving and restoring states."""
@@ -1083,7 +1085,6 @@ def __init__(self) -> None:
10831085 )
10841086 super ().__init__ ()
10851087
1086-
10871088 def save_to (self , model , save_path : str ):
10881089 """Save model to save path."""
10891090 app_state = AppState ()
@@ -1102,17 +1103,16 @@ def save_to(self, model, save_path: str):
11021103 is_msc_enabled = False
11031104 if MSC_PROTOCOL in dir_name :
11041105 is_msc_enabled = True
1105-
1106+
11061107 # dist ckpt calls save on every rank
11071108 if dist_ckpt :
11081109 # model weights is a directory
11091110 dist_ckpt_dir = ckpt_to_dir (os .path .join (dir_name , self .model_weights_ckpt ))
11101111
11111112 if is_msc_enabled :
1112- filename = os .path .join (dir_name , self .model_weights_ckpt )
1113+ filename = os .path .join (dir_name , self .model_weights_ckpt )
11131114 dist_ckpt_dir = os .path .splitext (filename )[0 ]
1114-
1115-
1115+
11161116 # dist checkpoint needs torch.distributed to save the checkpoint
11171117 if not parallel_state .is_initialized ():
11181118
@@ -1185,8 +1185,10 @@ def dummy():
11851185
11861186 if is_msc_enabled :
11871187 print (f"Downloading { mp_model_weights } to { tmpdir } " )
1188- msc_dest = os .path .join (tmpdir , f'mp_rank_{ tp_rank :02d} ' , self .model_weights_ckpt )
1189- logging .warning (f"msc_download_dir mp_model_weights from { mp_model_weights } { msc_dest } rank { torch .distributed .get_rank ()} " )
1188+ msc_dest = os .path .join (tmpdir , f'mp_rank_{ tp_rank :02d} ' , self .model_weights_ckpt )
1189+ logging .warning (
1190+ f"msc_download_dir mp_model_weights from { mp_model_weights } { msc_dest } rank { torch .distributed .get_rank ()} "
1191+ )
11901192 msc_download_dir (mp_model_weights , msc_dest )
11911193 else :
11921194 shutil .move (
@@ -1206,8 +1208,12 @@ def dummy():
12061208
12071209 if is_msc_enabled :
12081210 print (f"Downloading { mp_model_weights } to { tmpdir } " )
1209- msc_dest = os .path .join (tmpdir , f'tp_rank_{ tp_rank :02d} _pp_rank_{ pp_rank :03d} ' , self .model_weights_ckpt )
1210- logging .warning (f"msc_download_dir mp_model_weights from { mp_model_weights } { msc_dest } rank { torch .distributed .get_rank ()} " )
1211+ msc_dest = os .path .join (
1212+ tmpdir , f'tp_rank_{ tp_rank :02d} _pp_rank_{ pp_rank :03d} ' , self .model_weights_ckpt
1213+ )
1214+ logging .warning (
1215+ f"msc_download_dir mp_model_weights from { mp_model_weights } { msc_dest } rank { torch .distributed .get_rank ()} "
1216+ )
12111217 msc_download_dir (mp_model_weights , msc_dest )
12121218 else :
12131219 shutil .move (
@@ -1368,28 +1374,24 @@ def _load_state_dict_from_disk(self, model_weights, map_location=None):
13681374 else :
13691375 raise ValueError (f'Expected { model_weights } to be a file or directory.' )
13701376
1371-
1372- def _download_nemo_file (self ,
1373- restore_path : str ,
1374- tmpdir : str ) -> str :
1375- # .nemo filename
1377+ def _download_nemo_file (self , restore_path : str , tmpdir : str ) -> str :
1378+ # .nemo filename
13761379 fname = os .path .basename (restore_path )
1377-
1378- #check if msc path exists
1380+
1381+ # check if msc path exists
13791382 if not msc .os .path .exists (restore_path ):
13801383 raise FileNotFoundError (f".nemo file doesn't exist at { restore_path } " )
1381-
1382- #download .nemo file to tempdir
1384+
1385+ # download .nemo file to tempdir
13831386 os .makedirs (tmpdir , exist_ok = True )
13841387 logging .warning (f"Starting .nemo download { restore_path } " )
13851388 msc .download_file (restore_path , f"{ tmpdir } /{ fname } " )
1386-
1387- #update restore_path to point to downloaded .nemo
1389+
1390+ # update restore_path to point to downloaded .nemo
13881391 updated_restore_path = os .path .join (tmpdir , fname )
13891392 logging .warning (f".nemo download complete; updated_restore_path to { updated_restore_path } " )
13901393 return updated_restore_path
13911394
1392-
13931395 def restore_from (
13941396 self ,
13951397 calling_cls ,
@@ -1459,7 +1461,7 @@ def dummy():
14591461 trainer .strategy .setup_environment ()
14601462
14611463 # with tempfile.TemporaryDirectory() as tmpdir:
1462- # Check if self.model_extracted_dir is set, and is a valid path
1464+ # Check if self.model_extracted_dir is set, and is a valid path
14631465 if self .model_extracted_dir is not None and os .path .isdir (self .model_extracted_dir ):
14641466 # Log that NeMo will use the provided `model_extracted_dir`
14651467 logging .info (
@@ -1512,7 +1514,7 @@ def dummy():
15121514 else :
15131515 state_dict = self .modify_state_dict (conf , state_dict )
15141516 super ().load_instance_with_state_dict (instance , state_dict , strict )
1515-
1517+
15161518 logging .info (f'Model { instance .__class__ .__name__ } was successfully restored from { restore_path } .' )
15171519 return instance
15181520
0 commit comments