2323import torch
2424from os .path import join
2525from matplotlib import pyplot as plt
26- from argparse import ArgumentParser , ArgumentDefaultsHelpFormatter
26+ from argparse import ArgumentParser ,ArgumentDefaultsHelpFormatter
2727# TIME_STAMP = datetime.utcnow().isoformat()
2828
2929from config import Config
6868"""
6969
7070
71- def plot_occurance (losses :list , plot_name = 'val_loss.jpg' , clear = True , log = False ):
71+ def plot_occurance (losses : list ,plot_name = 'val_loss.jpg' ,clear = True ,log = False ):
7272 """
7373 Plots the validation loss against epochs.
7474
@@ -108,7 +108,7 @@ def main(args):
108108
109109 train_epoch_losses = []
110110 val_epoch_losses = []
111- val_p1s , val_p3s , val_p5s = [], [], []
111+ val_p1s ,val_p3s ,val_p5s = [],[],[]
112112 separator_length = 92
113113 for epoch in range (config ["sampling" ]["num_epochs" ]):
114114 train_epoch_loss = match_net .training (num_train_epoch = config ["sampling" ]["num_train_epoch" ])
@@ -118,7 +118,7 @@ def main(args):
118118 logger .info ("-" * separator_length )
119119
120120 # val_epoch_loss, val_p1, val_p3, val_p5 = match_net.testing()
121- val_epoch_loss , val_p1 , val_p3 , val_p5 = match_net .validating (epoch_count = epoch )
121+ val_epoch_loss ,val_p1 ,val_p3 ,val_p5 = match_net .validating (epoch_count = epoch )
122122 val_epoch_losses .append (val_epoch_loss )
123123 val_p1s .append (val_p1 )
124124 val_p3s .append (val_p3 )
@@ -129,7 +129,7 @@ def main(args):
129129 ## Storing trained model
130130 torch .save (match_net .match_net .state_dict (),
131131 join (config ["paths" ]["dataset_dir" ][plat ][user ],config ["data" ]["dataset_name" ],
132- config ["data" ]["dataset_name" ]+ '_' + str (config ["sampling" ]["num_epochs" ])))
132+ config ["data" ]["dataset_name" ] + '_' + str (config ["sampling" ]["num_epochs" ])))
133133
134134 logger .info ("#" * separator_length )
135135 # logger.info("Train losses: [{}]".format(train_epoch_losses))
@@ -145,7 +145,7 @@ def main(args):
145145 ## Inference Phase
146146 logger .info ("=" * separator_length )
147147 logger .info ("\n Starting Inference..." )
148- test_epoch_loss , test_p1 , test_p3 , test_p5 = match_net .testing ()
148+ test_epoch_loss ,test_p1 ,test_p3 ,test_p5 = match_net .testing ()
149149 logger .info ("Test losses: [{}]" .format (test_epoch_loss ))
150150 logger .info ("Test Precision 1: [{}]" .format (test_p1 ))
151151 logger .info ("Test Precision 3: [{}]" .format (test_p3 ))
@@ -174,43 +174,43 @@ def main(args):
174174 help = 'Config to read details' ,
175175 default = 'MNXC.config' )
176176 parser .add_argument ('--dataset_dir' ,
177- help = 'Path to dataset folder.' , type = str ,
177+ help = 'Path to dataset folder.' ,type = str ,
178178 default = "" )
179179 parser .add_argument ('--dataset_name' ,
180- help = 'Name of the dataset to use.' , type = str ,
180+ help = 'Name of the dataset to use.' ,type = str ,
181181 default = 'all' )
182182 parser .add_argument ('--train_path' ,
183- help = 'Path to train file (Absolute or Relative to [dataset_url]).' , type = str ,
183+ help = 'Path to train file (Absolute or Relative to [dataset_url]).' ,type = str ,
184184 default = 'train' )
185185 parser .add_argument ('--test_path' ,
186- help = 'Path to test file (Absolute or Relative to [dataset_url]).' , type = str ,
186+ help = 'Path to test file (Absolute or Relative to [dataset_url]).' ,type = str ,
187187 default = 'test' )
188188 parser .add_argument ('--solution_path' ,
189- help = 'Path to result folder (Absolute or Relative to [dataset_url]).' , type = str ,
189+ help = 'Path to result folder (Absolute or Relative to [dataset_url]).' ,type = str ,
190190 default = 'result' )
191191 parser .add_argument ('--pretrain_dir' ,
192- help = 'Path to pre-trained embedding file. Default: [dataset_url/pretrain].' , type = str ,
192+ help = 'Path to pre-trained embedding file. Default: [dataset_url/pretrain].' ,type = str ,
193193 default = 'pretrain' )
194194
195195 # Training configuration arguments
196- parser .add_argument ('--device' , type = str , default = 'cpu' ,
196+ parser .add_argument ('--device' ,type = str ,default = 'cpu' ,
197197 help = 'PyTorch device string <device_name>:<device_id>' )
198- parser .add_argument ('--seed' , type = int , default = None ,
198+ parser .add_argument ('--seed' ,type = int ,default = None ,
199199 help = 'Manually set the seed for the experiments for reproducibility.' )
200- parser .add_argument ('--batch_size' , type = int , default = 32 ,
200+ parser .add_argument ('--batch_size' ,type = int ,default = 32 ,
201201 help = 'Batch size for training.' )
202- parser .add_argument ('--epochs' , type = int , default = 20 ,
202+ parser .add_argument ('--epochs' ,type = int ,default = 20 ,
203203 help = 'Number of epochs to train.' )
204- parser .add_argument ('--interval' , type = int , default = - 1 ,
204+ parser .add_argument ('--interval' ,type = int ,default = - 1 ,
205205 help = 'Interval between two status updates during training.' )
206206
207207 # Optimizer arguments
208- parser .add_argument ('--optimizer_cfg' , type = str ,
208+ parser .add_argument ('--optimizer_cfg' ,type = str ,
209209 help = 'Optimizer configuration in YAML format for model.' )
210210
211211 # Post-training arguments
212- parser .add_argument ('--save_model' , type = str , default = None ,
213- choices = ['all' , 'inputAE' , 'outputAE' , 'regressor' ], nargs = '+' ,
212+ parser .add_argument ('--save_model' ,type = str ,default = None ,
213+ choices = ['all' ,'inputAE' ,'outputAE' ,'regressor' ],nargs = '+' ,
214214 help = 'Options to save the model partially or completely.' )
215215
216216 args = parser .parse_args ()
0 commit comments