From 27f48c4cd156849fbdf27f848e84988bf64e855a Mon Sep 17 00:00:00 2001 From: John Zakkam Date: Fri, 17 Feb 2023 16:08:53 +0530 Subject: [PATCH 1/3] gpu_index update in simclr.py --- simclr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/simclr.py b/simclr.py index e022dca6..c211c421 100644 --- a/simclr.py +++ b/simclr.py @@ -63,7 +63,7 @@ def train(self, train_loader): n_iter = 0 logging.info(f"Start SimCLR training for {self.args.epochs} epochs.") - logging.info(f"Training with gpu: {self.args.disable_cuda}.") + logging.info(f"Training with gpu: {self.args.gpu_index}.") for epoch_counter in range(self.args.epochs): for images, _ in tqdm(train_loader): From 9c0b46b2832c87792c24d052cc370f6f4be62163 Mon Sep 17 00:00:00 2001 From: John Zakkam Date: Fri, 17 Feb 2023 23:07:15 +0530 Subject: [PATCH 2/3] Update for training from checkpoint --- data_aug/contrastive_learning_dataset.py | 5 ++++- exceptions/exceptions.py | 3 +++ run.py | 5 ++++- simclr.py | 16 ++++++++++++---- utils.py | 8 ++++++++ 5 files changed, 31 insertions(+), 6 deletions(-) diff --git a/data_aug/contrastive_learning_dataset.py b/data_aug/contrastive_learning_dataset.py index e1777e4c..733ee842 100644 --- a/data_aug/contrastive_learning_dataset.py +++ b/data_aug/contrastive_learning_dataset.py @@ -32,7 +32,10 @@ def get_dataset(self, name, n_views): transform=ContrastiveLearningViewGenerator( self.get_simclr_pipeline_transform(96), n_views), - download=True)} + download=True), + 'facecrops': lambda: datasets.ImageFolder('/Users/johnzakkam/Research/SimCLR/data/facecrops', transform=ContrastiveLearningViewGenerator( + self.get_simclr_pipeline_transform(32), n_views)) + } try: dataset_fn = valid_datasets[name] diff --git a/exceptions/exceptions.py b/exceptions/exceptions.py index a7370841..48aa205c 100644 --- a/exceptions/exceptions.py +++ b/exceptions/exceptions.py @@ -8,3 +8,6 @@ class InvalidBackboneError(BaseSimCLRException): class InvalidDatasetSelection(BaseSimCLRException): """Raised when the choice of dataset is invalid.""" + +class InvalidCheckpointPath(BaseSimCLRException): + """Raised when the path of the checkpoint is invalid""" \ No newline at end of file diff --git a/run.py b/run.py index e2424c6a..ef260a9f 100644 --- a/run.py +++ b/run.py @@ -20,6 +20,8 @@ help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') +parser.add_argument('-ckpt', default=None, type=str, metavar='CKPT', + help='the checkpoint to resume training') parser.add_argument('-j', '--workers', default=12, type=int, metavar='N', help='number of data loading workers (default: 32)') parser.add_argument('--epochs', default=200, type=int, metavar='N', @@ -73,6 +75,7 @@ def main(): num_workers=args.workers, pin_memory=True, drop_last=True) model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim) + ckpt = args.ckpt optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) @@ -81,7 +84,7 @@ def main(): # It’s a no-op if the 'gpu_index' argument is a negative integer or None. with torch.cuda.device(args.gpu_index): - simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args) + simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args, ckpt=ckpt) simclr.train(train_loader) diff --git a/simclr.py b/simclr.py index c211c421..f108e03a 100644 --- a/simclr.py +++ b/simclr.py @@ -7,7 +7,8 @@ from torch.cuda.amp import GradScaler, autocast from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from utils import save_config_file, accuracy, save_checkpoint +from exceptions.exceptions import InvalidCheckpointPath +from utils import save_config_file, accuracy, save_checkpoint, load_checkpoint torch.manual_seed(0) @@ -60,12 +61,19 @@ def train(self, train_loader): # save config file save_config_file(self.writer.log_dir, self.args) - - n_iter = 0 + n_iter, start_epochs, end_epochs = 0, 0, self.args.epochs + + if(self.args.ckpt): + try: + self.model, start_epochs = load_checkpoint(self.model, self.args.ckpt) + end_epochs = end_epochs + start_epochs + except: + InvalidCheckpointPath() + logging.info(f"Start SimCLR training for {self.args.epochs} epochs.") logging.info(f"Training with gpu: {self.args.gpu_index}.") - for epoch_counter in range(self.args.epochs): + for epoch_counter in range(start_epochs, end_epochs): for images, _ in tqdm(train_loader): images = torch.cat(images, dim=0) diff --git a/utils.py b/utils.py index cf92cbd8..cdf22aa2 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,14 @@ import yaml +def load_checkpoint(model, filepath): + if(os.path.exists(filepath)): + ckpt = torch.load(filepath) + model.load_state_dict(ckpt['state_dict']) + epoch = ckpt['epoch'] + return model, epoch + + def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): torch.save(state, filename) if is_best: From ed0f6e7d449386a05814d3bd5c618a801720a276 Mon Sep 17 00:00:00 2001 From: John Zakkam Date: Sun, 19 Feb 2023 16:18:17 +0530 Subject: [PATCH 3/3] Update contrastive_learning_dataset.py --- data_aug/contrastive_learning_dataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/data_aug/contrastive_learning_dataset.py b/data_aug/contrastive_learning_dataset.py index 733ee842..4094cd6b 100644 --- a/data_aug/contrastive_learning_dataset.py +++ b/data_aug/contrastive_learning_dataset.py @@ -32,9 +32,7 @@ def get_dataset(self, name, n_views): transform=ContrastiveLearningViewGenerator( self.get_simclr_pipeline_transform(96), n_views), - download=True), - 'facecrops': lambda: datasets.ImageFolder('/Users/johnzakkam/Research/SimCLR/data/facecrops', transform=ContrastiveLearningViewGenerator( - self.get_simclr_pipeline_transform(32), n_views)) + download=True) } try: