@@ -108,7 +108,7 @@ def create_function(fun_name, args=None):
108108 elif fun_name == "elu" :
109109 fx = elu
110110 dfx = d_elu
111- elif fun_name == "silu" :
111+ elif fun_name == "silu" : # NOTE: this is also the swish function
112112 fx = silu
113113 dfx = d_silu
114114 elif fun_name == "gelu" :
@@ -122,22 +122,20 @@ def create_function(fun_name, args=None):
122122 dfx = d_softplus
123123 elif fun_name == "softmax" :
124124 fx = softmax
125- ## NOTE: below is an improper derivative proxy
126- ## correct dfx is a Jacobian of softmax (not currently supported!)
127- dfx = d_identity
125+ dfx = d_softmax ## NOTE: this yields a Jacobian tensor Jx
128126 elif fun_name == "unit_threshold" :
129127 fx = threshold ## default threshold is 1 (thus unit)
130128 dfx = d_threshold ## STE approximation
131129 elif "heaviside" in fun_name :
132130 fx = heaviside
133- dfx = d_heaviside ## STE approximation
131+ dfx = d_heaviside ## NOTE: this is an STE approximation
134132 elif fun_name == "identity" :
135133 fx = identity
136134 dfx = d_identity
137- else :
135+ else : ## throw exception for un-supported activation
138136 raise RuntimeError (
139137 "Activation function (" + fun_name + ") is not recognized/supported!"
140- )
138+ )
141139 return fx , dfx
142140
143141@partial (jit , static_argnums = [1 ])
@@ -214,7 +212,6 @@ def clamp_max(x, max_val):
214212 _x = x * mask + (1. - mask ) * max_val
215213 return _x
216214
217-
218215@jit
219216def one_hot (P ):
220217 """
@@ -514,31 +511,32 @@ def d_softplus(x):
514511 output (tensor) derivative value (with respect to input argument)
515512 """
516513 ## d/dx of softplus = logistic sigmoid
517- return nn .sigmoid (x )
514+ return sigmoid ( x ) # nn.sigmoid(x)
518515
519516@jit
520517def threshold (x , thr = 1. ):
521518 return (x >= thr ).astype (jnp .float32 )
522519
523520@jit
524521def d_threshold (x , thr = 1. ):
525- return x * 0. + 1. ## straight-thru estimator
522+ return x * 0. + 1. ## NOTE: straight-thru estimator (STE)
526523
527524@jit
528525def heaviside (x ):
529526 return (x >= 0. ).astype (jnp .float32 )
530527
531528@jit
532529def d_heaviside (x ):
533- return x * 0. + 1. ## straight-thru estimator
530+ return x * 0. + 1. ## NOTE: straight-thru estimator (STE)
534531
535532@jit
536533def sigmoid (x ):
537- return nn .sigmoid (x )
534+ sigm_x = 1. / (1. + jnp .exp (- x ))
535+ return sigm_x #nn.sigmoid(x)
538536
539537@jit
540538def d_sigmoid (x ):
541- sigm_x = nn .sigmoid (x ) ## pre-compute once
539+ sigm_x = sigmoid ( x ) # nn.sigmoid(x) ## pre-compute once
542540 return sigm_x * (1. - sigm_x )
543541
544542def inverse_sigmoid (x , clip_bound = 0.03 ): ## wrapper call for naming convention ease
@@ -590,7 +588,9 @@ def d_swish(x, beta):
590588@jit
591589def silu (x ):
592590 """
593- Applies the sigmoid-weighted linear unit (SiLU or SiL) activation.
591+ Applies the sigmoid-weighted linear unit (SiLU or SiL) activation.
592+ Note that this is primarily a convenience wrapper function for
593+ the `swish` activation.
594594
595595 Args:
596596 x: data to transform via inverse logistic function
@@ -607,7 +607,8 @@ def d_silu(x):
607607@jit
608608def gelu (x ):
609609 """
610- Applies the Gaussian Error Linear Unit (GeLU) activation (specifically, a fast approximation is used).
610+ Applies the Gaussian Error Linear Unit (GeLU) activation
611+ (specifically, a fast approximation is used via a weighted `swish`).
611612
612613 Args:
613614 x: data to transform via inverse logistic function
@@ -635,7 +636,7 @@ def elu(x, alpha=1.):
635636 Returns:
636637 output of the GeLU activation
637638 """
638- mask = x >= 0.
639+ mask = x >= 0. ## pre-compute mask
639640 return x * mask + ((jnp .exp (x ) - 1 ) * alpha ) * (1. - mask )
640641
641642@jit
@@ -653,7 +654,8 @@ def softmax(x, tau=0.0):
653654 Args:
654655 x: a (N x D) input argument (pre-activity) to the softmax operator
655656
656- tau: probability sharpening/softening factor
657+ tau: probability sharpening/softening factor, if > 0.; else, <= 0 disables
658+ this (Default: 0.)
657659
658660 Returns:
659661 a (N x D) probability distribution output block
@@ -664,28 +666,47 @@ def softmax(x, tau=0.0):
664666 exp_x = jnp .exp (x - max_x )
665667 return exp_x / jnp .sum (exp_x , axis = 1 , keepdims = True )
666668
667- @jit
668- def d_softmax (x ):
669+ @partial ( jit , static_argnums = [ 2 ])
670+ def d_softmax (x , tau = 0. , vmap_form = False ): ## temperature-controlled softmax derivative co-routine
669671 """
670672 Derivative of the softmax function.
671- Note that this returns specifically the Jacobian tensor of softmax(x) w.r.t.
673+ Note that this returns specifically the Jacobian tensor `Jx` of softmax(x) w.r.t.
672674 potential batch set of vectors (one per row).
673675
674676 Args:
675677 x: input (tensor) value (B x D)
676678
679+ vmap_form: optional algorithm switch flag; if True, `Jx` is computed using
680+ Jax vmap (Default: False)
681+
677682 Returns:
678683 output (tensor) derivative values (Jacobian with respect to input argument; B x D x D)
679684 """
685+ _m = tau > 0.
686+ _tau = tau * _m + (1. - _m ) ## sets _tau=1 if tau <= 0
687+ Jx = 0. ## d_softmax(x)/d_x is a Jacobian matrix per sample
680688 ## caclulate softmax along feature dimension (axis=-1)
681- s = jax .nn .softmax (x , axis = - 1 ) ## Shape: (B, D)
682- ## Batch-up diag(s); multiply s by 3D identity tensor
683- ## Shape: (B, D, 1) * (1, D, D) => (B, D, D)
684- diag_s = jnp .expand_dims (s , axis = - 1 ) * jnp .eye (s .shape [- 1 ])
685- ## Batched outer(s, s): Broadcasted multiplication
686- ## Shape: (B, D, 1) * (B, 1, D) => (B, D, D)
687- outer_s = jnp .expand_dims (s , axis = - 1 ) * jnp .expand_dims (s , axis = - 2 )
688- return diag_s - outer_s ## return full final Jacobian
689+ s = softmax (x , tau = _tau ) # nn.softmax(x, axis=-1) ## (BxD)
690+ if not vmap_form : ### use pure tensorized batch-identity trick algorithm
691+ diag_s = jnp .expand_dims (s , axis = - 1 ) * jnp .eye (s .shape [- 1 ]) ## Shape: (BxDx1) * (1xDxD) => (BxDxD)
692+ ## batched outer(s, s) ~> outer product for each batch vector
693+ outer_s = jnp .expand_dims (s , axis = - 1 ) * jnp .expand_dims (s , axis = - 2 ) ## (BxDx1) * (Bx1xD) => (BxDxD)
694+ Jx = (diag_s - outer_s ) * (1. / _tau )
695+ else : ### switch to vmap algorithm
696+ ## calc outer product using einsum (clean and readable)
697+ outer_s = jnp .einsum ('bi,bj->bij' , s , s ) ## (BxDxD)
698+ ## fast batched diagonal insertion via a diagonal mask
699+ d = s .shape [- 1 ]
700+ diag_indices = jnp .arange (d )
701+ ## jax.at subtracts outer product from diagonal
702+ ## (s - s^2) for diagonal, (-s_i s_j) for off-diagonal
703+ ## avoids constructing a giant identity matrix
704+ jacobian = - outer_s
705+ ## vmap over index updates across batch
706+ def add_diag (J_matrix , s_vector ):
707+ return J_matrix .at [diag_indices , diag_indices ].add (s_vector )
708+ Jx = ( jax .vmap (add_diag )(jacobian , s ) ) * (1. / _tau )
709+ return Jx ## return full, final Jacobian
689710
690711@jit
691712def threshold_soft (x , lmbda ):
0 commit comments