Skip to content

Commit 62e6bae

Browse files
committed
Apply isort and black reformatting
Signed-off-by: ankitmaster08 <ankitmaster08@users.noreply.github.com>
1 parent 84e0a8b commit 62e6bae

5 files changed

Lines changed: 49 additions & 41 deletions

File tree

nemo/collections/nlp/parts/nlp_overrides.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@
6464
# since PyTorch 2.3 the path has changed
6565
from torch.amp.grad_scaler import _refresh_per_optimizer_state
6666

67+
from concurrent.futures import ThreadPoolExecutor
68+
69+
import multistorageclient as msc
70+
from multistorageclient.types import MSC_PROTOCOL
71+
6772
from nemo.collections.nlp.modules.common.megatron.module import Float16Module
6873
from nemo.collections.nlp.modules.common.megatron.transformer import AutocastTransformerLayer, ParallelTransformerLayer
6974
from nemo.collections.nlp.parts import utils_funcs
@@ -73,10 +78,6 @@
7378
from nemo.utils import AppState, logging
7479
from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank
7580

76-
from concurrent.futures import ThreadPoolExecutor
77-
from multistorageclient.types import MSC_PROTOCOL
78-
import multistorageclient as msc
79-
8081
try:
8182

8283
from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam
@@ -1042,25 +1043,26 @@ def msc_download_dir(url: str, local_path: str):
10421043
if not msc.os.path.exists(url):
10431044
raise Exception(f"Download Path doesn't exist: {url}")
10441045

1045-
base_name = os.path.basename(url) #url = "msc://my-profile/path/to/data", base_name = "data"
1046+
base_name = os.path.basename(url) # url = "msc://my-profile/path/to/data", base_name = "data"
10461047
files = msc.list(url)
10471048

10481049
def download_file(item):
10491050
"""Helper function to download a single file."""
1050-
file_name = item.key #item.key = "msc://profile/path/to/data/file1.txt"
1051-
base_name_idx = file_name.find(base_name) # base_name_idx = 23
1052-
local_file_path = f"{local_path}/{file_name[base_name_idx:]}" #local_file_path = f"{local_path}/data/file1.txt"
1051+
file_name = item.key # item.key = "msc://profile/path/to/data/file1.txt"
1052+
base_name_idx = file_name.find(base_name) # base_name_idx = 23
1053+
local_file_path = (
1054+
f"{local_path}/{file_name[base_name_idx:]}" # local_file_path = f"{local_path}/data/file1.txt"
1055+
)
10531056
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
10541057
msc.download_file(item, local_file_path)
1055-
#msc.download_file(f"{MSC_PROTOCOL}{get_profile()}/{file_name}", local_file_path)
1058+
# msc.download_file(f"{MSC_PROTOCOL}{get_profile()}/{file_name}", local_file_path)
10561059

10571060
# Use ThreadPoolExecutor for par allel downloads
10581061
with ThreadPoolExecutor(max_workers=32) as executor: # Adjust max_workers as needed
10591062
executor.map(download_file, files)
10601063

10611064
logging.warning(f"msc_download_dir completed rank {torch.distributed.get_rank()}")
1062-
1063-
1065+
10641066

10651067
class NLPSaveRestoreConnector(SaveRestoreConnector):
10661068
"""Custom connector to support saving and restoring states."""
@@ -1083,7 +1085,6 @@ def __init__(self) -> None:
10831085
)
10841086
super().__init__()
10851087

1086-
10871088
def save_to(self, model, save_path: str):
10881089
"""Save model to save path."""
10891090
app_state = AppState()
@@ -1102,17 +1103,16 @@ def save_to(self, model, save_path: str):
11021103
is_msc_enabled = False
11031104
if MSC_PROTOCOL in dir_name:
11041105
is_msc_enabled = True
1105-
1106+
11061107
# dist ckpt calls save on every rank
11071108
if dist_ckpt:
11081109
# model weights is a directory
11091110
dist_ckpt_dir = ckpt_to_dir(os.path.join(dir_name, self.model_weights_ckpt))
11101111

11111112
if is_msc_enabled:
1112-
filename = os.path.join(dir_name, self.model_weights_ckpt)
1113+
filename = os.path.join(dir_name, self.model_weights_ckpt)
11131114
dist_ckpt_dir = os.path.splitext(filename)[0]
1114-
1115-
1115+
11161116
# dist checkpoint needs torch.distributed to save the checkpoint
11171117
if not parallel_state.is_initialized():
11181118

@@ -1185,8 +1185,10 @@ def dummy():
11851185

11861186
if is_msc_enabled:
11871187
print(f"Downloading {mp_model_weights} to {tmpdir}")
1188-
msc_dest=os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt)
1189-
logging.warning(f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}")
1188+
msc_dest = os.path.join(tmpdir, f'mp_rank_{tp_rank:02d}', self.model_weights_ckpt)
1189+
logging.warning(
1190+
f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}"
1191+
)
11901192
msc_download_dir(mp_model_weights, msc_dest)
11911193
else:
11921194
shutil.move(
@@ -1206,8 +1208,12 @@ def dummy():
12061208

12071209
if is_msc_enabled:
12081210
print(f"Downloading {mp_model_weights} to {tmpdir}")
1209-
msc_dest = os.path.join(tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt)
1210-
logging.warning(f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}")
1211+
msc_dest = os.path.join(
1212+
tmpdir, f'tp_rank_{tp_rank:02d}_pp_rank_{pp_rank:03d}', self.model_weights_ckpt
1213+
)
1214+
logging.warning(
1215+
f"msc_download_dir mp_model_weights from {mp_model_weights} {msc_dest} rank {torch.distributed.get_rank()}"
1216+
)
12111217
msc_download_dir(mp_model_weights, msc_dest)
12121218
else:
12131219
shutil.move(
@@ -1368,28 +1374,24 @@ def _load_state_dict_from_disk(self, model_weights, map_location=None):
13681374
else:
13691375
raise ValueError(f'Expected {model_weights} to be a file or directory.')
13701376

1371-
1372-
def _download_nemo_file(self,
1373-
restore_path: str,
1374-
tmpdir: str) -> str:
1375-
# .nemo filename
1377+
def _download_nemo_file(self, restore_path: str, tmpdir: str) -> str:
1378+
# .nemo filename
13761379
fname = os.path.basename(restore_path)
1377-
1378-
#check if msc path exists
1380+
1381+
# check if msc path exists
13791382
if not msc.os.path.exists(restore_path):
13801383
raise FileNotFoundError(f".nemo file doesn't exist at {restore_path}")
1381-
1382-
#download .nemo file to tempdir
1384+
1385+
# download .nemo file to tempdir
13831386
os.makedirs(tmpdir, exist_ok=True)
13841387
logging.warning(f"Starting .nemo download {restore_path}")
13851388
msc.download_file(restore_path, f"{tmpdir}/{fname}")
1386-
1387-
#update restore_path to point to downloaded .nemo
1389+
1390+
# update restore_path to point to downloaded .nemo
13881391
updated_restore_path = os.path.join(tmpdir, fname)
13891392
logging.warning(f".nemo download complete; updated_restore_path to {updated_restore_path}")
13901393
return updated_restore_path
13911394

1392-
13931395
def restore_from(
13941396
self,
13951397
calling_cls,
@@ -1459,7 +1461,7 @@ def dummy():
14591461
trainer.strategy.setup_environment()
14601462

14611463
# with tempfile.TemporaryDirectory() as tmpdir:
1462-
# Check if self.model_extracted_dir is set, and is a valid path
1464+
# Check if self.model_extracted_dir is set, and is a valid path
14631465
if self.model_extracted_dir is not None and os.path.isdir(self.model_extracted_dir):
14641466
# Log that NeMo will use the provided `model_extracted_dir`
14651467
logging.info(
@@ -1512,7 +1514,7 @@ def dummy():
15121514
else:
15131515
state_dict = self.modify_state_dict(conf, state_dict)
15141516
super().load_instance_with_state_dict(instance, state_dict, strict)
1515-
1517+
15161518
logging.info(f'Model {instance.__class__.__name__} was successfully restored from {restore_path}.')
15171519
return instance
15181520

nemo/core/connectors/save_restore_connector.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
import tempfile
2121
import time
2222
import uuid
23-
import time
2423
from contextlib import contextmanager
2524
from typing import Callable, Generator, Optional, Set, Union
25+
2626
import torch
2727
from lightning.pytorch.trainer.trainer import Trainer
2828
from omegaconf import DictConfig, OmegaConf
@@ -42,6 +42,7 @@
4242
except (ImportError, ModuleNotFoundError):
4343
MULTISTORAGECLIENT_AVAILABLE = False
4444

45+
4546
class SaveRestoreConnector:
4647
def __init__(self) -> None:
4748
self._model_config_yaml = "model_config.yaml"

nemo/utils/callbacks/dist_ckpt_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from lightning.fabric.utilities.types import _PATH
2626
from lightning.pytorch import Callback
2727
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
28+
from multistorageclient.types import MSC_PROTOCOL
2829

2930
from nemo.utils import logging
30-
from multistorageclient.types import MSC_PROTOCOL
3131

3232
try:
3333
from megatron.core import dist_checkpointing

nemo/utils/callbacks/nemo_model_checkpoint.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol
2727
from lightning.pytorch.trainer import call
2828
from lightning.pytorch.utilities import rank_zero_info
29-
from torch import Tensor
29+
from torch import Tensor
3030

3131
from nemo.collections.common.callbacks import EMA
3232
from nemo.utils import logging
@@ -232,7 +232,7 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
232232
if self.multistorageclient_enabled:
233233
if not multistorageclient.os.path.exists(maybe_injected_best_model_path):
234234
return
235-
235+
236236
if not os.path.exists(maybe_injected_best_model_path):
237237
return
238238

@@ -242,7 +242,9 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
242242

243243
self.previous_best_path = self.best_model_path
244244
old_state_dict = deepcopy(pl_module.state_dict())
245-
checkpoint = multistorageclient.torch.load(maybe_injected_best_model_path, map_location='cpu', weights_only=False)
245+
checkpoint = multistorageclient.torch.load(
246+
maybe_injected_best_model_path, map_location='cpu', weights_only=False
247+
)
246248
if 'state_dict' in checkpoint:
247249
checkpoint = checkpoint['state_dict']
248250
# get a new instanace of the model
@@ -295,7 +297,9 @@ def on_train_end(self, trainer, pl_module):
295297
)
296298
else:
297299
if self.multistorageclient_enabled:
298-
if multistorageclient.os.path.exists(self.best_model_path) and multistorageclient.os.path.isdir(self.best_model_path):
300+
if multistorageclient.os.path.exists(self.best_model_path) and multistorageclient.os.path.isdir(
301+
self.best_model_path
302+
):
299303
self.best_model_path = self.best_model_path.split('.ckpt')[0]
300304

301305
else:
@@ -540,7 +544,7 @@ def file_exists(
540544
) -> bool:
541545
"""Checks if a file or a file without a suffix (distributed checkpoint) exists."""
542546
if self.multistorageclient_enabled:
543-
exists = self._fs.exists(filepath) # todo(avm): unsure if we need this check
547+
exists = self._fs.exists(filepath) # todo(avm): unsure if we need this check
544548
else:
545549
exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath)))
546550

nemo/utils/exp_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class CallbackParams:
164164
save_last_n_optim_states: Optional[int] = -1
165165
multistorageclient_enabled: Optional[bool] = False
166166

167+
167168
@dataclass
168169
class StepTimingParams:
169170
"""StepTimingParams POD"""

0 commit comments

Comments
 (0)