diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index f0731aaccc..7582a5676d 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -53,7 +53,7 @@ from nemo_rl.data.collate_fn import rl_collate_fn from nemo_rl.data.dataloader import MultipleDataloaderWrapper from nemo_rl.data.datasets import AllTaskProcessedDataset -from nemo_rl.data.interfaces import DatumSpec +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType, VLMMessageLogType from nemo_rl.data.llm_message_utils import ( batched_message_log_to_flat_message, get_keys_from_message_log, @@ -1002,6 +1002,70 @@ def _scale(reward_tensor: torch.Tensor) -> torch.Tensor: return repeated_batch +def extract_initial_prompt_messages( + message_logs: list, + original_prompt_lengths: torch.Tensor, +) -> list: + """Extract the original prompt messages from message logs using token length. + + This function correctly identifies original prompt messages even when the prompt + contains assistant messages (e.g., multi-turn conversation history). + + Args: + message_logs: List of message logs, where each log is a list of messages. + original_prompt_lengths: Tensor of original prompt token lengths per sample. + + Returns: + List of message logs containing only the original prompt messages. + """ + initial_prompt_message_logs = [] + for i, message_log in enumerate(message_logs): + initial_prompt_log = [] + cumulative_length = 0 + target_length = original_prompt_lengths[i].item() + + for message in message_log: + if cumulative_length >= target_length: + break + initial_prompt_log.append(message) + cumulative_length += len(message["token_ids"]) + + initial_prompt_message_logs.append(initial_prompt_log) + + return initial_prompt_message_logs + + +def add_grpo_token_loss_masks_and_generation_logprobs( + message_logs: list[LLMMessageLogType | VLMMessageLogType], +) -> None: + """Add GRPO loss masks and ensure generation logprobs exist in message logs. + + Assistant messages can be part of the original multi-turn prompt history. Only + generated assistant messages have generation_logprobs, so use that field as the + trainable-token marker. This function mutates each message in-place by adding a + token_loss_mask and, when missing, a zero-valued generation_logprobs tensor. + + Args: + message_logs: Batch of tokenized message logs. Each message must contain a + ``role`` and ``token_ids`` field. Messages that already contain + ``generation_logprobs`` are treated as rollout-generated messages. + """ + for message_log in message_logs: + for message in message_log: + role = cast(str, message["role"]) + token_ids = cast(torch.Tensor, message["token_ids"]) + + if role == "assistant" and "generation_logprobs" in message: + message["token_loss_mask"] = torch.ones_like(token_ids) + else: + message["token_loss_mask"] = torch.zeros_like(token_ids) + + if "generation_logprobs" not in message: + message["generation_logprobs"] = torch.zeros_like( + token_ids, dtype=torch.float32 + ) + + def _should_use_async_rollouts(master_config: MasterConfig) -> bool: """Determine if async rollouts should be used based on the configuration. @@ -1098,28 +1162,6 @@ def _create_advantage_estimator(master_config: MasterConfig): return adv_estimator -def _extract_prompt_only_messages(message_logs: list) -> list: - """Extract only prompt messages (user/system) from message logs. - - This is used to get prompt IDs for advantage estimation, excluding - any assistant responses. - - Args: - message_logs: List of message logs, where each log is a list of messages. - - Returns: - List of message logs containing only user and system messages. - """ - prompt_only_message_logs = [] - for message_log in message_logs: - prompt_only_log = [] - for message in message_log: - if message["role"] == "user" or message["role"] == "system": - prompt_only_log.append(message) - prompt_only_message_logs.append(prompt_only_log) - return prompt_only_message_logs - - def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, @@ -1694,16 +1736,20 @@ def grpo_train( # Save baseline for logging (before deletion) baseline_for_log = baseline.clone() - # Extract prompt-only messages for advantage estimation - prompt_only_message_logs = _extract_prompt_only_messages( - repeated_batch["message_log"] + # Extract original prompt messages using the length field + # This correctly handles multi-turn prompts that contain assistant messages + initial_prompt_message_logs = extract_initial_prompt_messages( + repeated_batch["message_log"], + repeated_batch["length"], ) - prompt_batched_flat, _ = batched_message_log_to_flat_message( - prompt_only_message_logs, - pad_value_dict={"token_ids": tokenizer.pad_token_id}, + prompt_batched_flat, prompt_input_lengths = ( + batched_message_log_to_flat_message( + initial_prompt_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) ) prompt_ids_for_adv = prompt_batched_flat["token_ids"] - del prompt_only_message_logs + del initial_prompt_message_logs del prompt_batched_flat del input_ids del baseline @@ -1720,21 +1766,9 @@ def grpo_train( loss_multiplier[truncated] = 0 repeated_batch["loss_multiplier"] = loss_multiplier - # Add loss mask to each message in LLMMessageLogType - for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): - if message["role"] == "assistant": - message["token_loss_mask"] = torch.ones_like( - message["token_ids"] - ) - else: - message["token_loss_mask"] = torch.zeros_like( - message["token_ids"] - ) - if "generation_logprobs" not in message: - message["generation_logprobs"] = torch.zeros_like( - message["token_ids"], dtype=torch.float32 - ) + add_grpo_token_loss_masks_and_generation_logprobs( + repeated_batch["message_log"] + ) # Convert updated LLMMessageLogType to FlatMessagesType for training flat_messages, input_lengths = batched_message_log_to_flat_message( @@ -2828,16 +2862,21 @@ def async_grpo_train( print("▶ Processing rewards...") with timer.time("reward_calculation"): - # Extract prompt-only messages for advantage estimation - prompt_only_message_logs = _extract_prompt_only_messages( - repeated_batch["message_log"] + # Extract original prompt messages using the length field + # This correctly handles multi-turn prompts that contain assistant messages + initial_prompt_message_logs = extract_initial_prompt_messages( + repeated_batch["message_log"], + repeated_batch["length"], ) - prompt_batched_flat, _ = batched_message_log_to_flat_message( - prompt_only_message_logs, - pad_value_dict={"token_ids": tokenizer.pad_token_id}, + + prompt_batched_flat, prompt_input_lengths = ( + batched_message_log_to_flat_message( + initial_prompt_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) ) prompt_ids_for_adv = prompt_batched_flat["token_ids"] - del prompt_only_message_logs + del initial_prompt_message_logs del prompt_batched_flat rewards = repeated_batch["total_reward"] @@ -2848,21 +2887,9 @@ def async_grpo_train( # Prepare training data (same as sync version) with timer.time("data_processing"): - # Add loss mask to each message - for i, message_log in enumerate(repeated_batch["message_log"]): - for j, message in enumerate(message_log): - if message["role"] == "assistant": - message["token_loss_mask"] = torch.ones_like( - message["token_ids"] - ) - else: - message["token_loss_mask"] = torch.zeros_like( - message["token_ids"] - ) - if "generation_logprobs" not in message: - message["generation_logprobs"] = torch.zeros_like( - message["token_ids"], dtype=torch.float32 - ) + add_grpo_token_loss_masks_and_generation_logprobs( + repeated_batch["message_log"] + ) # Convert to flat format for training flat_messages, input_lengths = batched_message_log_to_flat_message( diff --git a/tests/unit/algorithms/test_async_utils.py b/tests/unit/algorithms/test_async_utils.py index 3844c9f3dc..32c47cfa60 100644 --- a/tests/unit/algorithms/test_async_utils.py +++ b/tests/unit/algorithms/test_async_utils.py @@ -33,7 +33,11 @@ ReplayBuffer, ) from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBufferNew -from nemo_rl.algorithms.grpo import MasterConfig +from nemo_rl.algorithms.grpo import ( + MasterConfig, + add_grpo_token_loss_masks_and_generation_logprobs, + extract_initial_prompt_messages, +) from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.environments.interfaces import ( @@ -834,3 +838,248 @@ def test_error_handling(self): assert sample_result is None ray.kill(buffer) + + +class TestPromptExtraction: + """Test cases for prompt extraction logic used in async GRPO advantage calculation. + + These tests verify that the length-based prompt extraction correctly handles + multi-turn conversation prompts where the original prompt itself contains + assistant messages (conversation history). + """ + + def test_prompt_extraction_with_multi_turn_history(self): + """Test that prompt extraction correctly handles prompts containing assistant messages. + + This tests the fix for multi-turn conversation prompts where the original prompt + from the dataset itself contains assistant messages (conversation history). + The extraction should use the length field to identify original prompt messages, + not break at the first assistant message. + """ + # Create a multi-turn prompt with assistant messages in the history + # Original prompt: user -> assistant -> user (3 messages, 15 tokens total) + original_prompt_messages = [ + { + "role": "user", + "content": "What is 2+2?", + "token_ids": torch.tensor([1, 2, 3, 4, 5]), + }, + { + "role": "assistant", + "content": "4", + "token_ids": torch.tensor([6, 7, 8, 9, 10]), + }, + { + "role": "user", + "content": "Now what is 3+3?", + "token_ids": torch.tensor([11, 12, 13, 14, 15]), + }, + ] + + # Generated response (added after original prompt) + generated_message = { + "role": "assistant", + "content": "6", + "token_ids": torch.tensor([16, 17, 18]), + } + + # Full message_log after generation + full_message_log = original_prompt_messages + [generated_message] + + # Original prompt length = sum of token_ids in original prompt + original_prompt_length = sum( + len(m["token_ids"]) for m in original_prompt_messages + ) # 15 + + message_logs = [full_message_log] + original_prompt_lengths = torch.tensor([original_prompt_length]) + + result = extract_initial_prompt_messages(message_logs, original_prompt_lengths) + initial_prompt_log = result[0] + + # Should extract all 3 original messages, NOT break at first assistant + assert len(initial_prompt_log) == 3, ( + f"Expected 3 messages (user, assistant, user), got {len(initial_prompt_log)}. " + "The extraction should NOT break at the first assistant message when it's part of the original prompt." + ) + + assert initial_prompt_log[0]["role"] == "user" + assert initial_prompt_log[1]["role"] == "assistant" + assert initial_prompt_log[2]["role"] == "user" + assert generated_message not in initial_prompt_log + + def test_prompt_extraction_with_single_turn(self): + """Test that prompt extraction works correctly for single-turn prompts (regression test).""" + original_prompt_messages = [ + { + "role": "user", + "content": "What is 2+2?", + "token_ids": torch.tensor([1, 2, 3, 4, 5]), + }, + ] + + generated_message = { + "role": "assistant", + "content": "4", + "token_ids": torch.tensor([6, 7, 8]), + } + + full_message_log = original_prompt_messages + [generated_message] + original_prompt_length = sum( + len(m["token_ids"]) for m in original_prompt_messages + ) + + result = extract_initial_prompt_messages( + [full_message_log], torch.tensor([original_prompt_length]) + ) + initial_prompt_log = result[0] + + assert len(initial_prompt_log) == 1 + assert initial_prompt_log[0]["role"] == "user" + assert generated_message not in initial_prompt_log + + def test_prompt_extraction_with_system_message(self): + """Test prompt extraction with system message included.""" + original_prompt_messages = [ + { + "role": "system", + "content": "You are a math tutor.", + "token_ids": torch.tensor([1, 2, 3]), + }, + { + "role": "user", + "content": "What is 2+2?", + "token_ids": torch.tensor([4, 5, 6, 7]), + }, + ] + + generated_message = { + "role": "assistant", + "content": "4", + "token_ids": torch.tensor([8, 9]), + } + + full_message_log = original_prompt_messages + [generated_message] + original_prompt_length = sum( + len(m["token_ids"]) for m in original_prompt_messages + ) + + result = extract_initial_prompt_messages( + [full_message_log], torch.tensor([original_prompt_length]) + ) + initial_prompt_log = result[0] + + assert len(initial_prompt_log) == 2 + assert initial_prompt_log[0]["role"] == "system" + assert initial_prompt_log[1]["role"] == "user" + assert generated_message not in initial_prompt_log + + def test_prompt_extraction_complex_multi_turn(self): + """Test prompt extraction with complex multi-turn history (multiple assistant turns).""" + original_prompt_messages = [ + { + "role": "system", + "content": "Math tutor", + "token_ids": torch.tensor([1, 2]), + }, + {"role": "user", "content": "2+2?", "token_ids": torch.tensor([3, 4])}, + {"role": "assistant", "content": "4", "token_ids": torch.tensor([5, 6])}, + {"role": "user", "content": "3+3?", "token_ids": torch.tensor([7, 8])}, + {"role": "assistant", "content": "6", "token_ids": torch.tensor([9, 10])}, + {"role": "user", "content": "4+4?", "token_ids": torch.tensor([11, 12])}, + ] + + generated_message = { + "role": "assistant", + "content": "8", + "token_ids": torch.tensor([13, 14]), + } + + full_message_log = original_prompt_messages + [generated_message] + original_prompt_length = sum( + len(m["token_ids"]) for m in original_prompt_messages + ) + + result = extract_initial_prompt_messages( + [full_message_log], torch.tensor([original_prompt_length]) + ) + initial_prompt_log = result[0] + + assert len(initial_prompt_log) == 6, ( + f"Expected 6 messages, got {len(initial_prompt_log)}. " + "All history messages should be included in the prompt." + ) + + expected_roles = ["system", "user", "assistant", "user", "assistant", "user"] + actual_roles = [m["role"] for m in initial_prompt_log] + assert actual_roles == expected_roles + assert generated_message not in initial_prompt_log + + def test_grpo_loss_mask_excludes_assistant_prompt_history(self): + """Test that assistant messages in the original prompt are not trained on.""" + original_prompt_messages = [ + { + "role": "user", + "content": "What is 2+2?", + "token_ids": torch.tensor([1, 2]), + }, + { + "role": "assistant", + "content": "4", + "token_ids": torch.tensor([3, 4]), + }, + { + "role": "user", + "content": "Now what is 3+3?", + "token_ids": torch.tensor([5, 6]), + }, + ] + generated_logprobs = torch.tensor([0.1, 0.2]) + generated_message = { + "role": "assistant", + "content": "6", + "token_ids": torch.tensor([7, 8]), + "generation_logprobs": generated_logprobs, + } + full_message_log = original_prompt_messages + [generated_message] + + add_grpo_token_loss_masks_and_generation_logprobs([full_message_log]) + + assert torch.equal(full_message_log[0]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal(full_message_log[1]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal(full_message_log[2]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal(full_message_log[3]["token_loss_mask"], torch.tensor([1, 1])) + assert torch.equal( + full_message_log[3]["generation_logprobs"], generated_logprobs + ) + + def test_grpo_loss_mask_uses_generation_logprobs_marker(self): + """Test that only assistant messages with generation logprobs are trainable.""" + message_log = [ + { + "role": "assistant", + "content": "prompt history", + "token_ids": torch.tensor([1, 2]), + }, + { + "role": "user", + "content": "next question", + "token_ids": torch.tensor([3, 4]), + "generation_logprobs": torch.tensor([0.3, 0.4]), + }, + { + "role": "assistant", + "content": "generated response", + "token_ids": torch.tensor([5, 6]), + "generation_logprobs": torch.tensor([0.5, 0.6]), + }, + ] + + add_grpo_token_loss_masks_and_generation_logprobs([message_log]) + + assert torch.equal(message_log[0]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal( + message_log[0]["generation_logprobs"], torch.tensor([0.0, 0.0]) + ) + assert torch.equal(message_log[1]["token_loss_mask"], torch.tensor([0, 0])) + assert torch.equal(message_log[2]["token_loss_mask"], torch.tensor([1, 1]))