Skip to content

Commit 4cfd761

Browse files
committed
Checkpoint func should also accept epoch number.
1 parent eb16e78 commit 4cfd761

2 files changed

Lines changed: 3 additions & 3 deletions

File tree

nasbench_pytorch/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
2323
device: Device to train on, default is cpu.
2424
print_frequency: How often to print info about batches.
2525
checkpoint_every_k: Every k epochs, save a checkpoint.
26-
checkpoint_func: Custom function to save the checkpoint, signature: func(net, metric_dict)
26+
checkpoint_func: Custom function to save the checkpoint, signature: func(net, metric_dict, epoch num)
2727
2828
Returns:
2929
Final train (and validation) metrics.
@@ -55,7 +55,7 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
5555
for epoch in range(num_epochs):
5656
# checkpoint using a user defined function
5757
if checkpoint_every_k is not None and (epoch + 1) % checkpoint_every_k == 0:
58-
checkpoint_func(net, metric_dict)
58+
checkpoint_func(net, metric_dict, epoch + 1)
5959

6060
net.train()
6161

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setuptools.setup(
44
name='nasbench_pytorch',
5-
version='1.0',
5+
version='1.1',
66
license='Apache License 2.0',
77
author='Romulus Hong, Gabriela Suchopárová',
88
packages=setuptools.find_packages()

0 commit comments

Comments
 (0)