forked from Anubhav741/ThinkSync
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgrader.py
More file actions
224 lines (167 loc) · 7.31 KB
/
grader.py
File metadata and controls
224 lines (167 loc) · 7.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
TrustOps-Env: Grader Module (OpenEnv-Compliant)
"""
from typing import Any
from models import (
Content, Action, ActionType, Difficulty, ContentLabel,
CONTENT_BANK
)
# Small epsilon to avoid exact 0 or 1
EPS = 1e-6
def _compute_embedding_similarity(agent_reasoning: str, label: str) -> float:
try:
from engine import _compute_embedding_similarity as _engine_sim
return _engine_sim(agent_reasoning, label)
except Exception:
return 0.05
def safe(x: float) -> float:
"""Global safety clamp: ensures (0,1) exclusive"""
try:
val = float(x)
except (TypeError, ValueError):
val = 0.5
return max(EPS, min(val, 1.0 - EPS))
def normalize_score(total_reward, max_reward):
raw = total_reward / max_reward if max_reward != 0 else 0
if raw <= 0:
return EPS
elif raw >= 1:
return 1.0 - EPS
# return safe(round(raw, 3))
return safe(raw)
# ─── Internal scoring logic ─────────────────────
def _score_base(content: Content, agent_action: ActionType) -> float:
expected_action = content.expected_action
# ✅ Now spans full range
if agent_action == expected_action:
# return normalize_score(1.0, 1.0) # near 1
return safe(0.999)
elif agent_action == ActionType.FLAG:
return normalize_score(0.5, 1.0) # mid
else:
# return normalize_score(0.0, 1.0) # near 0
return safe(0.001)
def _score_easy(content: Content, agent_action: ActionType, agent_reasoning: str) -> float:
return _score_base(content, agent_action)
def _score_medium(content: Content, agent_action: ActionType, agent_reasoning: str) -> float:
return _score_base(content, agent_action)
def _score_hard(content: Content, agent_action: ActionType, agent_reasoning: str) -> float:
return _score_base(content, agent_action)
# ─── Helper functions ─────────────────────
def _extract_action_reasoning(sample: Any) -> tuple:
agent_action = ActionType.FLAG
agent_reasoning = ""
try:
if isinstance(sample, dict):
raw_action = sample.get("action_type",
sample.get("action",
sample.get("output",
sample.get("response", "flag"))))
if isinstance(raw_action, dict):
raw_action = raw_action.get("action_type", "flag")
raw_action = str(raw_action).lower().strip()
if raw_action in ("approve", "remove", "flag"):
agent_action = ActionType(raw_action)
agent_reasoning = str(sample.get("reasoning_chain",
sample.get("reasoning",
sample.get("explanation",
sample.get("output", "")))))
elif hasattr(sample, "action_type"):
agent_action = sample.action_type
agent_reasoning = getattr(sample, "reasoning_chain", "")
elif isinstance(sample, str):
agent_reasoning = sample
except Exception:
pass
return agent_action, agent_reasoning
def _extract_content(item: Any) -> Content:
if isinstance(item, Content):
return item
if isinstance(item, dict):
try:
return Content(**item)
except Exception:
return Content(
id=str(item.get("id", "unknown")),
text=str(item.get("text", item.get("content", ""))),
difficulty=Difficulty(item.get("difficulty", "EASY")),
expected_label=ContentLabel(item.get("expected_label", "SAFE")),
expected_action=ActionType(item.get("expected_action", "approve")),
)
return CONTENT_BANK[0]
# ─── Grader functions ─────────────────────
def grade_easy_detection(*args, **kwargs) -> float:
try:
if len(args) == 2 and not kwargs:
sample, item = args
content = _extract_content(item)
agent_action, agent_reasoning = _extract_action_reasoning(sample)
elif len(args) >= 4 or kwargs:
content = args[0] if args else kwargs.get("content", CONTENT_BANK[0])
agent_action = args[1] if len(args) > 1 else kwargs.get("agent_action", ActionType.FLAG)
agent_reasoning = args[2] if len(args) > 2 else kwargs.get("agent_reasoning", "")
elif len(args) == 1:
sample = args[0]
if isinstance(sample, dict):
content = _extract_content(sample.get("item", sample))
agent_action, agent_reasoning = _extract_action_reasoning(sample)
else:
return safe(0.5)
else:
return safe(0.5)
return safe(_score_easy(content, agent_action, agent_reasoning))
except Exception:
return safe(0.5)
def grade_medium_classification(*args, **kwargs) -> float:
try:
if len(args) == 2 and not kwargs:
sample, item = args
content = _extract_content(item)
agent_action, agent_reasoning = _extract_action_reasoning(sample)
elif len(args) >= 4 or kwargs:
content = args[0] if args else kwargs.get("content", CONTENT_BANK[0])
agent_action = args[1] if len(args) > 1 else kwargs.get("agent_action", ActionType.FLAG)
agent_reasoning = args[2] if len(args) > 2 else kwargs.get("agent_reasoning", "")
elif len(args) == 1:
sample = args[0]
if isinstance(sample, dict):
content = _extract_content(sample.get("item", sample))
agent_action, agent_reasoning = _extract_action_reasoning(sample)
else:
return safe(0.5)
else:
return safe(0.5)
return safe(_score_medium(content, agent_action, agent_reasoning))
except Exception:
return safe(0.5)
def grade_hard_contextual(*args, **kwargs) -> float:
try:
if len(args) == 2 and not kwargs:
sample, item = args
content = _extract_content(item)
agent_action, agent_reasoning = _extract_action_reasoning(sample)
elif len(args) >= 4 or kwargs:
content = args[0] if args else kwargs.get("content", CONTENT_BANK[0])
agent_action = args[1] if len(args) > 1 else kwargs.get("agent_action", ActionType.FLAG)
agent_reasoning = args[2] if len(args) > 2 else kwargs.get("agent_reasoning", "")
elif len(args) == 1:
sample = args[0]
if isinstance(sample, dict):
content = _extract_content(sample.get("item", sample))
agent_action, agent_reasoning = _extract_action_reasoning(sample)
else:
return safe(0.5)
else:
return safe(0.5)
return safe(_score_hard(content, agent_action, agent_reasoning))
except Exception:
return safe(0.5)
def grade_task(task_name: str, *args, **kwargs) -> float:
if task_name == "easy_detection":
return grade_easy_detection(*args, **kwargs)
elif task_name == "medium_classification":
return grade_medium_classification(*args, **kwargs)
elif task_name == "hard_contextual":
return grade_hard_contextual(*args, **kwargs)
else:
return safe(0.5)