|
| 1 | +# %% |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +from jax import jit, numpy as jnp, random, nn, lax |
| 5 | +from functools import partial |
| 6 | +import time |
| 7 | + |
| 8 | + |
| 9 | +def step_update(param, update, phi_old, lr, mu, time_step): |
| 10 | + """ |
| 11 | + Runs one step of Nesterov's accelerated gradient (NAG) over a set of parameters given updates. |
| 12 | + The dynamics for any set of parameters is as follows: |
| 13 | +
|
| 14 | + | phi = param - update * lr |
| 15 | + | param = phi + (phi - phi_previous) * mu, where mu = 0 iff t <= 1 (first iteration) |
| 16 | +
|
| 17 | + Args: |
| 18 | + param: parameter tensor to change/adjust |
| 19 | +
|
| 20 | + update: update tensor to be applied to parameter tensor (must be same |
| 21 | + shape as "param") |
| 22 | +
|
| 23 | + phi_old: previous friction/momentum parameter |
| 24 | +
|
| 25 | + lr: global step size value to be applied to updates to parameters |
| 26 | +
|
| 27 | + mu: friction/momentum control factor |
| 28 | +
|
| 29 | + time_step: current time t or iteration step/call to this NAG update |
| 30 | +
|
| 31 | + Returns: |
| 32 | + adjusted parameter tensor (same shape as "param"), adjusted momentum/friction variable |
| 33 | + """ |
| 34 | + phi = param - update * lr ## do a phantom gradient adjustment step |
| 35 | + _param = phi + (phi - phi_old) * (mu * (time_step > 1.)) ## NAG-step |
| 36 | + _phi_old = phi |
| 37 | + return _param, _phi_old |
| 38 | + |
| 39 | +@jit |
| 40 | +def nag_step(opt_params, theta, updates, eta=0.01, mu=0.9): ## apply adjustment to theta |
| 41 | + """ |
| 42 | + Implements Nesterov's accelerated gradient (NAG) algorithm as a decoupled update rule given adjustments produced |
| 43 | + by a credit assignment algorithm/process. |
| 44 | +
|
| 45 | + Args: |
| 46 | + opt_params: (ArrayLike) parameters of the optimization algorithm |
| 47 | +
|
| 48 | + theta: (ArrayLike) the weights of neural network |
| 49 | +
|
| 50 | + updates: (ArrayLike) the updates of neural network |
| 51 | +
|
| 52 | + eta: (float, optional) step size coefficient for NAG update (Default: 0.001) |
| 53 | +
|
| 54 | + mu: (float, optional) friction/momentum control factor. (Default: 0.9) |
| 55 | +
|
| 56 | + Returns: |
| 57 | + ArrayLike: opt_params. New opt params, ArrayLike: theta. The updated weights |
| 58 | + """ |
| 59 | + phi, time_step = opt_params |
| 60 | + time_step = time_step + 1 |
| 61 | + new_theta = [] |
| 62 | + new_phi = [] |
| 63 | + for i in range(len(theta)): |
| 64 | + px_i, phi_i = step_update(theta[i], updates[i], phi[i], eta, mu, time_step) |
| 65 | + new_theta.append(px_i) |
| 66 | + new_phi.append(phi_i) |
| 67 | + return (new_phi, time_step), new_theta |
| 68 | + |
| 69 | +@jit |
| 70 | +def nag_init(theta): |
| 71 | + time_step = jnp.asarray(0.0) |
| 72 | + phi = [jnp.zeros(theta[i].shape) for i in range(len(theta))] |
| 73 | + return phi, time_step |
| 74 | + |
| 75 | +if __name__ == '__main__': |
| 76 | + weights = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])] |
| 77 | + updates = [jnp.asarray([3.0, 3.0]), jnp.asarray([3.0, 3.0])] |
| 78 | + opt_params = nag_init(weights) |
| 79 | + opt_params, theta = nag_step(opt_params, weights, updates) |
| 80 | + print(f"opt_params: {opt_params}, theta: {theta}") |
| 81 | + weights = theta |
| 82 | + print("##################") |
| 83 | + opt_params, theta = nag_step(opt_params, weights, updates) |
| 84 | + print(f"opt_params: {opt_params}, theta: {theta}") |
0 commit comments