Skip to content

Commit e6ba816

Browse files
Merge pull request #2858 from AI-Hypercomputer:chengnuojin-fix-vjp
PiperOrigin-RevId: 846813597
2 parents 28d570a + 8668c46 commit e6ba816

1 file changed

Lines changed: 5 additions & 8 deletions

File tree

src/MaxText/max_utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,6 @@ def _cross_entropy_with_logits_fwd(logits: jnp.ndarray, targets: jnp.ndarray, z_
546546
jnp.ndarray,
547547
jnp.ndarray,
548548
jnp.ndarray,
549-
jnp.ndarray,
550549
],
551550
]:
552551
"""Forward-mode of `cross_entropy_with_logits`."""
@@ -566,7 +565,6 @@ def _cross_entropy_with_logits_fwd(logits: jnp.ndarray, targets: jnp.ndarray, z_
566565
z_loss,
567566
exp_shifted,
568567
sum_exp, # pytype: disable=bad-return-type #jax-ndarray
569-
log_softmax,
570568
log_z,
571569
)
572570

@@ -579,21 +577,20 @@ def _cross_entropy_with_logits_bwd(
579577
jnp.ndarray,
580578
jnp.ndarray,
581579
jnp.ndarray,
582-
jnp.ndarray,
583580
],
584581
g: tuple[jnp.ndarray, jnp.ndarray],
585-
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
582+
) -> tuple[jnp.ndarray, None, None]:
586583
"""Backward-mode of `cross_entropy_with_logits`."""
587584
g = g[0] # Ignore z_loss component as that is only used for logging.
588-
logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res
585+
logits, targets, z_loss, exp_shifted, sum_exp, log_z = res
589586
# z-loss term adds the (2 * z_loss * log_z) factor.
590587
deriv = jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - targets
591588
g_logits = jnp.expand_dims(g, axis=-1) * deriv
592-
g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax
589+
593590
return (
594591
jnp.asarray(g_logits, logits.dtype),
595-
jnp.asarray(g_targets, targets.dtype),
596-
jnp.array(0.0),
592+
None, # we don't need gradients on targets
593+
None, # we don't need gradients on z_loss
597594
) # sets z-loss coeff gradient to 0
598595

599596

0 commit comments

Comments
 (0)