Skip to content

Commit 0248d72

Browse files
committed
ignore luspo for token level
1 parent 11f5710 commit 0248d72

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

test/transformers/test_grpo_loss.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,8 @@ def trl_reference_grpo_loss(
561561
)
562562
def test_grpo_loss_vs_trl(B, T, V, beta, loss_type, importance_sampling_level, delta):
563563
"""Test that triton_grpo_loss matches TRL's exact implementation."""
564+
if importance_sampling_level == "token" and loss_type == "luspo":
565+
pytest.skip("Token-level importance sampling is not supported for loss_type='luspo'")
564566
torch.manual_seed(42)
565567

566568
logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32)
@@ -769,6 +771,8 @@ def torch_grpo_loss_with_vllm_is(
769771
)
770772
def test_grpo_loss_with_vllm_is_ratio_reduced(B, T, V, beta, loss_type, importance_sampling_level):
771773
"""Test that triton_grpo_loss with vllm_is_ratio matches TRL's behavior with reduce=True."""
774+
if importance_sampling_level == "token" and loss_type == "luspo":
775+
pytest.skip("Token-level importance sampling is not supported for loss_type='luspo'")
772776
torch.manual_seed(42)
773777

774778
logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32)

0 commit comments

Comments
 (0)