Skip to content

Commit f47549a

Browse files
author
Alexander Ororbia
committed
mod to model_utils
1 parent abe7dfa commit f47549a

1 file changed

Lines changed: 27 additions & 2 deletions

File tree

ngclearn/utils/model_utils.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,10 +664,34 @@ def softmax(x, tau=0.0):
664664
exp_x = jnp.exp(x - max_x)
665665
return exp_x / jnp.sum(exp_x, axis=1, keepdims=True)
666666

667+
@jit
668+
def d_softmax(x):
669+
"""
670+
Derivative of the softmax function.
671+
Note that this returns specifically the Jacobian tensor of softmax(x) w.r.t.
672+
potential batch set of vectors (one per row).
673+
674+
Args:
675+
x: input (tensor) value (B x D)
676+
677+
Returns:
678+
output (tensor) derivative values (Jacobian with respect to input argument; B x D x D)
679+
"""
680+
## caclulate softmax along feature dimension (axis=-1)
681+
s = jax.nn.softmax(x, axis=-1) ## Shape: (B, D)
682+
## Batch-up diag(s); multiply s by 3D identity tensor
683+
## Shape: (B, D, 1) * (1, D, D) => (B, D, D)
684+
diag_s = jnp.expand_dims(s, axis=-1) * jnp.eye(s.shape[-1])
685+
## Batched outer(s, s): Broadcasted multiplication
686+
## Shape: (B, D, 1) * (B, 1, D) => (B, D, D)
687+
outer_s = jnp.expand_dims(s, axis=-1) * jnp.expand_dims(s, axis=-2)
688+
return diag_s - outer_s ## return full final Jacobian
689+
667690
@jit
668691
def threshold_soft(x, lmbda):
669692
"""
670-
A soft threshold routine applied to each dimension of input
693+
A soft threshold routine applied to each dimension of input.
694+
(Note that this function does not contain a complementary derivative.)
671695
672696
Args:
673697
x: data to apply threshold function over
@@ -684,7 +708,8 @@ def threshold_soft(x, lmbda):
684708
@jit
685709
def threshold_cauchy(x, lmbda):
686710
"""
687-
A Cauchy distributional threshold routine applied to each dimension of input
711+
A Cauchy distributional threshold routine applied to each dimension of input.
712+
(Note that this function does not contain a complementary derivative.)
688713
689714
Args:
690715
x: data to apply threshold function over

0 commit comments

Comments
 (0)