Skip to content

Commit 2519cc1

Browse files
committed
Type hints and spacing change only.
1 parent b4ded40 commit 2519cc1

24 files changed

Lines changed: 752 additions & 696 deletions

MNXC_main.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
from os.path import join
2525
from matplotlib import pyplot as plt
26-
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
26+
from argparse import ArgumentParser,ArgumentDefaultsHelpFormatter
2727
# TIME_STAMP = datetime.utcnow().isoformat()
2828

2929
from config import Config
@@ -68,7 +68,7 @@
6868
"""
6969

7070

71-
def plot_occurance(losses:list, plot_name='val_loss.jpg', clear=True, log=False):
71+
def plot_occurance(losses: list,plot_name='val_loss.jpg',clear=True,log=False):
7272
"""
7373
Plots the validation loss against epochs.
7474
@@ -108,7 +108,7 @@ def main(args):
108108

109109
train_epoch_losses = []
110110
val_epoch_losses = []
111-
val_p1s, val_p3s, val_p5s = [], [], []
111+
val_p1s,val_p3s,val_p5s = [],[],[]
112112
separator_length = 92
113113
for epoch in range(config["sampling"]["num_epochs"]):
114114
train_epoch_loss = match_net.training(num_train_epoch=config["sampling"]["num_train_epoch"])
@@ -118,7 +118,7 @@ def main(args):
118118
logger.info("-" * separator_length)
119119

120120
# val_epoch_loss, val_p1, val_p3, val_p5 = match_net.testing()
121-
val_epoch_loss, val_p1, val_p3, val_p5 = match_net.validating(epoch_count=epoch)
121+
val_epoch_loss,val_p1,val_p3,val_p5 = match_net.validating(epoch_count=epoch)
122122
val_epoch_losses.append(val_epoch_loss)
123123
val_p1s.append(val_p1)
124124
val_p3s.append(val_p3)
@@ -129,7 +129,7 @@ def main(args):
129129
## Storing trained model
130130
torch.save(match_net.match_net.state_dict(),
131131
join(config["paths"]["dataset_dir"][plat][user],config["data"]["dataset_name"],
132-
config["data"]["dataset_name"]+'_'+str(config["sampling"]["num_epochs"])))
132+
config["data"]["dataset_name"] + '_' + str(config["sampling"]["num_epochs"])))
133133

134134
logger.info("#" * separator_length)
135135
# logger.info("Train losses: [{}]".format(train_epoch_losses))
@@ -145,7 +145,7 @@ def main(args):
145145
## Inference Phase
146146
logger.info("=" * separator_length)
147147
logger.info("\nStarting Inference...")
148-
test_epoch_loss, test_p1, test_p3, test_p5 = match_net.testing()
148+
test_epoch_loss,test_p1,test_p3,test_p5 = match_net.testing()
149149
logger.info("Test losses: [{}]".format(test_epoch_loss))
150150
logger.info("Test Precision 1: [{}]".format(test_p1))
151151
logger.info("Test Precision 3: [{}]".format(test_p3))
@@ -174,43 +174,43 @@ def main(args):
174174
help='Config to read details',
175175
default='MNXC.config')
176176
parser.add_argument('--dataset_dir',
177-
help='Path to dataset folder.', type=str,
177+
help='Path to dataset folder.',type=str,
178178
default="")
179179
parser.add_argument('--dataset_name',
180-
help='Name of the dataset to use.', type=str,
180+
help='Name of the dataset to use.',type=str,
181181
default='all')
182182
parser.add_argument('--train_path',
183-
help='Path to train file (Absolute or Relative to [dataset_url]).', type=str,
183+
help='Path to train file (Absolute or Relative to [dataset_url]).',type=str,
184184
default='train')
185185
parser.add_argument('--test_path',
186-
help='Path to test file (Absolute or Relative to [dataset_url]).', type=str,
186+
help='Path to test file (Absolute or Relative to [dataset_url]).',type=str,
187187
default='test')
188188
parser.add_argument('--solution_path',
189-
help='Path to result folder (Absolute or Relative to [dataset_url]).', type=str,
189+
help='Path to result folder (Absolute or Relative to [dataset_url]).',type=str,
190190
default='result')
191191
parser.add_argument('--pretrain_dir',
192-
help='Path to pre-trained embedding file. Default: [dataset_url/pretrain].', type=str,
192+
help='Path to pre-trained embedding file. Default: [dataset_url/pretrain].',type=str,
193193
default='pretrain')
194194

195195
# Training configuration arguments
196-
parser.add_argument('--device', type=str, default='cpu',
196+
parser.add_argument('--device',type=str,default='cpu',
197197
help='PyTorch device string <device_name>:<device_id>')
198-
parser.add_argument('--seed', type=int, default=None,
198+
parser.add_argument('--seed',type=int,default=None,
199199
help='Manually set the seed for the experiments for reproducibility.')
200-
parser.add_argument('--batch_size', type=int, default=32,
200+
parser.add_argument('--batch_size',type=int,default=32,
201201
help='Batch size for training.')
202-
parser.add_argument('--epochs', type=int, default=20,
202+
parser.add_argument('--epochs',type=int,default=20,
203203
help='Number of epochs to train.')
204-
parser.add_argument('--interval', type=int, default=-1,
204+
parser.add_argument('--interval',type=int,default=-1,
205205
help='Interval between two status updates during training.')
206206

207207
# Optimizer arguments
208-
parser.add_argument('--optimizer_cfg', type=str,
208+
parser.add_argument('--optimizer_cfg',type=str,
209209
help='Optimizer configuration in YAML format for model.')
210210

211211
# Post-training arguments
212-
parser.add_argument('--save_model', type=str, default=None,
213-
choices=['all', 'inputAE', 'outputAE', 'regressor'], nargs='+',
212+
parser.add_argument('--save_model',type=str,default=None,
213+
choices=['all','inputAE','outputAE','regressor'],nargs='+',
214214
help='Options to save the model partially or completely.')
215215

216216
args = parser.parse_args()

MNXC_main_orig.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
"""
2222

2323
import torch
24-
from os.path import join, isfile
24+
from os.path import join,isfile
2525
from matplotlib import pyplot as plt
26-
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
26+
from argparse import ArgumentParser,ArgumentDefaultsHelpFormatter
2727
# TIME_STAMP = datetime.utcnow().isoformat()
2828

2929
from config import Config
@@ -118,7 +118,7 @@
118118
"""
119119

120120

121-
def plot_occurance(losses:list, plot_name='val_loss.jpg', clear=True, log=False):
121+
def plot_occurance(losses: list,plot_name='val_loss.jpg',clear=True,log=False):
122122
"""
123123
Plots the validation loss against epochs.
124124
@@ -170,7 +170,7 @@ def main(args):
170170
logger.info("-" * separator_length)
171171

172172
# val_epoch_loss, val_p1, val_p3, val_p5 = match_net.testing()
173-
val_epoch_loss, val_p1, val_p3, val_p5 = match_net.validating(epoch_count=epoch)
173+
val_epoch_loss,val_p1,val_p3,val_p5 = match_net.validating(epoch_count=epoch)
174174
val_losses.append(val_epoch_loss)
175175
val_p1s.append(val_p1)
176176
val_p3s.append(val_p3)
@@ -182,7 +182,7 @@ def main(args):
182182
## Storing trained model
183183
torch.save(match_net.match_net.state_dict(),
184184
join(config["paths"]["dataset_dir"][plat],config["data"]["dataset_name"],
185-
config["data"]["dataset_name"]+'_'+str(config["sampling"]["num_epochs"])))
185+
config["data"]["dataset_name"] + '_' + str(config["sampling"]["num_epochs"])))
186186

187187
logger.info("#" * separator_length)
188188
# logger.info("Train losses: [{}]".format(train_losses))
@@ -198,7 +198,7 @@ def main(args):
198198
## Inference Phase
199199
logger.info("=" * separator_length)
200200
logger.info("\nStarting Inference...")
201-
test_epoch_loss, test_p1, test_p3, test_p5 = match_net.testing()
201+
test_epoch_loss,test_p1,test_p3,test_p5 = match_net.testing()
202202
logger.info("Test losses: [{}]".format(test_epoch_loss))
203203
logger.info("Test Precision 1: [{}]".format(test_p1))
204204
logger.info("Test Precision 3: [{}]".format(test_p3))
@@ -227,43 +227,43 @@ def main(args):
227227
help='Config to read details',
228228
default='MNXC.config')
229229
parser.add_argument('--dataset_dir',
230-
help='Path to dataset folder.', type=str,
230+
help='Path to dataset folder.',type=str,
231231
default="")
232232
parser.add_argument('--dataset_name',
233-
help='Name of the dataset to use.', type=str,
233+
help='Name of the dataset to use.',type=str,
234234
default='all')
235235
parser.add_argument('--train_path',
236-
help='Path to train file (Absolute or Relative to [dataset_url]).', type=str,
236+
help='Path to train file (Absolute or Relative to [dataset_url]).',type=str,
237237
default='train')
238238
parser.add_argument('--test_path',
239-
help='Path to test file (Absolute or Relative to [dataset_url]).', type=str,
239+
help='Path to test file (Absolute or Relative to [dataset_url]).',type=str,
240240
default='test')
241241
parser.add_argument('--solution_path',
242-
help='Path to result folder (Absolute or Relative to [dataset_url]).', type=str,
242+
help='Path to result folder (Absolute or Relative to [dataset_url]).',type=str,
243243
default='result')
244244
parser.add_argument('--pretrain_dir',
245-
help='Path to pre-trained embedding file. Default: [dataset_url/pretrain].', type=str,
245+
help='Path to pre-trained embedding file. Default: [dataset_url/pretrain].',type=str,
246246
default='pretrain')
247247

248248
# Training configuration arguments
249-
parser.add_argument('--device', type=str, default='cpu',
249+
parser.add_argument('--device',type=str,default='cpu',
250250
help='PyTorch device string <device_name>:<device_id>')
251-
parser.add_argument('--seed', type=int, default=None,
251+
parser.add_argument('--seed',type=int,default=None,
252252
help='Manually set the seed for the experiments for reproducibility.')
253-
parser.add_argument('--batch_size', type=int, default=32,
253+
parser.add_argument('--batch_size',type=int,default=32,
254254
help='Batch size for training.')
255-
parser.add_argument('--epochs', type=int, default=20,
255+
parser.add_argument('--epochs',type=int,default=20,
256256
help='Number of epochs to train.')
257-
parser.add_argument('--interval', type=int, default=-1,
257+
parser.add_argument('--interval',type=int,default=-1,
258258
help='Interval between two status updates during training.')
259259

260260
# Optimizer arguments
261-
parser.add_argument('--optimizer_cfg', type=str,
261+
parser.add_argument('--optimizer_cfg',type=str,
262262
help='Optimizer configuration in YAML format for model.')
263263

264264
# Post-training arguments
265-
parser.add_argument('--save_model', type=str, default=None,
266-
choices=['all', 'inputAE', 'outputAE', 'regressor'], nargs='+',
265+
parser.add_argument('--save_model',type=str,default=None,
266+
choices=['all','inputAE','outputAE','regressor'],nargs='+',
267267
help='Options to save the model partially or completely.')
268268

269269
args = parser.parse_args()

0 commit comments

Comments
 (0)