Skip to content

Commit 55e77d1

Browse files
authored
Update run_batch.py
1 parent 4e5dc72 commit 55e77d1

1 file changed

Lines changed: 16 additions & 12 deletions

File tree

sweagent/run/run_batch.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -713,17 +713,17 @@ def _execute_locations_batch(self, locations: list, round_num: int) -> None:
713713

714714
def _collect_location_trajectories(self, locations: list, round_num: int) -> list[dict]:
715715
locations_data = []
716-
# select_regression_file = ""
717-
# select_regression_dict = dict()
718-
# with open(select_regression_file, 'r', encoding='utf-8') as f:
719-
# for item in f:
720-
# item = json.loads(item)
721-
# select_regression_dict[item['instance_id']] = dict()
722-
# select_regression_dict[item['instance_id']]['tests_passing_in_original_repo'] = item['tests_passing_in_original_repo']
716+
select_regression_file = "/home/jiangty9/PhoenixRepair/sweagent/run/select_regression/output.jsonl"
717+
select_regression_dict = dict()
718+
with open(select_regression_file, 'r', encoding='utf-8') as f:
719+
for item in f:
720+
item = json.loads(item)
721+
select_regression_dict[item['instance_id']] = dict()
722+
select_regression_dict[item['instance_id']]['tests_passing_in_original_repo'] = item['tests_passing_in_original_repo']
723723
for loc in locations:
724724
instance_id = loc.problem_statement.id
725725
original_id = re.sub(r'-location-\d+.*$', '', instance_id)
726-
# loc_item = select_regression_dict[original_id]
726+
loc_item = select_regression_dict[original_id]
727727
traj_file = self.output_dir / instance_id / f"{instance_id}.traj"
728728
if not traj_file.exists():
729729
continue
@@ -741,6 +741,7 @@ def _collect_location_trajectories(self, locations: list, round_num: int) -> lis
741741
'exit_status': data.get('info', {}).get('exit_status', 'unknown'),
742742
'trajectory': data.get('trajectory', []),
743743
'round': round_num,
744+
'tests_passing_in_original_repo' : loc_item['tests_passing_in_original_repo']
744745
})
745746
except Exception as e:
746747
self.logger.error(f"读取失败 {instance_id}: {e}")
@@ -764,9 +765,12 @@ def _select_top_half_locations(self, original_id: str, locations_data: list[dict
764765
expected_count = math.ceil(len(locations_data) / 2)
765766
self.logger.info(f"Round{round_num}: 从{len(locations_data)}个location中选{expected_count}个")
766767
# use regression test
767-
# for data in locations_data:
768-
# regression_dict = _run_regression(data)
769-
# data['regression_result'] = regression_dict
768+
llm_config = None
769+
if self._comparator_config:
770+
llm_config = {'api_key': self._comparator_config.get('api_key', ''), 'api_base': self._comparator_config.get('base_url', ''), 'model_name': self._comparator_config.get('model_name', '')}
771+
for data in locations_data:
772+
regression_dict = _run_regression(data, llm_config=llm_config)
773+
data['regression_result'] = regression_dict
770774
analysis_dir = self.output_dir / f"{original_id}-analysis"
771775
analysis_dir.mkdir(parents=True, exist_ok=True)
772776
analysis_file = analysis_dir / f"round{round_num}_comparison.json"
@@ -1483,4 +1487,4 @@ def run_from_cli(args: list[str] | None = None):
14831487
run_from_config(BasicCLI(RunBatchConfig, help_text=help_text).get_config(args))
14841488

14851489
if __name__ == "__main__":
1486-
run_from_cli()
1490+
run_from_cli()

0 commit comments

Comments
 (0)