forked from zilliztech/claude-context
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_evaluation.py
More file actions
123 lines (108 loc) · 3.63 KB
/
run_evaluation.py
File metadata and controls
123 lines (108 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
from argparse import ArgumentParser
from typing import List, Optional
from retrieval.custom import CustomRetrieval
from utils.constant import evaluation_path, project_path
import logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
def main(
dataset_name_or_path: str,
output_dir: str,
retrieval_types: List[str],
llm_type: str = "openai",
llm_model: Optional[str] = None,
splits: List[str] = ["test"],
root_dir: str = str(evaluation_path / "repos"),
max_instances: Optional[int] = 5,
):
"""
Main function to run custom retrieval.
Args:
dataset_name_or_path: Dataset path or name
output_dir: Output directory for results
retrieval_types: List of retrieval types to use ('cc', 'grep', or both)
llm_type: Type of LLM to use
llm_model: LLM model name
splits: Dataset splits to process
root_dir: Root directory for repositories
max_instances: Maximum number of instances to process
"""
logger.info(f"Starting custom retrieval with types: {retrieval_types}")
retrieval = CustomRetrieval(
dataset_name_or_path=dataset_name_or_path,
splits=splits,
output_dir=output_dir,
retrieval_types=retrieval_types,
llm_type=llm_type,
llm_model=llm_model,
max_instances=max_instances,
)
retrieval.run(root_dir, token=os.environ.get("GITHUB_TOKEN", "git"))
def parse_retrieval_types(value: str) -> List[str]:
"""Parse comma-separated retrieval types string into list"""
types = [t.strip().lower() for t in value.split(",")]
valid_types = {"cc", "grep"}
for t in types:
if t not in valid_types:
raise ValueError(
f"Invalid retrieval type '{t}'. Must be one of: {valid_types}"
)
return types
if __name__ == "__main__":
parser = ArgumentParser(
description="Custom Retrieval for SWE-bench with flexible retrieval types"
)
parser.add_argument(
"--dataset_name_or_path",
type=str,
# default="SWE-bench/SWE-bench_Lite",
default="swe_verified_15min1h_2files_instances.json",
help="Dataset name or path",
)
parser.add_argument(
"--output_dir",
type=str,
default=str(evaluation_path / "retrieval_results_custom"),
help="Output directory",
)
parser.add_argument(
"--retrieval_types",
type=parse_retrieval_types,
default="cc,grep",
help="Comma-separated list of retrieval types to use. Options: 'cc', 'grep', or 'cc,grep' (default: 'cc,grep')",
)
parser.add_argument(
"--llm_type",
type=str,
choices=["openai", "ollama", "moonshot"],
# default="moonshot",
default="openai",
# default="anthropic",
help="LLM type",
)
parser.add_argument(
"--llm_model",
type=str,
# default="kimi-k2-0711-preview",
default="gpt-4o-mini",
# default="claude-sonnet-4-20250514",
help="LLM model name, e.g. gpt-4o-mini",
)
parser.add_argument(
"--splits", nargs="+", default=["test"], help="Dataset splits to process"
)
parser.add_argument(
"--root_dir",
type=str,
default=str(evaluation_path / "repos"),
help="Temporary directory for repositories",
)
parser.add_argument(
"--max_instances",
type=int,
default=5,
help="Maximum number of instances to process (default: 5, set to -1 for all)",
)
args = parser.parse_args()
main(**vars(args))