|
28 | 28 | from nemo_rl.algorithms.grpo import ( |
29 | 29 | MasterConfig, |
30 | 30 | _default_grpo_save_state, |
| 31 | + aggregate_rollout_metrics, |
31 | 32 | async_grpo_train, |
32 | 33 | compute_and_apply_seq_logprob_error_masking, |
33 | 34 | dynamic_sampling, |
@@ -2770,3 +2771,76 @@ def test_threshold_boundary_values(self): |
2770 | 2771 | # At least sequence 2 should be masked |
2771 | 2772 | assert num_masked >= 1, "At least one sequence should be masked" |
2772 | 2773 | assert train_data["sample_mask"][0] == 1.0, "Sequence 0 should be kept" |
| 2774 | + |
| 2775 | + |
| 2776 | +class TestAggregateRolloutMetrics: |
| 2777 | + """Tests for aggregate_rollout_metrics which aggregates per-group metrics by semantic type.""" |
| 2778 | + |
| 2779 | + def test_min_metrics_take_minimum(self): |
| 2780 | + metrics = { |
| 2781 | + "gen_tokens/min": [10, 5, 8], |
| 2782 | + "min_reward": [0.3, 0.1, 0.5], |
| 2783 | + } |
| 2784 | + result = aggregate_rollout_metrics(metrics) |
| 2785 | + assert result["gen_tokens/min"] == 5 |
| 2786 | + assert result["min_reward"] == 0.1 |
| 2787 | + |
| 2788 | + def test_max_metrics_take_maximum(self): |
| 2789 | + metrics = { |
| 2790 | + "gen_tokens/max": [10, 50, 30], |
| 2791 | + "max_reward": [0.3, 0.9, 0.5], |
| 2792 | + } |
| 2793 | + result = aggregate_rollout_metrics(metrics) |
| 2794 | + assert result["gen_tokens/max"] == 50 |
| 2795 | + assert result["max_reward"] == 0.9 |
| 2796 | + |
| 2797 | + def test_rate_suffix_excluded_from_min_max(self): |
| 2798 | + """min_*_rate and max_*_rate should be averaged, not min/maxed.""" |
| 2799 | + metrics = { |
| 2800 | + "min_completion_rate": [0.2, 0.8, 0.5], |
| 2801 | + "max_completion_rate": [0.3, 0.9, 0.6], |
| 2802 | + } |
| 2803 | + result = aggregate_rollout_metrics(metrics) |
| 2804 | + assert result["min_completion_rate"] == pytest.approx(0.5) |
| 2805 | + assert result["max_completion_rate"] == pytest.approx(0.6) |
| 2806 | + |
| 2807 | + def test_total_turns_summed(self): |
| 2808 | + metrics = {"total_turns": [10, 20, 30]} |
| 2809 | + result = aggregate_rollout_metrics(metrics) |
| 2810 | + assert result["total_turns"] == 60 |
| 2811 | + |
| 2812 | + def test_mean_metrics_averaged(self): |
| 2813 | + metrics = { |
| 2814 | + "mean_gen_tokens_per_sample": [100, 200, 300], |
| 2815 | + "reward/mean": [0.5, 0.7, 0.9], |
| 2816 | + } |
| 2817 | + result = aggregate_rollout_metrics(metrics) |
| 2818 | + assert result["mean_gen_tokens_per_sample"] == pytest.approx(200.0) |
| 2819 | + assert result["reward/mean"] == pytest.approx(0.7) |
| 2820 | + |
| 2821 | + def test_non_numeric_passed_through(self): |
| 2822 | + metrics = {"some_list_metric": [["a", "b"], ["c", "d"]]} |
| 2823 | + result = aggregate_rollout_metrics(metrics) |
| 2824 | + assert result["some_list_metric"] == [["a", "b"], ["c", "d"]] |
| 2825 | + |
| 2826 | + def test_mixed_metrics(self): |
| 2827 | + """Full integration test with a realistic mix of metric types.""" |
| 2828 | + metrics = { |
| 2829 | + "gen_tokens/min": [5, 3, 7], |
| 2830 | + "gen_tokens/max": [100, 200, 150], |
| 2831 | + "gen_tokens/mean": [50, 60, 70], |
| 2832 | + "min_reward": [0.1, 0.2, 0.05], |
| 2833 | + "max_reward": [0.9, 0.8, 0.95], |
| 2834 | + "total_turns": [10, 15, 20], |
| 2835 | + "accuracy": [0.8, 0.9, 0.7], |
| 2836 | + "min_accuracy_rate": [0.1, 0.2, 0.3], |
| 2837 | + } |
| 2838 | + result = aggregate_rollout_metrics(metrics) |
| 2839 | + assert result["gen_tokens/min"] == 3 |
| 2840 | + assert result["gen_tokens/max"] == 200 |
| 2841 | + assert result["gen_tokens/mean"] == pytest.approx(60.0) |
| 2842 | + assert result["min_reward"] == 0.05 |
| 2843 | + assert result["max_reward"] == 0.95 |
| 2844 | + assert result["total_turns"] == 45 |
| 2845 | + assert result["accuracy"] == pytest.approx(0.8) |
| 2846 | + assert result["min_accuracy_rate"] == pytest.approx(0.2) |
0 commit comments