@@ -713,17 +713,17 @@ def _execute_locations_batch(self, locations: list, round_num: int) -> None:
713713
714714 def _collect_location_trajectories (self , locations : list , round_num : int ) -> list [dict ]:
715715 locations_data = []
716- # select_regression_file = ""
717- # select_regression_dict = dict()
718- # with open(select_regression_file, 'r', encoding='utf-8') as f:
719- # for item in f:
720- # item = json.loads(item)
721- # select_regression_dict[item['instance_id']] = dict()
722- # select_regression_dict[item['instance_id']]['tests_passing_in_original_repo'] = item['tests_passing_in_original_repo']
716+ select_regression_file = "/home/jiangty9/PhoenixRepair/sweagent/run/select_regression/output.jsonl "
717+ select_regression_dict = dict ()
718+ with open (select_regression_file , 'r' , encoding = 'utf-8' ) as f :
719+ for item in f :
720+ item = json .loads (item )
721+ select_regression_dict [item ['instance_id' ]] = dict ()
722+ select_regression_dict [item ['instance_id' ]]['tests_passing_in_original_repo' ] = item ['tests_passing_in_original_repo' ]
723723 for loc in locations :
724724 instance_id = loc .problem_statement .id
725725 original_id = re .sub (r'-location-\d+.*$' , '' , instance_id )
726- # loc_item = select_regression_dict[original_id]
726+ loc_item = select_regression_dict [original_id ]
727727 traj_file = self .output_dir / instance_id / f"{ instance_id } .traj"
728728 if not traj_file .exists ():
729729 continue
@@ -741,6 +741,7 @@ def _collect_location_trajectories(self, locations: list, round_num: int) -> lis
741741 'exit_status' : data .get ('info' , {}).get ('exit_status' , 'unknown' ),
742742 'trajectory' : data .get ('trajectory' , []),
743743 'round' : round_num ,
744+ 'tests_passing_in_original_repo' : loc_item ['tests_passing_in_original_repo' ]
744745 })
745746 except Exception as e :
746747 self .logger .error (f"读取失败 { instance_id } : { e } " )
@@ -764,9 +765,12 @@ def _select_top_half_locations(self, original_id: str, locations_data: list[dict
764765 expected_count = math .ceil (len (locations_data ) / 2 )
765766 self .logger .info (f"Round{ round_num } : 从{ len (locations_data )} 个location中选{ expected_count } 个" )
766767 # use regression test
767- # for data in locations_data:
768- # regression_dict = _run_regression(data)
769- # data['regression_result'] = regression_dict
768+ llm_config = None
769+ if self ._comparator_config :
770+ llm_config = {'api_key' : self ._comparator_config .get ('api_key' , '' ), 'api_base' : self ._comparator_config .get ('base_url' , '' ), 'model_name' : self ._comparator_config .get ('model_name' , '' )}
771+ for data in locations_data :
772+ regression_dict = _run_regression (data , llm_config = llm_config )
773+ data ['regression_result' ] = regression_dict
770774 analysis_dir = self .output_dir / f"{ original_id } -analysis"
771775 analysis_dir .mkdir (parents = True , exist_ok = True )
772776 analysis_file = analysis_dir / f"round{ round_num } _comparison.json"
@@ -1483,4 +1487,4 @@ def run_from_cli(args: list[str] | None = None):
14831487 run_from_config (BasicCLI (RunBatchConfig , help_text = help_text ).get_config (args ))
14841488
14851489if __name__ == "__main__" :
1486- run_from_cli ()
1490+ run_from_cli ()
0 commit comments