Skip to content

Commit d0a62e9

Browse files
committed
Reproducibility finished - changed hyperparameters to those from the paper, updated the readme with info about reproducibility, new version.
1 parent 18660d4 commit d0a62e9

7 files changed

Lines changed: 38 additions & 23 deletions

File tree

Contributors.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
- Package structure, reproducibility
77
---------
88
- [@abhash-er](https://github.com/abhash-er/) (Abhash Jha)
9-
- modified the model code so that cast to double is possible
9+
- Modified the model code so that cast to double is possible
1010
- [@longerHost](https://github.com/longerHost)
11-
- reproducibility of the original NAS-Bench-101 results
11+
- Reproducibility of the original NAS-Bench-101
12+
- Comparison of training results and API results

README.md

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ implementation is written in TensorFlow, and this projects contains
55
some files from the original repository (in the directory
66
`nasbench_pytorch/model/`).
77

8+
**Important:** if you want to reproduce the original results, please refer to the
9+
[Reproducibility](#repro) section.
10+
811
# Overview
912
A PyTorch implementation of *training* of NAS-Bench-101 dataset: [NAS-Bench-101: Towards Reproducible Neural Architecture Search](https://arxiv.org/abs/1902.09635).
1013
The dataset contains 423,624 unique neural networks exhaustively generated and evaluated from a fixed graph-based search space.
@@ -64,13 +67,25 @@ Then, you can train it just like the example network in `main.py`.
6467
Example architecture (picture from the original repository)
6568
![archtecture](./assets/architecture.png)
6669

70+
# Reproducibility <a id="repro"></a>
71+
The code should closely match the TensorFlow version (including the hyperparameters), but there are some differences:
72+
- RMSProp implementation in TensorFlow and PyTorch is **different**
73+
- For more information refer to [here](https://github.com/pytorch/pytorch/issues/32545) and [here](https://github.com/pytorch/pytorch/issues/23796).
74+
- Optionally, you can install pytorch-image-models where a [TensorFlow-like RMSProp](https://github.com/rwightman/pytorch-image-models/blob/main/timm/optim/rmsprop_tf.py#L5) is implemented
75+
- `pip install timm`
76+
- Then, pass `--optimizer rmsprop_tf` to `main.py` to use it
77+
78+
79+
- The original training was on TPUs, this code enables only GPU and CPU training
80+
- Input data augmentation methods are the same, but due to randomness they are not applied in the same manner
81+
- Cause: Batches and images cannot be shuffled as in the original TPU training, and the augmentation seed is also different
82+
- Results may still differ due to TensorFlow/PyTorch implementation differences
83+
84+
Refer to this [issue](https://github.com/romulus0914/NASBench-PyTorch/issues/6) for more information and for comparison with API results.
85+
6786
# Disclaimer
6887
Modified from [NASBench: A Neural Architecture Search Dataset and Benchmark](https://github.com/google-research/nasbench).
6988
*graph_util.py* and *model_spec.py* are directly copied from the original repo. Original license can be found [here](https://github.com/google-research/nasbench/blob/master/LICENSE).
7089

7190
<a id="note"></a>
7291
**Please note that this repo is only used to train one possible architecture in the search space, not to generate all possible graphs and train them.
73-
74-
**Important information:** The code should closely match the TensorFlow version, but
75-
you may still get slightly different results due to differences in TensorFlow/PyTorch implementation.
76-
Moreover, input data augmentation is the same, but due to randomness isn't exactly the same.

main.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,27 +41,24 @@ def reload_checkpoint(path, device=None):
4141
parser = argparse.ArgumentParser(description='NASBench')
4242
parser.add_argument('--random_state', default=1, type=int, help='Random seed.')
4343
parser.add_argument('--data_root', default='./data/', type=str, help='Path where cifar will be downloaded.')
44-
parser.add_argument('--module_vertices', default=7, type=int, help='#vertices in graph')
45-
parser.add_argument('--max_edges', default=9, type=int, help='max edges in graph')
46-
parser.add_argument('--available_ops', default=['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'],
47-
type=list, help='available operations performed on vertex')
4844
parser.add_argument('--in_channels', default=3, type=int, help='Number of input channels.')
4945
parser.add_argument('--stem_out_channels', default=128, type=int, help='output channels of stem convolution')
5046
parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules')
5147
parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack')
52-
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
53-
parser.add_argument('--test_batch_size', default=100, type=int, help='test set batch size')
54-
parser.add_argument('--epochs', default=100, type=int, help='#epochs of training')
48+
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
49+
parser.add_argument('--test_batch_size', default=256, type=int, help='test set batch size')
50+
parser.add_argument('--epochs', default=108, type=int, help='#epochs of training')
5551
parser.add_argument('--validation_size', default=10000, type=int, help="Size of the validation set to split off.")
5652
parser.add_argument('--num_workers', default=0, type=int, help="Number of parallel workers for the train dataset.")
57-
parser.add_argument('--learning_rate', default=0.025, type=float, help='base learning rate')
53+
parser.add_argument('--learning_rate', default=0.02, type=float, help='base learning rate')
5854
parser.add_argument('--lr_decay_method', default='COSINE_BY_STEP', type=str, help='learning decay method')
59-
parser.add_argument('--optimizer', default='sgd', type=str, help='Optimizer (sgd or rmsprop)')
60-
parser.add_argument('--rmsprop_eps', default=1e-08, type=float, help='RMSProp eps parameter.')
55+
parser.add_argument('--optimizer', default='rmsprop', type=str, help='Optimizer (sgd, rmsprop or rmsprop_tf)')
56+
parser.add_argument('--rmsprop_eps', default=1.0, type=float, help='RMSProp eps parameter.')
6157
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
6258
parser.add_argument('--weight_decay', default=1e-4, type=float, help='L2 regularization weight')
6359
parser.add_argument('--grad_clip', default=5, type=float, help='gradient clipping')
64-
parser.add_argument('--batch_norm_momentum', default=0.1, type=float, help='Batch normalization momentum')
60+
parser.add_argument('--grad_clip_off', default=False, type=bool, help='If True, turn off gradient clipping.')
61+
parser.add_argument('--batch_norm_momentum', default=0.997, type=float, help='Batch normalization momentum')
6562
parser.add_argument('--batch_norm_eps', default=1e-5, type=float, help='Batch normalization epsilon')
6663
parser.add_argument('--load_checkpoint', default='', type=str, help='Reload model from checkpoint')
6764
parser.add_argument('--num_labels', default=10, type=int, help='#classes')
@@ -110,7 +107,8 @@ def reload_checkpoint(path, device=None):
110107
weight_decay=args.weight_decay, **optimizer_kwargs)
111108
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
112109

113-
result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler, grad_clip=args.grad_clip,
110+
result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler,
111+
grad_clip=args.grad_clip if not args.grad_clip_off else None,
114112
num_epochs=args.epochs, num_validation=args.validation_size, validation_loader=valid_loader,
115113
device=args.device, print_frequency=args.print_freq)
116114

nasbench_pytorch/datasets/cifar10.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def seed_worker(seed, worker_id):
2828
random.seed(worker_seed)
2929

3030

31-
def prepare_dataset(batch_size, test_batch_size=100, root='./data/', use_validation=True, split_from_end=True,
31+
def prepare_dataset(batch_size, test_batch_size=256, root='./data/', use_validation=True, split_from_end=True,
3232
validation_size=10000, random_state=None, set_global_seed=False, no_valid_transform=True,
3333
num_workers=0, num_val_workers=0, num_test_workers=0):
3434
"""

nasbench_pytorch/model/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
class Network(nn.Module):
2828
def __init__(self, spec, num_labels=10, in_channels=3, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3,
29-
momentum=0.1, eps=1e-5, tf_like=False):
29+
momentum=0.997, eps=1e-5, tf_like=False):
3030
"""
3131
3232
Args:

nasbench_pytorch/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_cli
7474
optimizer.zero_grad()
7575
curr_loss = loss(outputs, targets)
7676
curr_loss.backward()
77-
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
77+
if grad_clip is not None:
78+
nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
7879
optimizer.step()
7980

8081
# metrics

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22

33
setuptools.setup(
44
name='nasbench_pytorch',
5-
version='1.2.3',
5+
version='1.3',
66
license='Apache License 2.0',
7-
author='Romulus Hong, Gabriela Suchopárová',
7+
author='Romulus Hong, Gabriela Kadlecová',
88
packages=setuptools.find_packages()
99
)

0 commit comments

Comments
 (0)