|
| 1 | +# Copyright (C) 2025 Intel Corporation |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | + |
| 4 | +import argparse |
| 5 | +import json |
| 6 | +import os |
| 7 | +import time |
| 8 | +from dataclasses import dataclass |
| 9 | +from typing import Any, Callable, Tuple |
| 10 | + |
| 11 | +import requests |
| 12 | +from datasets import load_dataset |
| 13 | +from litellm import completion |
| 14 | +from requests.exceptions import RequestException |
| 15 | + |
| 16 | + |
| 17 | +@dataclass |
| 18 | +class Result: |
| 19 | + question: str |
| 20 | + agent_answer: str |
| 21 | + correct_answer: str |
| 22 | + |
| 23 | + |
| 24 | +ScoringFunction = Callable[[Result], bool] |
| 25 | + |
| 26 | + |
| 27 | +def llm_as_a_judge_scoring(result: Result, model_id=None, llm_service=None) -> bool: |
| 28 | + prompt = f""" |
| 29 | + Given the following question and answer, evaluate the answer against the correct answer: |
| 30 | +
|
| 31 | + <question> |
| 32 | + {result.question} |
| 33 | + </question> |
| 34 | +
|
| 35 | + <agent_answer> |
| 36 | + {result.agent_answer} |
| 37 | + </agent_answer> |
| 38 | +
|
| 39 | + <correct_answer> |
| 40 | + {result.correct_answer} |
| 41 | + </correct_answer> |
| 42 | +
|
| 43 | + Note that the agent answer might be a long text containing a lot of information or it might be a short answer. |
| 44 | +
|
| 45 | + You should read the entire text and think if the agent answers the question somewhere |
| 46 | + in the text. You should try to be flexible with the answer but careful. |
| 47 | +
|
| 48 | + For example, answering with names instead of name and surname is fine. |
| 49 | +
|
| 50 | + The important thing is that the answer of the agent either contains the correct answer or is equal to the correct answer. |
| 51 | +
|
| 52 | + <reasoning> |
| 53 | + The agent answer is correct because I can read that .... |
| 54 | + </reasoning> |
| 55 | +
|
| 56 | + <answer> |
| 57 | + 1 |
| 58 | + </answer> |
| 59 | +
|
| 60 | + Otherwise, return |
| 61 | +
|
| 62 | + <reasoning> |
| 63 | + The agent answer is incorrect because there is ... |
| 64 | + </reasoning> |
| 65 | +
|
| 66 | + <answer> |
| 67 | + 0 |
| 68 | + </answer> |
| 69 | +
|
| 70 | + """ |
| 71 | + |
| 72 | + messages = [ |
| 73 | + {"role": "system", "content": "You are an helpful assistant that returns a number between 0 and 1."}, |
| 74 | + {"role": "user", "content": prompt}, |
| 75 | + ] |
| 76 | + answer = ( |
| 77 | + completion( |
| 78 | + model=model_id, |
| 79 | + api_base=llm_service, |
| 80 | + api_key="empty", |
| 81 | + timeout=1200, |
| 82 | + messages=messages, |
| 83 | + max_tokens=1000, |
| 84 | + temperature=0.0, |
| 85 | + ) |
| 86 | + .choices[0] # type: ignore |
| 87 | + .message["content"] # type: ignore |
| 88 | + ) |
| 89 | + |
| 90 | + return bool(int(answer.split("<answer>")[1].split("</answer>")[0].strip())) |
| 91 | + |
| 92 | + |
| 93 | +def load_questions(dataset_names: list[str] | None = None) -> list[dict[str, str]]: |
| 94 | + """Load questions from the specified Hugging Face dataset configurations. |
| 95 | +
|
| 96 | + Args: |
| 97 | + dataset_names: List of dataset configurations to load |
| 98 | + Options: |
| 99 | + "smolagents:simpleqa", |
| 100 | + "hotpotqa", |
| 101 | + "simpleqa", |
| 102 | + "together-search-bench" |
| 103 | + If None, all available configurations except hotpotqa will be loaded |
| 104 | +
|
| 105 | + Returns: |
| 106 | + List of question-answer pairs |
| 107 | + """ |
| 108 | + if dataset_names is None: |
| 109 | + dataset_names = ["smolagents:simpleqa"] |
| 110 | + |
| 111 | + all_questions = [] |
| 112 | + |
| 113 | + for dataset_name in dataset_names: |
| 114 | + print(f"Loading dataset: {dataset_name}") |
| 115 | + |
| 116 | + try: |
| 117 | + if dataset_name == "together-search-bench": |
| 118 | + # Load Together-Search-Bench dataset |
| 119 | + dataset_path = "togethercomputer/together-search-bench" |
| 120 | + ds = load_dataset(dataset_path) |
| 121 | + if "test" in ds: |
| 122 | + split_data = ds["test"] |
| 123 | + else: |
| 124 | + print(f"No 'test' split found in dataset at {dataset_path}") |
| 125 | + continue |
| 126 | + |
| 127 | + for i in range(len(split_data)): |
| 128 | + item = split_data[i] |
| 129 | + question_data = { |
| 130 | + "question": item["question"], |
| 131 | + "answer": item["answer"], |
| 132 | + "dataset": item.get("dataset", "together-search-bench"), |
| 133 | + } |
| 134 | + all_questions.append(question_data) |
| 135 | + |
| 136 | + print(f"Loaded {len(split_data)} questions from together-search-bench dataset") |
| 137 | + continue |
| 138 | + |
| 139 | + elif dataset_name == "hotpotqa": |
| 140 | + # Load HotpotQA dataset (using distractor version for validation) |
| 141 | + ds = load_dataset("hotpotqa/hotpot_qa", "distractor", trust_remote_code=True) |
| 142 | + split_name = "validation" |
| 143 | + elif dataset_name == "simpleqa": |
| 144 | + ds = load_dataset("basicv8vc/SimpleQA") |
| 145 | + split_name = "test" |
| 146 | + else: |
| 147 | + # Strip "smolagents:" prefix when loading the dataset |
| 148 | + actual_dataset = dataset_name.split(":")[-1] |
| 149 | + ds = load_dataset("smolagents/benchmark-v1", actual_dataset) |
| 150 | + split_name = "test" |
| 151 | + |
| 152 | + except Exception as e: |
| 153 | + print(f"Failed to load dataset {dataset_name}: {str(e)}") |
| 154 | + continue # Skip this dataset if it fails to load |
| 155 | + |
| 156 | + print(f"Dataset structure for {dataset_name}: {ds}") |
| 157 | + print(f"Available splits: {list(ds)}") |
| 158 | + |
| 159 | + split_data = ds[split_name] # type: ignore |
| 160 | + |
| 161 | + for i in range(len(split_data)): |
| 162 | + item = split_data[i] |
| 163 | + |
| 164 | + if dataset_name == "hotpotqa": |
| 165 | + # we remove questions that are easy or medium (if any) just to reduce the number of questions |
| 166 | + if item["level"] != "hard": |
| 167 | + continue |
| 168 | + |
| 169 | + question_data = { |
| 170 | + "question": item["question"], |
| 171 | + "answer": item["answer"], |
| 172 | + "dataset": dataset_name, |
| 173 | + } |
| 174 | + elif dataset_name == "simpleqa": |
| 175 | + # Handle SimpleQA dataset format |
| 176 | + question_data = { |
| 177 | + "question": item["problem"], |
| 178 | + "answer": item["answer"], |
| 179 | + "dataset": dataset_name, |
| 180 | + } |
| 181 | + else: |
| 182 | + question_data = { |
| 183 | + "question": item["question"], |
| 184 | + "answer": item["true_answer"], |
| 185 | + "dataset": dataset_name, |
| 186 | + } |
| 187 | + |
| 188 | + all_questions.append(question_data) |
| 189 | + |
| 190 | + print(f"Loaded {len(all_questions)} questions in total") |
| 191 | + return all_questions |
| 192 | + |
| 193 | + |
| 194 | +def process_single_question( |
| 195 | + question_data: dict[str, str], |
| 196 | + agent_service: str, |
| 197 | + model_id: str, |
| 198 | + llm_service: str, |
| 199 | +) -> dict[str, Any]: |
| 200 | + """Process a single benchmark question with the agent. |
| 201 | +
|
| 202 | + Args: |
| 203 | + question_data: Dictionary containing question and answer |
| 204 | +
|
| 205 | + Returns: |
| 206 | + Dictionary with question, answers and evaluation results |
| 207 | + """ |
| 208 | + question = question_data["question"] |
| 209 | + correct_answer = question_data["answer"] |
| 210 | + |
| 211 | + data = {"question": question} |
| 212 | + |
| 213 | + try: |
| 214 | + res = requests.post( |
| 215 | + f"{agent_service}", |
| 216 | + headers={"Content-Type": "application/json"}, |
| 217 | + data=json.dumps(data), |
| 218 | + ) |
| 219 | + res.raise_for_status() |
| 220 | + res = res.json() |
| 221 | + except RequestException as e: |
| 222 | + raise Exception(f"An unexpected error occurred: {str(e)}") |
| 223 | + |
| 224 | + agent_answer = res["answer"] |
| 225 | + |
| 226 | + result = Result(question=question, agent_answer=agent_answer, correct_answer=correct_answer) |
| 227 | + |
| 228 | + evaluation = llm_as_a_judge_scoring(result, model_id, llm_service) |
| 229 | + |
| 230 | + single_benchmark_result = { |
| 231 | + "question": question, |
| 232 | + "correct_answer": correct_answer, |
| 233 | + "agent_answer": agent_answer, |
| 234 | + "evaluation": evaluation, |
| 235 | + "metadata": {k: v for k, v in question_data.items() if k not in ["question", "answer"]}, |
| 236 | + } |
| 237 | + print(single_benchmark_result) |
| 238 | + |
| 239 | + return single_benchmark_result |
| 240 | + |
| 241 | + |
| 242 | +def run_benchmark( |
| 243 | + questions: list[dict[str, str]], |
| 244 | + agent_service: str = "http://localhost:8022/v1/deep_research_agent", |
| 245 | + model_id: str = "", |
| 246 | + llm_service: str = "http://localhost:8000/v1/", |
| 247 | +) -> Tuple[float, list[dict[str, Any]]]: |
| 248 | + """Run the benchmark on a list of questions concurrently. |
| 249 | +
|
| 250 | + Args: |
| 251 | + questions: List of question-answer pairs |
| 252 | +
|
| 253 | + Returns: |
| 254 | + Tuple of (accuracy score, detailed results) |
| 255 | + """ |
| 256 | + |
| 257 | + results = [] |
| 258 | + total_questions = len(questions) |
| 259 | + details = [] |
| 260 | + |
| 261 | + for idx, question_data in enumerate(questions): |
| 262 | + try: |
| 263 | + result = process_single_question(question_data, agent_service, model_id, llm_service) |
| 264 | + results.append(result["evaluation"]) |
| 265 | + except Exception as exc: |
| 266 | + import traceback |
| 267 | + |
| 268 | + traceback.print_exc() |
| 269 | + print(f"Question {idx+1} generated an exception: {exc}") |
| 270 | + results.append(0) |
| 271 | + details.append({"question": questions[idx]["question"], "agent_answer": str(exc), "evaluation": 0}) |
| 272 | + |
| 273 | + return sum(results) / len(results), details |
| 274 | + |
| 275 | + |
| 276 | +def main(): |
| 277 | + """Main function to run the benchmark.""" |
| 278 | + |
| 279 | + # Set up argument parser |
| 280 | + parser = argparse.ArgumentParser(description="Run scoring with benchmarking options") |
| 281 | + parser.add_argument( |
| 282 | + "--datasets", |
| 283 | + nargs="+", |
| 284 | + choices=["smolagents:simpleqa", "hotpotqa", "simpleqa", "together-search-bench"], |
| 285 | + help="Specific datasets to load (default: all)", |
| 286 | + default=["together-search-bench"], |
| 287 | + ) |
| 288 | + parser.add_argument("--limit", type=int, default=None, help="Limit number of questions to process (default: all)") |
| 289 | + parser.add_argument( |
| 290 | + "--service-url", |
| 291 | + default="http://localhost:8022/v1/deep_research_agent", |
| 292 | + help="the endpoint of deep research agent.", |
| 293 | + ) |
| 294 | + parser.add_argument( |
| 295 | + "--llm-endpoint", |
| 296 | + default="http://localhost:8000/v1/", |
| 297 | + help="llm service for llm-as-judge.", |
| 298 | + ) |
| 299 | + parser.add_argument( |
| 300 | + "--model", |
| 301 | + default="openai/meta-llama/Llama-3.3-70B-Instruct", |
| 302 | + help="model id of llm service.", |
| 303 | + ) |
| 304 | + |
| 305 | + args = parser.parse_args() |
| 306 | + |
| 307 | + questions = load_questions(args.datasets) |
| 308 | + |
| 309 | + if args.limit is not None: |
| 310 | + questions = questions[: args.limit] |
| 311 | + print(f"Limited to {len(questions)} questions") |
| 312 | + |
| 313 | + results, details = run_benchmark( |
| 314 | + questions, agent_service=args.service_url, model_id=args.model, llm_service=args.llm_endpoint |
| 315 | + ) |
| 316 | + |
| 317 | + print(f"Completed benchmark with {results} accuracy") |
| 318 | + |
| 319 | + benchmark_results_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "benchmark", "benchmark_results") |
| 320 | + os.makedirs(benchmark_results_dir, exist_ok=True) |
| 321 | + |
| 322 | + output_file = os.path.join( |
| 323 | + benchmark_results_dir, |
| 324 | + f"benchmark_{'_'.join(args.datasets)}_{time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime())}.json", |
| 325 | + ) |
| 326 | + |
| 327 | + output_data = { |
| 328 | + "metadata": { |
| 329 | + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), |
| 330 | + "datasets": args.datasets, |
| 331 | + "agent_config": args.agent_config, |
| 332 | + "scoring_method": "llm_as_a_judge_scoring", |
| 333 | + "sample_count": len(questions), |
| 334 | + }, |
| 335 | + "overall_accuracy": results, |
| 336 | + "question_details": details, |
| 337 | + } |
| 338 | + |
| 339 | + with open(output_file, "w") as f: |
| 340 | + json.dump(output_data, f, indent=2) |
| 341 | + |
| 342 | + return results |
| 343 | + |
| 344 | + |
| 345 | +if __name__ == "__main__": |
| 346 | + main() |
0 commit comments