diff --git a/src/liger_kernel/transformers/model/loss_utils.py b/src/liger_kernel/transformers/model/loss_utils.py index 508b3583d..34fbfc7c8 100644 --- a/src/liger_kernel/transformers/model/loss_utils.py +++ b/src/liger_kernel/transformers/model/loss_utils.py @@ -39,7 +39,10 @@ def fixed_fused_linear_cross_entropy( return_predicted_tokens: bool = False, **kwargs, ): - reduction = "sum" if num_items_in_batch is not None else "mean" + reduction = kwargs.pop("reduction", None) + if reduction is None: + reduction = "sum" if num_items_in_batch is not None else "mean" + result = F.liger_fused_linear_cross_entropy( hidden_states, lm_head_weight,