We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 3a84ebc commit 12f00a1Copy full SHA for 12f00a1
1 file changed
test/chunked_loss/test_grpo_loss.py
@@ -397,15 +397,6 @@ def test_correctness(
397
if delta is not None and loss_type in ("cispo", "sapo"):
398
pytest.skip(f"delta is not supported for loss_type='{loss_type}'")
399
400
- # LUSPO's formula multiplies per_token_loss by seq_lens, amplifying torch.compile
401
- # numerical differences by O(T). Relax tolerances to account for this amplification.
402
- if loss_type == "luspo":
403
- if dtype == torch.bfloat16:
404
- atol = max(atol, 2.0)
405
- rtol = max(rtol, 8.0)
406
- else:
407
- atol = max(atol, 1e-4)
408
- rtol = max(rtol, 5e-3)
409
# Reset torch compiler cache for each parameter of the test case
410
torch.compiler.reset()
411
max_completion_length = T if loss_type == "dr_grpo" else None
0 commit comments