Skip to content

Commit 106858f

Browse files
committed
attempt to match pytorch submission behaviour
1 parent 532e5f0 commit 106858f

1 file changed

Lines changed: 35 additions & 28 deletions

File tree

submissions/self_tuning/ademamix/submission.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -20,39 +20,41 @@
2020
HPARAMS = {
2121
'ademamix_variant': 'simplified',
2222
'alpha': 8.0,
23-
'alpha_start': 0,
24-
'warmup': 10,
25-
'beta_end': 0.9999,
26-
'beta_start': 0.9,
27-
'learning_rate': 0.01,
28-
'b1': 0.9,
29-
'b2': 0.999,
30-
'b3': 0.9999,
23+
'warmup_factor': 0.02,
24+
'beta3_warmup': 500e3,
25+
'alpha_warmup': 500e3,
26+
'learning_rate': 2e-3,
27+
'one_minus_beta1': 0.2,
28+
'beta2': 0.995,
29+
'beta3': 0.9995,
3130
'eps': 1e-8,
3231
'eps_root': 0.0,
33-
'weight_decay': 0.01,
32+
'weight_decay': 0.1,
33+
'grad_clip': 0.5,
3434
'dropout_rate': 0.1,
3535
}
3636

3737
_GRAD_CLIP_EPS = 1e-6
3838

39-
def lr_scheduler(learning_rate, warmup_steps, total_steps):
39+
def lr_scheduler(learning_rate, warmup_factor, total_steps):
40+
warmup_steps = int(warmup_factor * total_steps)
41+
cosine_steps = max(total_steps - warmup_steps, 1)
4042
return optax.warmup_cosine_decay_schedule(
41-
init_value=0.0,
43+
init_value=learning_rate * 1e-10,
4244
peak_value=learning_rate,
4345
warmup_steps=warmup_steps,
44-
decay_steps=total_steps,
45-
end_value=learning_rate * 0.01
46+
decay_steps=warmup_steps + cosine_steps,
47+
end_value=0.0
4648
)
4749

48-
def alpha_scheduler(alpha, alpha_start=0, warmup=0):
49-
warmup_fn = optax.linear_schedule(init_value=alpha_start, end_value=alpha, transition_steps=warmup)
50+
def alpha_scheduler(alpha, warmup=0):
51+
warmup_fn = optax.linear_schedule(init_value=0, end_value=alpha, transition_steps=warmup)
5052
constant_fn = optax.constant_schedule(alpha)
5153
schedule_fn = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup])
5254
return schedule_fn
5355

5456

55-
def beta3_scheduler(beta_end, beta_start=0, warmup=0):
57+
def beta3_scheduler(beta3, beta1=0, warmup=0):
5658

5759
def f(beta):
5860
return jnp.log(0.5)/jnp.log(beta)-1
@@ -62,9 +64,9 @@ def f_inv(t):
6264

6365
def warmup_fn(step):
6466
frac = 1 - step / warmup
65-
return f_inv( frac * f(beta_start) + (1 - frac) * f(beta_end))
67+
return f_inv( frac * f(beta1) + (1 - frac) * f(beta3))
6668

67-
constant_fn = optax.constant_schedule(beta_end)
69+
constant_fn = optax.constant_schedule(beta3)
6870
schedule_fn = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup])
6971
return schedule_fn
7072

@@ -329,16 +331,19 @@ def init_optimizer_state(
329331
lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes
330332
)
331333
lr = HPARAMS['learning_rate']
332-
b1 = HPARAMS['b1']
333-
b2 = HPARAMS['b2']
334-
b3 = HPARAMS['b3']
334+
one_minus_beta1 = HPARAMS['one_minus_beta1']
335+
b1 = 1.0 - one_minus_beta1
336+
b2 = HPARAMS['beta2']
337+
b3 = HPARAMS['beta3']
335338
alpha = HPARAMS['alpha']
336339
variant = HPARAMS['ademamix_variant']
337-
warmup = HPARAMS['warmup']
340+
warmup_factor = HPARAMS['warmup_factor']
341+
beta3_warmup = HPARAMS['beta3_warmup']
342+
alpha_warmup = HPARAMS['alpha_warmup']
338343
T = workload.step_hint
339-
f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=T)
340-
f_a = alpha_scheduler(alpha, alpha_start=0, warmup=T)
341-
f_lr = lr_scheduler(lr, warmup, T)
344+
f_b3 = beta3_scheduler(b3, beta1=b1, warmup=beta3_warmup)
345+
f_a = alpha_scheduler(alpha, warmup=alpha_warmup)
346+
f_lr = lr_scheduler(lr, warmup_factor, T)
342347
weight_decay = HPARAMS['weight_decay']
343348
optimizer = build_ademamix_optimizer(
344349
lr=f_lr,
@@ -361,10 +366,12 @@ def init_optimizer_state(
361366
def f(x): return jnp.sum(x ** 2) # simple quadratic function
362367

363368
alpha = 8.0
364-
b1, b2, b3 = 0.9, 0.999, 0.9999
369+
one_minus_beta1 = 0.1
370+
b1 = 1.0 - one_minus_beta1
371+
b2, b3 = 0.999, 0.9999
365372

366-
f_a = alpha_scheduler(alpha, alpha_start=0, warmup=10)
367-
f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=10)
373+
f_a = alpha_scheduler(alpha, warmup=10)
374+
f_b3 = beta3_scheduler(b3, beta1=b1, warmup=10)
368375

369376
solver = build_ademamix_optimizer(
370377
lr=0.01,

0 commit comments

Comments
 (0)