-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_agent.py
More file actions
84 lines (74 loc) · 2.48 KB
/
Copy pathevaluate_agent.py
File metadata and controls
84 lines (74 loc) · 2.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from __future__ import annotations
import argparse
import json
import os
from typing import Any
from inference import (
API_BASE_URL,
HF_TOKEN,
MODEL_NAME,
DEFAULT_ENV_URL,
RL_POLICY_PATH,
resolve_rl_policy_path,
run_all_tasks,
resolve_tasks,
)
def evaluate_policy(
*,
env_url: str,
policy: str,
task_set: str,
api_base_url: str,
model_name: str,
hf_token: str | None,
rl_policy_path: str | None,
) -> dict[str, Any]:
tasks = resolve_tasks(tasks=None, task_set=task_set)
resolved_rl_path = resolve_rl_policy_path(rl_policy_path=rl_policy_path, task_set=task_set)
results = run_all_tasks(
env_url=env_url,
tasks=tasks,
api_base_url=api_base_url,
model_name=model_name,
hf_token=hf_token,
policy=policy,
rl_policy_path=resolved_rl_path,
emit_logs=False,
)
final_scores = {name: round(run["final_score"], 4) for name, run in results.items()}
average = round(sum(final_scores.values()) / max(len(final_scores), 1), 4)
return {
"policy": policy,
"task_set": task_set,
"tasks": tasks,
"final_scores": final_scores,
"average_score": average,
}
def main() -> int:
parser = argparse.ArgumentParser(description="Quick policy benchmark across task sets.")
parser.add_argument("--env-url", default=os.getenv("ENV_BASE_URL", DEFAULT_ENV_URL))
parser.add_argument("--policies", nargs="+", default=["strategic", "rl"])
parser.add_argument("--task-sets", nargs="+", default=["standard", "challenge"])
parser.add_argument("--model", default=MODEL_NAME)
parser.add_argument("--api-base-url", default=API_BASE_URL)
parser.add_argument("--hf-token", default=HF_TOKEN)
parser.add_argument("--rl-policy-path", default=RL_POLICY_PATH)
args = parser.parse_args()
summary: list[dict[str, Any]] = []
for policy in args.policies:
for task_set in args.task_sets:
summary.append(
evaluate_policy(
env_url=args.env_url,
policy=policy,
task_set=task_set,
api_base_url=args.api_base_url,
model_name=args.model,
hf_token=args.hf_token,
rl_policy_path=args.rl_policy_path,
)
)
print(json.dumps({"results": summary}, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())