Skip to content

Commit 501cd12

Browse files
yfwclaude
andauthored
fix: aggregate rollout metrics by semantic type (min/max/sum/mean) (#2175)
Signed-off-by: Yi-Fu Wu <yifu.wu@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b603568 commit 501cd12

2 files changed

Lines changed: 111 additions & 8 deletions

File tree

nemo_rl/algorithms/grpo.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,6 +2433,39 @@ def validate(
24332433
return val_metrics, timing_metrics
24342434

24352435

2436+
def aggregate_rollout_metrics(
2437+
per_group_metrics: dict[str, list],
2438+
) -> dict[str, Any]:
2439+
"""Aggregate rollout metrics from multiple trajectory groups.
2440+
2441+
Different metric types are aggregated according to their semantics:
2442+
- Metrics ending with "/min" or starting with "min_" (excluding "_rate" suffix): take the minimum
2443+
- Metrics ending with "/max" or starting with "max_" (excluding "_rate" suffix): take the maximum
2444+
- "total_turns": summed
2445+
- Non-numeric values: passed through as-is
2446+
- All other numeric metrics: averaged
2447+
2448+
Args:
2449+
per_group_metrics: A dict mapping metric names to lists of per-group values.
2450+
2451+
Returns:
2452+
A dict mapping metric names to their aggregated scalar values.
2453+
"""
2454+
aggregated = {}
2455+
for k, v in per_group_metrics.items():
2456+
if not isinstance(v[0], (int, float)):
2457+
aggregated[k] = v
2458+
elif k.endswith("/min") or (k.startswith("min_") and not k.endswith("_rate")):
2459+
aggregated[k] = min(v)
2460+
elif k.endswith("/max") or (k.startswith("max_") and not k.endswith("_rate")):
2461+
aggregated[k] = max(v)
2462+
elif k == "total_turns":
2463+
aggregated[k] = sum(v)
2464+
else:
2465+
aggregated[k] = sum(v) / len(v)
2466+
return aggregated
2467+
2468+
24362469
def async_grpo_train(
24372470
policy: ColocatablePolicyInterface,
24382471
policy_generation: Optional[GenerationInterface],
@@ -2765,16 +2798,12 @@ def async_grpo_train(
27652798
# Concatenate per-prompt groups into a single training batch
27662799
per_prompt_batches = [t["batch"] for t in trajectories]
27672800
repeated_batch = BatchedDataDict.from_batches(per_prompt_batches)
2768-
# Aggregate rollout metrics across groups (simple mean where applicable)
2769-
rollout_metrics = {}
2801+
# Aggregate rollout metrics across groups with proper aggregation per metric type
2802+
per_group_metrics = {}
27702803
for t in trajectories:
27712804
for k, v in t["rollout_metrics"].items():
2772-
rollout_metrics.setdefault(k, []).append(v)
2773-
# TODO: this simple averaging might cause misleading information for such data as max_gen_tokens, etc.
2774-
rollout_metrics = {
2775-
k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v)
2776-
for k, v in rollout_metrics.items()
2777-
}
2805+
per_group_metrics.setdefault(k, []).append(v)
2806+
rollout_metrics = aggregate_rollout_metrics(per_group_metrics)
27782807

27792808
# Enforce fixed training batch: num_prompts_per_step * num_generations_per_prompt
27802809
expected_batch_size = (

tests/unit/algorithms/test_grpo.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from nemo_rl.algorithms.grpo import (
2929
MasterConfig,
3030
_default_grpo_save_state,
31+
aggregate_rollout_metrics,
3132
async_grpo_train,
3233
compute_and_apply_seq_logprob_error_masking,
3334
dynamic_sampling,
@@ -2770,3 +2771,76 @@ def test_threshold_boundary_values(self):
27702771
# At least sequence 2 should be masked
27712772
assert num_masked >= 1, "At least one sequence should be masked"
27722773
assert train_data["sample_mask"][0] == 1.0, "Sequence 0 should be kept"
2774+
2775+
2776+
class TestAggregateRolloutMetrics:
2777+
"""Tests for aggregate_rollout_metrics which aggregates per-group metrics by semantic type."""
2778+
2779+
def test_min_metrics_take_minimum(self):
2780+
metrics = {
2781+
"gen_tokens/min": [10, 5, 8],
2782+
"min_reward": [0.3, 0.1, 0.5],
2783+
}
2784+
result = aggregate_rollout_metrics(metrics)
2785+
assert result["gen_tokens/min"] == 5
2786+
assert result["min_reward"] == 0.1
2787+
2788+
def test_max_metrics_take_maximum(self):
2789+
metrics = {
2790+
"gen_tokens/max": [10, 50, 30],
2791+
"max_reward": [0.3, 0.9, 0.5],
2792+
}
2793+
result = aggregate_rollout_metrics(metrics)
2794+
assert result["gen_tokens/max"] == 50
2795+
assert result["max_reward"] == 0.9
2796+
2797+
def test_rate_suffix_excluded_from_min_max(self):
2798+
"""min_*_rate and max_*_rate should be averaged, not min/maxed."""
2799+
metrics = {
2800+
"min_completion_rate": [0.2, 0.8, 0.5],
2801+
"max_completion_rate": [0.3, 0.9, 0.6],
2802+
}
2803+
result = aggregate_rollout_metrics(metrics)
2804+
assert result["min_completion_rate"] == pytest.approx(0.5)
2805+
assert result["max_completion_rate"] == pytest.approx(0.6)
2806+
2807+
def test_total_turns_summed(self):
2808+
metrics = {"total_turns": [10, 20, 30]}
2809+
result = aggregate_rollout_metrics(metrics)
2810+
assert result["total_turns"] == 60
2811+
2812+
def test_mean_metrics_averaged(self):
2813+
metrics = {
2814+
"mean_gen_tokens_per_sample": [100, 200, 300],
2815+
"reward/mean": [0.5, 0.7, 0.9],
2816+
}
2817+
result = aggregate_rollout_metrics(metrics)
2818+
assert result["mean_gen_tokens_per_sample"] == pytest.approx(200.0)
2819+
assert result["reward/mean"] == pytest.approx(0.7)
2820+
2821+
def test_non_numeric_passed_through(self):
2822+
metrics = {"some_list_metric": [["a", "b"], ["c", "d"]]}
2823+
result = aggregate_rollout_metrics(metrics)
2824+
assert result["some_list_metric"] == [["a", "b"], ["c", "d"]]
2825+
2826+
def test_mixed_metrics(self):
2827+
"""Full integration test with a realistic mix of metric types."""
2828+
metrics = {
2829+
"gen_tokens/min": [5, 3, 7],
2830+
"gen_tokens/max": [100, 200, 150],
2831+
"gen_tokens/mean": [50, 60, 70],
2832+
"min_reward": [0.1, 0.2, 0.05],
2833+
"max_reward": [0.9, 0.8, 0.95],
2834+
"total_turns": [10, 15, 20],
2835+
"accuracy": [0.8, 0.9, 0.7],
2836+
"min_accuracy_rate": [0.1, 0.2, 0.3],
2837+
}
2838+
result = aggregate_rollout_metrics(metrics)
2839+
assert result["gen_tokens/min"] == 3
2840+
assert result["gen_tokens/max"] == 200
2841+
assert result["gen_tokens/mean"] == pytest.approx(60.0)
2842+
assert result["min_reward"] == 0.05
2843+
assert result["max_reward"] == 0.95
2844+
assert result["total_turns"] == 45
2845+
assert result["accuracy"] == pytest.approx(0.8)
2846+
assert result["min_accuracy_rate"] == pytest.approx(0.2)

0 commit comments

Comments
 (0)