Skip to content

Commit ed03e7d

Browse files
yfwHeyyyyyyGclaude
authored andcommitted
fix: use prompt token length for advantage group extraction
The previous role-based extraction (`_extract_prompt_only_messages`) broke on multi-turn prompts containing assistant messages in the conversation history — it would strip them, corrupting the prompt IDs used for advantage estimation. Replace with `extract_initial_prompt_messages()` which uses the `length` field to identify the original prompt boundary. Applied to both sync and async GRPO paths. Closes #1960 Co-Authored-By: Jiaqi Zeng <jiaqiz@nvidia.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com>
1 parent 501cd12 commit ed03e7d

2 files changed

Lines changed: 197 additions & 37 deletions

File tree

nemo_rl/algorithms/grpo.py

Lines changed: 56 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,39 @@ def _scale(reward_tensor: torch.Tensor) -> torch.Tensor:
10021002
return repeated_batch
10031003

10041004

1005+
def extract_initial_prompt_messages(
1006+
message_logs: list,
1007+
original_prompt_lengths: torch.Tensor,
1008+
) -> list:
1009+
"""Extract the original prompt messages from message logs using token length.
1010+
1011+
This function correctly identifies original prompt messages even when the prompt
1012+
contains assistant messages (e.g., multi-turn conversation history).
1013+
1014+
Args:
1015+
message_logs: List of message logs, where each log is a list of messages.
1016+
original_prompt_lengths: Tensor of original prompt token lengths per sample.
1017+
1018+
Returns:
1019+
List of message logs containing only the original prompt messages.
1020+
"""
1021+
initial_prompt_message_logs = []
1022+
for i, message_log in enumerate(message_logs):
1023+
initial_prompt_log = []
1024+
cumulative_length = 0
1025+
target_length = original_prompt_lengths[i].item()
1026+
1027+
for message in message_log:
1028+
if cumulative_length >= target_length:
1029+
break
1030+
initial_prompt_log.append(message)
1031+
cumulative_length += len(message["token_ids"])
1032+
1033+
initial_prompt_message_logs.append(initial_prompt_log)
1034+
1035+
return initial_prompt_message_logs
1036+
1037+
10051038
def _should_use_async_rollouts(master_config: MasterConfig) -> bool:
10061039
"""Determine if async rollouts should be used based on the configuration.
10071040
@@ -1098,28 +1131,6 @@ def _create_advantage_estimator(master_config: MasterConfig):
10981131
return adv_estimator
10991132

11001133

1101-
def _extract_prompt_only_messages(message_logs: list) -> list:
1102-
"""Extract only prompt messages (user/system) from message logs.
1103-
1104-
This is used to get prompt IDs for advantage estimation, excluding
1105-
any assistant responses.
1106-
1107-
Args:
1108-
message_logs: List of message logs, where each log is a list of messages.
1109-
1110-
Returns:
1111-
List of message logs containing only user and system messages.
1112-
"""
1113-
prompt_only_message_logs = []
1114-
for message_log in message_logs:
1115-
prompt_only_log = []
1116-
for message in message_log:
1117-
if message["role"] == "user" or message["role"] == "system":
1118-
prompt_only_log.append(message)
1119-
prompt_only_message_logs.append(prompt_only_log)
1120-
return prompt_only_message_logs
1121-
1122-
11231134
def refit_policy_generation(
11241135
policy: ColocatablePolicyInterface,
11251136
policy_generation: GenerationInterface,
@@ -1694,16 +1705,20 @@ def grpo_train(
16941705
# Save baseline for logging (before deletion)
16951706
baseline_for_log = baseline.clone()
16961707

1697-
# Extract prompt-only messages for advantage estimation
1698-
prompt_only_message_logs = _extract_prompt_only_messages(
1699-
repeated_batch["message_log"]
1708+
# Extract original prompt messages using the length field
1709+
# This correctly handles multi-turn prompts that contain assistant messages
1710+
initial_prompt_message_logs = extract_initial_prompt_messages(
1711+
repeated_batch["message_log"],
1712+
repeated_batch["length"],
17001713
)
1701-
prompt_batched_flat, _ = batched_message_log_to_flat_message(
1702-
prompt_only_message_logs,
1703-
pad_value_dict={"token_ids": tokenizer.pad_token_id},
1714+
prompt_batched_flat, prompt_input_lengths = (
1715+
batched_message_log_to_flat_message(
1716+
initial_prompt_message_logs,
1717+
pad_value_dict={"token_ids": tokenizer.pad_token_id},
1718+
)
17041719
)
17051720
prompt_ids_for_adv = prompt_batched_flat["token_ids"]
1706-
del prompt_only_message_logs
1721+
del initial_prompt_message_logs
17071722
del prompt_batched_flat
17081723
del input_ids
17091724
del baseline
@@ -2828,16 +2843,21 @@ def async_grpo_train(
28282843

28292844
print("▶ Processing rewards...")
28302845
with timer.time("reward_calculation"):
2831-
# Extract prompt-only messages for advantage estimation
2832-
prompt_only_message_logs = _extract_prompt_only_messages(
2833-
repeated_batch["message_log"]
2846+
# Extract original prompt messages using the length field
2847+
# This correctly handles multi-turn prompts that contain assistant messages
2848+
initial_prompt_message_logs = extract_initial_prompt_messages(
2849+
repeated_batch["message_log"],
2850+
repeated_batch["length"],
28342851
)
2835-
prompt_batched_flat, _ = batched_message_log_to_flat_message(
2836-
prompt_only_message_logs,
2837-
pad_value_dict={"token_ids": tokenizer.pad_token_id},
2852+
2853+
prompt_batched_flat, prompt_input_lengths = (
2854+
batched_message_log_to_flat_message(
2855+
initial_prompt_message_logs,
2856+
pad_value_dict={"token_ids": tokenizer.pad_token_id},
2857+
)
28382858
)
28392859
prompt_ids_for_adv = prompt_batched_flat["token_ids"]
2840-
del prompt_only_message_logs
2860+
del initial_prompt_message_logs
28412861
del prompt_batched_flat
28422862

28432863
rewards = repeated_batch["total_reward"]

tests/unit/algorithms/test_async_utils.py

Lines changed: 141 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
ReplayBuffer,
3434
)
3535
from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBufferNew
36-
from nemo_rl.algorithms.grpo import MasterConfig
36+
from nemo_rl.algorithms.grpo import MasterConfig, extract_initial_prompt_messages
3737
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
3838
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
3939
from nemo_rl.environments.interfaces import (
@@ -834,3 +834,143 @@ def test_error_handling(self):
834834
assert sample_result is None
835835

836836
ray.kill(buffer)
837+
838+
839+
class TestPromptExtraction:
840+
"""Test cases for prompt extraction logic used in async GRPO advantage calculation.
841+
842+
These tests verify that the length-based prompt extraction correctly handles
843+
multi-turn conversation prompts where the original prompt itself contains
844+
assistant messages (conversation history).
845+
"""
846+
847+
def test_prompt_extraction_with_multi_turn_history(self):
848+
"""Test that prompt extraction correctly handles prompts containing assistant messages.
849+
850+
This tests the fix for multi-turn conversation prompts where the original prompt
851+
from the dataset itself contains assistant messages (conversation history).
852+
The extraction should use the length field to identify original prompt messages,
853+
not break at the first assistant message.
854+
"""
855+
# Create a multi-turn prompt with assistant messages in the history
856+
# Original prompt: user -> assistant -> user (3 messages, 15 tokens total)
857+
original_prompt_messages = [
858+
{"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])},
859+
{"role": "assistant", "content": "4", "token_ids": torch.tensor([6, 7, 8, 9, 10])},
860+
{"role": "user", "content": "Now what is 3+3?", "token_ids": torch.tensor([11, 12, 13, 14, 15])},
861+
]
862+
863+
# Generated response (added after original prompt)
864+
generated_message = {
865+
"role": "assistant",
866+
"content": "6",
867+
"token_ids": torch.tensor([16, 17, 18]),
868+
}
869+
870+
# Full message_log after generation
871+
full_message_log = original_prompt_messages + [generated_message]
872+
873+
# Original prompt length = sum of token_ids in original prompt
874+
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) # 15
875+
876+
message_logs = [full_message_log]
877+
original_prompt_lengths = torch.tensor([original_prompt_length])
878+
879+
result = extract_initial_prompt_messages(message_logs, original_prompt_lengths)
880+
initial_prompt_log = result[0]
881+
882+
# Should extract all 3 original messages, NOT break at first assistant
883+
assert len(initial_prompt_log) == 3, (
884+
f"Expected 3 messages (user, assistant, user), got {len(initial_prompt_log)}. "
885+
"The extraction should NOT break at the first assistant message when it's part of the original prompt."
886+
)
887+
888+
assert initial_prompt_log[0]["role"] == "user"
889+
assert initial_prompt_log[1]["role"] == "assistant"
890+
assert initial_prompt_log[2]["role"] == "user"
891+
assert generated_message not in initial_prompt_log
892+
893+
def test_prompt_extraction_with_single_turn(self):
894+
"""Test that prompt extraction works correctly for single-turn prompts (regression test)."""
895+
original_prompt_messages = [
896+
{"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])},
897+
]
898+
899+
generated_message = {
900+
"role": "assistant",
901+
"content": "4",
902+
"token_ids": torch.tensor([6, 7, 8]),
903+
}
904+
905+
full_message_log = original_prompt_messages + [generated_message]
906+
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages)
907+
908+
result = extract_initial_prompt_messages(
909+
[full_message_log], torch.tensor([original_prompt_length])
910+
)
911+
initial_prompt_log = result[0]
912+
913+
assert len(initial_prompt_log) == 1
914+
assert initial_prompt_log[0]["role"] == "user"
915+
assert generated_message not in initial_prompt_log
916+
917+
def test_prompt_extraction_with_system_message(self):
918+
"""Test prompt extraction with system message included."""
919+
original_prompt_messages = [
920+
{"role": "system", "content": "You are a math tutor.", "token_ids": torch.tensor([1, 2, 3])},
921+
{"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([4, 5, 6, 7])},
922+
]
923+
924+
generated_message = {
925+
"role": "assistant",
926+
"content": "4",
927+
"token_ids": torch.tensor([8, 9]),
928+
}
929+
930+
full_message_log = original_prompt_messages + [generated_message]
931+
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages)
932+
933+
result = extract_initial_prompt_messages(
934+
[full_message_log], torch.tensor([original_prompt_length])
935+
)
936+
initial_prompt_log = result[0]
937+
938+
assert len(initial_prompt_log) == 2
939+
assert initial_prompt_log[0]["role"] == "system"
940+
assert initial_prompt_log[1]["role"] == "user"
941+
assert generated_message not in initial_prompt_log
942+
943+
def test_prompt_extraction_complex_multi_turn(self):
944+
"""Test prompt extraction with complex multi-turn history (multiple assistant turns)."""
945+
original_prompt_messages = [
946+
{"role": "system", "content": "Math tutor", "token_ids": torch.tensor([1, 2])},
947+
{"role": "user", "content": "2+2?", "token_ids": torch.tensor([3, 4])},
948+
{"role": "assistant", "content": "4", "token_ids": torch.tensor([5, 6])},
949+
{"role": "user", "content": "3+3?", "token_ids": torch.tensor([7, 8])},
950+
{"role": "assistant", "content": "6", "token_ids": torch.tensor([9, 10])},
951+
{"role": "user", "content": "4+4?", "token_ids": torch.tensor([11, 12])},
952+
]
953+
954+
generated_message = {
955+
"role": "assistant",
956+
"content": "8",
957+
"token_ids": torch.tensor([13, 14]),
958+
}
959+
960+
full_message_log = original_prompt_messages + [generated_message]
961+
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages)
962+
963+
result = extract_initial_prompt_messages(
964+
[full_message_log], torch.tensor([original_prompt_length])
965+
)
966+
initial_prompt_log = result[0]
967+
968+
assert len(initial_prompt_log) == 6, (
969+
f"Expected 6 messages, got {len(initial_prompt_log)}. "
970+
"All history messages should be included in the prompt."
971+
)
972+
973+
expected_roles = ["system", "user", "assistant", "user", "assistant", "user"]
974+
actual_roles = [m["role"] for m in initial_prompt_log]
975+
assert actual_roles == expected_roles
976+
assert generated_message not in initial_prompt_log

0 commit comments

Comments
 (0)