1- """
2- Forked from apple's ademamix jax implementation:
3- https://github.com/apple/ml-ademamix
4- for the purposes of submitting to the algoperf benchmark.
5- | Adapted from optax's implementation of AdamW:
6- | https://github.com/google-deepmind/optax/blob/b75644809f2f68fc11f42d4395a5753e11e92e80/optax/_src/alias.py#L548#L675
7- """
1+ """AlgoPerf AdEMAMix submission built on Optax."""
82import functools
93from typing import (
104 Any ,
11- Callable ,
125 Dict ,
136 Iterator ,
147 List ,
15- NamedTuple ,
168 Optional ,
179 Tuple ,
18- Union ,
1910 )
2011
21- import chex
2212import jax
23- from jax import tree_util as jtu
2413import jax .numpy as jnp
2514import optax
2615from flax import jax_utils
27- from flax import traverse_util as tu
28- from jax import lax
29-
30- from optax ._src import transform , combine , base , numerics , utils
31- from optax import tree_utils as otu
3216
3317from algoperf import spec , jax_sharding_utils
3418
3519
3620HPARAMS = {
21+ 'ademamix_variant' : 'simplified' ,
3722 'alpha' : 8.0 ,
3823 'alpha_start' : 0 ,
3924 'warmup' : 10 ,
@@ -83,106 +68,44 @@ def warmup_fn(step):
8368 schedule_fn = optax .join_schedules (schedules = [warmup_fn , constant_fn ], boundaries = [warmup ])
8469 return schedule_fn
8570
86-
87- class ScaleByAdemamixState (NamedTuple ):
88- """State for the AdEMAMix algorithm."""
89- count : chex .Array
90- count_m2 : chex .Array
91- m1 : optax .Updates
92- m2 : optax .Updates
93- nu : optax .Updates
94-
95-
96- def ademamix (lr , b1 = 0.9 , b2 = 0.999 , b3 = 0.9999 , alpha = 5.0 , b3_scheduler = None , alpha_scheduler = None ,
97- eps = 1e-8 , eps_root = 0.0 , weight_decay = 0.0 , mask = None ):
98- r"""AdEMAMix.
99-
100- Args:
101- lr: A global scaling factor, either fixed or evolving along
102- iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
103- b1: Exponential decay rate to track the fast EMA.
104- b2: Exponential decay rate to track the second moment of past gradients.
105- b3: Exponential decay rate to track the slow EMA.
106- alpha: Mixing coeficient use for the linear combination of the fast and slow EMAs.
107- b3_scheduler: an optional scheduler function, given a timestep, returns the
108- value of b3. Use `beta3_scheduler(b3,b1,T_b3)` to follow the AdEMAMix paper.
109- alpha_scheduler: an optional scheduler function, given a timestep, returns the
110- value of alpha. Use `alpha_scheduler(alpha,0,T_alpha)` to follow the
111- AdEMAMix paper.
112- eps: A small constant applied to denominator outside of the square root
113- (as in the Adam paper) to avoid dividing by zero when rescaling.
114- eps_root: A small constant applied to denominator inside the square root (as
115- in RMSProp), to avoid dividing by zero when rescaling. This is needed for
116- instance when computing (meta-)gradients through Adam.
117- mu_dtype: Optional `dtype` to be used for the first order accumulator; if
118- `None` then the `dtype` is inferred from `params` and `updates`.
119- weight_decay: Strength of the weight decay regularization. Note that this
120- weight decay is multiplied with the learning rate. This is consistent
121- with other frameworks such as PyTorch, but different from
122- (Loshchilov et al, 2019) where the weight decay is only multiplied with
123- the "schedule multiplier", but not the base learning rate.
124- mask: A tree with same structure as (or a prefix of) the params PyTree,
125- or a Callable that returns such a pytree given the params/updates.
126- The leaves should be booleans, `True` for leaves/subtrees you want to
127- apply the weight decay to, and `False` for those you want to skip. Note
128- that the Adam gradient transformations are applied to all parameters.
129-
130- Returns:
131- The corresponding `GradientTransformation`.
132- """
133- return combine .chain (
134- scale_by_ademamix (b1 , b2 , b3 , alpha , b3_scheduler , alpha_scheduler , eps , eps_root ),
135- transform .add_decayed_weights (weight_decay , mask ),
136- transform .scale_by_learning_rate (lr ),
137- )
138-
139-
140- def scale_by_ademamix (b1 , b2 , b3 , alpha , b3_scheduler , alpha_scheduler , eps = 1e-8 , eps_root = 0.0 ):
141-
142- def init_fn (params ):
143- m1 = jax .tree .map (jnp .zeros_like , params ) # fast EMA
144- m2 = jax .tree .map (jnp .zeros_like , params ) # slow EMA
145- nu = jax .tree .map (jnp .zeros_like , params ) # second moment estimate
146- return ScaleByAdemamixState (count = jnp .zeros ([], jnp .int32 ), count_m2 = jnp .zeros ([], jnp .int32 ), m1 = m1 , m2 = m2 , nu = nu )
147-
148- def update_fn (updates , state , params = None ):
149- del params
150- c_b3 = b3_scheduler (state .count_m2 ) if b3_scheduler is not None else b3
151- c_alpha = alpha_scheduler (state .count_m2 ) if alpha_scheduler is not None else alpha
152- m1 = _update_moment (updates , state .m1 , b1 , 1 ) # m1 = b1 * m1 + (1-b1) * updates
153- m2 = _update_moment (updates , state .m2 , c_b3 , 1 )
154- nu = _update_moment (updates , state .nu , b2 , 2 )
155- count = state .count + jnp .array (1 , dtype = jnp .int32 )
156- # count_inc = numerics.safe_int32_increment(state.count)
157- count_m2 = state .count_m2 + jnp .array (1 , dtype = jnp .int32 )
158- # count_m2_inc = numerics.safe_int32_increment(state.count_m2)
159- m1_hat = _bias_correction (m1 , b1 , count )
160- nu_hat = _bias_correction (nu , b2 , count )
161- updates = jax .tree .map (lambda m1_ , m2_ , v_ : (m1_ + c_alpha * m2_ )/ (jnp .sqrt (v_ + eps_root )+ eps ), m1_hat , m2 , nu_hat )
162- return updates , ScaleByAdemamixState (count = count , count_m2 = count_m2 , m1 = m1 , m2 = m2 , nu = nu )
163-
164- return base .GradientTransformation (init_fn , update_fn )
165-
166-
167- def _update_moment (updates , moments , decay , order ):
168- """Compute the exponential moving average of the `order`-th moment."""
169- return jax .tree .map (
170- lambda g , t : (1 - decay ) * (g ** order ) + decay * t , updates , moments )
171-
172-
173-
174- def _bias_correction (moment , decay , count ):
175- """Performs bias correction. It becomes a no-op as count goes to infinity."""
176- # The conversion to the data type of the moment ensures that bfloat16 remains
177- # bfloat16 in the optimizer state. This conversion has to be done after
178- # `bias_correction_` is calculated as calculating `decay**count` in low
179- # precision can result in it being rounded to 1 and subsequently a
180- # "division by zero" error.
181- bias_correction_ = 1 - decay ** count
182-
183- # Perform division in the original precision.
184- return jax .tree .map (
185- lambda t : t / bias_correction_ .astype (t .dtype ), moment )
71+ def build_ademamix_optimizer (
72+ lr ,
73+ variant = 'simplified' ,
74+ b1 = 0.9 ,
75+ b2 = 0.999 ,
76+ b3 = 0.9999 ,
77+ alpha = 5.0 ,
78+ b3_scheduler = None ,
79+ alpha_scheduler = None ,
80+ eps = 1e-8 ,
81+ eps_root = 0.0 ,
82+ weight_decay = 0.0 ,
83+ mask = None ,
84+ ):
85+ if variant == 'simplified' :
86+ return optax .contrib .simplified_ademamix (
87+ learning_rate = lr ,
88+ b1 = b1 ,
89+ b2 = b2 ,
90+ alpha = alpha_scheduler if alpha_scheduler is not None else alpha ,
91+ eps = eps ,
92+ eps_root = eps_root ,
93+ weight_decay = weight_decay ,
94+ mask = mask ,
95+ )
96+ if variant == 'full' :
97+ return optax .contrib .ademamix (
98+ learning_rate = lr ,
99+ b1 = b1 ,
100+ b2 = b2 ,
101+ b3 = b3_scheduler if b3_scheduler is not None else b3 ,
102+ alpha = alpha_scheduler if alpha_scheduler is not None else alpha ,
103+ eps = eps ,
104+ eps_root = eps_root ,
105+ weight_decay = weight_decay ,
106+ mask = mask ,
107+ )
108+ raise ValueError (f'Unsupported ademamix variant: { variant } ' )
186109
187110
188111def train_step (workload ,
@@ -410,21 +333,26 @@ def init_optimizer_state(
410333 b2 = HPARAMS ['b2' ]
411334 b3 = HPARAMS ['b3' ]
412335 alpha = HPARAMS ['alpha' ]
336+ variant = HPARAMS ['ademamix_variant' ]
413337 warmup = HPARAMS ['warmup' ]
414338 T = workload .step_hint
415339 f_b3 = beta3_scheduler (b3 , beta_start = b1 , warmup = T )
416340 f_a = alpha_scheduler (alpha , alpha_start = 0 , warmup = T )
417341 f_lr = lr_scheduler (lr , warmup , T )
418342 weight_decay = HPARAMS ['weight_decay' ]
419- opt_init_fn , opt_update_fn = ademamix (lr = f_lr ,
420- b1 = b1 ,
421- b2 = b2 ,
422- b3 = b3 ,
423- alpha = alpha ,
424- b3_scheduler = f_b3 ,
425- alpha_scheduler = f_a ,
426- weight_decay = weight_decay
427- )
343+ optimizer = build_ademamix_optimizer (
344+ lr = f_lr ,
345+ variant = variant ,
346+ b1 = b1 ,
347+ b2 = b2 ,
348+ b3 = b3 ,
349+ alpha = alpha ,
350+ b3_scheduler = f_b3 ,
351+ alpha_scheduler = f_a ,
352+ weight_decay = weight_decay ,
353+ )
354+ opt_init_fn = optimizer .init
355+ opt_update_fn = optimizer .update
428356 optimizer_state = opt_init_fn (params_zeros_like )
429357 return optimizer_state , opt_update_fn
430358
@@ -438,14 +366,17 @@ def f(x): return jnp.sum(x ** 2) # simple quadratic function
438366 f_a = alpha_scheduler (alpha , alpha_start = 0 , warmup = 10 )
439367 f_b3 = beta3_scheduler (b3 , beta_start = b1 , warmup = 10 )
440368
441- solver = ademamix (lr = 0.01 ,
442- b1 = b1 ,
443- b2 = b2 ,
444- b3 = b3 ,
445- alpha = alpha ,
446- b3_scheduler = f_b3 ,
447- alpha_scheduler = f_a ,
448- weight_decay = 0.01 )
369+ solver = build_ademamix_optimizer (
370+ lr = 0.01 ,
371+ variant = 'full' ,
372+ b1 = b1 ,
373+ b2 = b2 ,
374+ b3 = b3 ,
375+ alpha = alpha ,
376+ b3_scheduler = f_b3 ,
377+ alpha_scheduler = f_a ,
378+ weight_decay = 0.01 ,
379+ )
449380
450381 params = jnp .array ([1. , 2. , 3. ])
451382 print ('Objective function: {:.2f}' .format (f (params )))
0 commit comments