fix: guard fp32 lm-head logits to contiguous to avoid vLLM NaN#2710
Open
mvanhorn wants to merge 1 commit into
Open
fix: guard fp32 lm-head logits to contiguous to avoid vLLM NaN#2710mvanhorn wants to merge 1 commit into
mvanhorn wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The fp32 lm-head path in
src/prime_rl/inference/patches.pyslices the padded vocab dimension withlogits[..., : self.org_vocab_size], which returns a non-contiguous view whenpadded_vocab > org_vocab_size. Adding.contiguous()after the slice makes the physical row stride equalorg_vocab_size, so vLLM's Triton top-k/top-p kernel reads the correct rows.Why this matters
Issue #2497 reported NaN log-probs on Olmo3 (org vocab 100278, padded 100288). vLLM's native Triton top-k/top-p kernel indexes rows as
row_id * VOCAB_SIZErather than bystride(0), so against the non-contiguous slice it read the wrong physical row, could mask a logical row to all-inf, andprocessed_logprobsthen computedlog_softmax(all -inf) = NaN. An upstream vLLM kernel fix was discussed but the issue remains open, and the reporter asked prime-rl to guard at this boundary since other logits processors can also produce non-contiguous views. The merged PR #2506 (padded_input_scrub) covers a different padded-decode-input path, not this lm-head slice. This fix keeps the fp32 dtype and math unchanged and only normalizes memory layout.Testing
Covered by the new
tests/unit/inference/test_fp32_lmhead_contiguous.py: a sliced padded-vocab tensor is contiguous with stride(org_vocab_size, 1)after the patch, the no-padding case is a no-op, per-row argmax/top values are preserved, and a synthetic top-p selection over the guarded logits no longer yields an all--infrow / NaN log-softmax. Full suite runs in CI.Fixes #2497
Note
Low Risk
Narrow layout fix at the lm-head boundary with unit tests; fp32 math and slice semantics unchanged aside from memory layout.
Overview
Fixes NaN log-probs when the fp32 lm-head path trims logits from a padded vocabulary to
org_vocab_size.The patch replaces a bare slice with
_trim_logits_to_org_vocab, which keeps the same values but calls.contiguous()so each row’s stride matches the logical vocab width. vLLM’s Triton top-k/top-p path assumes contiguous rows; a non-contiguous view (common whenpadded_vocab > org_vocab) could read the wrong memory and mask a row to all-inf, yielding NaNlog_softmax.New unit tests in
test_fp32_lmhead_contiguous.pylock in layout (stride/contiguity), unchanged argmax/top-k, and that guarded logits avoid the all--inf/ NaN failure mode under a synthetic top-p kernel.Reviewed by Cursor Bugbot for commit a138b11. Bugbot is set up for automated code reviews on this repo. Configure here.