Skip to content

Commit 532e5f0

Browse files
committed
switched to simplified ademamix from optax
1 parent 4a314de commit 532e5f0

1 file changed

Lines changed: 65 additions & 134 deletions

File tree

submissions/self_tuning/ademamix/submission.py

Lines changed: 65 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,24 @@
1-
"""
2-
Forked from apple's ademamix jax implementation:
3-
https://github.com/apple/ml-ademamix
4-
for the purposes of submitting to the algoperf benchmark.
5-
| Adapted from optax's implementation of AdamW:
6-
| https://github.com/google-deepmind/optax/blob/b75644809f2f68fc11f42d4395a5753e11e92e80/optax/_src/alias.py#L548#L675
7-
"""
1+
"""AlgoPerf AdEMAMix submission built on Optax."""
82
import functools
93
from typing import (
104
Any,
11-
Callable,
125
Dict,
136
Iterator,
147
List,
15-
NamedTuple,
168
Optional,
179
Tuple,
18-
Union,
1910
)
2011

21-
import chex
2212
import jax
23-
from jax import tree_util as jtu
2413
import jax.numpy as jnp
2514
import optax
2615
from flax import jax_utils
27-
from flax import traverse_util as tu
28-
from jax import lax
29-
30-
from optax._src import transform, combine, base, numerics, utils
31-
from optax import tree_utils as otu
3216

3317
from algoperf import spec, jax_sharding_utils
3418

3519

3620
HPARAMS = {
21+
'ademamix_variant': 'simplified',
3722
'alpha': 8.0,
3823
'alpha_start': 0,
3924
'warmup': 10,
@@ -83,106 +68,44 @@ def warmup_fn(step):
8368
schedule_fn = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup])
8469
return schedule_fn
8570

86-
87-
class ScaleByAdemamixState(NamedTuple):
88-
"""State for the AdEMAMix algorithm."""
89-
count: chex.Array
90-
count_m2: chex.Array
91-
m1: optax.Updates
92-
m2: optax.Updates
93-
nu: optax.Updates
94-
95-
96-
def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None,
97-
eps=1e-8, eps_root=0.0, weight_decay=0.0, mask=None):
98-
r"""AdEMAMix.
99-
100-
Args:
101-
lr: A global scaling factor, either fixed or evolving along
102-
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
103-
b1: Exponential decay rate to track the fast EMA.
104-
b2: Exponential decay rate to track the second moment of past gradients.
105-
b3: Exponential decay rate to track the slow EMA.
106-
alpha: Mixing coeficient use for the linear combination of the fast and slow EMAs.
107-
b3_scheduler: an optional scheduler function, given a timestep, returns the
108-
value of b3. Use `beta3_scheduler(b3,b1,T_b3)` to follow the AdEMAMix paper.
109-
alpha_scheduler: an optional scheduler function, given a timestep, returns the
110-
value of alpha. Use `alpha_scheduler(alpha,0,T_alpha)` to follow the
111-
AdEMAMix paper.
112-
eps: A small constant applied to denominator outside of the square root
113-
(as in the Adam paper) to avoid dividing by zero when rescaling.
114-
eps_root: A small constant applied to denominator inside the square root (as
115-
in RMSProp), to avoid dividing by zero when rescaling. This is needed for
116-
instance when computing (meta-)gradients through Adam.
117-
mu_dtype: Optional `dtype` to be used for the first order accumulator; if
118-
`None` then the `dtype` is inferred from `params` and `updates`.
119-
weight_decay: Strength of the weight decay regularization. Note that this
120-
weight decay is multiplied with the learning rate. This is consistent
121-
with other frameworks such as PyTorch, but different from
122-
(Loshchilov et al, 2019) where the weight decay is only multiplied with
123-
the "schedule multiplier", but not the base learning rate.
124-
mask: A tree with same structure as (or a prefix of) the params PyTree,
125-
or a Callable that returns such a pytree given the params/updates.
126-
The leaves should be booleans, `True` for leaves/subtrees you want to
127-
apply the weight decay to, and `False` for those you want to skip. Note
128-
that the Adam gradient transformations are applied to all parameters.
129-
130-
Returns:
131-
The corresponding `GradientTransformation`.
132-
"""
133-
return combine.chain(
134-
scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps, eps_root),
135-
transform.add_decayed_weights(weight_decay, mask),
136-
transform.scale_by_learning_rate(lr),
137-
)
138-
139-
140-
def scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps=1e-8, eps_root=0.0):
141-
142-
def init_fn(params):
143-
m1 = jax.tree.map(jnp.zeros_like, params) # fast EMA
144-
m2 = jax.tree.map(jnp.zeros_like, params) # slow EMA
145-
nu = jax.tree.map(jnp.zeros_like, params) # second moment estimate
146-
return ScaleByAdemamixState(count=jnp.zeros([], jnp.int32), count_m2=jnp.zeros([], jnp.int32), m1=m1, m2=m2, nu=nu)
147-
148-
def update_fn(updates, state, params=None):
149-
del params
150-
c_b3 = b3_scheduler(state.count_m2) if b3_scheduler is not None else b3
151-
c_alpha = alpha_scheduler(state.count_m2) if alpha_scheduler is not None else alpha
152-
m1 = _update_moment(updates, state.m1, b1, 1) # m1 = b1 * m1 + (1-b1) * updates
153-
m2 = _update_moment(updates, state.m2, c_b3, 1)
154-
nu = _update_moment(updates, state.nu, b2, 2)
155-
count = state.count + jnp.array(1, dtype=jnp.int32)
156-
# count_inc = numerics.safe_int32_increment(state.count)
157-
count_m2 = state.count_m2 + jnp.array(1, dtype=jnp.int32)
158-
# count_m2_inc = numerics.safe_int32_increment(state.count_m2)
159-
m1_hat = _bias_correction(m1, b1, count)
160-
nu_hat = _bias_correction(nu, b2, count)
161-
updates = jax.tree.map(lambda m1_, m2_, v_: (m1_+c_alpha*m2_)/(jnp.sqrt(v_+eps_root)+eps), m1_hat, m2, nu_hat)
162-
return updates, ScaleByAdemamixState(count=count, count_m2=count_m2, m1=m1, m2=m2, nu=nu)
163-
164-
return base.GradientTransformation(init_fn, update_fn)
165-
166-
167-
def _update_moment(updates, moments, decay, order):
168-
"""Compute the exponential moving average of the `order`-th moment."""
169-
return jax.tree.map(
170-
lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments)
171-
172-
173-
174-
def _bias_correction(moment, decay, count):
175-
"""Performs bias correction. It becomes a no-op as count goes to infinity."""
176-
# The conversion to the data type of the moment ensures that bfloat16 remains
177-
# bfloat16 in the optimizer state. This conversion has to be done after
178-
# `bias_correction_` is calculated as calculating `decay**count` in low
179-
# precision can result in it being rounded to 1 and subsequently a
180-
# "division by zero" error.
181-
bias_correction_ = 1 - decay**count
182-
183-
# Perform division in the original precision.
184-
return jax.tree.map(
185-
lambda t: t / bias_correction_.astype(t.dtype), moment)
71+
def build_ademamix_optimizer(
72+
lr,
73+
variant='simplified',
74+
b1=0.9,
75+
b2=0.999,
76+
b3=0.9999,
77+
alpha=5.0,
78+
b3_scheduler=None,
79+
alpha_scheduler=None,
80+
eps=1e-8,
81+
eps_root=0.0,
82+
weight_decay=0.0,
83+
mask=None,
84+
):
85+
if variant == 'simplified':
86+
return optax.contrib.simplified_ademamix(
87+
learning_rate=lr,
88+
b1=b1,
89+
b2=b2,
90+
alpha=alpha_scheduler if alpha_scheduler is not None else alpha,
91+
eps=eps,
92+
eps_root=eps_root,
93+
weight_decay=weight_decay,
94+
mask=mask,
95+
)
96+
if variant == 'full':
97+
return optax.contrib.ademamix(
98+
learning_rate=lr,
99+
b1=b1,
100+
b2=b2,
101+
b3=b3_scheduler if b3_scheduler is not None else b3,
102+
alpha=alpha_scheduler if alpha_scheduler is not None else alpha,
103+
eps=eps,
104+
eps_root=eps_root,
105+
weight_decay=weight_decay,
106+
mask=mask,
107+
)
108+
raise ValueError(f'Unsupported ademamix variant: {variant}')
186109

187110

188111
def train_step(workload,
@@ -410,21 +333,26 @@ def init_optimizer_state(
410333
b2 = HPARAMS['b2']
411334
b3 = HPARAMS['b3']
412335
alpha = HPARAMS['alpha']
336+
variant = HPARAMS['ademamix_variant']
413337
warmup = HPARAMS['warmup']
414338
T = workload.step_hint
415339
f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=T)
416340
f_a = alpha_scheduler(alpha, alpha_start=0, warmup=T)
417341
f_lr = lr_scheduler(lr, warmup, T)
418342
weight_decay = HPARAMS['weight_decay']
419-
opt_init_fn, opt_update_fn = ademamix(lr=f_lr,
420-
b1=b1,
421-
b2=b2,
422-
b3=b3,
423-
alpha=alpha,
424-
b3_scheduler=f_b3,
425-
alpha_scheduler=f_a,
426-
weight_decay=weight_decay
427-
)
343+
optimizer = build_ademamix_optimizer(
344+
lr=f_lr,
345+
variant=variant,
346+
b1=b1,
347+
b2=b2,
348+
b3=b3,
349+
alpha=alpha,
350+
b3_scheduler=f_b3,
351+
alpha_scheduler=f_a,
352+
weight_decay=weight_decay,
353+
)
354+
opt_init_fn = optimizer.init
355+
opt_update_fn = optimizer.update
428356
optimizer_state = opt_init_fn(params_zeros_like)
429357
return optimizer_state, opt_update_fn
430358

@@ -438,14 +366,17 @@ def f(x): return jnp.sum(x ** 2) # simple quadratic function
438366
f_a = alpha_scheduler(alpha, alpha_start=0, warmup=10)
439367
f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=10)
440368

441-
solver = ademamix(lr=0.01,
442-
b1=b1,
443-
b2=b2,
444-
b3=b3,
445-
alpha=alpha,
446-
b3_scheduler=f_b3,
447-
alpha_scheduler=f_a,
448-
weight_decay=0.01)
369+
solver = build_ademamix_optimizer(
370+
lr=0.01,
371+
variant='full',
372+
b1=b1,
373+
b2=b2,
374+
b3=b3,
375+
alpha=alpha,
376+
b3_scheduler=f_b3,
377+
alpha_scheduler=f_a,
378+
weight_decay=0.01,
379+
)
449380

450381
params = jnp.array([1., 2., 3.])
451382
print('Objective function: {:.2f}'.format(f(params)))

0 commit comments

Comments
 (0)