We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent c746af2 commit e09b8eaCopy full SHA for e09b8ea
1 file changed
agentlightning/verl/trainer.py
@@ -417,10 +417,17 @@ def _train_step(self, batch_dict: dict) -> dict:
417
print(batch.batch.keys())
418
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
419
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
420
+ sample_gts = [
421
+ item.non_tensor_batch.get("reward_model", {}).get(
422
+ "ground_truth", None
423
+ )
424
+ for item in batch
425
+ ]
426
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
427
self._dump_generations(
428
inputs=inputs,
429
outputs=outputs,
430
+ gts=sample_gts,
431
scores=scores,
432
reward_extra_infos_dict=reward_extra_infos_dict,
433
dump_path=rollout_data_dir,
0 commit comments