@@ -546,7 +546,6 @@ def _cross_entropy_with_logits_fwd(logits: jnp.ndarray, targets: jnp.ndarray, z_
546546 jnp .ndarray ,
547547 jnp .ndarray ,
548548 jnp .ndarray ,
549- jnp .ndarray ,
550549 ],
551550]:
552551 """Forward-mode of `cross_entropy_with_logits`."""
@@ -566,7 +565,6 @@ def _cross_entropy_with_logits_fwd(logits: jnp.ndarray, targets: jnp.ndarray, z_
566565 z_loss ,
567566 exp_shifted ,
568567 sum_exp , # pytype: disable=bad-return-type #jax-ndarray
569- log_softmax ,
570568 log_z ,
571569 )
572570
@@ -579,21 +577,20 @@ def _cross_entropy_with_logits_bwd(
579577 jnp .ndarray ,
580578 jnp .ndarray ,
581579 jnp .ndarray ,
582- jnp .ndarray ,
583580 ],
584581 g : tuple [jnp .ndarray , jnp .ndarray ],
585- ) -> tuple [jnp .ndarray , jnp . ndarray , jnp . ndarray ]:
582+ ) -> tuple [jnp .ndarray , None , None ]:
586583 """Backward-mode of `cross_entropy_with_logits`."""
587584 g = g [0 ] # Ignore z_loss component as that is only used for logging.
588- logits , targets , z_loss , exp_shifted , sum_exp , log_softmax , log_z = res
585+ logits , targets , z_loss , exp_shifted , sum_exp , log_z = res
589586 # z-loss term adds the (2 * z_loss * log_z) factor.
590587 deriv = jnp .expand_dims (1 + 2 * z_loss * log_z , - 1 ) * exp_shifted / sum_exp - targets
591588 g_logits = jnp .expand_dims (g , axis = - 1 ) * deriv
592- g_targets = - jnp . expand_dims ( g , axis = - 1 ) * log_softmax
589+
593590 return (
594591 jnp .asarray (g_logits , logits .dtype ),
595- jnp . asarray ( g_targets , targets . dtype ),
596- jnp . array ( 0.0 ),
592+ None , # we don't need gradients on targets
593+ None , # we don't need gradients on z_loss
597594 ) # sets z-loss coeff gradient to 0
598595
599596
0 commit comments