Skip to content

Commit 63f3f79

Browse files
author
Alexander Ororbia
committed
mod to model_utils
1 parent 9c3b1a0 commit 63f3f79

1 file changed

Lines changed: 49 additions & 28 deletions

File tree

ngclearn/utils/model_utils.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def create_function(fun_name, args=None):
108108
elif fun_name == "elu":
109109
fx = elu
110110
dfx = d_elu
111-
elif fun_name == "silu":
111+
elif fun_name == "silu": # NOTE: this is also the swish function
112112
fx = silu
113113
dfx = d_silu
114114
elif fun_name == "gelu":
@@ -122,22 +122,20 @@ def create_function(fun_name, args=None):
122122
dfx = d_softplus
123123
elif fun_name == "softmax":
124124
fx = softmax
125-
## NOTE: below is an improper derivative proxy
126-
## correct dfx is a Jacobian of softmax (not currently supported!)
127-
dfx = d_identity
125+
dfx = d_softmax ## NOTE: this yields a Jacobian tensor Jx
128126
elif fun_name == "unit_threshold":
129127
fx = threshold ## default threshold is 1 (thus unit)
130128
dfx = d_threshold ## STE approximation
131129
elif "heaviside" in fun_name:
132130
fx = heaviside
133-
dfx = d_heaviside ## STE approximation
131+
dfx = d_heaviside ## NOTE: this is an STE approximation
134132
elif fun_name == "identity":
135133
fx = identity
136134
dfx = d_identity
137-
else:
135+
else: ## throw exception for un-supported activation
138136
raise RuntimeError(
139137
"Activation function (" + fun_name + ") is not recognized/supported!"
140-
)
138+
)
141139
return fx, dfx
142140

143141
@partial(jit, static_argnums=[1])
@@ -214,7 +212,6 @@ def clamp_max(x, max_val):
214212
_x = x * mask + (1. - mask) * max_val
215213
return _x
216214

217-
218215
@jit
219216
def one_hot(P):
220217
"""
@@ -514,31 +511,32 @@ def d_softplus(x):
514511
output (tensor) derivative value (with respect to input argument)
515512
"""
516513
## d/dx of softplus = logistic sigmoid
517-
return nn.sigmoid(x)
514+
return sigmoid(x) #nn.sigmoid(x)
518515

519516
@jit
520517
def threshold(x, thr=1.):
521518
return (x >= thr).astype(jnp.float32)
522519

523520
@jit
524521
def d_threshold(x, thr=1.):
525-
return x * 0. + 1. ## straight-thru estimator
522+
return x * 0. + 1. ## NOTE: straight-thru estimator (STE)
526523

527524
@jit
528525
def heaviside(x):
529526
return (x >= 0.).astype(jnp.float32)
530527

531528
@jit
532529
def d_heaviside(x):
533-
return x * 0. + 1. ## straight-thru estimator
530+
return x * 0. + 1. ## NOTE: straight-thru estimator (STE)
534531

535532
@jit
536533
def sigmoid(x):
537-
return nn.sigmoid(x)
534+
sigm_x = 1./ (1. + jnp.exp(-x))
535+
return sigm_x #nn.sigmoid(x)
538536

539537
@jit
540538
def d_sigmoid(x):
541-
sigm_x = nn.sigmoid(x) ## pre-compute once
539+
sigm_x = sigmoid(x) #nn.sigmoid(x) ## pre-compute once
542540
return sigm_x * (1. - sigm_x)
543541

544542
def inverse_sigmoid(x, clip_bound=0.03): ## wrapper call for naming convention ease
@@ -590,7 +588,9 @@ def d_swish(x, beta):
590588
@jit
591589
def silu(x):
592590
"""
593-
Applies the sigmoid-weighted linear unit (SiLU or SiL) activation.
591+
Applies the sigmoid-weighted linear unit (SiLU or SiL) activation.
592+
Note that this is primarily a convenience wrapper function for
593+
the `swish` activation.
594594
595595
Args:
596596
x: data to transform via inverse logistic function
@@ -607,7 +607,8 @@ def d_silu(x):
607607
@jit
608608
def gelu(x):
609609
"""
610-
Applies the Gaussian Error Linear Unit (GeLU) activation (specifically, a fast approximation is used).
610+
Applies the Gaussian Error Linear Unit (GeLU) activation
611+
(specifically, a fast approximation is used via a weighted `swish`).
611612
612613
Args:
613614
x: data to transform via inverse logistic function
@@ -635,7 +636,7 @@ def elu(x, alpha=1.):
635636
Returns:
636637
output of the GeLU activation
637638
"""
638-
mask = x >= 0.
639+
mask = x >= 0. ## pre-compute mask
639640
return x * mask + ((jnp.exp(x) - 1) * alpha) * (1. - mask)
640641

641642
@jit
@@ -653,7 +654,8 @@ def softmax(x, tau=0.0):
653654
Args:
654655
x: a (N x D) input argument (pre-activity) to the softmax operator
655656
656-
tau: probability sharpening/softening factor
657+
tau: probability sharpening/softening factor, if > 0.; else, <= 0 disables
658+
this (Default: 0.)
657659
658660
Returns:
659661
a (N x D) probability distribution output block
@@ -664,28 +666,47 @@ def softmax(x, tau=0.0):
664666
exp_x = jnp.exp(x - max_x)
665667
return exp_x / jnp.sum(exp_x, axis=1, keepdims=True)
666668

667-
@jit
668-
def d_softmax(x):
669+
@partial(jit, static_argnums=[2])
670+
def d_softmax(x, tau=0., vmap_form=False): ## temperature-controlled softmax derivative co-routine
669671
"""
670672
Derivative of the softmax function.
671-
Note that this returns specifically the Jacobian tensor of softmax(x) w.r.t.
673+
Note that this returns specifically the Jacobian tensor `Jx` of softmax(x) w.r.t.
672674
potential batch set of vectors (one per row).
673675
674676
Args:
675677
x: input (tensor) value (B x D)
676678
679+
vmap_form: optional algorithm switch flag; if True, `Jx` is computed using
680+
Jax vmap (Default: False)
681+
677682
Returns:
678683
output (tensor) derivative values (Jacobian with respect to input argument; B x D x D)
679684
"""
685+
_m = tau > 0.
686+
_tau = tau * _m + (1. - _m) ## sets _tau=1 if tau <= 0
687+
Jx = 0. ## d_softmax(x)/d_x is a Jacobian matrix per sample
680688
## caclulate softmax along feature dimension (axis=-1)
681-
s = jax.nn.softmax(x, axis=-1) ## Shape: (B, D)
682-
## Batch-up diag(s); multiply s by 3D identity tensor
683-
## Shape: (B, D, 1) * (1, D, D) => (B, D, D)
684-
diag_s = jnp.expand_dims(s, axis=-1) * jnp.eye(s.shape[-1])
685-
## Batched outer(s, s): Broadcasted multiplication
686-
## Shape: (B, D, 1) * (B, 1, D) => (B, D, D)
687-
outer_s = jnp.expand_dims(s, axis=-1) * jnp.expand_dims(s, axis=-2)
688-
return diag_s - outer_s ## return full final Jacobian
689+
s = softmax(x, tau=_tau) # nn.softmax(x, axis=-1) ## (BxD)
690+
if not vmap_form: ### use pure tensorized batch-identity trick algorithm
691+
diag_s = jnp.expand_dims(s, axis=-1) * jnp.eye(s.shape[-1]) ## Shape: (BxDx1) * (1xDxD) => (BxDxD)
692+
## batched outer(s, s) ~> outer product for each batch vector
693+
outer_s = jnp.expand_dims(s, axis=-1) * jnp.expand_dims(s, axis=-2) ## (BxDx1) * (Bx1xD) => (BxDxD)
694+
Jx = (diag_s - outer_s) * (1. / _tau)
695+
else: ### switch to vmap algorithm
696+
## calc outer product using einsum (clean and readable)
697+
outer_s = jnp.einsum('bi,bj->bij', s, s) ## (BxDxD)
698+
## fast batched diagonal insertion via a diagonal mask
699+
d = s.shape[-1]
700+
diag_indices = jnp.arange(d)
701+
## jax.at subtracts outer product from diagonal
702+
## (s - s^2) for diagonal, (-s_i s_j) for off-diagonal
703+
## avoids constructing a giant identity matrix
704+
jacobian = -outer_s
705+
## vmap over index updates across batch
706+
def add_diag(J_matrix, s_vector):
707+
return J_matrix.at[diag_indices, diag_indices].add(s_vector)
708+
Jx = ( jax.vmap(add_diag)(jacobian, s) ) * (1. / _tau)
709+
return Jx ## return full, final Jacobian
689710

690711
@jit
691712
def threshold_soft(x, lmbda):

0 commit comments

Comments
 (0)