3333 ReplayBuffer ,
3434)
3535from nemo_rl .algorithms .async_utils .replay_buffer import ReplayBufferNew
36- from nemo_rl .algorithms .grpo import MasterConfig , extract_initial_prompt_messages
36+ from nemo_rl .algorithms .grpo import (
37+ MasterConfig ,
38+ add_grpo_token_loss_masks_and_generation_logprobs ,
39+ extract_initial_prompt_messages ,
40+ )
3741from nemo_rl .data .interfaces import DatumSpec , LLMMessageLogType
3842from nemo_rl .distributed .batched_data_dict import BatchedDataDict
3943from nemo_rl .environments .interfaces import (
@@ -855,9 +859,21 @@ def test_prompt_extraction_with_multi_turn_history(self):
855859 # Create a multi-turn prompt with assistant messages in the history
856860 # Original prompt: user -> assistant -> user (3 messages, 15 tokens total)
857861 original_prompt_messages = [
858- {"role" : "user" , "content" : "What is 2+2?" , "token_ids" : torch .tensor ([1 , 2 , 3 , 4 , 5 ])},
859- {"role" : "assistant" , "content" : "4" , "token_ids" : torch .tensor ([6 , 7 , 8 , 9 , 10 ])},
860- {"role" : "user" , "content" : "Now what is 3+3?" , "token_ids" : torch .tensor ([11 , 12 , 13 , 14 , 15 ])},
862+ {
863+ "role" : "user" ,
864+ "content" : "What is 2+2?" ,
865+ "token_ids" : torch .tensor ([1 , 2 , 3 , 4 , 5 ]),
866+ },
867+ {
868+ "role" : "assistant" ,
869+ "content" : "4" ,
870+ "token_ids" : torch .tensor ([6 , 7 , 8 , 9 , 10 ]),
871+ },
872+ {
873+ "role" : "user" ,
874+ "content" : "Now what is 3+3?" ,
875+ "token_ids" : torch .tensor ([11 , 12 , 13 , 14 , 15 ]),
876+ },
861877 ]
862878
863879 # Generated response (added after original prompt)
@@ -871,7 +887,9 @@ def test_prompt_extraction_with_multi_turn_history(self):
871887 full_message_log = original_prompt_messages + [generated_message ]
872888
873889 # Original prompt length = sum of token_ids in original prompt
874- original_prompt_length = sum (len (m ["token_ids" ]) for m in original_prompt_messages ) # 15
890+ original_prompt_length = sum (
891+ len (m ["token_ids" ]) for m in original_prompt_messages
892+ ) # 15
875893
876894 message_logs = [full_message_log ]
877895 original_prompt_lengths = torch .tensor ([original_prompt_length ])
@@ -893,7 +911,11 @@ def test_prompt_extraction_with_multi_turn_history(self):
893911 def test_prompt_extraction_with_single_turn (self ):
894912 """Test that prompt extraction works correctly for single-turn prompts (regression test)."""
895913 original_prompt_messages = [
896- {"role" : "user" , "content" : "What is 2+2?" , "token_ids" : torch .tensor ([1 , 2 , 3 , 4 , 5 ])},
914+ {
915+ "role" : "user" ,
916+ "content" : "What is 2+2?" ,
917+ "token_ids" : torch .tensor ([1 , 2 , 3 , 4 , 5 ]),
918+ },
897919 ]
898920
899921 generated_message = {
@@ -903,7 +925,9 @@ def test_prompt_extraction_with_single_turn(self):
903925 }
904926
905927 full_message_log = original_prompt_messages + [generated_message ]
906- original_prompt_length = sum (len (m ["token_ids" ]) for m in original_prompt_messages )
928+ original_prompt_length = sum (
929+ len (m ["token_ids" ]) for m in original_prompt_messages
930+ )
907931
908932 result = extract_initial_prompt_messages (
909933 [full_message_log ], torch .tensor ([original_prompt_length ])
@@ -917,8 +941,16 @@ def test_prompt_extraction_with_single_turn(self):
917941 def test_prompt_extraction_with_system_message (self ):
918942 """Test prompt extraction with system message included."""
919943 original_prompt_messages = [
920- {"role" : "system" , "content" : "You are a math tutor." , "token_ids" : torch .tensor ([1 , 2 , 3 ])},
921- {"role" : "user" , "content" : "What is 2+2?" , "token_ids" : torch .tensor ([4 , 5 , 6 , 7 ])},
944+ {
945+ "role" : "system" ,
946+ "content" : "You are a math tutor." ,
947+ "token_ids" : torch .tensor ([1 , 2 , 3 ]),
948+ },
949+ {
950+ "role" : "user" ,
951+ "content" : "What is 2+2?" ,
952+ "token_ids" : torch .tensor ([4 , 5 , 6 , 7 ]),
953+ },
922954 ]
923955
924956 generated_message = {
@@ -928,7 +960,9 @@ def test_prompt_extraction_with_system_message(self):
928960 }
929961
930962 full_message_log = original_prompt_messages + [generated_message ]
931- original_prompt_length = sum (len (m ["token_ids" ]) for m in original_prompt_messages )
963+ original_prompt_length = sum (
964+ len (m ["token_ids" ]) for m in original_prompt_messages
965+ )
932966
933967 result = extract_initial_prompt_messages (
934968 [full_message_log ], torch .tensor ([original_prompt_length ])
@@ -943,7 +977,11 @@ def test_prompt_extraction_with_system_message(self):
943977 def test_prompt_extraction_complex_multi_turn (self ):
944978 """Test prompt extraction with complex multi-turn history (multiple assistant turns)."""
945979 original_prompt_messages = [
946- {"role" : "system" , "content" : "Math tutor" , "token_ids" : torch .tensor ([1 , 2 ])},
980+ {
981+ "role" : "system" ,
982+ "content" : "Math tutor" ,
983+ "token_ids" : torch .tensor ([1 , 2 ]),
984+ },
947985 {"role" : "user" , "content" : "2+2?" , "token_ids" : torch .tensor ([3 , 4 ])},
948986 {"role" : "assistant" , "content" : "4" , "token_ids" : torch .tensor ([5 , 6 ])},
949987 {"role" : "user" , "content" : "3+3?" , "token_ids" : torch .tensor ([7 , 8 ])},
@@ -958,7 +996,9 @@ def test_prompt_extraction_complex_multi_turn(self):
958996 }
959997
960998 full_message_log = original_prompt_messages + [generated_message ]
961- original_prompt_length = sum (len (m ["token_ids" ]) for m in original_prompt_messages )
999+ original_prompt_length = sum (
1000+ len (m ["token_ids" ]) for m in original_prompt_messages
1001+ )
9621002
9631003 result = extract_initial_prompt_messages (
9641004 [full_message_log ], torch .tensor ([original_prompt_length ])
@@ -974,3 +1014,72 @@ def test_prompt_extraction_complex_multi_turn(self):
9741014 actual_roles = [m ["role" ] for m in initial_prompt_log ]
9751015 assert actual_roles == expected_roles
9761016 assert generated_message not in initial_prompt_log
1017+
1018+ def test_grpo_loss_mask_excludes_assistant_prompt_history (self ):
1019+ """Test that assistant messages in the original prompt are not trained on."""
1020+ original_prompt_messages = [
1021+ {
1022+ "role" : "user" ,
1023+ "content" : "What is 2+2?" ,
1024+ "token_ids" : torch .tensor ([1 , 2 ]),
1025+ },
1026+ {
1027+ "role" : "assistant" ,
1028+ "content" : "4" ,
1029+ "token_ids" : torch .tensor ([3 , 4 ]),
1030+ },
1031+ {
1032+ "role" : "user" ,
1033+ "content" : "Now what is 3+3?" ,
1034+ "token_ids" : torch .tensor ([5 , 6 ]),
1035+ },
1036+ ]
1037+ generated_logprobs = torch .tensor ([0.1 , 0.2 ])
1038+ generated_message = {
1039+ "role" : "assistant" ,
1040+ "content" : "6" ,
1041+ "token_ids" : torch .tensor ([7 , 8 ]),
1042+ "generation_logprobs" : generated_logprobs ,
1043+ }
1044+ full_message_log = original_prompt_messages + [generated_message ]
1045+
1046+ add_grpo_token_loss_masks_and_generation_logprobs ([full_message_log ])
1047+
1048+ assert torch .equal (full_message_log [0 ]["token_loss_mask" ], torch .tensor ([0 , 0 ]))
1049+ assert torch .equal (full_message_log [1 ]["token_loss_mask" ], torch .tensor ([0 , 0 ]))
1050+ assert torch .equal (full_message_log [2 ]["token_loss_mask" ], torch .tensor ([0 , 0 ]))
1051+ assert torch .equal (full_message_log [3 ]["token_loss_mask" ], torch .tensor ([1 , 1 ]))
1052+ assert torch .equal (
1053+ full_message_log [3 ]["generation_logprobs" ], generated_logprobs
1054+ )
1055+
1056+ def test_grpo_loss_mask_uses_generation_logprobs_marker (self ):
1057+ """Test that only assistant messages with generation logprobs are trainable."""
1058+ message_log = [
1059+ {
1060+ "role" : "assistant" ,
1061+ "content" : "prompt history" ,
1062+ "token_ids" : torch .tensor ([1 , 2 ]),
1063+ },
1064+ {
1065+ "role" : "user" ,
1066+ "content" : "next question" ,
1067+ "token_ids" : torch .tensor ([3 , 4 ]),
1068+ "generation_logprobs" : torch .tensor ([0.3 , 0.4 ]),
1069+ },
1070+ {
1071+ "role" : "assistant" ,
1072+ "content" : "generated response" ,
1073+ "token_ids" : torch .tensor ([5 , 6 ]),
1074+ "generation_logprobs" : torch .tensor ([0.5 , 0.6 ]),
1075+ },
1076+ ]
1077+
1078+ add_grpo_token_loss_masks_and_generation_logprobs ([message_log ])
1079+
1080+ assert torch .equal (message_log [0 ]["token_loss_mask" ], torch .tensor ([0 , 0 ]))
1081+ assert torch .equal (
1082+ message_log [0 ]["generation_logprobs" ], torch .tensor ([0.0 , 0.0 ])
1083+ )
1084+ assert torch .equal (message_log [1 ]["token_loss_mask" ], torch .tensor ([0 , 0 ]))
1085+ assert torch .equal (message_log [2 ]["token_loss_mask" ], torch .tensor ([1 , 1 ]))
0 commit comments