Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 94 additions & 67 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from nemo_rl.data.collate_fn import rl_collate_fn
from nemo_rl.data.dataloader import MultipleDataloaderWrapper
from nemo_rl.data.datasets import AllTaskProcessedDataset
from nemo_rl.data.interfaces import DatumSpec
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType, VLMMessageLogType
from nemo_rl.data.llm_message_utils import (
batched_message_log_to_flat_message,
get_keys_from_message_log,
Expand Down Expand Up @@ -1002,6 +1002,70 @@ def _scale(reward_tensor: torch.Tensor) -> torch.Tensor:
return repeated_batch


def extract_initial_prompt_messages(
message_logs: list,
original_prompt_lengths: torch.Tensor,
) -> list:
"""Extract the original prompt messages from message logs using token length.

This function correctly identifies original prompt messages even when the prompt
contains assistant messages (e.g., multi-turn conversation history).

Args:
message_logs: List of message logs, where each log is a list of messages.
original_prompt_lengths: Tensor of original prompt token lengths per sample.

Returns:
List of message logs containing only the original prompt messages.
"""
initial_prompt_message_logs = []
for i, message_log in enumerate(message_logs):
initial_prompt_log = []
cumulative_length = 0
target_length = original_prompt_lengths[i].item()

for message in message_log:
if cumulative_length >= target_length:
break
initial_prompt_log.append(message)
cumulative_length += len(message["token_ids"])

initial_prompt_message_logs.append(initial_prompt_log)

return initial_prompt_message_logs


def add_grpo_token_loss_masks_and_generation_logprobs(
message_logs: list[LLMMessageLogType | VLMMessageLogType],
) -> None:
"""Add GRPO loss masks and ensure generation logprobs exist in message logs.

Assistant messages can be part of the original multi-turn prompt history. Only
generated assistant messages have generation_logprobs, so use that field as the
trainable-token marker. This function mutates each message in-place by adding a
token_loss_mask and, when missing, a zero-valued generation_logprobs tensor.

Args:
message_logs: Batch of tokenized message logs. Each message must contain a
``role`` and ``token_ids`` field. Messages that already contain
``generation_logprobs`` are treated as rollout-generated messages.
"""
for message_log in message_logs:
for message in message_log:
role = cast(str, message["role"])
token_ids = cast(torch.Tensor, message["token_ids"])

if role == "assistant" and "generation_logprobs" in message:
message["token_loss_mask"] = torch.ones_like(token_ids)
else:
message["token_loss_mask"] = torch.zeros_like(token_ids)

if "generation_logprobs" not in message:
message["generation_logprobs"] = torch.zeros_like(
token_ids, dtype=torch.float32
)


def _should_use_async_rollouts(master_config: MasterConfig) -> bool:
"""Determine if async rollouts should be used based on the configuration.

Expand Down Expand Up @@ -1098,28 +1162,6 @@ def _create_advantage_estimator(master_config: MasterConfig):
return adv_estimator


def _extract_prompt_only_messages(message_logs: list) -> list:
"""Extract only prompt messages (user/system) from message logs.

This is used to get prompt IDs for advantage estimation, excluding
any assistant responses.

Args:
message_logs: List of message logs, where each log is a list of messages.

Returns:
List of message logs containing only user and system messages.
"""
prompt_only_message_logs = []
for message_log in message_logs:
prompt_only_log = []
for message in message_log:
if message["role"] == "user" or message["role"] == "system":
prompt_only_log.append(message)
prompt_only_message_logs.append(prompt_only_log)
return prompt_only_message_logs


def refit_policy_generation(
policy: ColocatablePolicyInterface,
policy_generation: GenerationInterface,
Expand Down Expand Up @@ -1694,16 +1736,20 @@ def grpo_train(
# Save baseline for logging (before deletion)
baseline_for_log = baseline.clone()

# Extract prompt-only messages for advantage estimation
prompt_only_message_logs = _extract_prompt_only_messages(
repeated_batch["message_log"]
# Extract original prompt messages using the length field
# This correctly handles multi-turn prompts that contain assistant messages
initial_prompt_message_logs = extract_initial_prompt_messages(
repeated_batch["message_log"],
repeated_batch["length"],
)
prompt_batched_flat, _ = batched_message_log_to_flat_message(
prompt_only_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},
prompt_batched_flat, prompt_input_lengths = (
batched_message_log_to_flat_message(
initial_prompt_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
)
prompt_ids_for_adv = prompt_batched_flat["token_ids"]
del prompt_only_message_logs
del initial_prompt_message_logs
del prompt_batched_flat
del input_ids
del baseline
Expand All @@ -1720,21 +1766,9 @@ def grpo_train(

loss_multiplier[truncated] = 0
repeated_batch["loss_multiplier"] = loss_multiplier
# Add loss mask to each message in LLMMessageLogType
for i, message_log in enumerate(repeated_batch["message_log"]):
for j, message in enumerate(message_log):
if message["role"] == "assistant":
message["token_loss_mask"] = torch.ones_like(
message["token_ids"]
)
else:
message["token_loss_mask"] = torch.zeros_like(
message["token_ids"]
)
if "generation_logprobs" not in message:
message["generation_logprobs"] = torch.zeros_like(
message["token_ids"], dtype=torch.float32
)
add_grpo_token_loss_masks_and_generation_logprobs(
repeated_batch["message_log"]
)

# Convert updated LLMMessageLogType to FlatMessagesType for training
flat_messages, input_lengths = batched_message_log_to_flat_message(
Expand Down Expand Up @@ -2799,16 +2833,21 @@ def async_grpo_train(

print("▶ Processing rewards...")
with timer.time("reward_calculation"):
# Extract prompt-only messages for advantage estimation
prompt_only_message_logs = _extract_prompt_only_messages(
repeated_batch["message_log"]
# Extract original prompt messages using the length field
# This correctly handles multi-turn prompts that contain assistant messages
initial_prompt_message_logs = extract_initial_prompt_messages(
repeated_batch["message_log"],
repeated_batch["length"],
)
prompt_batched_flat, _ = batched_message_log_to_flat_message(
prompt_only_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},

prompt_batched_flat, prompt_input_lengths = (
batched_message_log_to_flat_message(
initial_prompt_message_logs,
pad_value_dict={"token_ids": tokenizer.pad_token_id},
)
)
prompt_ids_for_adv = prompt_batched_flat["token_ids"]
del prompt_only_message_logs
del initial_prompt_message_logs
del prompt_batched_flat

rewards = repeated_batch["total_reward"]
Expand All @@ -2819,21 +2858,9 @@ def async_grpo_train(

# Prepare training data (same as sync version)
with timer.time("data_processing"):
# Add loss mask to each message
for i, message_log in enumerate(repeated_batch["message_log"]):
for j, message in enumerate(message_log):
if message["role"] == "assistant":
message["token_loss_mask"] = torch.ones_like(
message["token_ids"]
)
else:
message["token_loss_mask"] = torch.zeros_like(
message["token_ids"]
)
if "generation_logprobs" not in message:
message["generation_logprobs"] = torch.zeros_like(
message["token_ids"], dtype=torch.float32
)
add_grpo_token_loss_masks_and_generation_logprobs(
repeated_batch["message_log"]
)

# Convert to flat format for training
flat_messages, input_lengths = batched_message_log_to_flat_message(
Expand Down
Loading
Loading