Skip to content

Commit e09b8ea

Browse files
committed
fix: pass gts argument in _dump_generations call in _train_step
1 parent c746af2 commit e09b8ea

1 file changed

Lines changed: 7 additions & 0 deletions

File tree

agentlightning/verl/trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,10 +417,17 @@ def _train_step(self, batch_dict: dict) -> dict:
417417
print(batch.batch.keys())
418418
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
419419
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
420+
sample_gts = [
421+
item.non_tensor_batch.get("reward_model", {}).get(
422+
"ground_truth", None
423+
)
424+
for item in batch
425+
]
420426
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
421427
self._dump_generations(
422428
inputs=inputs,
423429
outputs=outputs,
430+
gts=sample_gts,
424431
scores=scores,
425432
reward_extra_infos_dict=reward_extra_infos_dict,
426433
dump_path=rollout_data_dir,

0 commit comments

Comments
 (0)