Skip to content

fix: use prompt token length for advantage group extraction and fix token mask#2176

Open
yfw wants to merge 2 commits into
mainfrom
yifu/fix-prompt-extraction-multi-turn
Open

fix: use prompt token length for advantage group extraction and fix token mask#2176
yfw wants to merge 2 commits into
mainfrom
yifu/fix-prompt-extraction-multi-turn

Conversation

@yfw
Copy link
Copy Markdown
Contributor

@yfw yfw commented Mar 30, 2026

This PR fixes two multi-turn GRPO training issues:

  1. The previous role-based extraction (_extract_prompt_only_messages) broke on multi-turn prompts containing assistant messages in the conversation history — it would strip them, corrupting the prompt IDs used for advantage estimation.
    Replace with extract_initial_prompt_messages() which uses the length field to identify the original prompt boundary. Applied to both sync and async GRPO paths.

  2. GRPO token loss masks previously unmasked every message with role == "assistant". In multi-turn data, assistant messages can be part of the prompt history, not generated rollout output, so those tokens should not contribute to the policy loss. This PR updates masking so only assistant messages produced by generation, identified by existing generation_logprobs, are trainable. Missing generation_logprobs are still filled with zeros for downstream tensorization.

Closes #1960 and #1956

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

@yfw yfw requested a review from a team as a code owner March 30, 2026 23:17
@yfw yfw added the super-v3 label Mar 30, 2026
@yfw yfw requested a review from a team as a code owner March 30, 2026 23:17
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Mar 30, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

yuki-97
yuki-97 previously approved these changes Mar 31, 2026
@yuki-97 yuki-97 added the CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) label Mar 31, 2026
@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented Mar 31, 2026

/ok to test 628a248

@macandro96 macandro96 force-pushed the yifu/fix-prompt-extraction-multi-turn branch from 628a248 to 7d96ad3 Compare May 21, 2026 23:07
@macandro96 macandro96 requested a review from yuki-97 May 21, 2026 23:18
@macandro96 macandro96 force-pushed the yifu/fix-prompt-extraction-multi-turn branch from 7d96ad3 to d5961e5 Compare May 21, 2026 23:20
@macandro96
Copy link
Copy Markdown

/ok to test 7d96ad3

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 21, 2026

/ok to test 7d96ad3

@macandro96, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@macandro96
Copy link
Copy Markdown

/ok to test d5961e5

@macandro96 macandro96 changed the title fix: use prompt token length for advantage group extraction fix: use prompt token length for advantage group extraction and fix token mask May 21, 2026
yfw and others added 2 commits May 22, 2026 10:30
The previous role-based extraction (`_extract_prompt_only_messages`)
broke on multi-turn prompts containing assistant messages in the
conversation history — it would strip them, corrupting the prompt IDs
used for advantage estimation.

Replace with `extract_initial_prompt_messages()` which uses the
`length` field to identify the original prompt boundary. Applied to
both sync and async GRPO paths.

Closes #1960

Co-Authored-By: Jiaqi Zeng <jiaqiz@nvidia.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
Signed-off-by: Anish Mahishi <amahishi@cw-dfw-cs-001-vscode-02.cm.cluster>
@macandro96 macandro96 force-pushed the yifu/fix-prompt-extraction-multi-turn branch from d5961e5 to 20adf67 Compare May 22, 2026 14:30
role = cast(str, message["role"])
token_ids = cast(torch.Tensor, message["token_ids"])

if role == "assistant" and "generation_logprobs" in message:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we didn't check generation_logprobs before, is there a reason we need to check it now?

Suggested change
if role == "assistant" and "generation_logprobs" in message:
if role == "assistant":

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - I think we want to set token_mask = 1 for assistant part of messages where generation logprobs are available. If its not available, it means - that assistant text was part of input prompt for a multi-turn conversation and should be excluded while computing gradients.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a separate commit for super. I combined it into this PR as it was related.

prompt_only_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},

prompt_batched_flat, prompt_input_lengths = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: looks prompt_input_lengths is never used.

Suggested change
prompt_batched_flat, prompt_input_lengths = (
prompt_batched_flat, _ = (

prompt_batched_flat, _ = batched_message_log_to_flat_message(
prompt_only_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},
prompt_batched_flat, prompt_input_lengths = (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: same as above

Suggested change
prompt_batched_flat, prompt_input_lengths = (
prompt_batched_flat, _ = (

@yuki-97
Copy link
Copy Markdown
Contributor

yuki-97 commented May 22, 2026

@yfw @HeyyyyyyG could you help to take a review as well?

@yuki-97 yuki-97 requested a review from HeyyyyyyG May 22, 2026 14:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:Lfast Runs a fast test suite and re-use nightly `main` container (but sync dependencies to PRs version) super-v3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[super-pr] Use prompt length to find groups for advantage calculation

3 participants