diff --git a/data_aug/contrastive_learning_dataset.py b/data_aug/contrastive_learning_dataset.py index e1777e4c..4094cd6b 100644 --- a/data_aug/contrastive_learning_dataset.py +++ b/data_aug/contrastive_learning_dataset.py @@ -32,7 +32,8 @@ def get_dataset(self, name, n_views): transform=ContrastiveLearningViewGenerator( self.get_simclr_pipeline_transform(96), n_views), - download=True)} + download=True) + } try: dataset_fn = valid_datasets[name] diff --git a/exceptions/exceptions.py b/exceptions/exceptions.py index a7370841..48aa205c 100644 --- a/exceptions/exceptions.py +++ b/exceptions/exceptions.py @@ -8,3 +8,6 @@ class InvalidBackboneError(BaseSimCLRException): class InvalidDatasetSelection(BaseSimCLRException): """Raised when the choice of dataset is invalid.""" + +class InvalidCheckpointPath(BaseSimCLRException): + """Raised when the path of the checkpoint is invalid""" \ No newline at end of file diff --git a/run.py b/run.py index e2424c6a..ef260a9f 100644 --- a/run.py +++ b/run.py @@ -20,6 +20,8 @@ help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet50)') +parser.add_argument('-ckpt', default=None, type=str, metavar='CKPT', + help='the checkpoint to resume training') parser.add_argument('-j', '--workers', default=12, type=int, metavar='N', help='number of data loading workers (default: 32)') parser.add_argument('--epochs', default=200, type=int, metavar='N', @@ -73,6 +75,7 @@ def main(): num_workers=args.workers, pin_memory=True, drop_last=True) model = ResNetSimCLR(base_model=args.arch, out_dim=args.out_dim) + ckpt = args.ckpt optimizer = torch.optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) @@ -81,7 +84,7 @@ def main(): # It’s a no-op if the 'gpu_index' argument is a negative integer or None. with torch.cuda.device(args.gpu_index): - simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args) + simclr = SimCLR(model=model, optimizer=optimizer, scheduler=scheduler, args=args, ckpt=ckpt) simclr.train(train_loader) diff --git a/simclr.py b/simclr.py index e022dca6..f108e03a 100644 --- a/simclr.py +++ b/simclr.py @@ -7,7 +7,8 @@ from torch.cuda.amp import GradScaler, autocast from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from utils import save_config_file, accuracy, save_checkpoint +from exceptions.exceptions import InvalidCheckpointPath +from utils import save_config_file, accuracy, save_checkpoint, load_checkpoint torch.manual_seed(0) @@ -60,12 +61,19 @@ def train(self, train_loader): # save config file save_config_file(self.writer.log_dir, self.args) - - n_iter = 0 + n_iter, start_epochs, end_epochs = 0, 0, self.args.epochs + + if(self.args.ckpt): + try: + self.model, start_epochs = load_checkpoint(self.model, self.args.ckpt) + end_epochs = end_epochs + start_epochs + except: + InvalidCheckpointPath() + logging.info(f"Start SimCLR training for {self.args.epochs} epochs.") - logging.info(f"Training with gpu: {self.args.disable_cuda}.") + logging.info(f"Training with gpu: {self.args.gpu_index}.") - for epoch_counter in range(self.args.epochs): + for epoch_counter in range(start_epochs, end_epochs): for images, _ in tqdm(train_loader): images = torch.cat(images, dim=0) diff --git a/utils.py b/utils.py index cf92cbd8..cdf22aa2 100644 --- a/utils.py +++ b/utils.py @@ -5,6 +5,14 @@ import yaml +def load_checkpoint(model, filepath): + if(os.path.exists(filepath)): + ckpt = torch.load(filepath) + model.load_state_dict(ckpt['state_dict']) + epoch = ckpt['epoch'] + return model, epoch + + def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): torch.save(state, filename) if is_best: