Skip to content

Commit 3c5b03b

Browse files
author
learned_optimization authors
committed
No public description
PiperOrigin-RevId: 592758118
1 parent 4c74865 commit 3c5b03b

6 files changed

Lines changed: 1349 additions & 0 deletions

File tree

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
# coding=utf-8
2+
# Copyright 2021 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# pylint: disable=invalid-name
17+
"""Train generalization predictor."""
18+
19+
import time
20+
21+
from absl import app
22+
from absl import flags
23+
from clu import metric_writers
24+
import flax.linen as nn
25+
import jax
26+
import jax.numpy as jnp
27+
import jax.tree_util as jtu
28+
from learned_optimization.research.univ_nfn.nfn import universal_layers
29+
import numpy as np
30+
import optax
31+
import scipy.stats
32+
from sklearn.metrics import r2_score
33+
import tensorflow.compat.v2 as tf
34+
35+
36+
FLAGS = flags.FLAGS
37+
flags.DEFINE_string('workdir', default='.', help='Where to store log output.')
38+
flags.DEFINE_string('data_root', default=None, help='Data path')
39+
flags.DEFINE_string('method', default='nfn', help='nfn or stat')
40+
flags.DEFINE_integer('bs', default=10, help='Batch size.')
41+
flags.DEFINE_integer('n_epochs', default=10, help='No. of training epochs.')
42+
flags.DEFINE_float('dropout', default=0.0, help='Dropout rate.')
43+
flags.DEFINE_bool('debug', default=False, help='Whether to run in debug mode.')
44+
45+
46+
def make_perm_spec_GRUCell(in_perm_num, h_perm_num):
47+
"""Make NFN permutation spec for a single cell."""
48+
spec = {}
49+
for layer in ['hn']:
50+
spec[layer] = {'kernel': (h_perm_num, h_perm_num), 'bias': (h_perm_num,)}
51+
for layer in ['hr', 'hz']:
52+
spec[layer] = {'kernel': (h_perm_num, h_perm_num)}
53+
for layer in ['in', 'ir', 'iz']:
54+
spec[layer] = {'kernel': (in_perm_num, h_perm_num), 'bias': (h_perm_num,)}
55+
return spec
56+
57+
58+
def make_perm_spec_Seq2Seq():
59+
"""Make NFN permutation spec for model."""
60+
# -1: input/output dimensions
61+
# 0: encoder side output
62+
perm_spec = {}
63+
perm_spec['GRUCell_0'] = make_perm_spec_GRUCell(-1, 0) # encoder
64+
perm_spec['DecoderGRUCell_0'] = {
65+
'GRUCell_0': make_perm_spec_GRUCell(-1, 0),
66+
'Dense_0': {'kernel': (0, -1), 'bias': (-1,)},
67+
}
68+
return {'params': perm_spec}
69+
70+
71+
def process_dset_example(example):
72+
"""Input is a pytree of tf tensors. Output is a pytree of nump arrays."""
73+
return jtu.tree_map(lambda x: x.numpy(), example)
74+
75+
76+
def make_flattened_perm_spec():
77+
perm_spec = make_perm_spec_Seq2Seq()['params']
78+
new_perm_spec = {}
79+
for path, arr in jtu.tree_flatten_with_path(
80+
perm_spec, is_leaf=lambda x: isinstance(x, tuple)
81+
)[0]:
82+
key = '/'.join([x.key for x in path])
83+
new_perm_spec[key] = arr
84+
return new_perm_spec
85+
86+
87+
def make_train_fns(opt, nfn, perm_spec):
88+
"""Produce training-related functions."""
89+
90+
def loss(theta, x, y, rngs):
91+
pred_logits = jnp.squeeze(
92+
nfn.apply(theta, x, perm_spec, train=True, rngs=rngs), -1
93+
)
94+
return jnp.mean(optax.sigmoid_binary_cross_entropy(pred_logits, y))
95+
96+
@jax.jit
97+
def step(opt_state, theta, x, y, rngs):
98+
loss_val, grad = jax.value_and_grad(loss)(theta, x, y, rngs)
99+
updates, opt_state = opt.update(grad, opt_state)
100+
theta = optax.apply_updates(theta, updates)
101+
return theta, opt_state, loss_val
102+
103+
@jax.jit
104+
def get_pred_logits(theta, x):
105+
return jnp.squeeze(nfn.apply(theta, x, perm_spec, train=False), -1)
106+
107+
return step, get_pred_logits
108+
109+
110+
def compute_stats(tensor):
111+
"""Computes the statistics of the given tensor."""
112+
C = tensor.shape[-1] # (..., C)
113+
flat_tensor = jnp.reshape(tensor, (-1, C))
114+
mean = jnp.mean(flat_tensor, 0)
115+
var = jnp.var(flat_tensor, 0)
116+
q = jnp.array([0.0, 0.25, 0.5, 0.75, 1.0])
117+
quantiles = jnp.quantile(flat_tensor, q, axis=0)
118+
return jnp.stack([mean, var, *quantiles], axis=0) # (7, C)
119+
120+
121+
class NFN(nn.Module):
122+
"""NFN gen predictor."""
123+
124+
dropout: float
125+
126+
@nn.compact
127+
def __call__(self, params, perm_spec, train):
128+
out = universal_layers.BatchNFLinear(16, 1)(params, perm_spec)
129+
out = universal_layers.nf_relu(out)
130+
out = universal_layers.NFDropout(self.dropout)(out, train)
131+
out = universal_layers.BatchNFLinear(16, 16)(out, perm_spec)
132+
out = universal_layers.nf_relu(out)
133+
out = universal_layers.NFDropout(self.dropout)(out, train)
134+
out = universal_layers.batch_nf_pool(out)
135+
out = jax.nn.relu(nn.Dense(512)(out))
136+
out = universal_layers.NFDropout(self.dropout)(out, train)
137+
out = nn.Dense(1)(out)
138+
return out
139+
140+
141+
class StatPred(nn.Module):
142+
"""Statistical gen predictor (Unterthiner et al)."""
143+
144+
dropout: float
145+
146+
@nn.compact
147+
def __call__(self, x, perm_spec, train):
148+
def pool_stats(_x):
149+
stats = jtu.tree_map(compute_stats, _x)
150+
return jnp.ravel(
151+
jnp.concatenate(jtu.tree_leaves(stats), axis=0)
152+
) # (num_outs,)
153+
154+
out = jax.vmap(pool_stats)(x)
155+
out = jax.nn.relu(nn.Dense(600)(out))
156+
out = universal_layers.NFDropout(self.dropout)(out, train)
157+
out = jax.nn.relu(nn.Dense(600)(out))
158+
out = universal_layers.NFDropout(self.dropout)(out, train)
159+
out = jax.nn.relu(nn.Dense(600)(out))
160+
out = universal_layers.NFDropout(self.dropout)(out, train)
161+
out = nn.Dense(1)(out)
162+
return out
163+
164+
165+
def make_predictor():
166+
if FLAGS.method == 'nfn':
167+
predictor = NFN(dropout=FLAGS.dropout)
168+
else:
169+
predictor = StatPred(dropout=FLAGS.dropout)
170+
return predictor
171+
172+
173+
def main(_):
174+
writer = metric_writers.create_default_writer(FLAGS.workdir)
175+
176+
train_indices = range(0, 8000)
177+
val_indices = range(8000, 9000)
178+
test_indices = range(9000, 10000)
179+
if FLAGS.debug:
180+
train_indices = range(1, FLAGS.bs * 3 + 1)
181+
val_indices = range(FLAGS.bs * 3 + 1, FLAGS.bs * 6 + 1)
182+
test_indices = range(FLAGS.bs * 6 + 1, FLAGS.bs * 9 + 1)
183+
print('Started loading data.')
184+
with tf.io.gfile.GFile(FLAGS.data_root, 'rb') as f:
185+
raw_data = np.load(f)
186+
print('Finished loading data.')
187+
test_srs = raw_data['test_srs']
188+
test_losses = raw_data['test_losses']
189+
params = {}
190+
for key in list(raw_data.keys()):
191+
if key not in ['test_srs', 'test_losses']:
192+
params[key] = raw_data[key]
193+
train_arrs = (
194+
{k: v[train_indices] for k, v in params.items()},
195+
test_srs[train_indices],
196+
test_losses[train_indices],
197+
)
198+
val_arrs = (
199+
{k: v[val_indices] for k, v in params.items()},
200+
test_srs[val_indices],
201+
test_losses[val_indices],
202+
)
203+
test_arrs = (
204+
{k: v[test_indices] for k, v in params.items()},
205+
test_srs[test_indices],
206+
test_losses[test_indices],
207+
)
208+
train_dset = (
209+
tf.data.Dataset.from_tensor_slices(train_arrs)
210+
.shuffle(1000)
211+
.repeat(10)
212+
.batch(FLAGS.bs)
213+
.prefetch(tf.data.AUTOTUNE)
214+
)
215+
val_dset = (
216+
tf.data.Dataset.from_tensor_slices(val_arrs)
217+
.batch(FLAGS.bs)
218+
.prefetch(tf.data.AUTOTUNE)
219+
)
220+
test_dset = (
221+
tf.data.Dataset.from_tensor_slices(test_arrs)
222+
.batch(FLAGS.bs)
223+
.prefetch(tf.data.AUTOTUNE)
224+
)
225+
del test_dset
226+
227+
test_inp, _, _ = process_dset_example(next(iter(train_dset)))
228+
perm_spec = make_flattened_perm_spec()
229+
230+
rng = jax.random.PRNGKey(0)
231+
rng, rng1 = jax.random.split(rng)
232+
233+
predictor = make_predictor()
234+
235+
opt = optax.adam(1e-3)
236+
step, get_pred_logits = make_train_fns(opt, predictor, perm_spec)
237+
238+
theta = predictor.init(rng1, test_inp, perm_spec, train=False)
239+
opt_state = opt.init(theta)
240+
param_count = sum(x.size for x in jtu.tree_leaves(theta))
241+
print(param_count)
242+
writer.write_hparams(
243+
{'param_count': param_count, 'predictor_method': FLAGS.method}
244+
)
245+
246+
def evaluate(dset):
247+
test_accs, preds = [], []
248+
for example in dset:
249+
example, test_acc, _ = process_dset_example(example)
250+
logit = get_pred_logits(theta, example)
251+
test_accs.append(test_acc)
252+
preds.append(np.asarray(jax.nn.sigmoid(logit)))
253+
test_accs = np.concatenate(test_accs, 0)
254+
preds = np.concatenate(preds, 0)
255+
tau = scipy.stats.kendalltau(preds, test_accs)
256+
rsq = r2_score(test_accs, preds)
257+
return tau.correlation, rsq, preds, test_accs
258+
259+
max_val_rsq, max_val_tau = float('-inf'), float('-inf')
260+
for epoch in range(FLAGS.n_epochs):
261+
steps = 0
262+
start_time = time.time()
263+
for example in train_dset:
264+
rng, rng1 = jax.random.split(rng)
265+
example, test_acc, _ = process_dset_example(example)
266+
rngs = {'dropout': rng1}
267+
theta, opt_state, loss_value = step(
268+
opt_state, theta, example, test_acc, rngs
269+
)
270+
del loss_value
271+
steps += 1
272+
train_tau, train_rsq, _, _ = evaluate(train_dset)
273+
val_tau, val_rsq, _, _ = evaluate(val_dset)
274+
max_val_tau = max(max_val_tau, val_tau)
275+
max_val_rsq = max(max_val_rsq, val_rsq)
276+
writer.write_scalars(
277+
epoch,
278+
{
279+
'train_tau': train_tau,
280+
'val_tau': val_tau,
281+
'train_rsq': train_rsq,
282+
'val_rsq': val_rsq,
283+
'max_val_tau': max_val_tau,
284+
'max_val_rsq': max_val_rsq,
285+
'steps_per_sec': steps / (time.time() - start_time),
286+
},
287+
)
288+
289+
290+
if __name__ == '__main__':
291+
app.run(main)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# coding=utf-8
2+
# Copyright 2021 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# https://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+

0 commit comments

Comments
 (0)