From ed03e7d508ea149d5fedbc303ff3a05d52c3ad53 Mon Sep 17 00:00:00 2001 From: Yi-Fu Wu Date: Mon, 30 Mar 2026 16:11:46 -0700 Subject: [PATCH 1/2] fix: use prompt token length for advantage group extraction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous role-based extraction (`_extract_prompt_only_messages`) broke on multi-turn prompts containing assistant messages in the conversation history — it would strip them, corrupting the prompt IDs used for advantage estimation. Replace with `extract_initial_prompt_messages()` which uses the `length` field to identify the original prompt boundary. Applied to both sync and async GRPO paths. Closes https://github.com/NVIDIA-NeMo/RL/issues/1960 Co-Authored-By: Jiaqi Zeng Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Yi-Fu Wu --- nemo_rl/algorithms/grpo.py | 92 ++++++++------ tests/unit/algorithms/test_async_utils.py | 142 +++++++++++++++++++++- 2 files changed, 197 insertions(+), 37 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index f0731aaccc..c561be97bd 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1002,6 +1002,39 @@ def _scale(reward_tensor: torch.Tensor) -> torch.Tensor: return repeated_batch +def extract_initial_prompt_messages( + message_logs: list, + original_prompt_lengths: torch.Tensor, +) -> list: + """Extract the original prompt messages from message logs using token length. + + This function correctly identifies original prompt messages even when the prompt + contains assistant messages (e.g., multi-turn conversation history). + + Args: + message_logs: List of message logs, where each log is a list of messages. + original_prompt_lengths: Tensor of original prompt token lengths per sample. + + Returns: + List of message logs containing only the original prompt messages. + """ + initial_prompt_message_logs = [] + for i, message_log in enumerate(message_logs): + initial_prompt_log = [] + cumulative_length = 0 + target_length = original_prompt_lengths[i].item() + + for message in message_log: + if cumulative_length >= target_length: + break + initial_prompt_log.append(message) + cumulative_length += len(message["token_ids"]) + + initial_prompt_message_logs.append(initial_prompt_log) + + return initial_prompt_message_logs + + def _should_use_async_rollouts(master_config: MasterConfig) -> bool: """Determine if async rollouts should be used based on the configuration. @@ -1098,28 +1131,6 @@ def _create_advantage_estimator(master_config: MasterConfig): return adv_estimator -def _extract_prompt_only_messages(message_logs: list) -> list: - """Extract only prompt messages (user/system) from message logs. - - This is used to get prompt IDs for advantage estimation, excluding - any assistant responses. - - Args: - message_logs: List of message logs, where each log is a list of messages. - - Returns: - List of message logs containing only user and system messages. - """ - prompt_only_message_logs = [] - for message_log in message_logs: - prompt_only_log = [] - for message in message_log: - if message["role"] == "user" or message["role"] == "system": - prompt_only_log.append(message) - prompt_only_message_logs.append(prompt_only_log) - return prompt_only_message_logs - - def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, @@ -1694,16 +1705,20 @@ def grpo_train( # Save baseline for logging (before deletion) baseline_for_log = baseline.clone() - # Extract prompt-only messages for advantage estimation - prompt_only_message_logs = _extract_prompt_only_messages( - repeated_batch["message_log"] + # Extract original prompt messages using the length field + # This correctly handles multi-turn prompts that contain assistant messages + initial_prompt_message_logs = extract_initial_prompt_messages( + repeated_batch["message_log"], + repeated_batch["length"], ) - prompt_batched_flat, _ = batched_message_log_to_flat_message( - prompt_only_message_logs, - pad_value_dict={"token_ids": tokenizer.pad_token_id}, + prompt_batched_flat, prompt_input_lengths = ( + batched_message_log_to_flat_message( + initial_prompt_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) ) prompt_ids_for_adv = prompt_batched_flat["token_ids"] - del prompt_only_message_logs + del initial_prompt_message_logs del prompt_batched_flat del input_ids del baseline @@ -2828,16 +2843,21 @@ def async_grpo_train( print("▶ Processing rewards...") with timer.time("reward_calculation"): - # Extract prompt-only messages for advantage estimation - prompt_only_message_logs = _extract_prompt_only_messages( - repeated_batch["message_log"] + # Extract original prompt messages using the length field + # This correctly handles multi-turn prompts that contain assistant messages + initial_prompt_message_logs = extract_initial_prompt_messages( + repeated_batch["message_log"], + repeated_batch["length"], ) - prompt_batched_flat, _ = batched_message_log_to_flat_message( - prompt_only_message_logs, - pad_value_dict={"token_ids": tokenizer.pad_token_id}, + + prompt_batched_flat, prompt_input_lengths = ( + batched_message_log_to_flat_message( + initial_prompt_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) ) prompt_ids_for_adv = prompt_batched_flat["token_ids"] - del prompt_only_message_logs + del initial_prompt_message_logs del prompt_batched_flat rewards = repeated_batch["total_reward"] diff --git a/tests/unit/algorithms/test_async_utils.py b/tests/unit/algorithms/test_async_utils.py index 3844c9f3dc..e34be28d41 100644 --- a/tests/unit/algorithms/test_async_utils.py +++ b/tests/unit/algorithms/test_async_utils.py @@ -33,7 +33,7 @@ ReplayBuffer, ) from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBufferNew -from nemo_rl.algorithms.grpo import MasterConfig +from nemo_rl.algorithms.grpo import MasterConfig, extract_initial_prompt_messages from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.environments.interfaces import ( @@ -834,3 +834,143 @@ def test_error_handling(self): assert sample_result is None ray.kill(buffer) + + +class TestPromptExtraction: + """Test cases for prompt extraction logic used in async GRPO advantage calculation. + + These tests verify that the length-based prompt extraction correctly handles + multi-turn conversation prompts where the original prompt itself contains + assistant messages (conversation history). + """ + + def test_prompt_extraction_with_multi_turn_history(self): + """Test that prompt extraction correctly handles prompts containing assistant messages. + + This tests the fix for multi-turn conversation prompts where the original prompt + from the dataset itself contains assistant messages (conversation history). + The extraction should use the length field to identify original prompt messages, + not break at the first assistant message. + """ + # Create a multi-turn prompt with assistant messages in the history + # Original prompt: user -> assistant -> user (3 messages, 15 tokens total) + original_prompt_messages = [ + {"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])}, + {"role": "assistant", "content": "4", "token_ids": torch.tensor([6, 7, 8, 9, 10])}, + {"role": "user", "content": "Now what is 3+3?", "token_ids": torch.tensor([11, 12, 13, 14, 15])}, + ] + + # Generated response (added after original prompt) + generated_message = { + "role": "assistant", + "content": "6", + "token_ids": torch.tensor([16, 17, 18]), + } + + # Full message_log after generation + full_message_log = original_prompt_messages + [generated_message] + + # Original prompt length = sum of token_ids in original prompt + original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) # 15 + + message_logs = [full_message_log] + original_prompt_lengths = torch.tensor([original_prompt_length]) + + result = extract_initial_prompt_messages(message_logs, original_prompt_lengths) + initial_prompt_log = result[0] + + # Should extract all 3 original messages, NOT break at first assistant + assert len(initial_prompt_log) == 3, ( + f"Expected 3 messages (user, assistant, user), got {len(initial_prompt_log)}. " + "The extraction should NOT break at the first assistant message when it's part of the original prompt." + ) + + assert initial_prompt_log[0]["role"] == "user" + assert initial_prompt_log[1]["role"] == "assistant" + assert initial_prompt_log[2]["role"] == "user" + assert generated_message not in initial_prompt_log + + def test_prompt_extraction_with_single_turn(self): + """Test that prompt extraction works correctly for single-turn prompts (regression test).""" + original_prompt_messages = [ + {"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])}, + ] + + generated_message = { + "role": "assistant", + "content": "4", + "token_ids": torch.tensor([6, 7, 8]), + } + + full_message_log = original_prompt_messages + [generated_message] + original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) + + result = extract_initial_prompt_messages( + [full_message_log], torch.tensor([original_prompt_length]) + ) + initial_prompt_log = result[0] + + assert len(initial_prompt_log) == 1 + assert initial_prompt_log[0]["role"] == "user" + assert generated_message not in initial_prompt_log + + def test_prompt_extraction_with_system_message(self): + """Test prompt extraction with system message included.""" + original_prompt_messages = [ + {"role": "system", "content": "You are a math tutor.", "token_ids": torch.tensor([1, 2, 3])}, + {"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([4, 5, 6, 7])}, + ] + + generated_message = { + "role": "assistant", + "content": "4", + "token_ids": torch.tensor([8, 9]), + } + + full_message_log = original_prompt_messages + [generated_message] + original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) + + result = extract_initial_prompt_messages( + [full_message_log], torch.tensor([original_prompt_length]) + ) + initial_prompt_log = result[0] + + assert len(initial_prompt_log) == 2 + assert initial_prompt_log[0]["role"] == "system" + assert initial_prompt_log[1]["role"] == "user" + assert generated_message not in initial_prompt_log + + def test_prompt_extraction_complex_multi_turn(self): + """Test prompt extraction with complex multi-turn history (multiple assistant turns).""" + original_prompt_messages = [ + {"role": "system", "content": "Math tutor", "token_ids": torch.tensor([1, 2])}, + {"role": "user", "content": "2+2?", "token_ids": torch.tensor([3, 4])}, + {"role": "assistant", "content": "4", "token_ids": torch.tensor([5, 6])}, + {"role": "user", "content": "3+3?", "token_ids": torch.tensor([7, 8])}, + {"role": "assistant", "content": "6", "token_ids": torch.tensor([9, 10])}, + {"role": "user", "content": "4+4?", "token_ids": torch.tensor([11, 12])}, + ] + + generated_message = { + "role": "assistant", + "content": "8", + "token_ids": torch.tensor([13, 14]), + } + + full_message_log = original_prompt_messages + [generated_message] + original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) + + result = extract_initial_prompt_messages( + [full_message_log], torch.tensor([original_prompt_length]) + ) + initial_prompt_log = result[0] + + assert len(initial_prompt_log) == 6, ( + f"Expected 6 messages, got {len(initial_prompt_log)}. " + "All history messages should be included in the prompt." + ) + + expected_roles = ["system", "user", "assistant", "user", "assistant", "user"] + actual_roles = [m["role"] for m in initial_prompt_log] + assert actual_roles == expected_roles + assert generated_message not in initial_prompt_log From 20adf67cdf9432496738e0e6b80e9ae78efbd04f Mon Sep 17 00:00:00 2001 From: Anish Mahishi Date: Thu, 21 May 2026 16:03:41 -0700 Subject: [PATCH 2/2] fix: only train on generated assistant turns Signed-off-by: Anish Mahishi --- nemo_rl/algorithms/grpo.py | 69 ++++++----- tests/unit/algorithms/test_async_utils.py | 133 ++++++++++++++++++++-- 2 files changed, 159 insertions(+), 43 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index c561be97bd..7582a5676d 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -53,7 +53,7 @@ from nemo_rl.data.collate_fn import rl_collate_fn from nemo_rl.data.dataloader import MultipleDataloaderWrapper from nemo_rl.data.datasets import AllTaskProcessedDataset -from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType, VLMMessageLogType from nemo_rl.data.llm_message_utils import ( batched_message_log_to_flat_message, get_keys_from_message_log, @@ -1035,6 +1035,37 @@ def extract_initial_prompt_messages( return initial_prompt_message_logs +def add_grpo_token_loss_masks_and_generation_logprobs( + message_logs: list[LLMMessageLogType | VLMMessageLogType], +) -> None: + """Add GRPO loss masks and ensure generation logprobs exist in message logs. + + Assistant messages can be part of the original multi-turn prompt history. Only + generated assistant messages have generation_logprobs, so use that field as the + trainable-token marker. This function mutates each message in-place by adding a + token_loss_mask and, when missing, a zero-valued generation_logprobs tensor. + + Args: + message_logs: Batch of tokenized message logs. Each message must contain a + ``role`` and ``token_ids`` field. Messages that already contain + ``generation_logprobs`` are treated as rollout-generated messages. + """ + for message_log in message_logs: + for message in message_log: + role = cast(str, message["role"]) + token_ids = cast(torch.Tensor, message["token_ids"]) + + if role == "assistant" and "generation_logprobs" in message: + message["token_loss_mask"] = torch.ones_like(token_ids) + else: + message["token_loss_mask"] = torch.zeros_like(token_ids) + + if "generation_logprobs" not in message: + message["generation_logprobs"] = torch.zeros_like( + token_ids, dtype=torch.float32 + ) + + def _should_use_async_rollouts(master_config: MasterConfig) -> bool: """Determine if async rollouts should be used based on the configuration. @@ -1735,21 +1766,9 @@ def grpo_train( loss_multiplier[truncated] = 0 repeated_batch["loss_multiplier"] = loss_multiplier - # Add loss mask to each message in LLMMessageLogType - for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): - if message["role"] == "assistant": - message["token_loss_mask"] = torch.ones_like( - message["token_ids"] - ) - else: - message["token_loss_mask"] = torch.zeros_like( - message["token_ids"] - ) - if "generation_logprobs" not in message: - message["generation_logprobs"] = torch.zeros_like( - message["token_ids"], dtype=torch.float32 - ) + add_grpo_token_loss_masks_and_generation_logprobs( + repeated_batch["message_log"] + ) # Convert updated LLMMessageLogType to FlatMessagesType for training flat_messages, input_lengths = batched_message_log_to_flat_message( @@ -2868,21 +2887,9 @@ def async_grpo_train( # Prepare training data (same as sync version) with timer.time("data_processing"): - # Add loss mask to each message - for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): - if message["role"] == "assistant": - message["token_loss_mask"] = torch.ones_like( - message["token_ids"] - ) - else: - message["token_loss_mask"] = torch.zeros_like( - message["token_ids"] - ) - if "generation_logprobs" not in message: - message["generation_logprobs"] = torch.zeros_like( - message["token_ids"], dtype=torch.float32 - ) + add_grpo_token_loss_masks_and_generation_logprobs( + repeated_batch["message_log"] + ) # Convert to flat format for training flat_messages, input_lengths = batched_message_log_to_flat_message( diff --git a/tests/unit/algorithms/test_async_utils.py b/tests/unit/algorithms/test_async_utils.py index e34be28d41..32c47cfa60 100644 --- a/tests/unit/algorithms/test_async_utils.py +++ b/tests/unit/algorithms/test_async_utils.py @@ -33,7 +33,11 @@ ReplayBuffer, ) from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBufferNew -from nemo_rl.algorithms.grpo import MasterConfig, extract_initial_prompt_messages +from nemo_rl.algorithms.grpo import ( + MasterConfig, + add_grpo_token_loss_masks_and_generation_logprobs, + extract_initial_prompt_messages, +) from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.environments.interfaces import ( @@ -855,9 +859,21 @@ def test_prompt_extraction_with_multi_turn_history(self): # Create a multi-turn prompt with assistant messages in the history # Original prompt: user -> assistant -> user (3 messages, 15 tokens total) original_prompt_messages = [ - {"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])}, - {"role": "assistant", "content": "4", "token_ids": torch.tensor([6, 7, 8, 9, 10])}, - {"role": "user", "content": "Now what is 3+3?", "token_ids": torch.tensor([11, 12, 13, 14, 15])}, + { + "role": "user", + "content": "What is 2+2?", + "token_ids": torch.tensor([1, 2, 3, 4, 5]), + }, + { + "role": "assistant", + "content": "4", + "token_ids": torch.tensor([6, 7, 8, 9, 10]), + }, + { + "role": "user", + "content": "Now what is 3+3?", + "token_ids": torch.tensor([11, 12, 13, 14, 15]), + }, ] # Generated response (added after original prompt) @@ -871,7 +887,9 @@ def test_prompt_extraction_with_multi_turn_history(self): full_message_log = original_prompt_messages + [generated_message] # Original prompt length = sum of token_ids in original prompt - original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) # 15 + original_prompt_length = sum( + len(m["token_ids"]) for m in original_prompt_messages + ) # 15 message_logs = [full_message_log] original_prompt_lengths = torch.tensor([original_prompt_length]) @@ -893,7 +911,11 @@ def test_prompt_extraction_with_multi_turn_history(self): def test_prompt_extraction_with_single_turn(self): """Test that prompt extraction works correctly for single-turn prompts (regression test).""" original_prompt_messages = [ - {"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])}, + { + "role": "user", + "content": "What is 2+2?", + "token_ids": torch.tensor([1, 2, 3, 4, 5]), + }, ] generated_message = { @@ -903,7 +925,9 @@ def test_prompt_extraction_with_single_turn(self): } full_message_log = original_prompt_messages + [generated_message] - original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) + original_prompt_length = sum( + len(m["token_ids"]) for m in original_prompt_messages + ) result = extract_initial_prompt_messages( [full_message_log], torch.tensor([original_prompt_length]) @@ -917,8 +941,16 @@ def test_prompt_extraction_with_single_turn(self): def test_prompt_extraction_with_system_message(self): """Test prompt extraction with system message included.""" original_prompt_messages = [ - {"role": "system", "content": "You are a math tutor.", "token_ids": torch.tensor([1, 2, 3])}, - {"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([4, 5, 6, 7])}, + { + "role": "system", + "content": "You are a math tutor.", + "token_ids": torch.tensor([1, 2, 3]), + }, + { + "role": "user", + "content": "What is 2+2?", + "token_ids": torch.tensor([4, 5, 6, 7]), + }, ] generated_message = { @@ -928,7 +960,9 @@ def test_prompt_extraction_with_system_message(self): } full_message_log = original_prompt_messages + [generated_message] - original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) + original_prompt_length = sum( + len(m["token_ids"]) for m in original_prompt_messages + ) result = extract_initial_prompt_messages( [full_message_log], torch.tensor([original_prompt_length]) @@ -943,7 +977,11 @@ def test_prompt_extraction_with_system_message(self): def test_prompt_extraction_complex_multi_turn(self): """Test prompt extraction with complex multi-turn history (multiple assistant turns).""" original_prompt_messages = [ - {"role": "system", "content": "Math tutor", "token_ids": torch.tensor([1, 2])}, + { + "role": "system", + "content": "Math tutor", + "token_ids": torch.tensor([1, 2]), + }, {"role": "user", "content": "2+2?", "token_ids": torch.tensor([3, 4])}, {"role": "assistant", "content": "4", "token_ids": torch.tensor([5, 6])}, {"role": "user", "content": "3+3?", "token_ids": torch.tensor([7, 8])}, @@ -958,7 +996,9 @@ def test_prompt_extraction_complex_multi_turn(self): } full_message_log = original_prompt_messages + [generated_message] - original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) + original_prompt_length = sum( + len(m["token_ids"]) for m in original_prompt_messages + ) result = extract_initial_prompt_messages( [full_message_log], torch.tensor([original_prompt_length]) @@ -974,3 +1014,72 @@ def test_prompt_extraction_complex_multi_turn(self): actual_roles = [m["role"] for m in initial_prompt_log] assert actual_roles == expected_roles assert generated_message not in initial_prompt_log + + def test_grpo_loss_mask_excludes_assistant_prompt_history(self): + """Test that assistant messages in the original prompt are not trained on.""" + original_prompt_messages = [ + { + "role": "user", + "content": "What is 2+2?", + "token_ids": torch.tensor([1, 2]), + }, + { + "role": "assistant", + "content": "4", + "token_ids": torch.tensor([3, 4]), + }, + { + "role": "user", + "content": "Now what is 3+3?", + "token_ids": torch.tensor([5, 6]), + }, + ] + generated_logprobs = torch.tensor([0.1, 0.2]) + generated_message = { + "role": "assistant", + "content": "6", + "token_ids": torch.tensor([7, 8]), + "generation_logprobs": generated_logprobs, + } + full_message_log = original_prompt_messages + [generated_message] + + add_grpo_token_loss_masks_and_generation_logprobs([full_message_log]) + + assert torch.equal(full_message_log[0]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal(full_message_log[1]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal(full_message_log[2]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal(full_message_log[3]["token_loss_mask"], torch.tensor([1, 1])) + assert torch.equal( + full_message_log[3]["generation_logprobs"], generated_logprobs + ) + + def test_grpo_loss_mask_uses_generation_logprobs_marker(self): + """Test that only assistant messages with generation logprobs are trainable.""" + message_log = [ + { + "role": "assistant", + "content": "prompt history", + "token_ids": torch.tensor([1, 2]), + }, + { + "role": "user", + "content": "next question", + "token_ids": torch.tensor([3, 4]), + "generation_logprobs": torch.tensor([0.3, 0.4]), + }, + { + "role": "assistant", + "content": "generated response", + "token_ids": torch.tensor([5, 6]), + "generation_logprobs": torch.tensor([0.5, 0.6]), + }, + ] + + add_grpo_token_loss_masks_and_generation_logprobs([message_log]) + + assert torch.equal(message_log[0]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal( + message_log[0]["generation_logprobs"], torch.tensor([0.0, 0.0]) + ) + assert torch.equal(message_log[1]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal(message_log[2]["token_loss_mask"], torch.tensor([1, 1]))