Skip to content

Commit b708fef

Browse files
committed
feat: support structured reward outputs and grouped reward aggregation
1 parent d489178 commit b708fef

2 files changed

Lines changed: 8 additions & 3 deletions

File tree

areal/api/reward_api.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from concurrent.futures import ProcessPoolExecutor
1010
from concurrent.futures.process import BrokenProcessPool
1111
from functools import partial
12-
1312
from areal.utils import logging
1413

1514
logger = logging.getLogger("RewardAPI")
@@ -56,7 +55,7 @@ def reward_fn(
5655
:param completion_ids: The token IDs of the trajectory generated by the model.
5756
:param kwargs: Other attributes of the data in the dataset, such as solutions, input_outputs, etc.
5857
Any other attributes in the dataset will be passed as keyword arguments to this function.
59-
:rtype: float
58+
:rtype: float | dict[str, float]
6059
"""
6160

6261

@@ -135,7 +134,7 @@ def _recreate_executor(cls, executor_key, max_workers):
135134
return cls._executors[executor_key]
136135
return None
137136

138-
async def __call__(self, *args, **kwargs) -> float:
137+
async def __call__(self, *args, **kwargs) -> float | dict[str, float]:
139138
last_exception = None
140139

141140
for attempt in range(self.max_retries + 1):

areal/infra/remote_inf_engine.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ async def arun_episode(
102102
"trajectories returned None, using remaining results"
103103
)
104104

105+
aggregate_group_results = getattr(
106+
self.workflow, "aggregate_group_results", None
107+
)
108+
if callable(aggregate_group_results):
109+
return aggregate_group_results(valid_results)
110+
105111
# Check if results are InteractionWithTokenLogpReward dicts
106112
first = valid_results[0]
107113
if (

0 commit comments

Comments
 (0)