4141flags .DEFINE_integer ('n_epochs' , default = 10 , help = 'No. of training epochs.' )
4242flags .DEFINE_float ('dropout' , default = 0.0 , help = 'Dropout rate.' )
4343flags .DEFINE_bool ('debug' , default = False , help = 'Whether to run in debug mode.' )
44+ flags .DEFINE_integer ('seed' , default = 0 , help = 'Jax PRNG seed.' )
4445
4546
4647def make_perm_spec_GRUCell (in_perm_num , h_perm_num ):
@@ -222,12 +223,11 @@ def main(_):
222223 .batch (FLAGS .bs )
223224 .prefetch (tf .data .AUTOTUNE )
224225 )
225- del test_dset
226226
227227 test_inp , _ , _ = process_dset_example (next (iter (train_dset )))
228228 perm_spec = make_flattened_perm_spec ()
229229
230- rng = jax .random .PRNGKey (0 )
230+ rng = jax .random .PRNGKey (FLAGS . seed )
231231 rng , rng1 = jax .random .split (rng )
232232
233233 predictor = make_predictor ()
@@ -257,6 +257,7 @@ def evaluate(dset):
257257 return tau .correlation , rsq , preds , test_accs
258258
259259 max_val_rsq , max_val_tau = float ('-inf' ), float ('-inf' )
260+ max_test_rsq , max_test_tau = float ('-inf' ), float ('-inf' )
260261 for epoch in range (FLAGS .n_epochs ):
261262 steps = 0
262263 start_time = time .time ()
@@ -271,17 +272,24 @@ def evaluate(dset):
271272 steps += 1
272273 train_tau , train_rsq , _ , _ = evaluate (train_dset )
273274 val_tau , val_rsq , _ , _ = evaluate (val_dset )
275+ test_tau , test_rsq , _ , _ = evaluate (test_dset )
274276 max_val_tau = max (max_val_tau , val_tau )
275277 max_val_rsq = max (max_val_rsq , val_rsq )
278+ max_test_tau = max (max_test_tau , test_tau )
279+ max_test_rsq = max (max_test_rsq , test_rsq )
276280 writer .write_scalars (
277281 epoch ,
278282 {
279283 'train_tau' : train_tau ,
280284 'val_tau' : val_tau ,
281285 'train_rsq' : train_rsq ,
282286 'val_rsq' : val_rsq ,
287+ 'test_tau' : test_tau ,
288+ 'test_rsq' : test_rsq ,
283289 'max_val_tau' : max_val_tau ,
284290 'max_val_rsq' : max_val_rsq ,
291+ 'max_test_tau' : max_test_tau ,
292+ 'max_test_rsq' : max_test_rsq ,
285293 'steps_per_sec' : steps / (time .time () - start_time ),
286294 },
287295 )
0 commit comments