From bb280a8ef8948bb8dae4959cb7563b099db6e6cf Mon Sep 17 00:00:00 2001 From: ashors1 Date: Thu, 21 May 2026 15:26:02 -0700 Subject: [PATCH] make only_unmask_final configurable in SFT Signed-off-by: ashors1 --- examples/configs/sft.yaml | 3 +++ nemo_rl/algorithms/sft.py | 6 ++++++ tests/unit/reference_configs/sft.yaml | 3 +++ 3 files changed, 12 insertions(+) diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index cf02bdfc74..30dea47fde 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -12,6 +12,9 @@ sft: val_at_start: true val_at_end: false seed: 42 + # If true, only the final message in each conversation is unmasked for loss + # computation. If false, all assistant messages are unmasked. + only_unmask_final: false checkpointing: enabled: true diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 747a05fe93..73d33802cc 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -73,6 +73,10 @@ class SFTConfig(TypedDict): # final checkpoint has validation metrics, which is required for get_best_checkpoint_path(). val_at_end: bool seed: int + # If True, only the final message in each conversation is unmasked for loss + # computation, regardless of role. If False, all messages with role in + # `roles_to_train_on` (currently hard-coded to ["assistant"]) are unmasked. + only_unmask_final: bool class MasterConfig(BaseModel, extra="allow"): @@ -271,6 +275,7 @@ def validate( add_loss_mask_to_message_log( val_batch["message_log"], roles_to_train_on=["assistant"], + only_unmask_final=master_config.sft["only_unmask_final"], ) cat_and_padded, input_lengths = batched_message_log_to_flat_message( @@ -434,6 +439,7 @@ def sft_train( add_loss_mask_to_message_log( batch["message_log"], roles_to_train_on=["assistant"], + only_unmask_final=master_config.sft["only_unmask_final"], ) cat_and_padded, input_lengths = batched_message_log_to_flat_message( diff --git a/tests/unit/reference_configs/sft.yaml b/tests/unit/reference_configs/sft.yaml index 416895b635..ce3c206190 100644 --- a/tests/unit/reference_configs/sft.yaml +++ b/tests/unit/reference_configs/sft.yaml @@ -12,6 +12,9 @@ sft: val_at_start: true val_at_end: false seed: 42 + # If true, only the final message in each conversation is unmasked for loss + # computation. If false, all assistant messages are unmasked. + only_unmask_final: false checkpointing: enabled: true