diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 22b6373cf..dc574acd6 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -219,11 +219,6 @@ def fused_linear_cross_entropy_forward( alpha=1.0, ) - # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now. - # if reduction == "none": - # loss = loss_1d - # z_loss = z_loss_1d if return_z_loss else None - if reduction == "none": # Return per-token losses loss = loss_1d @@ -245,6 +240,8 @@ def fused_linear_cross_entropy_forward( def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): + # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now. + assert grad_output.ndim == 0, 'Backward unsupported for reduction="none"' # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place