-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexperiments.py
More file actions
116 lines (100 loc) · 3.38 KB
/
experiments.py
File metadata and controls
116 lines (100 loc) · 3.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import click
import torch
from pideq.deq.solvers import anderson, broyden, forward_iteration
from pideq.net import PINN, PIDEQ
from pideq.trainer import PINNTrainer, PIDEQTrainer
def experiment_1():
print('=== EXPERIMENT 1 ===')
PINNTrainer(
PINN(2., n_nodes=20),
epochs=5e4,
wandb_group=f'PINN-baseline',
).run()
def experiment_2():
print('=== EXPERIMENT 2 ===')
PIDEQTrainer(
PIDEQ(2., n_states=80),
epochs=5e4,
wandb_group=f'PIDEQ-baseline',
).run()
def experiment_3(ns_states=[40, 20, 10, 5, 2,]):
print('=== EXPERIMENT 3 ===')
for n_states in ns_states:
PIDEQTrainer(
PIDEQ(2., n_states=n_states),
epochs=5e4,
wandb_group=f'PIDEQ-#z={n_states}',
).run()
def experiment_4(ns_hidden=[1, 2,]):
print('=== EXPERIMENT 4 ===')
for n_hidden in ns_hidden:
PIDEQTrainer(
PIDEQ(2., n_states=5, n_hidden=n_hidden),
epochs=5e4,
wandb_group=f'PIDEQ-#hidden={n_hidden}',
).run()
def experiment_5(jac_lambdas=[0.1, 2]):
print('=== EXPERIMENT 5 ===')
for jac_lambda in jac_lambdas:
PIDEQTrainer(
PIDEQ(2., n_states=5),
epochs=5e4,
jac_lamb=jac_lambda,
wandb_group=f'PIDEQ-#jac_lamb={jac_lambda}',
).run()
def experiment_6(solvers=[forward_iteration, broyden]):
print('=== EXPERIMENT 6 ===')
for solver in solvers:
PIDEQTrainer(
PIDEQ(2., n_states=5, solver=solver),
epochs=5e4,
wandb_group=f'PIDEQ-#solver={solver.__name__}',
).run()
def experiment_7(epss=[1e-2, 1e-6]):
print('=== EXPERIMENT 7 ===')
for eps in epss:
PIDEQTrainer(
PIDEQ(2., n_states=5, solver=forward_iteration, solver_kwargs={'threshold': 200, 'eps': eps}),
epochs=5e4,
wandb_group=f'PIDEQ-#eps={eps:.0e}',
).run()
def experiment_8():
print('=== EXPERIMENT 8 ===')
PIDEQTrainer(
PIDEQ(2., n_states=5, solver=broyden),
epochs=5e4,
lr_scheduler='MultiStepLR',
lr_scheduler_params={'milestones': [30000, 40000]},
wandb_group=f'PIDEQ-#step_decay',
).run()
def experiment_9():
print('=== EXPERIMENT 9 ===')
PINNTrainer(
PINN(2., n_hidden=2, n_nodes=5),
epochs=5e4,
wandb_group=f'PINN-baseline-small',
).run()
@click.command()
@click.option('-n', '--n-runs', default=1, show_default=True, type=click.INT,
help=("Number of times each experiment is run, i.e., number of "
"networks trained with each experiment's configuration."))
@click.argument('experiment', nargs=-1)
def main(n_runs, experiment):
"""Runs `EXPERIMENT` for `--n-runs` time(s).
`EXPERIMENT` can be either a single number, an interval using dash notation
(`2-5` means that experiments 2, 3, 4 and 5 will be executed), or any
sequence of numbers and intervals, like `1 3-5 8`.
"""
for exp in experiment:
try:
exps = [int(exp),]
except ValueError:
# exp must be an interval, then
start, stop = exp.split('-')
exps = range(int(start), int(stop)+1)
for exp_n in exps:
for _ in range(n_runs):
experiment = eval(f"experiment_{exp_n}")
experiment()
if __name__ == '__main__':
main()