2020HPARAMS = {
2121 'ademamix_variant' : 'simplified' ,
2222 'alpha' : 8.0 ,
23- 'alpha_start' : 0 ,
24- 'warmup' : 10 ,
25- 'beta_end' : 0.9999 ,
26- 'beta_start' : 0.9 ,
27- 'learning_rate' : 0.01 ,
28- 'b1' : 0.9 ,
29- 'b2' : 0.999 ,
30- 'b3' : 0.9999 ,
23+ 'warmup_factor' : 0.02 ,
24+ 'beta3_warmup' : 500e3 ,
25+ 'alpha_warmup' : 500e3 ,
26+ 'learning_rate' : 2e-3 ,
27+ 'one_minus_beta1' : 0.2 ,
28+ 'beta2' : 0.995 ,
29+ 'beta3' : 0.9995 ,
3130 'eps' : 1e-8 ,
3231 'eps_root' : 0.0 ,
33- 'weight_decay' : 0.01 ,
32+ 'weight_decay' : 0.1 ,
33+ 'grad_clip' : 0.5 ,
3434 'dropout_rate' : 0.1 ,
3535 }
3636
3737_GRAD_CLIP_EPS = 1e-6
3838
39- def lr_scheduler (learning_rate , warmup_steps , total_steps ):
39+ def lr_scheduler (learning_rate , warmup_factor , total_steps ):
40+ warmup_steps = int (warmup_factor * total_steps )
41+ cosine_steps = max (total_steps - warmup_steps , 1 )
4042 return optax .warmup_cosine_decay_schedule (
41- init_value = 0.0 ,
43+ init_value = learning_rate * 1e-10 ,
4244 peak_value = learning_rate ,
4345 warmup_steps = warmup_steps ,
44- decay_steps = total_steps ,
45- end_value = learning_rate * 0.01
46+ decay_steps = warmup_steps + cosine_steps ,
47+ end_value = 0.0
4648 )
4749
48- def alpha_scheduler (alpha , alpha_start = 0 , warmup = 0 ):
49- warmup_fn = optax .linear_schedule (init_value = alpha_start , end_value = alpha , transition_steps = warmup )
50+ def alpha_scheduler (alpha , warmup = 0 ):
51+ warmup_fn = optax .linear_schedule (init_value = 0 , end_value = alpha , transition_steps = warmup )
5052 constant_fn = optax .constant_schedule (alpha )
5153 schedule_fn = optax .join_schedules (schedules = [warmup_fn , constant_fn ], boundaries = [warmup ])
5254 return schedule_fn
5355
5456
55- def beta3_scheduler (beta_end , beta_start = 0 , warmup = 0 ):
57+ def beta3_scheduler (beta3 , beta1 = 0 , warmup = 0 ):
5658
5759 def f (beta ):
5860 return jnp .log (0.5 )/ jnp .log (beta )- 1
@@ -62,9 +64,9 @@ def f_inv(t):
6264
6365 def warmup_fn (step ):
6466 frac = 1 - step / warmup
65- return f_inv ( frac * f (beta_start ) + (1 - frac ) * f (beta_end ))
67+ return f_inv ( frac * f (beta1 ) + (1 - frac ) * f (beta3 ))
6668
67- constant_fn = optax .constant_schedule (beta_end )
69+ constant_fn = optax .constant_schedule (beta3 )
6870 schedule_fn = optax .join_schedules (schedules = [warmup_fn , constant_fn ], boundaries = [warmup ])
6971 return schedule_fn
7072
@@ -329,16 +331,19 @@ def init_optimizer_state(
329331 lambda s : jnp .zeros (s .shape_tuple ), workload .param_shapes
330332 )
331333 lr = HPARAMS ['learning_rate' ]
332- b1 = HPARAMS ['b1' ]
333- b2 = HPARAMS ['b2' ]
334- b3 = HPARAMS ['b3' ]
334+ one_minus_beta1 = HPARAMS ['one_minus_beta1' ]
335+ b1 = 1.0 - one_minus_beta1
336+ b2 = HPARAMS ['beta2' ]
337+ b3 = HPARAMS ['beta3' ]
335338 alpha = HPARAMS ['alpha' ]
336339 variant = HPARAMS ['ademamix_variant' ]
337- warmup = HPARAMS ['warmup' ]
340+ warmup_factor = HPARAMS ['warmup_factor' ]
341+ beta3_warmup = HPARAMS ['beta3_warmup' ]
342+ alpha_warmup = HPARAMS ['alpha_warmup' ]
338343 T = workload .step_hint
339- f_b3 = beta3_scheduler (b3 , beta_start = b1 , warmup = T )
340- f_a = alpha_scheduler (alpha , alpha_start = 0 , warmup = T )
341- f_lr = lr_scheduler (lr , warmup , T )
344+ f_b3 = beta3_scheduler (b3 , beta1 = b1 , warmup = beta3_warmup )
345+ f_a = alpha_scheduler (alpha , warmup = alpha_warmup )
346+ f_lr = lr_scheduler (lr , warmup_factor , T )
342347 weight_decay = HPARAMS ['weight_decay' ]
343348 optimizer = build_ademamix_optimizer (
344349 lr = f_lr ,
@@ -361,10 +366,12 @@ def init_optimizer_state(
361366 def f (x ): return jnp .sum (x ** 2 ) # simple quadratic function
362367
363368 alpha = 8.0
364- b1 , b2 , b3 = 0.9 , 0.999 , 0.9999
369+ one_minus_beta1 = 0.1
370+ b1 = 1.0 - one_minus_beta1
371+ b2 , b3 = 0.999 , 0.9999
365372
366- f_a = alpha_scheduler (alpha , alpha_start = 0 , warmup = 10 )
367- f_b3 = beta3_scheduler (b3 , beta_start = b1 , warmup = 10 )
373+ f_a = alpha_scheduler (alpha , warmup = 10 )
374+ f_b3 = beta3_scheduler (b3 , beta1 = b1 , warmup = 10 )
368375
369376 solver = build_ademamix_optimizer (
370377 lr = 0.01 ,
0 commit comments