diff --git a/examples/nlp/gpt/train_reward_model.py b/examples/nlp/gpt/train_reward_model.py index 42cf5bde2..5fe04c3c6 100644 --- a/examples/nlp/gpt/train_reward_model.py +++ b/examples/nlp/gpt/train_reward_model.py @@ -126,6 +126,7 @@ def main(cfg) -> None: use_random_sampler=cfg.trainer.rm.train_random_sampler, ) if isinstance(validation_ds, RewardModelDataset): + drop_last = cfg.model.data.get("validation_drop_last", True) val_dataloader = build_dataloader( cfg=cfg, dataset=validation_ds, @@ -134,6 +135,8 @@ def main(cfg) -> None: gbs=cfg.model.global_batch_size, load_gbs=True, use_random_sampler=cfg.trainer.rm.val_random_sampler, + drop_last=drop_last, + pad_samples_to_global_batch_size=not drop_last, ) elif isinstance(validation_ds, dict): drop_last = cfg.model.data.get("validation_drop_last", True) @@ -145,7 +148,7 @@ def main(cfg) -> None: mbs=cfg.model.micro_batch_size, gbs=cfg.model.global_batch_size, load_gbs=True, - use_random_sampler=False, + use_random_sampler=cfg.trainer.rm.val_random_sampler, drop_last=drop_last, pad_samples_to_global_batch_size=not drop_last, ) diff --git a/nemo_aligner/algorithms/supervised.py b/nemo_aligner/algorithms/supervised.py index d607d1fd3..8cf0e6455 100644 --- a/nemo_aligner/algorithms/supervised.py +++ b/nemo_aligner/algorithms/supervised.py @@ -145,17 +145,19 @@ def run_validation_one_dataset(self, key: str): if "weights" in val_metrics: w = val_metrics.pop("weights") + val_loss = sum([value * weight for value, weight in zip(loss_means, w)]) / sum(w) val_metrics = { k: sum([value * weight for value, weight in zip(v, w)]) / sum(w) for k, v in val_metrics.items() } else: + val_loss = mean(loss_means) val_metrics = {k: mean(v) for k, v in val_metrics.items()} val_metrics.update(self.inference_metrics_handler.compute()) self.inference_metrics_handler.reset() self.logger.log_metrics(val_metrics, step=self.step, prefix=f"{key}/") - return mean(loss_means), val_metrics + return val_loss, val_metrics def train_single_step(self, batch): self.optimizer.zero_grad()