Skip to content

Commit 7d96ad3

Browse files
author
Anish Mahishi
committed
fix: only train on generated assistant turns
Signed-off-by: Anish Mahishi <amahishi@cw-dfw-cs-001-vscode-02.cm.cluster>
1 parent f20a14d commit 7d96ad3

2 files changed

Lines changed: 159 additions & 43 deletions

File tree

nemo_rl/algorithms/grpo.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from nemo_rl.data.collate_fn import rl_collate_fn
5454
from nemo_rl.data.dataloader import MultipleDataloaderWrapper
5555
from nemo_rl.data.datasets import AllTaskProcessedDataset
56-
from nemo_rl.data.interfaces import DatumSpec
56+
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType, VLMMessageLogType
5757
from nemo_rl.data.llm_message_utils import (
5858
batched_message_log_to_flat_message,
5959
get_keys_from_message_log,
@@ -1035,6 +1035,37 @@ def extract_initial_prompt_messages(
10351035
return initial_prompt_message_logs
10361036

10371037

1038+
def add_grpo_token_loss_masks_and_generation_logprobs(
1039+
message_logs: list[LLMMessageLogType | VLMMessageLogType],
1040+
) -> None:
1041+
"""Add GRPO loss masks and ensure generation logprobs exist in message logs.
1042+
1043+
Assistant messages can be part of the original multi-turn prompt history. Only
1044+
generated assistant messages have generation_logprobs, so use that field as the
1045+
trainable-token marker. This function mutates each message in-place by adding a
1046+
token_loss_mask and, when missing, a zero-valued generation_logprobs tensor.
1047+
1048+
Args:
1049+
message_logs: Batch of tokenized message logs. Each message must contain a
1050+
``role`` and ``token_ids`` field. Messages that already contain
1051+
``generation_logprobs`` are treated as rollout-generated messages.
1052+
"""
1053+
for message_log in message_logs:
1054+
for message in message_log:
1055+
role = cast(str, message["role"])
1056+
token_ids = cast(torch.Tensor, message["token_ids"])
1057+
1058+
if role == "assistant" and "generation_logprobs" in message:
1059+
message["token_loss_mask"] = torch.ones_like(token_ids)
1060+
else:
1061+
message["token_loss_mask"] = torch.zeros_like(token_ids)
1062+
1063+
if "generation_logprobs" not in message:
1064+
message["generation_logprobs"] = torch.zeros_like(
1065+
token_ids, dtype=torch.float32
1066+
)
1067+
1068+
10381069
def _should_use_async_rollouts(master_config: MasterConfig) -> bool:
10391070
"""Determine if async rollouts should be used based on the configuration.
10401071
@@ -1714,21 +1745,9 @@ def grpo_train(
17141745

17151746
loss_multiplier[truncated] = 0
17161747
repeated_batch["loss_multiplier"] = loss_multiplier
1717-
# Add loss mask to each message in LLMMessageLogType
1718-
for i, message_log in enumerate(repeated_batch["message_log"]):
1719-
for j, message in enumerate(message_log):
1720-
if message["role"] == "assistant":
1721-
message["token_loss_mask"] = torch.ones_like(
1722-
message["token_ids"]
1723-
)
1724-
else:
1725-
message["token_loss_mask"] = torch.zeros_like(
1726-
message["token_ids"]
1727-
)
1728-
if "generation_logprobs" not in message:
1729-
message["generation_logprobs"] = torch.zeros_like(
1730-
message["token_ids"], dtype=torch.float32
1731-
)
1748+
add_grpo_token_loss_masks_and_generation_logprobs(
1749+
repeated_batch["message_log"]
1750+
)
17321751

17331752
# Convert updated LLMMessageLogType to FlatMessagesType for training
17341753
flat_messages, input_lengths = batched_message_log_to_flat_message(
@@ -2818,21 +2837,9 @@ def async_grpo_train(
28182837

28192838
# Prepare training data (same as sync version)
28202839
with timer.time("data_processing"):
2821-
# Add loss mask to each message
2822-
for i, message_log in enumerate(repeated_batch["message_log"]):
2823-
for j, message in enumerate(message_log):
2824-
if message["role"] == "assistant":
2825-
message["token_loss_mask"] = torch.ones_like(
2826-
message["token_ids"]
2827-
)
2828-
else:
2829-
message["token_loss_mask"] = torch.zeros_like(
2830-
message["token_ids"]
2831-
)
2832-
if "generation_logprobs" not in message:
2833-
message["generation_logprobs"] = torch.zeros_like(
2834-
message["token_ids"], dtype=torch.float32
2835-
)
2840+
add_grpo_token_loss_masks_and_generation_logprobs(
2841+
repeated_batch["message_log"]
2842+
)
28362843

28372844
# Convert to flat format for training
28382845
flat_messages, input_lengths = batched_message_log_to_flat_message(

tests/unit/algorithms/test_async_utils.py

Lines changed: 121 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,11 @@
3333
ReplayBuffer,
3434
)
3535
from nemo_rl.algorithms.async_utils.replay_buffer import ReplayBufferNew
36-
from nemo_rl.algorithms.grpo import MasterConfig, extract_initial_prompt_messages
36+
from nemo_rl.algorithms.grpo import (
37+
MasterConfig,
38+
add_grpo_token_loss_masks_and_generation_logprobs,
39+
extract_initial_prompt_messages,
40+
)
3741
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
3842
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
3943
from nemo_rl.environments.interfaces import (
@@ -855,9 +859,21 @@ def test_prompt_extraction_with_multi_turn_history(self):
855859
# Create a multi-turn prompt with assistant messages in the history
856860
# Original prompt: user -> assistant -> user (3 messages, 15 tokens total)
857861
original_prompt_messages = [
858-
{"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])},
859-
{"role": "assistant", "content": "4", "token_ids": torch.tensor([6, 7, 8, 9, 10])},
860-
{"role": "user", "content": "Now what is 3+3?", "token_ids": torch.tensor([11, 12, 13, 14, 15])},
862+
{
863+
"role": "user",
864+
"content": "What is 2+2?",
865+
"token_ids": torch.tensor([1, 2, 3, 4, 5]),
866+
},
867+
{
868+
"role": "assistant",
869+
"content": "4",
870+
"token_ids": torch.tensor([6, 7, 8, 9, 10]),
871+
},
872+
{
873+
"role": "user",
874+
"content": "Now what is 3+3?",
875+
"token_ids": torch.tensor([11, 12, 13, 14, 15]),
876+
},
861877
]
862878

863879
# Generated response (added after original prompt)
@@ -871,7 +887,9 @@ def test_prompt_extraction_with_multi_turn_history(self):
871887
full_message_log = original_prompt_messages + [generated_message]
872888

873889
# Original prompt length = sum of token_ids in original prompt
874-
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages) # 15
890+
original_prompt_length = sum(
891+
len(m["token_ids"]) for m in original_prompt_messages
892+
) # 15
875893

876894
message_logs = [full_message_log]
877895
original_prompt_lengths = torch.tensor([original_prompt_length])
@@ -893,7 +911,11 @@ def test_prompt_extraction_with_multi_turn_history(self):
893911
def test_prompt_extraction_with_single_turn(self):
894912
"""Test that prompt extraction works correctly for single-turn prompts (regression test)."""
895913
original_prompt_messages = [
896-
{"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([1, 2, 3, 4, 5])},
914+
{
915+
"role": "user",
916+
"content": "What is 2+2?",
917+
"token_ids": torch.tensor([1, 2, 3, 4, 5]),
918+
},
897919
]
898920

899921
generated_message = {
@@ -903,7 +925,9 @@ def test_prompt_extraction_with_single_turn(self):
903925
}
904926

905927
full_message_log = original_prompt_messages + [generated_message]
906-
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages)
928+
original_prompt_length = sum(
929+
len(m["token_ids"]) for m in original_prompt_messages
930+
)
907931

908932
result = extract_initial_prompt_messages(
909933
[full_message_log], torch.tensor([original_prompt_length])
@@ -917,8 +941,16 @@ def test_prompt_extraction_with_single_turn(self):
917941
def test_prompt_extraction_with_system_message(self):
918942
"""Test prompt extraction with system message included."""
919943
original_prompt_messages = [
920-
{"role": "system", "content": "You are a math tutor.", "token_ids": torch.tensor([1, 2, 3])},
921-
{"role": "user", "content": "What is 2+2?", "token_ids": torch.tensor([4, 5, 6, 7])},
944+
{
945+
"role": "system",
946+
"content": "You are a math tutor.",
947+
"token_ids": torch.tensor([1, 2, 3]),
948+
},
949+
{
950+
"role": "user",
951+
"content": "What is 2+2?",
952+
"token_ids": torch.tensor([4, 5, 6, 7]),
953+
},
922954
]
923955

924956
generated_message = {
@@ -928,7 +960,9 @@ def test_prompt_extraction_with_system_message(self):
928960
}
929961

930962
full_message_log = original_prompt_messages + [generated_message]
931-
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages)
963+
original_prompt_length = sum(
964+
len(m["token_ids"]) for m in original_prompt_messages
965+
)
932966

933967
result = extract_initial_prompt_messages(
934968
[full_message_log], torch.tensor([original_prompt_length])
@@ -943,7 +977,11 @@ def test_prompt_extraction_with_system_message(self):
943977
def test_prompt_extraction_complex_multi_turn(self):
944978
"""Test prompt extraction with complex multi-turn history (multiple assistant turns)."""
945979
original_prompt_messages = [
946-
{"role": "system", "content": "Math tutor", "token_ids": torch.tensor([1, 2])},
980+
{
981+
"role": "system",
982+
"content": "Math tutor",
983+
"token_ids": torch.tensor([1, 2]),
984+
},
947985
{"role": "user", "content": "2+2?", "token_ids": torch.tensor([3, 4])},
948986
{"role": "assistant", "content": "4", "token_ids": torch.tensor([5, 6])},
949987
{"role": "user", "content": "3+3?", "token_ids": torch.tensor([7, 8])},
@@ -958,7 +996,9 @@ def test_prompt_extraction_complex_multi_turn(self):
958996
}
959997

960998
full_message_log = original_prompt_messages + [generated_message]
961-
original_prompt_length = sum(len(m["token_ids"]) for m in original_prompt_messages)
999+
original_prompt_length = sum(
1000+
len(m["token_ids"]) for m in original_prompt_messages
1001+
)
9621002

9631003
result = extract_initial_prompt_messages(
9641004
[full_message_log], torch.tensor([original_prompt_length])
@@ -974,3 +1014,72 @@ def test_prompt_extraction_complex_multi_turn(self):
9741014
actual_roles = [m["role"] for m in initial_prompt_log]
9751015
assert actual_roles == expected_roles
9761016
assert generated_message not in initial_prompt_log
1017+
1018+
def test_grpo_loss_mask_excludes_assistant_prompt_history(self):
1019+
"""Test that assistant messages in the original prompt are not trained on."""
1020+
original_prompt_messages = [
1021+
{
1022+
"role": "user",
1023+
"content": "What is 2+2?",
1024+
"token_ids": torch.tensor([1, 2]),
1025+
},
1026+
{
1027+
"role": "assistant",
1028+
"content": "4",
1029+
"token_ids": torch.tensor([3, 4]),
1030+
},
1031+
{
1032+
"role": "user",
1033+
"content": "Now what is 3+3?",
1034+
"token_ids": torch.tensor([5, 6]),
1035+
},
1036+
]
1037+
generated_logprobs = torch.tensor([0.1, 0.2])
1038+
generated_message = {
1039+
"role": "assistant",
1040+
"content": "6",
1041+
"token_ids": torch.tensor([7, 8]),
1042+
"generation_logprobs": generated_logprobs,
1043+
}
1044+
full_message_log = original_prompt_messages + [generated_message]
1045+
1046+
add_grpo_token_loss_masks_and_generation_logprobs([full_message_log])
1047+
1048+
assert torch.equal(full_message_log[0]["token_loss_mask"], torch.tensor([0, 0]))
1049+
assert torch.equal(full_message_log[1]["token_loss_mask"], torch.tensor([0, 0]))
1050+
assert torch.equal(full_message_log[2]["token_loss_mask"], torch.tensor([0, 0]))
1051+
assert torch.equal(full_message_log[3]["token_loss_mask"], torch.tensor([1, 1]))
1052+
assert torch.equal(
1053+
full_message_log[3]["generation_logprobs"], generated_logprobs
1054+
)
1055+
1056+
def test_grpo_loss_mask_uses_generation_logprobs_marker(self):
1057+
"""Test that only assistant messages with generation logprobs are trainable."""
1058+
message_log = [
1059+
{
1060+
"role": "assistant",
1061+
"content": "prompt history",
1062+
"token_ids": torch.tensor([1, 2]),
1063+
},
1064+
{
1065+
"role": "user",
1066+
"content": "next question",
1067+
"token_ids": torch.tensor([3, 4]),
1068+
"generation_logprobs": torch.tensor([0.3, 0.4]),
1069+
},
1070+
{
1071+
"role": "assistant",
1072+
"content": "generated response",
1073+
"token_ids": torch.tensor([5, 6]),
1074+
"generation_logprobs": torch.tensor([0.5, 0.6]),
1075+
},
1076+
]
1077+
1078+
add_grpo_token_loss_masks_and_generation_logprobs([message_log])
1079+
1080+
assert torch.equal(message_log[0]["token_loss_mask"], torch.tensor([0, 0]))
1081+
assert torch.equal(
1082+
message_log[0]["generation_logprobs"], torch.tensor([0.0, 0.0])
1083+
)
1084+
assert torch.equal(message_log[1]["token_loss_mask"], torch.tensor([0, 0]))
1085+
assert torch.equal(message_log[2]["token_loss_mask"], torch.tensor([1, 1]))

0 commit comments

Comments
 (0)