Skip to content

Commit 11f5710

Browse files
committed
luspo is not valid for token level
1 parent 12f00a1 commit 11f5710

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

test/chunked_loss/test_grpo_loss.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,8 @@ def test_correctness(
394394
):
395395
if importance_sampling_level == "sequence" and loss_type in ("cispo", "sapo"):
396396
pytest.skip(f"Sequence-level importance sampling is not supported for loss_type='{loss_type}'")
397+
if importance_sampling_level == "token" and loss_type == "luspo":
398+
pytest.skip("Token-level importance sampling is not supported for loss_type='luspo'")
397399
if delta is not None and loss_type in ("cispo", "sapo"):
398400
pytest.skip(f"delta is not supported for loss_type='{loss_type}'")
399401

@@ -622,6 +624,8 @@ def test_correctness_with_bias_correction_kl(loss_type, dtype, atol, rtol):
622624
@pytest.mark.parametrize("beta", [0.0, 0.1])
623625
def test_correctness_with_vllm_is_ratio(loss_type, beta):
624626
"""Test vllm_is_ratio correctness against torch reference, and 1D/2D shape equivalence."""
627+
if loss_type == "luspo":
628+
pytest.skip("Token-level importance sampling is not supported for loss_type='luspo'")
625629
torch.compiler.reset()
626630
B, T, H, V = 4, 32, 64, 128
627631
dtype = torch.float32

0 commit comments

Comments
 (0)