Skip to content
This repository was archived by the owner on Nov 19, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Add code and instructions for replicating Reward Modeling training in HelpSteer2 and HelpSteer2-Preference
- Implement REINFORCE algorithm.
- Add support for multiple validation sets when training a Reward Model.
- Add support for `validation_drop_last=False` when training a Reward Model.

### Breaking Changes
- Upgrade TRTLLM dependency from v0.10.0 to v0.12.0 and migrate from `GPTSession` cpp runtime to `ModelRunner` python runtime. Please use the latest Dockerfile.
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/gpt/train_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def main(cfg) -> None:
use_random_sampler=cfg.trainer.rm.val_random_sampler,
)
elif isinstance(validation_ds, dict):
drop_last = cfg.model.data.get("validation_drop_last", True)
val_dataloader = {
key: build_dataloader(
cfg=cfg,
Expand All @@ -145,6 +146,8 @@ def main(cfg) -> None:
gbs=cfg.model.global_batch_size,
load_gbs=True,
use_random_sampler=False,
drop_last=drop_last,
pad_samples_to_global_batch_size=not drop_last,
)
for key, dataset in validation_ds.items()
}
Expand Down
8 changes: 7 additions & 1 deletion nemo_aligner/algorithms/supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,13 @@ def run_validation_one_dataset(self, key: str):
log_val_metrics = {f"val_{k}": v for k, v in metrics.items()}
val_pbar.set_postfix(log_val_metrics)

val_metrics = {k: mean(v) for k, v in val_metrics.items()}
if "weights" in val_metrics:
w = val_metrics.pop("weights")
val_metrics = {
k: sum([value * weight for value, weight in zip(v, w)]) / sum(w) for k, v in val_metrics.items()
}
else:
val_metrics = {k: mean(v) for k, v in val_metrics.items()}
val_metrics.update(self.inference_metrics_handler.compute())
self.inference_metrics_handler.reset()

Expand Down
2 changes: 1 addition & 1 deletion nemo_aligner/data/nlp/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _build_dataset(current_data_prefix, current_num_samples):
logging.info(" Total {} documents is : {} ".format(name, total_num_of_documents))

drop_last = True
if name == "valid":
if name.startswith("validation"):
drop_last = cfg.data.get("validation_drop_last", True)

dataset = cls(
Expand Down
109 changes: 57 additions & 52 deletions nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,49 +184,54 @@ def gather_and_split_rewards(rewards_out):

def loss_func(output_tensor):
# Loss per micro batch (ub).
loss_for_ub, acc_chosen = self.loss_func(output_tensor)
loss_for_ub = self.loss_func(output_tensor)
# Number of valid pairs in the micro batch.
mask_valid_pairs = batch["loss_mask"].sum(dim=1) > 0
num_valid_pairs = mask_valid_pairs.sum()
# Compute loss average over valid pairs only.
loss_for_ub = loss_for_ub[mask_valid_pairs].mean()
# Compute accuracy over valid pairs only.
out_chosen, out_rejected = self.split_output_tensor(output_tensor)
comp = out_chosen > out_rejected
num_correct_pairs = comp[mask_valid_pairs].sum()
acc_chosen = num_correct_pairs / max(1, num_valid_pairs)
if validation_step and not self.cfg.data.get("validation_drop_last", True):
num_valid_tokens_in_ub = batch["loss_mask"].sum()

if loss_for_ub.isnan():
assert batch["loss_mask"].count_nonzero() == 0, "Got NaN loss with non-empty input"
loss_sum_for_ub = torch.zeros_like(num_valid_tokens_in_ub)
assert num_valid_pairs == 0, "Got NaN loss with non-empty input"
loss_sum_for_ub = torch.zeros_like(num_valid_pairs, dtype=loss_for_ub.dtype)
else:
loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub

loss_sum_and_ub_size_all_gpu = torch.cat(
[
loss_sum_for_ub.clone().detach().view(1),
torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(),
]
)
torch.distributed.all_reduce(
loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group()
)
out_chosen, out_rejected = gather_and_split_rewards(output_tensor)

return (
loss_for_ub,
{
"loss_sum_and_ub_size": loss_sum_and_ub_size_all_gpu,
"out_chosen": out_chosen,
"out_rejected": out_rejected,
},
)
loss_sum_for_ub = num_valid_pairs * loss_for_ub

tensor_to_reduce = torch.stack([loss_sum_for_ub, num_valid_pairs, num_correct_pairs,])
torch.distributed.all_reduce(tensor_to_reduce, group=parallel_state.get_data_parallel_group())
loss_sum_for_ub, num_valid_pairs, num_correct_pairs = tensor_to_reduce

if num_valid_pairs > 0:
reduced_loss = loss_sum_for_ub / num_valid_pairs
reduced_acc = num_correct_pairs / num_valid_pairs
else:
reduced_loss = torch.zeros_like(loss_sum_for_ub)
reduced_acc = torch.zeros_like(num_correct_pairs)

else:
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
reduced_acc = average_losses_across_data_parallel_group([acc_chosen])

out_chosen, out_rejected = gather_and_split_rewards(output_tensor)

return (
loss_for_ub,
{
"avg": reduced_loss,
"acc": reduced_acc,
"out_chosen": out_chosen,
"out_rejected": out_rejected,
},
)
# This assumes `drop_last=True` -- which is normally the case during training.
num_valid_pairs = num_valid_pairs * parallel_state.get_data_parallel_world_size()

out_chosen, out_rejected = gather_and_split_rewards(output_tensor)

return (
loss_for_ub,
{
"num_valid_pairs": num_valid_pairs,
"avg": reduced_loss,
"acc": reduced_acc,
"out_chosen": out_chosen,
"out_rejected": out_rejected,
},
)

return output_tensor, loss_func

Expand All @@ -238,10 +243,8 @@ def split_output_tensor(self, output_tensor):

def loss_func(self, output_tensor):
out_chosen, out_rejected = self.split_output_tensor(output_tensor)
comp = out_chosen > out_rejected
acc_chosen = torch.sum(comp) / comp.shape[0]
loss = -torch.nn.functional.logsigmoid(out_chosen - out_rejected).mean()
return loss, acc_chosen
loss = -torch.nn.functional.logsigmoid(out_chosen - out_rejected)
return loss

def get_loss_and_metrics(self, batch, forward_only):
data_iter = get_iterator_k_split(batch, get_num_microbatches())
Expand Down Expand Up @@ -272,21 +275,21 @@ def get_loss_and_metrics(self, batch, forward_only):
rewards_all_mean = rewards_all.mean()
rewards_all_std = rewards_all.std()

# average loss across micro batches
loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
loss_mean = loss_tensor.mean()
acc_tensors_list = [loss_reduced["acc"] for loss_reduced in losses_reduced_per_micro_batch]

if len(acc_tensors_list) == 1:
acc_tensor = acc_tensors_list[0]
elif len(acc_tensors_list) > 1:
acc_tensor = torch.concat(acc_tensors_list)
acc_mean = acc_tensor.mean()
num_valid_pairs = torch.stack(
[loss_reduced["num_valid_pairs"] for loss_reduced in losses_reduced_per_micro_batch]
)
loss_tensor = torch.cat([loss_reduced["avg"].view(1) for loss_reduced in losses_reduced_per_micro_batch])
acc_tensor = torch.cat([loss_reduced["acc"].view(1) for loss_reduced in losses_reduced_per_micro_batch])

weights = num_valid_pairs.sum().float()
loss_mean = (loss_tensor * num_valid_pairs).sum() / weights
acc_mean = (acc_tensor * num_valid_pairs).sum() / weights

else:

loss_mean = torch.tensor(0.0, device=torch.cuda.current_device())
acc_mean = torch.tensor(0.0, device=torch.cuda.current_device())
weights = torch.tensor(0.0, device=torch.cuda.current_device())

rewards_chosen_mean = torch.tensor(0.0, device=torch.cuda.current_device())
rewards_rejected_mean = torch.tensor(0.0, device=torch.cuda.current_device())
Expand All @@ -296,6 +299,7 @@ def get_loss_and_metrics(self, batch, forward_only):
# we can only log on one rank if it is rank zero so we broadcast from last rank
torch.distributed.broadcast(loss_mean, get_last_rank())
torch.distributed.broadcast(acc_mean, get_last_rank())
torch.distributed.broadcast(weights, get_last_rank())

torch.distributed.broadcast(rewards_chosen_mean, get_last_rank())
torch.distributed.broadcast(rewards_rejected_mean, get_last_rank())
Expand All @@ -305,6 +309,7 @@ def get_loss_and_metrics(self, batch, forward_only):
metrics = {
"loss": loss_mean,
"acc": acc_mean,
"weights": weights,
"rewards_chosen_mean": rewards_chosen_mean,
"rewards_rejected_mean": rewards_rejected_mean,
"rewards_all_mean": rewards_all_mean,
Expand Down