From 22034339b28c96daf604f91c41b4c33c2ebfc825 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 15:44:05 -0400 Subject: [PATCH 01/33] first commit for ademamix jax submission --- .../self_tuning/ademamix/submission.py | 429 ++++++++++++++++++ 1 file changed, 429 insertions(+) create mode 100644 submissions/self_tuning/ademamix/submission.py diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py new file mode 100644 index 000000000..bbc15fe07 --- /dev/null +++ b/submissions/self_tuning/ademamix/submission.py @@ -0,0 +1,429 @@ +""" +Forked from apple's ademamix jax implementation: +https://github.com/apple/ml-ademamix +for the purposes of submitting to the algoperf benchmark. +| Adapted from optax's implementation of AdamW: +| https://github.com/google-deepmind/optax/blob/b75644809f2f68fc11f42d4395a5753e11e92e80/optax/_src/alias.py#L548#L675 +""" +import functools +from typing import ( + Any, + Callable, + Dict, + Iterator, + List, + NamedTuple, + Optional, + Tuple, + Union, + ) + +import chex +import jax +from jax import tree_util as jtu +import jax.numpy as jnp +import optax +from flax import jax_utils +from jax import lax + +from optax._src import transform, combine, base, numerics, utils +from optax import tree_utils as otu + +from algoperf import spec + +HPARAMS = { + 'alpha': , + 'alpha_start': , + 'warmup': , + 'beta_end': , + 'beta_start': , + 'learning_rate': , + 'b1': , + 'b2': , + 'b3': , + 'eps': , + 'eps_root': , + 'weight_decay': , + } + +_GRAD_CLIP_EPS = 1e-6 + +def alpha_scheduler(alpha, alpha_start=0, warmup=0): + def schedule(step): + is_warmup = (step < warmup).astype(jnp.float32) + a = step / float(warmup) + return is_warmup * ((1.0-a) * alpha_start + a * alpha) + alpha * (1.0-is_warmup) + + return schedule + + +def beta3_scheduler(beta_end, beta_start=0, warmup=0): + + def f(beta): + return jnp.log(0.5)/jnp.log(beta)-1 + + def f_inv(t): + return jnp.power(0.5, 1/(t+1)) + + def schedule(step): + is_warmup = (step < warmup).astype(jnp.float32) + alpha = step / float(warmup) + return is_warmup * f_inv((1.0-alpha) * f(beta_start) + alpha * f(beta_end)) + beta_end * (1.0-is_warmup) + + return schedule + + +class ScaleByAdemamixState(NamedTuple): + """State for the AdEMAMix algorithm.""" + count: chex.Array + count_m2: chex.Array + m1: base.Updates + m2: base.Updates + nu: base.Updates + + +def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None, + eps=1e-8, eps_root=0.0, weight_decay=0.0, mu_dtype=None, mask=None): + r"""AdEMAMix. + + Args: + lr: A global scaling factor, either fixed or evolving along + iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. + b1: Exponential decay rate to track the fast EMA. + b2: Exponential decay rate to track the second moment of past gradients. + b3: Exponential decay rate to track the slow EMA. + alpha: Mixing coeficient use for the linear combination of the fast and slow EMAs. + b3_scheduler: an optional scheduler function, given a timestep, returns the + value of b3. Use `beta3_scheduler(b3,b1,T_b3)` to follow the AdEMAMix paper. + alpha_scheduler: an optional scheduler function, given a timestep, returns the + value of alpha. Use `alpha_scheduler(alpha,0,T_alpha)` to follow the + AdEMAMix paper. + eps: A small constant applied to denominator outside of the square root + (as in the Adam paper) to avoid dividing by zero when rescaling. + eps_root: A small constant applied to denominator inside the square root (as + in RMSProp), to avoid dividing by zero when rescaling. This is needed for + instance when computing (meta-)gradients through Adam. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + weight_decay: Strength of the weight decay regularization. Note that this + weight decay is multiplied with the learning rate. This is consistent + with other frameworks such as PyTorch, but different from + (Loshchilov et al, 2019) where the weight decay is only multiplied with + the "schedule multiplier", but not the base learning rate. + mask: A tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the weight decay to, and `False` for those you want to skip. Note + that the Adam gradient transformations are applied to all parameters. + + Returns: + The corresponding `GradientTransformation`. + """ + return combine.chain( + scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps, eps_root, mu_dtype), + transform.add_decayed_weights(weight_decay, mask), + transform.scale_by_learning_rate(lr), + ) + + +def scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps=1e-8, eps_root=0.0, mu_dtype=None): + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + m1 = tree_zeros_like(params, dtype=mu_dtype) # fast EMA + m2 = tree_zeros_like(params, dtype=mu_dtype) # slow EMA + nu = tree_zeros_like(params, dtype=mu_dtype) # second moment estimate + return ScaleByAdemamixState(count=jnp.zeros([], jnp.int32), count_m2=jnp.zeros([], jnp.int32), m1=m1, m2=m2, nu=nu) + + def update_fn(updates, state, params=None): + del params + c_b3 = b3_scheduler(state.count_m2) if b3_scheduler is not None else b3 + c_alpha = alpha_scheduler(state.count_m2) if alpha_scheduler is not None else alpha + m1 = tree_update_moment(updates, state.m1, b1, 1) # m1 = b1 * m1 + (1-b1) * updates + m2 = tree_update_moment(updates, state.m2, c_b3, 1) + nu = tree_update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = numerics.safe_int32_increment(state.count) + count_m2_inc = numerics.safe_int32_increment(state.count_m2) + m1_hat = tree_bias_correction(m1, b1, count_inc) + nu_hat = tree_bias_correction(nu, b2, count_inc) + updates = jtu.tree_map(lambda m1_, m2_, v_: (m1_+c_alpha*m2_)/(jnp.sqrt(v_+eps_root)+eps), m1_hat, m2, nu_hat) + mu1 = tree_cast(m1, mu_dtype) + mu2 = tree_cast(m2, mu_dtype) + return updates, ScaleByAdemamixState(count=count_inc, count_m2=count_m2_inc, m1=m1, m2=m2, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +def tree_cast(tree, dtype): + """Cast tree to given dtype, skip if None.""" + if dtype is not None: + return jtu.tree_map(lambda t: t.astype(dtype), tree) + else: + return tree + + +def tree_zeros_like( + tree, + dtype = None, +): + """Creates an all-zeros tree with the same structure. + + Args: + tree: pytree. + dtype: optional dtype to use for the tree of zeros. + + Returns: + an all-zeros tree with the same structure as ``tree``. + """ + return jtu.tree_map(lambda x: jnp.zeros_like(x, dtype=dtype), tree) + + +def tree_update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order`-th moment.""" + return jtu.tree_map( + lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) + + +def tree_update_moment_per_elem_norm(updates, moments, decay, order): + """Compute the EMA of the `order`-th moment of the element-wise norm.""" + + def orderth_norm(g): + if jnp.isrealobj(g): + return g ** order + else: + half_order = order / 2 + # JAX generates different HLO for int and float `order` + if half_order.is_integer(): + half_order = int(half_order) + return numerics.abs_sq(g) ** half_order + + return jtu.tree_map( + lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments) + + +@functools.partial(jax.jit, inline=True) +def tree_bias_correction(moment, decay, count): + """Performs bias correction. It becomes a no-op as count goes to infinity.""" + # The conversion to the data type of the moment ensures that bfloat16 remains + # bfloat16 in the optimizer state. This conversion has to be done after + # `bias_correction_` is calculated as calculating `decay**count` in low + # precision can result in it being rounded to 1 and subsequently a + # "division by zero" error. + bias_correction_ = 1 - decay**count + + # Perform division in the original precision. + return jax.tree_util.tree_map( + lambda t: t / bias_correction_.astype(t.dtype), moment) + +@functools.partial( + jax.pmap, + axis_name='batch', + in_axes=(None, None, 0, 0, 0, None, 0, 0), + static_broadcasted_argnums=(0, 1), + ) +def pmapped_train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing + ): + def _loss_fn(params): + logits, new_model_state = workload.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.TRAIN, + rng, + update_batch_norm=True, + ) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container + ) + (summed_loss, n_valid_examples, grad) = lax.psum( + (summed_loss, n_valid_examples, grad), axis_name='batch') + ) + # change to lax.psum or something different for jit? + loss = summed_loss / n_valid_examples + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) + grad_norm = jnp.sqrt( + sum(jnp.sum(g ** 2) for g in jax.tree_util.tree_leaves(grad)) + ) + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn( + grad, optimizer_state, current_param_container + ) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, + ) -> spec.UpdateReturn: + del current_params_types + del loss_type + del train_state + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters['label_smoothing'] + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters['grad_clip'] + else: + grad_clip = None + outputs = train_step( + workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + ) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, + global_step + ) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + ) -> spec.UpdateReturn: + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + +def get_batch_size(workload_name): + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, + ) -> Dict[str, spec.Tensor]: + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch + +if __name__ == "__main__": # dummy test + + def f(x): return jnp.sum(x ** 2) # simple quadratic function + + alpha = 8.0 + b1, b2, b3 = 0.9, 0.999, 0.9999 + + f_a = alpha_scheduler(alpha, alpha_start=0, warmup=10) + f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=10) + + solver = ademamix(lr=0.01, + b1=b1, + b2=b2, + b3=b3, + alpha=alpha, + b3_scheduler=f_b3, + alpha_scheduler=f_a, + weight_decay=0.01) + + params = jnp.array([1., 2., 3.]) + print('Objective function: {:.2f}'.format(f(params))) + opt_state = solver.init(params) + for itr in range(100): + grad = jax.grad(f)(params) + updates, opt_state = solver.update(grad, opt_state, params) + params = jax.tree_util.tree_map(lambda p, u: p + u, params, updates) + if itr % 5 == 0: + print('Objective function: {:.2f}'.format(f(params))) + print(params) From 9fce31e35c616ceb80f92843f5d174ebf03c0618 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 16:20:48 -0400 Subject: [PATCH 02/33] updated hparams --- .../self_tuning/ademamix/submission.py | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index bbc15fe07..163842e2c 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -31,19 +31,20 @@ from algoperf import spec + HPARAMS = { - 'alpha': , - 'alpha_start': , - 'warmup': , - 'beta_end': , - 'beta_start': , - 'learning_rate': , - 'b1': , - 'b2': , - 'b3': , - 'eps': , - 'eps_root': , - 'weight_decay': , + 'alpha': 8.0, + 'alpha_start': 0, + 'warmup': 10, + 'beta_end': 0.9999, + 'beta_start': 0.9, + 'learning_rate': 0.01, + 'b1': 0.9, + 'b2': 0.999, + 'b3': 0.9999, + 'eps': 1e-8, + 'eps_root': 0.0, + 'weight_decay': 0.01, } _GRAD_CLIP_EPS = 1e-6 From f22e0c6f2af23a9ceb8580682ac5bd4cc24a2550 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 19:09:55 -0400 Subject: [PATCH 03/33] typo --- submissions/self_tuning/ademamix/submission.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 163842e2c..236c19eb6 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -258,9 +258,8 @@ def _loss_fn(params): current_param_container ) (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch') + (summed_loss, n_valid_examples, grad), axis_name='batch' ) - # change to lax.psum or something different for jit? loss = summed_loss / n_valid_examples grad = jax.tree.map(lambda x: x / n_valid_examples, grad) grad_norm = jnp.sqrt( From 2953358267e6c3b9c460dd41dae60cab01e34e25 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 19:20:40 -0400 Subject: [PATCH 04/33] added init_optimizer_state --- .../self_tuning/ademamix/submission.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 236c19eb6..82f6a15c2 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -398,6 +398,40 @@ def data_selection( batch = next(input_queue) return batch +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, + ) -> spec.OptimizerState: + del model_params + del model_state + del rng + params_zeros_like = jax.tree.map( + lambda s: jnp.zeros(s.shape_tuple, workload.param_shapes) + ) + lr = HPARAMS['learning_rate'] + b1 = HPARAMS['b1'] + b2 = HPARAMS['b2'] + b3 = HPARAMS['b3'] + alpha = HPARAMS['alpha'] + warmup = HPARAMS['warmup'] + f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=warmup) + f_a = alpha_scheduler(alpha, alpha_start=0, warmup=warmup) + weight_decay = HPARAMS['weight_decay'] + opt_init_fn, opt_update_fn = ademamix(lr=lr, + b1=b1, + b2=b2, + b3=b3, + alpha=alpha, + b3_scheduler=f_b3, + alpha_scheduler=f_a, + weight_decay=weight_decay + ) + optimizer_state = opt_init_fn(params_zeros_like) + return jax_utils.replicate(optimizer_state), opt_update_fn + if __name__ == "__main__": # dummy test def f(x): return jnp.sum(x ** 2) # simple quadratic function From 1f86006441e883bebf79d30b49f49f465a85a230 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 19:23:28 -0400 Subject: [PATCH 05/33] added cifar as workload --- submissions/self_tuning/ademamix/submission.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 82f6a15c2..f979cf4c1 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -375,6 +375,8 @@ def get_batch_size(workload_name): return 128 elif workload_name == 'mnist': return 16 + elif workload_name == 'cifar': + return 128 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From cd48528e9f1c251a5398ab9cafa2fd2b887b38be Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 19:34:37 -0400 Subject: [PATCH 06/33] error in init_optimizer_state --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index f979cf4c1..02936c561 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -411,7 +411,7 @@ def init_optimizer_state( del model_state del rng params_zeros_like = jax.tree.map( - lambda s: jnp.zeros(s.shape_tuple, workload.param_shapes) + lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes ) lr = HPARAMS['learning_rate'] b1 = HPARAMS['b1'] From 07510f8fae610e8c2bfe55048b8f3614ead836fb Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 19:39:33 -0400 Subject: [PATCH 07/33] changed train_step to pmapped_train_step --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 02936c561..83fda42bb 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -308,7 +308,7 @@ def update_params( grad_clip = hyperparameters['grad_clip'] else: grad_clip = None - outputs = train_step( + outputs = pmapped_train_step( workload, opt_update_fn, model_state, From f7918fedbfb320aeefcf191d4ee50f15e48d45ad Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 19:45:11 -0400 Subject: [PATCH 08/33] pmap was incorrect --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 83fda42bb..a87bb6898 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -220,7 +220,7 @@ def tree_bias_correction(moment, decay, count): @functools.partial( jax.pmap, axis_name='batch', - in_axes=(None, None, 0, 0, 0, None, 0, 0), + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), static_broadcasted_argnums=(0, 1), ) def pmapped_train_step( From 582fff135661691e5e183cc9bd67de46c2383801 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 20:01:23 -0400 Subject: [PATCH 09/33] hparams --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index a87bb6898..7f637c8a4 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -376,7 +376,7 @@ def get_batch_size(workload_name): elif workload_name == 'mnist': return 16 elif workload_name == 'cifar': - return 128 + return 1024 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 69a94bcbc605a847c0acc68af9ed8602e44782a8 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 20 Aug 2025 20:27:11 -0400 Subject: [PATCH 10/33] batch size --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 7f637c8a4..c62a763a7 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -376,7 +376,7 @@ def get_batch_size(workload_name): elif workload_name == 'mnist': return 16 elif workload_name == 'cifar': - return 1024 + return 2048 else: raise ValueError(f'Unsupported workload name: {workload_name}.') From 0f7f625fc79dec07dfab8672d1b2a118a7423952 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Fri, 12 Sep 2025 20:02:32 -0400 Subject: [PATCH 11/33] updated beta and alpha schedulers to use stephint --- submissions/self_tuning/ademamix/submission.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index c62a763a7..79bab1c70 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -418,9 +418,9 @@ def init_optimizer_state( b2 = HPARAMS['b2'] b3 = HPARAMS['b3'] alpha = HPARAMS['alpha'] - warmup = HPARAMS['warmup'] - f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=warmup) - f_a = alpha_scheduler(alpha, alpha_start=0, warmup=warmup) + T = workload.step_hint + f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=T) + f_a = alpha_scheduler(alpha, alpha_start=0, warmup=T) weight_decay = HPARAMS['weight_decay'] opt_init_fn, opt_update_fn = ademamix(lr=lr, b1=b1, From 4a05d5a928ccf0768bca3d65bbf03d44665a631c Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 17 Sep 2025 17:48:29 -0400 Subject: [PATCH 12/33] testing ademamix --- submissions/self_tuning/ademamix/submission.py | 1 + 1 file changed, 1 insertion(+) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 79bab1c70..b6b71f647 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -222,6 +222,7 @@ def tree_bias_correction(moment, decay, count): axis_name='batch', in_axes=(None, None, 0, 0, 0, 0, 0, None, None), static_broadcasted_argnums=(0, 1), + donate_argnums=(2, 3, 4) ) def pmapped_train_step( workload, From f381548d4688a0486edfaa953bc42e628269359a Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 24 Sep 2025 13:46:28 -0400 Subject: [PATCH 13/33] switch to jitted train step --- .../self_tuning/ademamix/submission.py | 248 ++++++++++-------- 1 file changed, 141 insertions(+), 107 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index b6b71f647..c7299cf03 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -217,120 +217,154 @@ def tree_bias_correction(moment, decay, count): return jax.tree_util.tree_map( lambda t: t / bias_correction_.astype(t.dtype), moment) -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None, None), - static_broadcasted_argnums=(0, 1), - donate_argnums=(2, 3, 4) - ) -def pmapped_train_step( - workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, +#@functools.partial( +# jax.pmap, +# axis_name='batch', +# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), +# static_broadcasted_argnums=(0, 1), +# donate_argnums=(2, 3, 4) +# ) +def train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + rng, + grad_clip, + label_smoothing, + dropout_rate, + ): + + def _loss_fn(params): + """Loss function used for training.""" + logits, new_model_state = workload.model_fn( + params, batch, + model_state, + spec.ForwardPassMode.TRAIN, rng, - grad_clip, - label_smoothing - ): - def _loss_fn(params): - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - update_batch_norm=True, - ) - loss_dict = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits, - mask_batch=batch.get('weights'), - label_smoothing=label_smoothing, - ) - summed_loss = loss_dict['summed'] - n_valid_examples = loss_dict['n_valid_examples'] - return summed_loss, (n_valid_examples, new_model_state) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( - current_param_container - ) - (summed_loss, n_valid_examples, grad) = lax.psum( - (summed_loss, n_valid_examples, grad), axis_name='batch' - ) - loss = summed_loss / n_valid_examples - grad = jax.tree.map(lambda x: x / n_valid_examples, grad) - grad_norm = jnp.sqrt( - sum(jnp.sum(g ** 2) for g in jax.tree_util.tree_leaves(grad)) - ) - if grad_clip is not None: - grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) - grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + update_batch_norm=True, + dropout_rate=dropout_rate, + ) + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + return summed_loss, (n_valid_examples, new_model_state) + + grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) + (summed_loss, (n_valid_examples, new_model_state)), grad = grad_fn( + current_param_container) + # Get correct global mean loss and grad. + loss = summed_loss / n_valid_examples + grad = jax.tree.map(lambda x: x / n_valid_examples, grad) + + grad_norm = jnp.sqrt( + sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(grad))) + + if grad_clip is not None: + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0 x=grad_scaling_factor, max=1.0) + grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) + + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, + current_param_container) + updated_params = optax.apply_updates(current_param_container, updates) + return new_optimizer_state, updated_params, new_model_state, loss, grad_norm - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container - ) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params, new_model_state, loss, grad_norm def update_params( - workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None, - ) -> spec.UpdateReturn: - del current_params_types - del loss_type - del train_state - del eval_results - del hyperparameters + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state, opt_update_fn = optimizer_state + per_device_rngs = jax.random.split(rng, jax.local_device_count()) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters['label_smoothing'] + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters['grad_clip'] + else: + grad_clip = None + dropout_rate = hyperparameters.dropout_rate + + mesh = jax.sharding.Mesh(jax.devices(), ('batch')) + replicated = jax_sharding_utils.get_replicate_sharding( + mesh + ) + sharded = jax_sharding_utils.get_batch_sharding( + mesh + ) + arg_shardings = ( + replicated, #model_state + replicated, #optimizer_state # change to optimizer sharding eventually + replicated, # current_param_container + sharded, # batch + replicated, # per_device_rngs + replicated, # grad_clip + replicated, #label_smoothing + replicated, #dropout_rate + ) + out_shardings = ( + replicated, # new_optimizer_state # maybe sharded eventually + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated, # grad_norm + ) + jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings, + ) + outputs = jitted_train_step(workload, + opt_update_fn, + model_state, + optimizer_state, + current_param_container, + batch, + per_device_rngs, + grad_clip, + label_smoothing, + dropout_rate, + ) + new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs + + # Log loss, grad_norm. + if global_step % 100 == 0 and workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss[0], + 'grad_norm': grad_norm[0], + }, global_step) + return (new_optimizer_state, opt_update_fn), new_params, new_model_state - hyperparameters = HPARAMS - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - if hasattr(hyperparameters, 'label_smoothing'): - label_smoothing = hyperparameters['label_smoothing'] - else: - label_smoothing = 0.0 - if hasattr(hyperparameters, 'grad_clip'): - grad_clip = hyperparameters['grad_clip'] - else: - grad_clip = None - outputs = pmapped_train_step( - workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - per_device_rngs, - grad_clip, - label_smoothing, - ) - new_optimizer_state, new_params, new_model_state, loss, grad_norm = outputs - - if global_step % 100 == 0 and workload.metrics_logger is not None: - workload.metrics_logger.append_scalar_metrics( - { - 'loss': loss[0], - 'grad_norm': grad_norm[0], - }, - global_step - ) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state def prepare_for_eval( workload: spec.Workload, From d44b1dec41b158d9810d8f0b88fc579419247126 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 24 Sep 2025 16:52:59 -0400 Subject: [PATCH 14/33] fixed typo --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index c7299cf03..d3261dba1 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -268,7 +268,7 @@ def _loss_fn(params): if grad_clip is not None: grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) - grad_scaling_factor = jax.lax.clamp(min=0.0 x=grad_scaling_factor, max=1.0) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree.map(lambda x: x * grad_scaling_factor, grad) updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, From 2bbaf5666f247f9c3d0b2850485488df60226195 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 24 Sep 2025 16:55:04 -0400 Subject: [PATCH 15/33] added dropout --- submissions/self_tuning/ademamix/submission.py | 1 + 1 file changed, 1 insertion(+) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index d3261dba1..10405353c 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -45,6 +45,7 @@ 'eps': 1e-8, 'eps_root': 0.0, 'weight_decay': 0.01, + 'dropout_rate': 0.1, } _GRAD_CLIP_EPS = 1e-6 From b5d98c1a2dd84c01618d054ca5f9a9f71afafacf Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 24 Sep 2025 16:57:06 -0400 Subject: [PATCH 16/33] fixed typo --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 10405353c..9a7f5a2b8 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -310,7 +310,7 @@ def update_params( grad_clip = hyperparameters['grad_clip'] else: grad_clip = None - dropout_rate = hyperparameters.dropout_rate + dropout_rate = hyperparameters['dropout_rate'] mesh = jax.sharding.Mesh(jax.devices(), ('batch')) replicated = jax_sharding_utils.get_replicate_sharding( From 8eac5195e2eaa319d1c46390e7f45cdfbecf78a7 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 24 Sep 2025 16:59:19 -0400 Subject: [PATCH 17/33] added import jax_sharding_utils --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 9a7f5a2b8..d68a5374f 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -29,7 +29,7 @@ from optax._src import transform, combine, base, numerics, utils from optax import tree_utils as otu -from algoperf import spec +from algoperf import spec, jax_sharding_utils HPARAMS = { From 18dceb9c3df01ad909cc41e2c0d190964edd92a9 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Wed, 24 Sep 2025 17:02:05 -0400 Subject: [PATCH 18/33] debugging sharding for jit --- submissions/self_tuning/ademamix/submission.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index d68a5374f..5888c3320 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -313,9 +313,7 @@ def update_params( dropout_rate = hyperparameters['dropout_rate'] mesh = jax.sharding.Mesh(jax.devices(), ('batch')) - replicated = jax_sharding_utils.get_replicate_sharding( - mesh - ) + replicated = jax_sharding_utils.get_replicate_sharding() sharded = jax_sharding_utils.get_batch_sharding( mesh ) From a624df79d9eb3508c262cb3ffd179880d1f8f7be Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 25 Sep 2025 09:55:04 -0400 Subject: [PATCH 19/33] typo in batch sharding --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 5888c3320..2d15346da 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -314,7 +314,7 @@ def update_params( mesh = jax.sharding.Mesh(jax.devices(), ('batch')) replicated = jax_sharding_utils.get_replicate_sharding() - sharded = jax_sharding_utils.get_batch_sharding( + sharded = jax_sharding_utils.get_batch_dim_sharding( mesh ) arg_shardings = ( From c1c9c26707b40c08534d16d8d393313684c17b34 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 25 Sep 2025 09:58:40 -0400 Subject: [PATCH 20/33] changed args to get_batch_dim_sharding --- submissions/self_tuning/ademamix/submission.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 2d15346da..78737ee4b 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -312,11 +312,9 @@ def update_params( grad_clip = None dropout_rate = hyperparameters['dropout_rate'] - mesh = jax.sharding.Mesh(jax.devices(), ('batch')) + # mesh = jax.sharding.Mesh(jax.devices(), ('batch')) replicated = jax_sharding_utils.get_replicate_sharding() - sharded = jax_sharding_utils.get_batch_dim_sharding( - mesh - ) + sharded = jax_sharding_utils.get_batch_dim_sharding() arg_shardings = ( replicated, #model_state replicated, #optimizer_state # change to optimizer sharding eventually From 72f317d6f96fe92dc838d1ab8ac3385719e05eeb Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 16 Oct 2025 14:22:49 -0400 Subject: [PATCH 21/33] Matched submission to nadamw style --- .../self_tuning/ademamix/submission.py | 130 ++++++------------ 1 file changed, 40 insertions(+), 90 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 78737ee4b..1e8e8c371 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -51,12 +51,10 @@ _GRAD_CLIP_EPS = 1e-6 def alpha_scheduler(alpha, alpha_start=0, warmup=0): - def schedule(step): - is_warmup = (step < warmup).astype(jnp.float32) - a = step / float(warmup) - return is_warmup * ((1.0-a) * alpha_start + a * alpha) + alpha * (1.0-is_warmup) - - return schedule + warmup_fn = optax.linear_schedule(init_value=alpha_start, end_value=alpha, transition_steps=warmup) + constant_fn = optax.constant_schedule(alpha) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup]) + return schedule_fn def beta3_scheduler(beta_end, beta_start=0, warmup=0): @@ -67,21 +65,22 @@ def f(beta): def f_inv(t): return jnp.power(0.5, 1/(t+1)) - def schedule(step): - is_warmup = (step < warmup).astype(jnp.float32) - alpha = step / float(warmup) - return is_warmup * f_inv((1.0-alpha) * f(beta_start) + alpha * f(beta_end)) + beta_end * (1.0-is_warmup) + def warmup_fn(step): + frac = 1 - step / warmup + return f_inv( frac * f(beta_start) + (1 - frac) * f(beta_end)) - return schedule + constant_fn = optax.constant_schedule(beta_end) + schedule_fn = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup]) + return schedule_fn class ScaleByAdemamixState(NamedTuple): """State for the AdEMAMix algorithm.""" count: chex.Array count_m2: chex.Array - m1: base.Updates - m2: base.Updates - nu: base.Updates + m1: optax.Updates + m2: optax.Updates + nu: optax.Updates def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None, @@ -128,84 +127,41 @@ def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alph ) -def scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps=1e-8, eps_root=0.0, mu_dtype=None): - - mu_dtype = utils.canonicalize_dtype(mu_dtype) +def scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps=1e-8, eps_root=0.0): def init_fn(params): - m1 = tree_zeros_like(params, dtype=mu_dtype) # fast EMA - m2 = tree_zeros_like(params, dtype=mu_dtype) # slow EMA - nu = tree_zeros_like(params, dtype=mu_dtype) # second moment estimate + m1 = jax.tree.map(jnp.zeros_like, params) # fast EMA + m2 = jax.tree.map(jnp.zeros_like, params) # slow EMA + nu = jax.tree.map(jnp.zeros_like, params) # second moment estimate return ScaleByAdemamixState(count=jnp.zeros([], jnp.int32), count_m2=jnp.zeros([], jnp.int32), m1=m1, m2=m2, nu=nu) def update_fn(updates, state, params=None): del params c_b3 = b3_scheduler(state.count_m2) if b3_scheduler is not None else b3 c_alpha = alpha_scheduler(state.count_m2) if alpha_scheduler is not None else alpha - m1 = tree_update_moment(updates, state.m1, b1, 1) # m1 = b1 * m1 + (1-b1) * updates - m2 = tree_update_moment(updates, state.m2, c_b3, 1) - nu = tree_update_moment_per_elem_norm(updates, state.nu, b2, 2) - count_inc = numerics.safe_int32_increment(state.count) - count_m2_inc = numerics.safe_int32_increment(state.count_m2) - m1_hat = tree_bias_correction(m1, b1, count_inc) - nu_hat = tree_bias_correction(nu, b2, count_inc) - updates = jtu.tree_map(lambda m1_, m2_, v_: (m1_+c_alpha*m2_)/(jnp.sqrt(v_+eps_root)+eps), m1_hat, m2, nu_hat) - mu1 = tree_cast(m1, mu_dtype) - mu2 = tree_cast(m2, mu_dtype) - return updates, ScaleByAdemamixState(count=count_inc, count_m2=count_m2_inc, m1=m1, m2=m2, nu=nu) + m1 = _update_moment(updates, state.m1, b1, 1) # m1 = b1 * m1 + (1-b1) * updates + m2 = _update_moment(updates, state.m2, c_b3, 1) + nu = _update_moment(updates, state.nu, b2, 2) + count = state.count + jnp.array(1, dtype=jnp.int32) + # count_inc = numerics.safe_int32_increment(state.count) + count_m2 = state.count_m2 + jnp.array(1, dtype=jnp.int32) + # count_m2_inc = numerics.safe_int32_increment(state.count_m2) + m1_hat = _bias_correction(m1, b1, count) + nu_hat = _bias_correction(nu, b2, count) + updates = jax.tree.map(lambda m1_, m2_, v_: (m1_+c_alpha*m2_)/(jnp.sqrt(v_+eps_root)+eps), m1_hat, m2, nu_hat) + return updates, ScaleByAdemamixState(count=count, count_m2=count_m2, m1=m1, m2=m2, nu=nu) return base.GradientTransformation(init_fn, update_fn) -def tree_cast(tree, dtype): - """Cast tree to given dtype, skip if None.""" - if dtype is not None: - return jtu.tree_map(lambda t: t.astype(dtype), tree) - else: - return tree - - -def tree_zeros_like( - tree, - dtype = None, -): - """Creates an all-zeros tree with the same structure. - - Args: - tree: pytree. - dtype: optional dtype to use for the tree of zeros. - - Returns: - an all-zeros tree with the same structure as ``tree``. - """ - return jtu.tree_map(lambda x: jnp.zeros_like(x, dtype=dtype), tree) - - -def tree_update_moment(updates, moments, decay, order): +def _update_moment(updates, moments, decay, order): """Compute the exponential moving average of the `order`-th moment.""" - return jtu.tree_map( + return jax.tree.map( lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) -def tree_update_moment_per_elem_norm(updates, moments, decay, order): - """Compute the EMA of the `order`-th moment of the element-wise norm.""" - - def orderth_norm(g): - if jnp.isrealobj(g): - return g ** order - else: - half_order = order / 2 - # JAX generates different HLO for int and float `order` - if half_order.is_integer(): - half_order = int(half_order) - return numerics.abs_sq(g) ** half_order - - return jtu.tree_map( - lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments) - -@functools.partial(jax.jit, inline=True) -def tree_bias_correction(moment, decay, count): +def _bias_correction(moment, decay, count): """Performs bias correction. It becomes a no-op as count goes to infinity.""" # The conversion to the data type of the moment ensures that bfloat16 remains # bfloat16 in the optimizer state. This conversion has to be done after @@ -215,16 +171,9 @@ def tree_bias_correction(moment, decay, count): bias_correction_ = 1 - decay**count # Perform division in the original precision. - return jax.tree_util.tree_map( + return jax.tree.map( lambda t: t / bias_correction_.astype(t.dtype), moment) -#@functools.partial( -# jax.pmap, -# axis_name='batch', -# in_axes=(None, None, 0, 0, 0, 0, 0, None, None), -# static_broadcasted_argnums=(0, 1), -# donate_argnums=(2, 3, 4) -# ) def train_step(workload, opt_update_fn, model_state, @@ -301,7 +250,6 @@ def update_params( hyperparameters = HPARAMS optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters['label_smoothing'] else: @@ -314,7 +262,9 @@ def update_params( # mesh = jax.sharding.Mesh(jax.devices(), ('batch')) replicated = jax_sharding_utils.get_replicate_sharding() - sharded = jax_sharding_utils.get_batch_dim_sharding() + sharded = ( + jax_sharding_utils.get_batch_dim_sharding() + ) arg_shardings = ( replicated, #model_state replicated, #optimizer_state # change to optimizer sharding eventually @@ -345,7 +295,7 @@ def update_params( optimizer_state, current_param_container, batch, - per_device_rngs, + rng, grad_clip, label_smoothing, dropout_rate, @@ -356,8 +306,8 @@ def update_params( if global_step % 100 == 0 and workload.metrics_logger is not None: workload.metrics_logger.append_scalar_metrics( { - 'loss': loss[0], - 'grad_norm': grad_norm[0], + 'loss': loss, + 'grad_norm': grad_norm, }, global_step) return (new_optimizer_state, opt_update_fn), new_params, new_model_state @@ -408,7 +358,7 @@ def get_batch_size(workload_name): elif workload_name == 'mnist': return 16 elif workload_name == 'cifar': - return 2048 + return 128 else: raise ValueError(f'Unsupported workload name: {workload_name}.') @@ -464,7 +414,7 @@ def init_optimizer_state( weight_decay=weight_decay ) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return optimizer_state, opt_update_fn if __name__ == "__main__": # dummy test From 2a220d6c2e9d2cb5a76ab15587a78535ae42384e Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 23 Oct 2025 12:18:25 -0400 Subject: [PATCH 22/33] Trying to shard optimizer state --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 1e8e8c371..93a85f368 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -267,7 +267,7 @@ def update_params( ) arg_shardings = ( replicated, #model_state - replicated, #optimizer_state # change to optimizer sharding eventually + sharded, #optimizer_state # change to optimizer sharding eventually replicated, # current_param_container sharded, # batch replicated, # per_device_rngs From a43e65ae8fa047e724017fa2dabf3ee212c59eaa Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 23 Oct 2025 12:30:19 -0400 Subject: [PATCH 23/33] typo --- submissions/self_tuning/ademamix/submission.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 93a85f368..720ba8a6d 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -84,7 +84,7 @@ class ScaleByAdemamixState(NamedTuple): def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None, - eps=1e-8, eps_root=0.0, weight_decay=0.0, mu_dtype=None, mask=None): + eps=1e-8, eps_root=0.0, weight_decay=0.0, mask=None): r"""AdEMAMix. Args: @@ -121,7 +121,7 @@ def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alph The corresponding `GradientTransformation`. """ return combine.chain( - scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps, eps_root, mu_dtype), + scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps, eps_root), transform.add_decayed_weights(weight_decay, mask), transform.scale_by_learning_rate(lr), ) From 6b89aa82b62008c192cd2f816b96c61fcb557b41 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 6 Nov 2025 12:50:55 -0500 Subject: [PATCH 24/33] debugging optimizer state sharding --- .../self_tuning/ademamix/submission.py | 80 ++++++++++++++++++- 1 file changed, 78 insertions(+), 2 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 720ba8a6d..116d3f4d8 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -24,6 +24,7 @@ import jax.numpy as jnp import optax from flax import jax_utils +from flax import traverse_util as tu from jax import lax from optax._src import transform, combine, base, numerics, utils @@ -50,6 +51,38 @@ _GRAD_CLIP_EPS = 1e-6 +def _path_matches_embedding(path_segments: tuple) -> bool: + return any(("embedding_table" in s) or ("embedding" in s) for s in path_segments) + +def build_embedding_name_mask(params_tree): + flat = tu.flatten_dict(params_tree, keep_empty_nodes=True) + mask_flat = {} + for path, leaf in flat.items(): + mask_flat[path] = _path_matches_embedding(path) + return tu.unflatten_dict(mask_flat) + +def _choose_sharding(mask_tree, target_tree, sharded, replicated): + return jax.tree.map( + lambda m, _: sharded if m else replicated, mask_tree, target_tree) + +def create_ademamix_sharding_from_names( + optimizer_state: ScaleByAdemamixState, + params_tree, + replicated, + sharded + ) -> ScaleByAdemamixState: + embed_mask = build_embedding_name_mask(params_tree) + m1_sharding = _choose_sharding(embed_mask, optimizer_state.m1, sharded, replicated) + m2_sharding = _choose_sharding(embed_mask, optimizer_state.m2, sharded, replicated) + nu_sharding = _choose_sharding(embed_mask, optimizer_state.nu, sharded, replicated) + return ScaleByAdemamixState( + count=replicated, + count_m2=replicated, + m1=m1_sharding, + m2=m2_sharding, + nu=nu_sharding + ) + def alpha_scheduler(alpha, alpha_start=0, warmup=0): warmup_fn = optax.linear_schedule(init_value=alpha_start, end_value=alpha, transition_steps=warmup) constant_fn = optax.constant_schedule(alpha) @@ -174,6 +207,43 @@ def _bias_correction(moment, decay, count): return jax.tree.map( lambda t: t / bias_correction_.astype(t.dtype), moment) + def create_optimizer_sharding(optimizer_state, replicated, sharded): + """ + Create sharding spec for optimizer + + Args: + optimizer_state: The optimizer state structure + replicated: Sharding spec for replicated data + sharded: Sharding spec for batch sharded data + + Returns: + Sharding spec sharding rng key across batches and replicating + all other optimizer variables + """ + def shard_optimizer_component(state_component): + if isinstance(state_component, ScaleByLowRankOrthogonalUpdateState): + return ScaleByLowRankOrthogonalUpdateState( + step=replicated, + shape_info=replicated, + momentum=jax.tree.map(lambda _: replicated, state_component.momentum), + key=jax.tree.map(lambda _: sharded, state_component.key), + ) + else: + return jax.tree.map(lambda _: replicated, state_component) + + return jax.tree.map( + shard_optimizer_component, + optimizer_state, + is_leaf=lambda x: ( + isinstance(x, ScaleByLowRankOrthogonalUpdateState) or + ( + hasattr(x, '_fields') and + not isinstance(x, ScaleByLowRankOrthogonalUpdateState) + ) + ) + ) + + def train_step(workload, opt_update_fn, model_state, @@ -265,9 +335,15 @@ def update_params( sharded = ( jax_sharding_utils.get_batch_dim_sharding() ) + optimizer_state_sharding = create_ademamix_sharding_from_names( + optimizer_state=optimizer_state, + params_tree=current_param_container, + replicated=replicated, + sharded=sharded + ) arg_shardings = ( replicated, #model_state - sharded, #optimizer_state # change to optimizer sharding eventually + optimizer_state_sharding, #optimizer_state # change to optimizer sharding eventually replicated, # current_param_container sharded, # batch replicated, # per_device_rngs @@ -276,7 +352,7 @@ def update_params( replicated, #dropout_rate ) out_shardings = ( - replicated, # new_optimizer_state # maybe sharded eventually + optimizer_state_sharding, # new_optimizer_state # maybe sharded eventually replicated, # updated_params replicated, # new_model_state replicated, # loss From 1fed8651aab2cfe5f7d2f7aafef9046b59ca0e93 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 6 Nov 2025 13:07:15 -0500 Subject: [PATCH 25/33] wrong order for ScaleByAdemamixState --- .../self_tuning/ademamix/submission.py | 63 ++++++++++--------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 116d3f4d8..58fc84a0e 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -51,37 +51,6 @@ _GRAD_CLIP_EPS = 1e-6 -def _path_matches_embedding(path_segments: tuple) -> bool: - return any(("embedding_table" in s) or ("embedding" in s) for s in path_segments) - -def build_embedding_name_mask(params_tree): - flat = tu.flatten_dict(params_tree, keep_empty_nodes=True) - mask_flat = {} - for path, leaf in flat.items(): - mask_flat[path] = _path_matches_embedding(path) - return tu.unflatten_dict(mask_flat) - -def _choose_sharding(mask_tree, target_tree, sharded, replicated): - return jax.tree.map( - lambda m, _: sharded if m else replicated, mask_tree, target_tree) - -def create_ademamix_sharding_from_names( - optimizer_state: ScaleByAdemamixState, - params_tree, - replicated, - sharded - ) -> ScaleByAdemamixState: - embed_mask = build_embedding_name_mask(params_tree) - m1_sharding = _choose_sharding(embed_mask, optimizer_state.m1, sharded, replicated) - m2_sharding = _choose_sharding(embed_mask, optimizer_state.m2, sharded, replicated) - nu_sharding = _choose_sharding(embed_mask, optimizer_state.nu, sharded, replicated) - return ScaleByAdemamixState( - count=replicated, - count_m2=replicated, - m1=m1_sharding, - m2=m2_sharding, - nu=nu_sharding - ) def alpha_scheduler(alpha, alpha_start=0, warmup=0): warmup_fn = optax.linear_schedule(init_value=alpha_start, end_value=alpha, transition_steps=warmup) @@ -115,6 +84,38 @@ class ScaleByAdemamixState(NamedTuple): m2: optax.Updates nu: optax.Updates +def _path_matches_embedding(path_segments: tuple) -> bool: + return any(("embedding_table" in s) or ("embedding" in s) for s in path_segments) + +def build_embedding_name_mask(params_tree): + flat = tu.flatten_dict(params_tree, keep_empty_nodes=True) + mask_flat = {} + for path, leaf in flat.items(): + mask_flat[path] = _path_matches_embedding(path) + return tu.unflatten_dict(mask_flat) + +def _choose_sharding(mask_tree, target_tree, sharded, replicated): + return jax.tree.map( + lambda m, _: sharded if m else replicated, mask_tree, target_tree) + +def create_ademamix_sharding_from_names( + optimizer_state: ScaleByAdemamixState, + params_tree, + replicated, + sharded + ) -> ScaleByAdemamixState: + embed_mask = build_embedding_name_mask(params_tree) + m1_sharding = _choose_sharding(embed_mask, optimizer_state.m1, sharded, replicated) + m2_sharding = _choose_sharding(embed_mask, optimizer_state.m2, sharded, replicated) + nu_sharding = _choose_sharding(embed_mask, optimizer_state.nu, sharded, replicated) + return ScaleByAdemamixState( + count=replicated, + count_m2=replicated, + m1=m1_sharding, + m2=m2_sharding, + nu=nu_sharding + ) + def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None, eps=1e-8, eps_root=0.0, weight_decay=0.0, mask=None): From 5ac23b56bdd0d78dbef086d37870111031d467f9 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 6 Nov 2025 13:19:28 -0500 Subject: [PATCH 26/33] error in sharding of optimizer --- .../self_tuning/ademamix/submission.py | 54 ++++++------------- 1 file changed, 17 insertions(+), 37 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 58fc84a0e..e25ed20dd 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -116,6 +116,22 @@ def create_ademamix_sharding_from_names( nu=nu_sharding ) +def create_full_optimizer_sharding_from_names( + optimizer_chain_state, + params_tree, + replicated, + sharded + ): + if isinstance(optimizer_chain_state, tuple) and len(optimizer_chain_state) >= 1: + ademamix_state = optimizer_chain_state[0] + ademamix_sharding = _ademamix_component_sharding(ademamix_state, params_tree, replicated, sharded) + rest_shardings = tuple( + jax.tree.map(lambda _: replicated, s) for s in optimizer_chain_state[1:] + ) + return (ademamix_sharding, *rest_shardings) + else: + return _ademamix_component_sharding(optimizer_chain_state, params_tree, replicated, sharded) + def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None, eps=1e-8, eps_root=0.0, weight_decay=0.0, mask=None): @@ -208,42 +224,6 @@ def _bias_correction(moment, decay, count): return jax.tree.map( lambda t: t / bias_correction_.astype(t.dtype), moment) - def create_optimizer_sharding(optimizer_state, replicated, sharded): - """ - Create sharding spec for optimizer - - Args: - optimizer_state: The optimizer state structure - replicated: Sharding spec for replicated data - sharded: Sharding spec for batch sharded data - - Returns: - Sharding spec sharding rng key across batches and replicating - all other optimizer variables - """ - def shard_optimizer_component(state_component): - if isinstance(state_component, ScaleByLowRankOrthogonalUpdateState): - return ScaleByLowRankOrthogonalUpdateState( - step=replicated, - shape_info=replicated, - momentum=jax.tree.map(lambda _: replicated, state_component.momentum), - key=jax.tree.map(lambda _: sharded, state_component.key), - ) - else: - return jax.tree.map(lambda _: replicated, state_component) - - return jax.tree.map( - shard_optimizer_component, - optimizer_state, - is_leaf=lambda x: ( - isinstance(x, ScaleByLowRankOrthogonalUpdateState) or - ( - hasattr(x, '_fields') and - not isinstance(x, ScaleByLowRankOrthogonalUpdateState) - ) - ) - ) - def train_step(workload, opt_update_fn, @@ -336,7 +316,7 @@ def update_params( sharded = ( jax_sharding_utils.get_batch_dim_sharding() ) - optimizer_state_sharding = create_ademamix_sharding_from_names( + optimizer_state_sharding = create_full_optimizer_sharding_from_names( optimizer_state=optimizer_state, params_tree=current_param_container, replicated=replicated, From 8c1e2015623dcf382c765411a07439c55c4d9b88 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 6 Nov 2025 13:21:03 -0500 Subject: [PATCH 27/33] error in sharding --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index e25ed20dd..bb7ff1596 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -317,7 +317,7 @@ def update_params( jax_sharding_utils.get_batch_dim_sharding() ) optimizer_state_sharding = create_full_optimizer_sharding_from_names( - optimizer_state=optimizer_state, + optimizer_chain_state=optimizer_state, params_tree=current_param_container, replicated=replicated, sharded=sharded From a04a81f2f1e574cc05b68d019fff37a2a01c813c Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 6 Nov 2025 13:27:35 -0500 Subject: [PATCH 28/33] fixed name of ademamix sharding --- submissions/self_tuning/ademamix/submission.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index bb7ff1596..86fd2622d 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -124,13 +124,13 @@ def create_full_optimizer_sharding_from_names( ): if isinstance(optimizer_chain_state, tuple) and len(optimizer_chain_state) >= 1: ademamix_state = optimizer_chain_state[0] - ademamix_sharding = _ademamix_component_sharding(ademamix_state, params_tree, replicated, sharded) + ademamix_sharding = create_ademamix_sharding_from_names(ademamix_state, params_tree, replicated, sharded) rest_shardings = tuple( jax.tree.map(lambda _: replicated, s) for s in optimizer_chain_state[1:] ) return (ademamix_sharding, *rest_shardings) else: - return _ademamix_component_sharding(optimizer_chain_state, params_tree, replicated, sharded) + return create_ademamix_sharding_from_names(optimizer_chain_state, params_tree, replicated, sharded) def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None, From 9503ddee57f947f19592037afa2b7ea2bf7f6e84 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 12 Feb 2026 10:19:29 -0500 Subject: [PATCH 29/33] removed sharding to test on a100s --- .../self_tuning/ademamix/submission.py | 28 +++++++------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 86fd2622d..f1de266b6 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -313,27 +313,19 @@ def update_params( # mesh = jax.sharding.Mesh(jax.devices(), ('batch')) replicated = jax_sharding_utils.get_replicate_sharding() - sharded = ( - jax_sharding_utils.get_batch_dim_sharding() - ) - optimizer_state_sharding = create_full_optimizer_sharding_from_names( - optimizer_chain_state=optimizer_state, - params_tree=current_param_container, - replicated=replicated, - sharded=sharded - ) + sharded = jax_sharding_utils.get_batch_dim_sharding() arg_shardings = ( - replicated, #model_state - optimizer_state_sharding, #optimizer_state # change to optimizer sharding eventually - replicated, # current_param_container - sharded, # batch - replicated, # per_device_rngs - replicated, # grad_clip - replicated, #label_smoothing - replicated, #dropout_rate + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # per_device_rngs + replicated, # grad_clip + replicated, # label_smoothing + replicated, # dropout_rate ) out_shardings = ( - optimizer_state_sharding, # new_optimizer_state # maybe sharded eventually + replicated, # new_optimizer_state replicated, # updated_params replicated, # new_model_state replicated, # loss From 484890899624bdd03de5563d6adaf2be2c982c5a Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Mon, 23 Feb 2026 21:43:36 -0500 Subject: [PATCH 30/33] updated ademamix to add learning rate scheduler, move jax.jit outside update params? --- .../self_tuning/ademamix/submission.py | 113 ++++++------------ 1 file changed, 38 insertions(+), 75 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index f1de266b6..78d2b10c2 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -51,6 +51,14 @@ _GRAD_CLIP_EPS = 1e-6 +def lr_scheduler(learning_rate, warmup_steps, total_steps): + return optax.warmup_cosine_decay_schedule( + init_value=0.0, + peak_value=learning_rate, + warmup_steps=warmup_steps, + decay_steps=total_steps, + end_value=learning_rate * 0.01 + ) def alpha_scheduler(alpha, alpha_start=0, warmup=0): warmup_fn = optax.linear_schedule(init_value=alpha_start, end_value=alpha, transition_steps=warmup) @@ -84,54 +92,6 @@ class ScaleByAdemamixState(NamedTuple): m2: optax.Updates nu: optax.Updates -def _path_matches_embedding(path_segments: tuple) -> bool: - return any(("embedding_table" in s) or ("embedding" in s) for s in path_segments) - -def build_embedding_name_mask(params_tree): - flat = tu.flatten_dict(params_tree, keep_empty_nodes=True) - mask_flat = {} - for path, leaf in flat.items(): - mask_flat[path] = _path_matches_embedding(path) - return tu.unflatten_dict(mask_flat) - -def _choose_sharding(mask_tree, target_tree, sharded, replicated): - return jax.tree.map( - lambda m, _: sharded if m else replicated, mask_tree, target_tree) - -def create_ademamix_sharding_from_names( - optimizer_state: ScaleByAdemamixState, - params_tree, - replicated, - sharded - ) -> ScaleByAdemamixState: - embed_mask = build_embedding_name_mask(params_tree) - m1_sharding = _choose_sharding(embed_mask, optimizer_state.m1, sharded, replicated) - m2_sharding = _choose_sharding(embed_mask, optimizer_state.m2, sharded, replicated) - nu_sharding = _choose_sharding(embed_mask, optimizer_state.nu, sharded, replicated) - return ScaleByAdemamixState( - count=replicated, - count_m2=replicated, - m1=m1_sharding, - m2=m2_sharding, - nu=nu_sharding - ) - -def create_full_optimizer_sharding_from_names( - optimizer_chain_state, - params_tree, - replicated, - sharded - ): - if isinstance(optimizer_chain_state, tuple) and len(optimizer_chain_state) >= 1: - ademamix_state = optimizer_chain_state[0] - ademamix_sharding = create_ademamix_sharding_from_names(ademamix_state, params_tree, replicated, sharded) - rest_shardings = tuple( - jax.tree.map(lambda _: replicated, s) for s in optimizer_chain_state[1:] - ) - return (ademamix_sharding, *rest_shardings) - else: - return create_ademamix_sharding_from_names(optimizer_chain_state, params_tree, replicated, sharded) - def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None, eps=1e-8, eps_root=0.0, weight_decay=0.0, mask=None): @@ -277,6 +237,33 @@ def _loss_fn(params): updated_params = optax.apply_updates(current_param_container, updates) return new_optimizer_state, updated_params, new_model_state, loss, grad_norm + +replicated = jax_sharding_utils.get_replicate_sharding() +sharded = jax_sharding_utils.get_batch_dim_sharding() +arg_shardings = ( + replicated, # model_state + replicated, # optimizer_state + replicated, # current_param_container + sharded, # batch + replicated, # per_device_rngs + replicated, # grad_clip + replicated, # label_smoothing + replicated, # dropout_rate + ) +out_shardings = ( + replicated, # new_optimizer_state + replicated, # updated_params + replicated, # new_model_state + replicated, # loss + replicated, # grad_norm + ) +jitted_train_step = jax.jit( + train_step, + static_argnums=(0, 1), + donate_argnums=(2, 3, 4), + in_shardings=arg_shardings, + out_shardings=out_shardings, + ) def update_params( workload: spec.Workload, @@ -312,32 +299,6 @@ def update_params( dropout_rate = hyperparameters['dropout_rate'] # mesh = jax.sharding.Mesh(jax.devices(), ('batch')) - replicated = jax_sharding_utils.get_replicate_sharding() - sharded = jax_sharding_utils.get_batch_dim_sharding() - arg_shardings = ( - replicated, # model_state - replicated, # optimizer_state - replicated, # current_param_container - sharded, # batch - replicated, # per_device_rngs - replicated, # grad_clip - replicated, # label_smoothing - replicated, # dropout_rate - ) - out_shardings = ( - replicated, # new_optimizer_state - replicated, # updated_params - replicated, # new_model_state - replicated, # loss - replicated, # grad_norm - ) - jitted_train_step = jax.jit( - train_step, - static_argnums=(0, 1), - donate_argnums=(2, 3, 4), - in_shardings=arg_shardings, - out_shardings=out_shardings, - ) outputs = jitted_train_step(workload, opt_update_fn, model_state, @@ -449,11 +410,13 @@ def init_optimizer_state( b2 = HPARAMS['b2'] b3 = HPARAMS['b3'] alpha = HPARAMS['alpha'] + warmup = HPARAMS['warmup'] T = workload.step_hint f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=T) f_a = alpha_scheduler(alpha, alpha_start=0, warmup=T) + f_lr = learning_rate_scheduler(lr, warmup, T) weight_decay = HPARAMS['weight_decay'] - opt_init_fn, opt_update_fn = ademamix(lr=lr, + opt_init_fn, opt_update_fn = ademamix(lr=f_lr, b1=b1, b2=b2, b3=b3, From 4a314de6762be4683403f1d3647e8b64bc870f43 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Thu, 5 Mar 2026 11:47:26 -0500 Subject: [PATCH 31/33] typo --- submissions/self_tuning/ademamix/submission.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 78d2b10c2..1374439b8 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -414,7 +414,7 @@ def init_optimizer_state( T = workload.step_hint f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=T) f_a = alpha_scheduler(alpha, alpha_start=0, warmup=T) - f_lr = learning_rate_scheduler(lr, warmup, T) + f_lr = lr_scheduler(lr, warmup, T) weight_decay = HPARAMS['weight_decay'] opt_init_fn, opt_update_fn = ademamix(lr=f_lr, b1=b1, From 532e5f0ac26c9b81bb894a9818496159e9a7b07b Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Fri, 10 Apr 2026 10:41:26 -0400 Subject: [PATCH 32/33] switched to simplified ademamix from optax --- .../self_tuning/ademamix/submission.py | 199 ++++++------------ 1 file changed, 65 insertions(+), 134 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 1374439b8..7c6139676 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -1,39 +1,24 @@ -""" -Forked from apple's ademamix jax implementation: -https://github.com/apple/ml-ademamix -for the purposes of submitting to the algoperf benchmark. -| Adapted from optax's implementation of AdamW: -| https://github.com/google-deepmind/optax/blob/b75644809f2f68fc11f42d4395a5753e11e92e80/optax/_src/alias.py#L548#L675 -""" +"""AlgoPerf AdEMAMix submission built on Optax.""" import functools from typing import ( Any, - Callable, Dict, Iterator, List, - NamedTuple, Optional, Tuple, - Union, ) -import chex import jax -from jax import tree_util as jtu import jax.numpy as jnp import optax from flax import jax_utils -from flax import traverse_util as tu -from jax import lax - -from optax._src import transform, combine, base, numerics, utils -from optax import tree_utils as otu from algoperf import spec, jax_sharding_utils HPARAMS = { + 'ademamix_variant': 'simplified', 'alpha': 8.0, 'alpha_start': 0, 'warmup': 10, @@ -83,106 +68,44 @@ def warmup_fn(step): schedule_fn = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup]) return schedule_fn - -class ScaleByAdemamixState(NamedTuple): - """State for the AdEMAMix algorithm.""" - count: chex.Array - count_m2: chex.Array - m1: optax.Updates - m2: optax.Updates - nu: optax.Updates - - -def ademamix(lr, b1=0.9, b2=0.999, b3=0.9999, alpha=5.0, b3_scheduler=None, alpha_scheduler=None, - eps=1e-8, eps_root=0.0, weight_decay=0.0, mask=None): - r"""AdEMAMix. - - Args: - lr: A global scaling factor, either fixed or evolving along - iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. - b1: Exponential decay rate to track the fast EMA. - b2: Exponential decay rate to track the second moment of past gradients. - b3: Exponential decay rate to track the slow EMA. - alpha: Mixing coeficient use for the linear combination of the fast and slow EMAs. - b3_scheduler: an optional scheduler function, given a timestep, returns the - value of b3. Use `beta3_scheduler(b3,b1,T_b3)` to follow the AdEMAMix paper. - alpha_scheduler: an optional scheduler function, given a timestep, returns the - value of alpha. Use `alpha_scheduler(alpha,0,T_alpha)` to follow the - AdEMAMix paper. - eps: A small constant applied to denominator outside of the square root - (as in the Adam paper) to avoid dividing by zero when rescaling. - eps_root: A small constant applied to denominator inside the square root (as - in RMSProp), to avoid dividing by zero when rescaling. This is needed for - instance when computing (meta-)gradients through Adam. - mu_dtype: Optional `dtype` to be used for the first order accumulator; if - `None` then the `dtype` is inferred from `params` and `updates`. - weight_decay: Strength of the weight decay regularization. Note that this - weight decay is multiplied with the learning rate. This is consistent - with other frameworks such as PyTorch, but different from - (Loshchilov et al, 2019) where the weight decay is only multiplied with - the "schedule multiplier", but not the base learning rate. - mask: A tree with same structure as (or a prefix of) the params PyTree, - or a Callable that returns such a pytree given the params/updates. - The leaves should be booleans, `True` for leaves/subtrees you want to - apply the weight decay to, and `False` for those you want to skip. Note - that the Adam gradient transformations are applied to all parameters. - - Returns: - The corresponding `GradientTransformation`. - """ - return combine.chain( - scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps, eps_root), - transform.add_decayed_weights(weight_decay, mask), - transform.scale_by_learning_rate(lr), - ) - - -def scale_by_ademamix(b1, b2, b3, alpha, b3_scheduler, alpha_scheduler, eps=1e-8, eps_root=0.0): - - def init_fn(params): - m1 = jax.tree.map(jnp.zeros_like, params) # fast EMA - m2 = jax.tree.map(jnp.zeros_like, params) # slow EMA - nu = jax.tree.map(jnp.zeros_like, params) # second moment estimate - return ScaleByAdemamixState(count=jnp.zeros([], jnp.int32), count_m2=jnp.zeros([], jnp.int32), m1=m1, m2=m2, nu=nu) - - def update_fn(updates, state, params=None): - del params - c_b3 = b3_scheduler(state.count_m2) if b3_scheduler is not None else b3 - c_alpha = alpha_scheduler(state.count_m2) if alpha_scheduler is not None else alpha - m1 = _update_moment(updates, state.m1, b1, 1) # m1 = b1 * m1 + (1-b1) * updates - m2 = _update_moment(updates, state.m2, c_b3, 1) - nu = _update_moment(updates, state.nu, b2, 2) - count = state.count + jnp.array(1, dtype=jnp.int32) - # count_inc = numerics.safe_int32_increment(state.count) - count_m2 = state.count_m2 + jnp.array(1, dtype=jnp.int32) - # count_m2_inc = numerics.safe_int32_increment(state.count_m2) - m1_hat = _bias_correction(m1, b1, count) - nu_hat = _bias_correction(nu, b2, count) - updates = jax.tree.map(lambda m1_, m2_, v_: (m1_+c_alpha*m2_)/(jnp.sqrt(v_+eps_root)+eps), m1_hat, m2, nu_hat) - return updates, ScaleByAdemamixState(count=count, count_m2=count_m2, m1=m1, m2=m2, nu=nu) - - return base.GradientTransformation(init_fn, update_fn) - - -def _update_moment(updates, moments, decay, order): - """Compute the exponential moving average of the `order`-th moment.""" - return jax.tree.map( - lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) - - - -def _bias_correction(moment, decay, count): - """Performs bias correction. It becomes a no-op as count goes to infinity.""" - # The conversion to the data type of the moment ensures that bfloat16 remains - # bfloat16 in the optimizer state. This conversion has to be done after - # `bias_correction_` is calculated as calculating `decay**count` in low - # precision can result in it being rounded to 1 and subsequently a - # "division by zero" error. - bias_correction_ = 1 - decay**count - - # Perform division in the original precision. - return jax.tree.map( - lambda t: t / bias_correction_.astype(t.dtype), moment) +def build_ademamix_optimizer( + lr, + variant='simplified', + b1=0.9, + b2=0.999, + b3=0.9999, + alpha=5.0, + b3_scheduler=None, + alpha_scheduler=None, + eps=1e-8, + eps_root=0.0, + weight_decay=0.0, + mask=None, +): + if variant == 'simplified': + return optax.contrib.simplified_ademamix( + learning_rate=lr, + b1=b1, + b2=b2, + alpha=alpha_scheduler if alpha_scheduler is not None else alpha, + eps=eps, + eps_root=eps_root, + weight_decay=weight_decay, + mask=mask, + ) + if variant == 'full': + return optax.contrib.ademamix( + learning_rate=lr, + b1=b1, + b2=b2, + b3=b3_scheduler if b3_scheduler is not None else b3, + alpha=alpha_scheduler if alpha_scheduler is not None else alpha, + eps=eps, + eps_root=eps_root, + weight_decay=weight_decay, + mask=mask, + ) + raise ValueError(f'Unsupported ademamix variant: {variant}') def train_step(workload, @@ -410,21 +333,26 @@ def init_optimizer_state( b2 = HPARAMS['b2'] b3 = HPARAMS['b3'] alpha = HPARAMS['alpha'] + variant = HPARAMS['ademamix_variant'] warmup = HPARAMS['warmup'] T = workload.step_hint f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=T) f_a = alpha_scheduler(alpha, alpha_start=0, warmup=T) f_lr = lr_scheduler(lr, warmup, T) weight_decay = HPARAMS['weight_decay'] - opt_init_fn, opt_update_fn = ademamix(lr=f_lr, - b1=b1, - b2=b2, - b3=b3, - alpha=alpha, - b3_scheduler=f_b3, - alpha_scheduler=f_a, - weight_decay=weight_decay - ) + optimizer = build_ademamix_optimizer( + lr=f_lr, + variant=variant, + b1=b1, + b2=b2, + b3=b3, + alpha=alpha, + b3_scheduler=f_b3, + alpha_scheduler=f_a, + weight_decay=weight_decay, + ) + opt_init_fn = optimizer.init + opt_update_fn = optimizer.update optimizer_state = opt_init_fn(params_zeros_like) return optimizer_state, opt_update_fn @@ -438,14 +366,17 @@ def f(x): return jnp.sum(x ** 2) # simple quadratic function f_a = alpha_scheduler(alpha, alpha_start=0, warmup=10) f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=10) - solver = ademamix(lr=0.01, - b1=b1, - b2=b2, - b3=b3, - alpha=alpha, - b3_scheduler=f_b3, - alpha_scheduler=f_a, - weight_decay=0.01) + solver = build_ademamix_optimizer( + lr=0.01, + variant='full', + b1=b1, + b2=b2, + b3=b3, + alpha=alpha, + b3_scheduler=f_b3, + alpha_scheduler=f_a, + weight_decay=0.01, + ) params = jnp.array([1., 2., 3.]) print('Objective function: {:.2f}'.format(f(params))) From 106858f67972b0b8952570506845a66e102ccef4 Mon Sep 17 00:00:00 2001 From: David Tweedle Date: Fri, 10 Apr 2026 12:09:07 -0400 Subject: [PATCH 33/33] attempt to match pytorch submission behaviour --- .../self_tuning/ademamix/submission.py | 63 ++++++++++--------- 1 file changed, 35 insertions(+), 28 deletions(-) diff --git a/submissions/self_tuning/ademamix/submission.py b/submissions/self_tuning/ademamix/submission.py index 7c6139676..eec8d5b70 100644 --- a/submissions/self_tuning/ademamix/submission.py +++ b/submissions/self_tuning/ademamix/submission.py @@ -20,39 +20,41 @@ HPARAMS = { 'ademamix_variant': 'simplified', 'alpha': 8.0, - 'alpha_start': 0, - 'warmup': 10, - 'beta_end': 0.9999, - 'beta_start': 0.9, - 'learning_rate': 0.01, - 'b1': 0.9, - 'b2': 0.999, - 'b3': 0.9999, + 'warmup_factor': 0.02, + 'beta3_warmup': 500e3, + 'alpha_warmup': 500e3, + 'learning_rate': 2e-3, + 'one_minus_beta1': 0.2, + 'beta2': 0.995, + 'beta3': 0.9995, 'eps': 1e-8, 'eps_root': 0.0, - 'weight_decay': 0.01, + 'weight_decay': 0.1, + 'grad_clip': 0.5, 'dropout_rate': 0.1, } _GRAD_CLIP_EPS = 1e-6 -def lr_scheduler(learning_rate, warmup_steps, total_steps): +def lr_scheduler(learning_rate, warmup_factor, total_steps): + warmup_steps = int(warmup_factor * total_steps) + cosine_steps = max(total_steps - warmup_steps, 1) return optax.warmup_cosine_decay_schedule( - init_value=0.0, + init_value=learning_rate * 1e-10, peak_value=learning_rate, warmup_steps=warmup_steps, - decay_steps=total_steps, - end_value=learning_rate * 0.01 + decay_steps=warmup_steps + cosine_steps, + end_value=0.0 ) -def alpha_scheduler(alpha, alpha_start=0, warmup=0): - warmup_fn = optax.linear_schedule(init_value=alpha_start, end_value=alpha, transition_steps=warmup) +def alpha_scheduler(alpha, warmup=0): + warmup_fn = optax.linear_schedule(init_value=0, end_value=alpha, transition_steps=warmup) constant_fn = optax.constant_schedule(alpha) schedule_fn = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup]) return schedule_fn -def beta3_scheduler(beta_end, beta_start=0, warmup=0): +def beta3_scheduler(beta3, beta1=0, warmup=0): def f(beta): return jnp.log(0.5)/jnp.log(beta)-1 @@ -62,9 +64,9 @@ def f_inv(t): def warmup_fn(step): frac = 1 - step / warmup - return f_inv( frac * f(beta_start) + (1 - frac) * f(beta_end)) + return f_inv( frac * f(beta1) + (1 - frac) * f(beta3)) - constant_fn = optax.constant_schedule(beta_end) + constant_fn = optax.constant_schedule(beta3) schedule_fn = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup]) return schedule_fn @@ -329,16 +331,19 @@ def init_optimizer_state( lambda s: jnp.zeros(s.shape_tuple), workload.param_shapes ) lr = HPARAMS['learning_rate'] - b1 = HPARAMS['b1'] - b2 = HPARAMS['b2'] - b3 = HPARAMS['b3'] + one_minus_beta1 = HPARAMS['one_minus_beta1'] + b1 = 1.0 - one_minus_beta1 + b2 = HPARAMS['beta2'] + b3 = HPARAMS['beta3'] alpha = HPARAMS['alpha'] variant = HPARAMS['ademamix_variant'] - warmup = HPARAMS['warmup'] + warmup_factor = HPARAMS['warmup_factor'] + beta3_warmup = HPARAMS['beta3_warmup'] + alpha_warmup = HPARAMS['alpha_warmup'] T = workload.step_hint - f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=T) - f_a = alpha_scheduler(alpha, alpha_start=0, warmup=T) - f_lr = lr_scheduler(lr, warmup, T) + f_b3 = beta3_scheduler(b3, beta1=b1, warmup=beta3_warmup) + f_a = alpha_scheduler(alpha, warmup=alpha_warmup) + f_lr = lr_scheduler(lr, warmup_factor, T) weight_decay = HPARAMS['weight_decay'] optimizer = build_ademamix_optimizer( lr=f_lr, @@ -361,10 +366,12 @@ def init_optimizer_state( def f(x): return jnp.sum(x ** 2) # simple quadratic function alpha = 8.0 - b1, b2, b3 = 0.9, 0.999, 0.9999 + one_minus_beta1 = 0.1 + b1 = 1.0 - one_minus_beta1 + b2, b3 = 0.999, 0.9999 - f_a = alpha_scheduler(alpha, alpha_start=0, warmup=10) - f_b3 = beta3_scheduler(b3, beta_start=b1, warmup=10) + f_a = alpha_scheduler(alpha, warmup=10) + f_b3 = beta3_scheduler(b3, beta1=b1, warmup=10) solver = build_ademamix_optimizer( lr=0.01,