Skip to content

Commit 12f00a1

Browse files
committed
remove luspo changes
1 parent 3a84ebc commit 12f00a1

1 file changed

Lines changed: 0 additions & 9 deletions

File tree

test/chunked_loss/test_grpo_loss.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -397,15 +397,6 @@ def test_correctness(
397397
if delta is not None and loss_type in ("cispo", "sapo"):
398398
pytest.skip(f"delta is not supported for loss_type='{loss_type}'")
399399

400-
# LUSPO's formula multiplies per_token_loss by seq_lens, amplifying torch.compile
401-
# numerical differences by O(T). Relax tolerances to account for this amplification.
402-
if loss_type == "luspo":
403-
if dtype == torch.bfloat16:
404-
atol = max(atol, 2.0)
405-
rtol = max(rtol, 8.0)
406-
else:
407-
atol = max(atol, 1e-4)
408-
rtol = max(rtol, 5e-3)
409400
# Reset torch compiler cache for each parameter of the test case
410401
torch.compiler.reset()
411402
max_completion_length = T if loss_type == "dr_grpo" else None

0 commit comments

Comments
 (0)