-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathmath_answer_as_judge.py
More file actions
65 lines (55 loc) · 2.53 KB
/
math_answer_as_judge.py
File metadata and controls
65 lines (55 loc) · 2.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import re
from ajet.task_judge.base_judge import BaseJudge
from ajet.task_rollout.dashscope_llm_bridge import create_external_llm_fn
from ajet.workflow import WorkflowOutput, WorkflowTask
class MathAnswerAsJudge(BaseJudge):
def __init__(self, config):
self.config = config
def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> tuple:
raw_reward = 0
final_answer = workflow_output.metadata[
"final_answer"
] # By default there's no final_answer; register it by calling ajet_proxy.update_judge_input_dictionary(final_answer=final_answer) in the workflow
reference_answer = workflow_task.task.metadata["answer"]
reference_answer = reference_answer.split("####")[-1].strip()
pattern = r"\\boxed\{([^}]*)\}"
match = re.search(pattern, final_answer)
if match:
result = match.group(1)
is_success = result == reference_answer
else:
is_success = False
raw_reward = 1.0 if is_success else 0.0
return raw_reward, is_success
class MathAnswerAndLlmAsJudge(BaseJudge):
def __init__(self, config):
self.config = config
def compute_reward(self, workflow_task: WorkflowTask, workflow_output: WorkflowOutput) -> tuple:
raw_reward = 0
final_answer = workflow_output.metadata[
"final_answer"
] # By default there's no final_answer; register it by calling ajet_proxy.update_judge_input_dictionary(final_answer=final_answer) in the workflow
reference_answer = workflow_task.task.metadata["answer"]
reference_answer = reference_answer.split("####")[-1].strip()
external_llm_fn = create_external_llm_fn(
alien_llm_model=self.config.ajet.task_judge.alien_llm_model,
alien_llm_response_length=self.config.ajet.task_judge.alien_llm_response_length,
)
messages = [
{
"role": "system",
"content": "Is my result correct? If correct, say <Correct>, otherwise say <NotCorrect>.",
},
{
"role": "user",
"content": f"Is my result correct?\n\n\n----\nMy result: {final_answer}\n\n\n----\nReal result: {reference_answer}",
},
]
res = external_llm_fn(messages=messages)
if "<Correct>" in res["content"]:
is_success = True
raw_reward = 1.0
else:
is_success = False
raw_reward = 0.0
return raw_reward, is_success