Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d583ac5
fix: update ensure_contiguous decorator for functorch compatibility
roycho96 Apr 3, 2026
38d1951
fix: add setup_context for torch.func compatibility for cross_entropy
roycho96 Apr 3, 2026
c1ea9a5
fix: add setup_context for torch.func compatibility for geglu
roycho96 Apr 3, 2026
2be18a3
fix: add setup_context for torch.func compatibility for swiglu
roycho96 Apr 3, 2026
1d11347
fix: add setup_context for torch.func compatibility for relu_squared
roycho96 Apr 3, 2026
44a5841
fix: add setup_context for torch.func compatibility for rms_norm
roycho96 Apr 3, 2026
4605610
fix: add setup_context for torch.func compatibility for layer_norm
roycho96 Apr 3, 2026
4d9d0b5
fix: add setup_context for torch.func compatibility for group_norm
roycho96 Apr 3, 2026
d8fc0be
fix: add setup_context for torch.func compatibility for poly_norm
roycho96 Apr 3, 2026
86daf11
fix: add setup_context for torch.func compatibility for fused_add_rms…
roycho96 Apr 3, 2026
e10136a
fix: add setup_context for torch.func compatibility for dyt
roycho96 Apr 3, 2026
4b60039
fix: add setup_context for torch.func compatibility for kl_div
roycho96 Apr 3, 2026
5920201
fix: add setup_context for torch.func compatibility for jsd
roycho96 Apr 3, 2026
f6034dc
fix: add setup_context for torch.func compatibility for tvd
roycho96 Apr 3, 2026
a9f9a87
fix: add setup_context for torch.func compatibility for grpo_loss
roycho96 Apr 3, 2026
00fb8bf
fix: add setup_context for torch.func compatibility for rope
roycho96 Apr 3, 2026
bf55f24
fix: add setup_context for torch.func compatibility for qwen2vl_mrope
roycho96 Apr 3, 2026
7e77066
fix: add setup_context for torch.func compatibility for llama4_rope
roycho96 Apr 3, 2026
f42f07d
fix: add setup_context for torch.func compatibility for softmax
roycho96 Apr 3, 2026
21c14d6
fix: add setup_context for torch.func compatibility for sparsemax
roycho96 Apr 3, 2026
a983686
fix: add setup_context for torch.func compatibility for multi_token_a…
roycho96 Apr 3, 2026
dd8725a
fix: add setup_context for torch.func compatibility for fused_neighbo…
roycho96 Apr 3, 2026
e8f963d
fix: add setup_context for torch.func compatibility for attn_res
roycho96 Apr 3, 2026
f171ad7
fix: add setup_context for torch.func compatibility for mhc
roycho96 Apr 3, 2026
f545681
fix: add setup_context for torch.func compatibility for tiled_mlp
roycho96 Apr 3, 2026
6b9ff44
fix: add setup_context for torch.func compatibility for embedding
roycho96 Apr 3, 2026
2c9cfe5
fix: add setup_context for torch.func compatibility for chunked_loss …
roycho96 Apr 3, 2026
be9b705
fix: update chunked_loss subclasses for torch.func compatibility
roycho96 Apr 3, 2026
bd95917
fix: add setup_context for torch.func compatibility for ascend backends
roycho96 Apr 3, 2026
72f6713
Merge branch 'linkedin:main' into functorch-support
roycho96 Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/scripts/benchmark_attn_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _setup_attn_res(input: SingleBenchmarkRunInput):
eps = cfg.get("eps", 1e-6)

if input.kernel_provider == "liger":
fn = lambda: LigerAttnResFunction.apply(V, w_query, w_norm, eps)
fn = lambda: LigerAttnResFunction.apply(V, w_query, w_norm, eps)[0]
elif input.kernel_provider == "pytorch":
from test.transformers.test_attn_res import pytorch_attn_res

Expand Down
8 changes: 5 additions & 3 deletions src/liger_kernel/chunked_loss/cosine_similarity_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def distillation_loss_fn(
@classmethod
def forward(
cls,
ctx,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
Expand All @@ -54,7 +53,6 @@ def forward(
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
return super().forward(
cls=cls,
ctx=ctx,
student_input=student_input,
student_weight=student_weight,
teacher_input=teacher_input,
Expand Down Expand Up @@ -123,7 +121,7 @@ def forward(
student_bias: torch.Tensor = None,
teacher_bias: torch.Tensor = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
return LigerFusedLinearCosineSimilarityFunction.apply(
result = LigerFusedLinearCosineSimilarityFunction.apply(
student_input,
student_weight,
teacher_input,
Expand All @@ -140,3 +138,7 @@ def forward(
self.chunk_size,
self.return_soft_hard_loss,
)
# Return only loss (and optionally soft/hard losses), not the grad tensors
if self.return_soft_hard_loss:
return result[0], result[1], result[2]
return result[0]
6 changes: 3 additions & 3 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, labe
@classmethod
def forward(
cls,
ctx,
_input,
weight,
target,
Expand Down Expand Up @@ -76,7 +75,6 @@ def forward(
"""
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
target=target,
Expand Down Expand Up @@ -141,7 +139,7 @@ def forward(
target,
bias=None,
):
return LigerFusedLinearCPOFunction.apply(
result = LigerFusedLinearCPOFunction.apply(
_input,
lin_weight,
target,
Expand All @@ -155,3 +153,5 @@ def forward(
self.average_log_prob,
self.chunk_size,
)
# Return only loss and aux outputs, not the grad tensors
return result[0], result[1]
6 changes: 3 additions & 3 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def preference_loss_fn(
@classmethod
def forward(
cls,
ctx,
_input,
weight,
target,
Expand Down Expand Up @@ -137,7 +136,6 @@ def forward(
"""
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
target=target,
Expand Down Expand Up @@ -210,7 +208,7 @@ def forward(
ref_weight=None,
ref_bias=None,
):
return LigerFusedLinearDPOFunction.apply(
result = LigerFusedLinearDPOFunction.apply(
_input,
lin_weight,
target,
Expand All @@ -227,3 +225,5 @@ def forward(
self.chunk_size,
self.loss_type,
)
# Return only loss and aux outputs, not the grad tensors
return result[0], result[1]
19 changes: 10 additions & 9 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _compute_loss(
@staticmethod
def forward(
cls,
ctx,
student_input,
student_weight,
teacher_input,
Expand Down Expand Up @@ -279,14 +278,16 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
grad_input = accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk)
grad_inputs.append(grad_input)

ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
if return_soft_hard_loss:
return loss_acc, soft_loss_acc, hard_loss_acc
return loss_acc
# Return grad tensors as extra outputs for setup_context
grad_input_cat = torch.cat(grad_inputs, dim=0)
# Always return 6 values for consistent setup_context unpacking
# When return_soft_hard_loss=False, soft_loss_acc and hard_loss_acc are None
return loss_acc, soft_loss_acc, hard_loss_acc, grad_input_cat, grad_weight, grad_bias

@staticmethod
def setup_context(ctx, inputs, output):
_, _, _, grad_input_cat, grad_weight, grad_bias = output
ctx.save_for_backward(grad_input_cat, grad_weight, grad_bias)

@staticmethod
def backward(ctx, grad_output, *args):
Expand Down
13 changes: 7 additions & 6 deletions src/liger_kernel/chunked_loss/fused_linear_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def ppo_loss_fn(*args, **kwargs):
@staticmethod
def forward(
cls,
ctx,
_input,
weight,
selected_token_ids,
Expand Down Expand Up @@ -50,7 +49,6 @@ def forward(

Args:
cls: The class
ctx: Context for backward
_input: Input tensor
weight: Weight tensor
selected_token_ids: Selected token ids tensor
Expand Down Expand Up @@ -271,9 +269,6 @@ def accumulate_chunk(
# Combine gradients
grad_input = torch.cat(grad_inputs, dim=0)

# Save for backward
ctx.save_for_backward(grad_input, grad_weight, grad_bias)

# Finalize metrics
final_metrics = []
for metric in aggregated_metrics:
Expand All @@ -282,7 +277,13 @@ def accumulate_chunk(
else:
final_metrics.append(metric)

return loss_acc, tuple(final_metrics)
# Return grad tensors as extra outputs for setup_context
return loss_acc, tuple(final_metrics), grad_input, grad_weight, grad_bias

@staticmethod
def setup_context(ctx, inputs, output):
loss_acc, final_metrics, grad_input, grad_weight, grad_bias = output
ctx.save_for_backward(grad_input, grad_weight, grad_bias)

@staticmethod
def _compute_dapo_normalizer(attention_mask):
Expand Down
15 changes: 8 additions & 7 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def preference_loss_fn(*args, **kwargs):
@staticmethod
def forward(
cls,
ctx,
_input,
weight,
target,
Expand Down Expand Up @@ -249,19 +248,21 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll
if isinstance(aux, list):
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)

ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
# Return grad tensors as extra outputs for setup_context
grad_input_cat = torch.cat(grad_inputs, dim=0)
return_vars = (
policy_chosen_logps,
policy_rejected_logps,
policy_chosen_logits_mean,
policy_rejected_logits_mean,
policy_nll_loss,
)
return loss_acc, (*return_vars, *aggregated_aux_outputs)
return loss_acc, (*return_vars, *aggregated_aux_outputs), grad_input_cat, grad_weight, grad_bias

@staticmethod
def setup_context(ctx, inputs, output):
loss_acc, aux_tuple, grad_input_cat, grad_weight, grad_bias = output
ctx.save_for_backward(grad_input_cat, grad_weight, grad_bias)

@staticmethod
def backward(ctx, *grad_output):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def preference_loss_fn(*args, **kwargs):
@staticmethod
def forward(
cls,
ctx,
_input,
weight,
target,
Expand Down Expand Up @@ -193,11 +192,8 @@ def accumulate_chunk(
if isinstance(aux, list):
aggregated_aux_outputs[i] = torch.cat(aux, dim=0)

ctx.save_for_backward(
torch.cat(grad_inputs, dim=0),
grad_weight,
grad_bias,
)
# Return grad tensors as extra outputs for setup_context
grad_input_cat = torch.cat(grad_inputs, dim=0)

return_vars = (
chosen_logps_sum,
Expand All @@ -206,7 +202,12 @@ def accumulate_chunk(
rejected_logits_sum,
)

return loss_acc, (*return_vars, *aggregated_aux_outputs)
return loss_acc, (*return_vars, *aggregated_aux_outputs), grad_input_cat, grad_weight, grad_bias

@staticmethod
def setup_context(ctx, inputs, output):
loss_acc, aux_tuple, grad_input_cat, grad_weight, grad_bias = output
ctx.save_for_backward(grad_input_cat, grad_weight, grad_bias)

@staticmethod
def backward(ctx, *grad_output):
Expand Down
6 changes: 3 additions & 3 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,6 @@ def ppo_loss_fn(
@classmethod
def forward(
cls,
ctx,
_input,
weight,
selected_token_ids,
Expand Down Expand Up @@ -290,7 +289,6 @@ def forward(

return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
selected_token_ids=selected_token_ids,
Expand Down Expand Up @@ -432,7 +430,7 @@ def forward(
ref_bias=None,
vllm_is_ratio=None,
):
return LigerFusedLinearGRPOFunction.apply(
result = LigerFusedLinearGRPOFunction.apply(
_input,
lin_weight,
selected_token_ids,
Expand Down Expand Up @@ -460,3 +458,5 @@ def forward(
self.delta,
self.use_bias_correction_kl,
)
# Return only loss and metrics, not the grad tensors
return result[0], result[1]
8 changes: 5 additions & 3 deletions src/liger_kernel/chunked_loss/jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5, target=None,
@classmethod
def forward(
cls,
ctx,
student_input: torch.Tensor,
student_weight: torch.Tensor,
teacher_input: torch.Tensor,
Expand Down Expand Up @@ -97,7 +96,6 @@ def forward(
"""
return super().forward(
cls=cls,
ctx=ctx,
student_input=student_input,
student_weight=student_weight,
teacher_input=teacher_input,
Expand Down Expand Up @@ -196,7 +194,7 @@ def forward(
If return_soft_hard_loss is False: Computed combined loss
If return_soft_hard_loss is True: Tuple of (combined_loss, soft_loss, hard_loss)
"""
return LigerFusedLinearJSDFunction.apply(
result = LigerFusedLinearJSDFunction.apply(
student_input,
student_weight,
teacher_input,
Expand All @@ -213,3 +211,7 @@ def forward(
self.chunk_size,
self.return_soft_hard_loss,
)
# Return only loss (and optionally soft/hard losses), not the grad tensors
if self.return_soft_hard_loss:
return result[0], result[1], result[2]
return result[0]
6 changes: 3 additions & 3 deletions src/liger_kernel/chunked_loss/kto_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def preference_loss_fn(
@classmethod
def forward(
cls,
ctx,
_input,
weight,
target,
Expand Down Expand Up @@ -111,7 +110,6 @@ def forward(
"""
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
target=target,
Expand Down Expand Up @@ -191,7 +189,7 @@ def forward(
ref_bias=None,
kl=None,
):
return LigerFusedLinearKTOFunction.apply(
result = LigerFusedLinearKTOFunction.apply(
_input,
lin_weight,
target,
Expand All @@ -208,3 +206,5 @@ def forward(
self.average_log_prob,
self.chunk_size,
)
# Return only loss and aux outputs, not the grad tensors
return result[0], result[1]
6 changes: 3 additions & 3 deletions src/liger_kernel/chunked_loss/orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
@classmethod
def forward(
cls,
ctx,
_input,
weight,
target,
Expand Down Expand Up @@ -75,7 +74,6 @@ def forward(
"""
return super().forward(
cls=cls,
ctx=ctx,
_input=_input,
weight=weight,
target=target,
Expand Down Expand Up @@ -130,7 +128,7 @@ def forward(
bias=None,
nll_target=None,
):
return LigerFusedLinearORPOFunction.apply(
result = LigerFusedLinearORPOFunction.apply(
_input,
lin_weight,
target,
Expand All @@ -142,3 +140,5 @@ def forward(
self.compiled,
self.chunk_size,
)
# Return only loss and aux outputs, not the grad tensors
return result[0], result[1]
Loading
Loading