Skip to content
This repository was archived by the owner on Nov 19, 2025. It is now read-only.

Commit 062dcb0

Browse files
Streamlines the augment_dataloader method in DPO (#134)
* Initial commit of DPO augment cleanup Signed-off-by: Daniel Egert <degert@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes for PR review Signed-off-by: Daniel Egert <degert@nvidia.com> --------- Signed-off-by: Daniel Egert <degert@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent eb15079 commit 062dcb0

2 files changed

Lines changed: 14 additions & 27 deletions

File tree

nemo_aligner/algorithms/dpo.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -308,27 +308,18 @@ def load_state_dict(self, state_dict):
308308
def augment_dataloader(self, dataloader):
309309
"""Augment dataloader with ref policy log prob"""
310310
iter_dataloader = iter(dataloader)
311-
buffer = []
312-
done = False
313-
while not done:
311+
while True:
314312
try:
315313
batch = next(iter_dataloader)
314+
logprobs = self.model.get_ref_policy_logprobs(batch).cpu()
315+
chosen_logps, reject_logps = torch.split(logprobs, len(logprobs) // 2, dim=0)
316+
batch["ref_policy_log_probs_chosen"] = chosen_logps
317+
batch["ref_policy_log_probs_rejected"] = reject_logps
318+
319+
yield batch
320+
del logprobs, chosen_logps, reject_logps
316321
except StopIteration:
317-
done = True
318-
else:
319-
buffer.append(batch)
320-
if (done and buffer) or len(buffer) == 1:
321-
logprobs = self.model.get_ref_policy_logprobs(buffer).cpu()
322-
start = 0
323-
for batch in buffer:
324-
batch_size = len(batch["chosen"])
325-
assert len(batch["rejected"]) == batch_size
326-
for key in ("chosen", "rejected"):
327-
batch[f"ref_policy_log_probs_{key}"] = logprobs[start : start + batch_size]
328-
start += batch_size
329-
yield batch
330-
buffer.clear()
331-
del logprobs
322+
break
332323

333324
@property
334325
def epoch(self):

nemo_aligner/models/nlp/gpt/megatron_gpt_dpo_model.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -355,15 +355,11 @@ def get_logprob_batch(self, global_batch):
355355

356356
return logprobs
357357

358-
def get_ref_policy_logprobs(self, list_of_batches):
359-
tokens = torch.cat([torch.cat((b["chosen"], b["rejected"]), dim=0) for b in list_of_batches], dim=0)
360-
masks = torch.cat(
361-
[torch.cat((b["attention_mask"], b["attention_mask"]), dim=0) for b in list_of_batches], dim=0
362-
)
363-
pos_ids = torch.cat([torch.cat((b["position_ids"], b["position_ids"]), dim=0) for b in list_of_batches], dim=0)
364-
labels = torch.cat(
365-
[torch.cat((b["chosen_labels"], b["rejected_labels"]), dim=0) for b in list_of_batches], dim=0
366-
)
358+
def get_ref_policy_logprobs(self, batch):
359+
tokens = torch.cat((batch["chosen"], batch["rejected"]), dim=0)
360+
masks = torch.cat((batch["attention_mask"], batch["attention_mask"]), dim=0)
361+
pos_ids = torch.cat((batch["position_ids"], batch["position_ids"]), dim=0)
362+
labels = torch.cat((batch["chosen_labels"], batch["rejected_labels"]), dim=0)
367363
global_batch = [tokens, masks, pos_ids, labels]
368364
with cpu_weight_swap(self, self.ref_policy_state_dict, megatron_amp_O2=self.megatron_amp_O2):
369365
ref_log_probs = self.get_logprob_batch(global_batch)

0 commit comments

Comments
 (0)