Skip to content

Commit 1ed0036

Browse files
author
learned_optimization authors
committed
No public description
PiperOrigin-RevId: 592937780
1 parent 3c5b03b commit 1ed0036

1 file changed

Lines changed: 10 additions & 2 deletions

File tree

learned_optimization/research/univ_nfn/gen_pred/train_pred.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
flags.DEFINE_integer('n_epochs', default=10, help='No. of training epochs.')
4242
flags.DEFINE_float('dropout', default=0.0, help='Dropout rate.')
4343
flags.DEFINE_bool('debug', default=False, help='Whether to run in debug mode.')
44+
flags.DEFINE_integer('seed', default=0, help='Jax PRNG seed.')
4445

4546

4647
def make_perm_spec_GRUCell(in_perm_num, h_perm_num):
@@ -222,12 +223,11 @@ def main(_):
222223
.batch(FLAGS.bs)
223224
.prefetch(tf.data.AUTOTUNE)
224225
)
225-
del test_dset
226226

227227
test_inp, _, _ = process_dset_example(next(iter(train_dset)))
228228
perm_spec = make_flattened_perm_spec()
229229

230-
rng = jax.random.PRNGKey(0)
230+
rng = jax.random.PRNGKey(FLAGS.seed)
231231
rng, rng1 = jax.random.split(rng)
232232

233233
predictor = make_predictor()
@@ -257,6 +257,7 @@ def evaluate(dset):
257257
return tau.correlation, rsq, preds, test_accs
258258

259259
max_val_rsq, max_val_tau = float('-inf'), float('-inf')
260+
max_test_rsq, max_test_tau = float('-inf'), float('-inf')
260261
for epoch in range(FLAGS.n_epochs):
261262
steps = 0
262263
start_time = time.time()
@@ -271,17 +272,24 @@ def evaluate(dset):
271272
steps += 1
272273
train_tau, train_rsq, _, _ = evaluate(train_dset)
273274
val_tau, val_rsq, _, _ = evaluate(val_dset)
275+
test_tau, test_rsq, _, _ = evaluate(test_dset)
274276
max_val_tau = max(max_val_tau, val_tau)
275277
max_val_rsq = max(max_val_rsq, val_rsq)
278+
max_test_tau = max(max_test_tau, test_tau)
279+
max_test_rsq = max(max_test_rsq, test_rsq)
276280
writer.write_scalars(
277281
epoch,
278282
{
279283
'train_tau': train_tau,
280284
'val_tau': val_tau,
281285
'train_rsq': train_rsq,
282286
'val_rsq': val_rsq,
287+
'test_tau': test_tau,
288+
'test_rsq': test_rsq,
283289
'max_val_tau': max_val_tau,
284290
'max_val_rsq': max_val_rsq,
291+
'max_test_tau': max_test_tau,
292+
'max_test_rsq': max_test_rsq,
285293
'steps_per_sec': steps / (time.time() - start_time),
286294
},
287295
)

0 commit comments

Comments
 (0)