@@ -664,10 +664,34 @@ def softmax(x, tau=0.0):
664664 exp_x = jnp .exp (x - max_x )
665665 return exp_x / jnp .sum (exp_x , axis = 1 , keepdims = True )
666666
667+ @jit
668+ def d_softmax (x ):
669+ """
670+ Derivative of the softmax function.
671+ Note that this returns specifically the Jacobian tensor of softmax(x) w.r.t.
672+ potential batch set of vectors (one per row).
673+
674+ Args:
675+ x: input (tensor) value (B x D)
676+
677+ Returns:
678+ output (tensor) derivative values (Jacobian with respect to input argument; B x D x D)
679+ """
680+ ## caclulate softmax along feature dimension (axis=-1)
681+ s = jax .nn .softmax (x , axis = - 1 ) ## Shape: (B, D)
682+ ## Batch-up diag(s); multiply s by 3D identity tensor
683+ ## Shape: (B, D, 1) * (1, D, D) => (B, D, D)
684+ diag_s = jnp .expand_dims (s , axis = - 1 ) * jnp .eye (s .shape [- 1 ])
685+ ## Batched outer(s, s): Broadcasted multiplication
686+ ## Shape: (B, D, 1) * (B, 1, D) => (B, D, D)
687+ outer_s = jnp .expand_dims (s , axis = - 1 ) * jnp .expand_dims (s , axis = - 2 )
688+ return diag_s - outer_s ## return full final Jacobian
689+
667690@jit
668691def threshold_soft (x , lmbda ):
669692 """
670- A soft threshold routine applied to each dimension of input
693+ A soft threshold routine applied to each dimension of input.
694+ (Note that this function does not contain a complementary derivative.)
671695
672696 Args:
673697 x: data to apply threshold function over
@@ -684,7 +708,8 @@ def threshold_soft(x, lmbda):
684708@jit
685709def threshold_cauchy (x , lmbda ):
686710 """
687- A Cauchy distributional threshold routine applied to each dimension of input
711+ A Cauchy distributional threshold routine applied to each dimension of input.
712+ (Note that this function does not contain a complementary derivative.)
688713
689714 Args:
690715 x: data to apply threshold function over
0 commit comments