From 63165f954bd0eaf4e62d29c07aeb12669a6ca5ab Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Wed, 9 Jul 2025 21:48:28 +0800 Subject: [PATCH 01/23] fix typo in comments --- hypogenic/algorithm/generation/default.py | 12 ++++++------ hypogenic/algorithm/inference/default.py | 4 ++-- hypogenic/algorithm/update/default.py | 4 ++-- hypogenic/tasks.py | 2 +- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/hypogenic/algorithm/generation/default.py b/hypogenic/algorithm/generation/default.py index bc33a3c..e6a1b27 100644 --- a/hypogenic/algorithm/generation/default.py +++ b/hypogenic/algorithm/generation/default.py @@ -15,7 +15,7 @@ @generation_register.register("default") class DefaultGeneration(Generation): """ - Add on extra functionality to the generation fucntion - we consider this + Add on extra functionality to the generation function - we consider this the "default" task. """ @@ -29,8 +29,8 @@ def __init__( """ Parameters: api: the language model that you're using, which may or may not be local - prompt_class: let's us know how the prompt is going to look - inference_class: gives us a way to predict labels for accuracy sake + prompt_class: let us know how the prompt is going to look + inference_class: gives us a way to predict labels for accuracy’s sake task: determines the goal to accomplish """ super().__init__(api, prompt_class, inference_class, task) @@ -38,7 +38,7 @@ def __init__( # ------------------------------------------------------------------------ # # # # ------------------------------------------------------------------------ # - # BATCH INITLALIZE HYPOTHESES # + # BATCH INITIALIZE HYPOTHESES # # ------------------------------------------------------------------------ # # # # ------------------------------------------------------------------------ # @@ -57,7 +57,7 @@ def batched_initialize_hypotheses( Parameters: num_init: the total amount of examples you want to use for initialize hypotheses init_batch size: the number of examples that will be used to generate these hypotheses - init_hypotheses_per_batch: the amount of hypotheses that you want to generate per btach + init_hypotheses_per_batch: the amount of hypotheses that you want to generate per batch cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number max_concurrent: the maximum amount of concurrent calls to the API @@ -132,7 +132,7 @@ def batched_hypothesis_generation( example_ids: The ids of the examples for which hypotheses need to be generated current_sample: the current sample in data which the algorithm is on num_hypotheses_generate: the number of hypotheses that we expect our response to generate - alpha: eploration constant in hypogenic reward funciton + alpha: exploration constant in hypogenic reward function cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number max_concurrent: The maximum number of concurrent requests diff --git a/hypogenic/algorithm/inference/default.py b/hypogenic/algorithm/inference/default.py index 84d9c4c..e66884d 100644 --- a/hypogenic/algorithm/inference/default.py +++ b/hypogenic/algorithm/inference/default.py @@ -35,7 +35,7 @@ def batched_predict( **generate_kwargs, ): """ - Makes a batch of preductions on a hypothesis. + Makes a batch of predictions on a hypothesis. Parameters: data: the data to predict on @@ -78,7 +78,7 @@ def run_inference_final( """ Function for testing the best hypothesis - Prameters: + Parameters: data: the data to predict on hyp_bank: the hypotheses that we want to predict from cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number diff --git a/hypogenic/algorithm/update/default.py b/hypogenic/algorithm/update/default.py index 7b11386..26bab5a 100644 --- a/hypogenic/algorithm/update/default.py +++ b/hypogenic/algorithm/update/default.py @@ -174,7 +174,7 @@ def update( # generate new hypotheses for j in range(self.num_hypotheses_to_update): - # Go through poorly performing exmaples and generate hypotheses for them + # Go through poorly performing examples and generate hypotheses for them # TODO: batched? new_hypotheses = ( self.generation_class.batched_hypothesis_generation( @@ -188,7 +188,7 @@ def update( ) ) - # If we onlt take the best performing hypothesis from the batch + # If we only take the best performing hypothesis from the batch if self.only_best_hypothesis: best_hypothesis = max( new_hypotheses, key=lambda x: new_hypotheses[x].reward diff --git a/hypogenic/tasks.py b/hypogenic/tasks.py index 264600f..5285b22 100644 --- a/hypogenic/tasks.py +++ b/hypogenic/tasks.py @@ -47,7 +47,7 @@ def __init__( self.test_data_path = data["ood_data_path"] self.val_data_path = data["ood_data_path"] - # getting omrpt templates from yaml file + # getting prompt templates from yaml file self.prompt_template = data["prompt_templates"] # task label From 495c985f06cba329b9df3857c2a8001c94410c20 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Thu, 10 Jul 2025 17:45:55 +0800 Subject: [PATCH 02/23] Add reference_hypotheses parameter to support error-based hypo augmentation --- hypogenic/algorithm/generation/base.py | 2 + hypogenic/algorithm/generation/default.py | 54 +++++++++++++++++++++++ hypogenic/algorithm/update/default.py | 1 + hypogenic/prompt.py | 29 +++++++++++- 4 files changed, 85 insertions(+), 1 deletion(-) diff --git a/hypogenic/algorithm/generation/base.py b/hypogenic/algorithm/generation/base.py index 0133762..dba9632 100644 --- a/hypogenic/algorithm/generation/base.py +++ b/hypogenic/algorithm/generation/base.py @@ -68,6 +68,7 @@ def batched_hyp_list_generation( example_indices: List[int], num_hypotheses_generate: int, cache_seed=None, + reference_hypotheses=None, **generate_kwargs ) -> List[str]: """Batched hypothesis generation method. Takes multiple examples and creates a hypothesis with them. @@ -76,6 +77,7 @@ def batched_hyp_list_generation( example_indices: the indices of examples being used to generate hypotheses num_hypotheses_generate: the number of hypotheses that we expect our response to generate cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number + reference_hypotheses: the current hypotheses that we have in the bank (if any) Returns: hypotheses_list: A list containing all newly generated hypotheses. diff --git a/hypogenic/algorithm/generation/default.py b/hypogenic/algorithm/generation/default.py index e6a1b27..3228938 100644 --- a/hypogenic/algorithm/generation/default.py +++ b/hypogenic/algorithm/generation/default.py @@ -123,6 +123,7 @@ def batched_hypothesis_generation( alpha: float, cache_seed=None, max_concurrent=3, + reference_hypotheses=None, **generate_kwargs, ): """ @@ -135,6 +136,7 @@ def batched_hypothesis_generation( alpha: exploration constant in hypogenic reward function cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number max_concurrent: The maximum number of concurrent requests + reference_hypotheses: the current hypotheses that we have in the bank (if any) Returns: hypotheses_bank: A dictionary with keys as hypotheses and the values as the Summary Information class @@ -143,6 +145,7 @@ def batched_hypothesis_generation( example_ids, num_hypotheses_generate, cache_seed=cache_seed, + reference_hypotheses=reference_hypotheses, **generate_kwargs, ) @@ -155,3 +158,54 @@ def batched_hypothesis_generation( max_concurrent=max_concurrent, **generate_kwargs, ) + + # ------------------------------------------------------------------------ # + # # + # ------------------------------------------------------------------------ # + # BATCHED HYPOTHESIS LIST GENERATION # + # ------------------------------------------------------------------------ # + # # + # ------------------------------------------------------------------------ # + def batched_hyp_list_generation( + self, + example_indices: List[int], + num_hypotheses_generate: int, + cache_seed=None, + reference_hypotheses=None, + **generate_kwargs + ) -> List[str]: + """Batched hypothesis generation method. Takes multiple examples and creates a hypothesis with them. + + Parameters: + example_indices: the indices of examples being used to generate hypotheses + num_hypotheses_generate: the number of hypotheses that we expect our response to generate + cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number + reference_hypotheses: the current hypotheses that we have in the bank (if any) + + Returns: + hypotheses_list: A list containing all newly generated hypotheses. + """ + + # ---------------------------------------------------------------------- + # Gather the examples to use for generation + # ---------------------------------------------------------------------- + # Gather examples based on example_indices + # TODO: need copy()? + example_bank = ( + self.train_data.loc[list(example_indices)].copy().reset_index(drop=True) + ) + + # ---------------------------------------------------------------------- + # Prompt LLM to generate hypotheses + # ---------------------------------------------------------------------- + # Batch generate a bunch of prompts based on yaml file + prompt_input = self.prompt_class.batched_error_augmented_generation( + example_bank, num_hypotheses_generate, reference_hypotheses + ) + + # Batch generate responses based on the prompts that we just generated + response = self.api.generate( + prompt_input, cache_seed=cache_seed, **generate_kwargs + ) + + return extract_hypotheses(response, num_hypotheses_generate) diff --git a/hypogenic/algorithm/update/default.py b/hypogenic/algorithm/update/default.py index 26bab5a..3e4d7a3 100644 --- a/hypogenic/algorithm/update/default.py +++ b/hypogenic/algorithm/update/default.py @@ -184,6 +184,7 @@ def update( self.alpha, cache_seed=cache_seed, max_concurrent=max_concurrent, + reference_hypotheses=top_k_hypotheses, **generate_kwargs, ) ) diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index d220c90..6f7f5d6 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -2,7 +2,7 @@ import os import textwrap from string import Template -from typing import List, Tuple, Union, Dict +from typing import List, Tuple, Union, Dict, Any from copy import deepcopy import pandas as pd @@ -187,6 +187,33 @@ def batched_generation(self, train_data, num_hypotheses): return prompt + def batched_error_augmented_generation(self, train_data, num_hypotheses, reference_hypotheses: List[Any]): + """ + Generate hypotheses that is useful for predicting the color of the shoes given the appearance of the person. + """ + + substitute_dict = {"num_hypotheses": num_hypotheses} + + multi_sub_dicts = { + "observations": [], + "reference_hypotheses": [ + {"hypothesis": h, "idx": i + 1} for i, h in enumerate(reference_hypotheses) + ] + } + + for example_idx in range(len(train_data)): + multi_sub_dicts["observations"].append( + self._get_substitute_dict(train_data, example_idx) + ) + + substitute_dict = self._fill_multi_in_sub_dict( + substitute_dict, multi_sub_dicts, "batched_error_augmented_generation" + ) + + prompt = self._information_prompt(substitute_dict, "batched_error_augmented_generation") + + return prompt + def inference(self, hypotheses_dict, test_data, test_idx): """ Create inference prompt. From 5b8af3b2c7702f4fc50a7cb8b4da73285ad45e33 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Thu, 10 Jul 2025 17:57:46 +0800 Subject: [PATCH 03/23] Record wrong hypotheses for each sample during prediction evaluation for error-based hypo augmentation --- hypogenic/algorithm/update/default.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/hypogenic/algorithm/update/default.py b/hypogenic/algorithm/update/default.py index 3e4d7a3..dda6c5c 100644 --- a/hypogenic/algorithm/update/default.py +++ b/hypogenic/algorithm/update/default.py @@ -138,6 +138,8 @@ def update( **generate_kwargs, ) + wrong_hypos_this_sample = [] # to record which hypos are wrong for this sample + # Comparison of the label and prediction for pred, label, hypothesis in zip(preds, labels, top_k_hypotheses): if pred != label: @@ -145,6 +147,9 @@ def update( hypotheses_bank[hypothesis].update_info_if_not_useful( current_sample, self.alpha ) # let the bank know it got one wrong + + # record the wrong hypothesis + wrong_hypos_this_sample.append(hypothesis) else: hypotheses_bank[hypothesis].update_info_if_useful( current_sample, self.alpha @@ -184,7 +189,7 @@ def update( self.alpha, cache_seed=cache_seed, max_concurrent=max_concurrent, - reference_hypotheses=top_k_hypotheses, + reference_hypotheses=wrong_hypos_this_sample, **generate_kwargs, ) ) From ed5ee3ae9ffdf3b90e63aa94c63c67f31790e45a Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Thu, 10 Jul 2025 20:17:43 +0800 Subject: [PATCH 04/23] encapsulate error-based hypo augmentation into AugmentedGeneration class and add run_augmented_hypogenic param in Pipeline --- hypogenic/algorithm/generation/augmented.py | 57 ++++++++++++ hypogenic/algorithm/generation/default.py | 51 ----------- pipeline.py | 98 +++++++++++++++++++++ run_pipeline.sh | 1 + 4 files changed, 156 insertions(+), 51 deletions(-) create mode 100644 hypogenic/algorithm/generation/augmented.py diff --git a/hypogenic/algorithm/generation/augmented.py b/hypogenic/algorithm/generation/augmented.py new file mode 100644 index 0000000..e7e5479 --- /dev/null +++ b/hypogenic/algorithm/generation/augmented.py @@ -0,0 +1,57 @@ +from typing import List + +from . import generation_register, DefaultGeneration +from .utils import extract_hypotheses + + +@generation_register.register("Augmented") +class AugmentedGeneration(DefaultGeneration): + # ------------------------------------------------------------------------ # + # # + # ------------------------------------------------------------------------ # + # BATCHED HYPOTHESIS LIST GENERATION # + # ------------------------------------------------------------------------ # + # # + # ------------------------------------------------------------------------ # + def batched_hyp_list_generation( + self, + example_indices: List[int], + num_hypotheses_generate: int, + cache_seed=None, + reference_hypotheses=None, + **generate_kwargs + ) -> List[str]: + """Batched hypothesis generation method. Takes multiple examples and creates a hypothesis with them. + + Parameters: + example_indices: the indices of examples being used to generate hypotheses + num_hypotheses_generate: the number of hypotheses that we expect our response to generate + cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number + reference_hypotheses: the current hypotheses that we have in the bank (if any) + + Returns: + hypotheses_list: A list containing all newly generated hypotheses. + """ + + # ---------------------------------------------------------------------- + # Gather the examples to use for generation + # ---------------------------------------------------------------------- + # Gather examples based on example_indices + example_bank = ( + self.train_data.loc[list(example_indices)].copy().reset_index(drop=True) + ) + + # ---------------------------------------------------------------------- + # Prompt LLM to generate hypotheses + # ---------------------------------------------------------------------- + # Batch generate a bunch of prompts based on yaml file + prompt_input = self.prompt_class.batched_error_augmented_generation( + example_bank, num_hypotheses_generate, reference_hypotheses + ) + + # Batch generate responses based on the prompts that we just generated + response = self.api.generate( + prompt_input, cache_seed=cache_seed, **generate_kwargs + ) + + return extract_hypotheses(response, num_hypotheses_generate) diff --git a/hypogenic/algorithm/generation/default.py b/hypogenic/algorithm/generation/default.py index 3228938..d20fd42 100644 --- a/hypogenic/algorithm/generation/default.py +++ b/hypogenic/algorithm/generation/default.py @@ -158,54 +158,3 @@ def batched_hypothesis_generation( max_concurrent=max_concurrent, **generate_kwargs, ) - - # ------------------------------------------------------------------------ # - # # - # ------------------------------------------------------------------------ # - # BATCHED HYPOTHESIS LIST GENERATION # - # ------------------------------------------------------------------------ # - # # - # ------------------------------------------------------------------------ # - def batched_hyp_list_generation( - self, - example_indices: List[int], - num_hypotheses_generate: int, - cache_seed=None, - reference_hypotheses=None, - **generate_kwargs - ) -> List[str]: - """Batched hypothesis generation method. Takes multiple examples and creates a hypothesis with them. - - Parameters: - example_indices: the indices of examples being used to generate hypotheses - num_hypotheses_generate: the number of hypotheses that we expect our response to generate - cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number - reference_hypotheses: the current hypotheses that we have in the bank (if any) - - Returns: - hypotheses_list: A list containing all newly generated hypotheses. - """ - - # ---------------------------------------------------------------------- - # Gather the examples to use for generation - # ---------------------------------------------------------------------- - # Gather examples based on example_indices - # TODO: need copy()? - example_bank = ( - self.train_data.loc[list(example_indices)].copy().reset_index(drop=True) - ) - - # ---------------------------------------------------------------------- - # Prompt LLM to generate hypotheses - # ---------------------------------------------------------------------- - # Batch generate a bunch of prompts based on yaml file - prompt_input = self.prompt_class.batched_error_augmented_generation( - example_bank, num_hypotheses_generate, reference_hypotheses - ) - - # Batch generate responses based on the prompts that we just generated - response = self.api.generate( - prompt_input, cache_seed=cache_seed, **generate_kwargs - ) - - return extract_hypotheses(response, num_hypotheses_generate) diff --git a/pipeline.py b/pipeline.py index 46725db..97b9bac 100644 --- a/pipeline.py +++ b/pipeline.py @@ -2,6 +2,8 @@ import json import logging import os + +from hypogenic.algorithm.generation.augmented import AugmentedGeneration from hypogenic.utils import set_seed, get_results from hypogenic.tasks import BaseTask from hypogenic.extract_label import extract_label_register @@ -69,6 +71,7 @@ parser.add_argument("--run_hyperwrite", action="store_true", help="Run HyperWrite") parser.add_argument("--run_notebooklm", action="store_true", help="Run NotebookLM") parser.add_argument("--run_hypogenic", action="store_true", help="Run original HypoGeniC") +parser.add_argument("--run_augmented_hypogenic", action="store_true", help="Run augmented HypoGeniC") parser.add_argument("--run_hyporefine", action="store_true", help="Run HypoRefine") parser.add_argument("--run_union_hypo", action="store_true", help="Run Union HypoGeniC and Paper") parser.add_argument("--run_union_refine", action="store_true", help="Run Union HypoRefine and Paper") @@ -396,6 +399,69 @@ def original_hypogenic(task_name, api, model_name): epoch=epoch, ) +def augmented_hypogenic(task_name, api, model_name): + output_folder = f"./results/{task_name}/{model_name}/hyp_{max_num_hypotheses}/" + + os.makedirs(output_folder, exist_ok=True) + + task = BaseTask( + config_path=f"./data/{task_name}/config.yaml", + from_register=extract_label_register, + use_ood=use_ood + ) + + set_seed(seed) + train_data, _, _ = task.get_data(num_train, num_test, num_val, seed) + prompt_class = BasePrompt(task) + inference_class = DefaultInference(api, prompt_class, train_data, task) + generation_class = AugmentedGeneration(api, prompt_class, inference_class, task) + + update_class = DefaultUpdate( + generation_class=generation_class, + inference_class=inference_class, + replace_class=DefaultReplace(max_num_hypotheses), + save_path=output_folder, + num_init=num_init, + k=k, + alpha=alpha, + update_batch_size=update_batch_size, + num_hypotheses_to_update=num_hypotheses_to_update, + save_every_n_examples=save_every_10_examples, + ) + + hypotheses_bank = {} + hypotheses_bank = update_class.batched_initialize_hypotheses( + num_init, + init_batch_size=init_batch_size, + init_hypotheses_per_batch=init_hypotheses_per_batch, + cache_seed=cache_seed, + temperature=temperature, + max_tokens=max_tokens, + max_concurrent=64, + ) + update_class.save_to_json( + hypotheses_bank, + sample=num_init, + seed=seed, + epoch=0, + ) + for epoch in range(1): + hypotheses_bank = update_class.update( + current_epoch=epoch, + hypotheses_bank=hypotheses_bank, + current_seed=seed, + cache_seed=cache_seed, + temperature=temperature, + max_tokens=max_tokens, + max_concurrent=64, + ) + update_class.save_to_json( + hypotheses_bank, + sample="final", + seed=seed, + epoch=epoch, + ) + def IO_iterative_refinement(task_name, api, model_name): output_folder = f"./results/{task_name}/{model_name}/IO_refinement/" @@ -871,6 +937,7 @@ def log_arguments(logger, args): ("Run HyperWrite", args.run_hyperwrite), ("Run NotebookLM", args.run_notebooklm), ("Run HypoGeniC", args.run_hypogenic), + ("Run Augmented HypoGeniC", args.run_augmented_hypogenic), ("Run HypoRefine", args.run_hyporefine), ("Run Union HypoGeniC", args.run_union_hypo), ("Run Union HypoRefine", args.run_union_refine), @@ -1034,6 +1101,37 @@ def log_arguments(logger, args): ) save_method_results(method_name, results, task_name, model_name, seed, use_ood=use_ood) + if args.run_augmented_hypogenic: + logger.info("=-=-=-=-=-=-=-=-=-=-=-=Augmented HypoGeniC=-=-=-=-=-=-=-=-=-=-=-=") + if DO_TRAIN: + augmented_hypogenic(task_name=task_name, api=api, model_name=model_name) + + method_name = "augmented_hypogenic_no_update" + methods_run.append(method_name) + logger.info("=-=-=-=-=-=-=-=-=-=-=-=No Update=-=-=-=-=-=-=-=-=-=-=-=") + results = get_res( + f"results/{task_name}/{model_name}/hyp_{max_num_hypotheses}/hypotheses_training_sample_10_seed_{seed}_epoch_0.json", + task_name=task_name, + api=api, + model_name=model_name, + use_val=use_val, + multihyp=multihyp, + ) + save_method_results(method_name, results, task_name, model_name, seed, use_ood=use_ood) + + method_name = "augmented_hypogenic" + methods_run.append(method_name) + logger.info("=-=-=-=-=-=-=-=-=-=-=-=With Update=-=-=-=-=-=-=-=-=-=-=-=") + results = get_res( + f"results/{task_name}/{model_name}/hyp_{max_num_hypotheses}/hypotheses_training_sample_final_seed_{seed}_epoch_0.json", + task_name=task_name, + api=api, + model_name=model_name, + use_val=use_val, + multihyp=multihyp, + ) + save_method_results(method_name, results, task_name, model_name, seed, use_ood=use_ood) + if args.run_hyporefine: logger.info("=-=-=-=-=-=-=-=-=-=-=-=HypoRefine=-=-=-=-=-=-=-=-=-=-=-=") if DO_TRAIN: diff --git a/run_pipeline.sh b/run_pipeline.sh index 9fd9794..c2590b3 100755 --- a/run_pipeline.sh +++ b/run_pipeline.sh @@ -44,6 +44,7 @@ METHODS=( # "zero_shot_gen" # "only_paper" "hypogenic" + # "augmented_hypogenic" # "hyporefine" # "union_hypo" # "union_refine" From 87acb242a0eece13695ec4d60fc5e99a0adfe36d Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Fri, 11 Jul 2025 00:03:14 +0800 Subject: [PATCH 05/23] rename output folder in augmented_hypogenic function to avoid overwritten --- pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pipeline.py b/pipeline.py index 97b9bac..3a7f41a 100644 --- a/pipeline.py +++ b/pipeline.py @@ -400,7 +400,7 @@ def original_hypogenic(task_name, api, model_name): ) def augmented_hypogenic(task_name, api, model_name): - output_folder = f"./results/{task_name}/{model_name}/hyp_{max_num_hypotheses}/" + output_folder = f"./results/{task_name}/{model_name}/aug_hyp_{max_num_hypotheses}/" os.makedirs(output_folder, exist_ok=True) From 5ea9973a4a6f1bdb9e5627aed58ae5abdf0eb0f0 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Fri, 11 Jul 2025 00:20:56 +0800 Subject: [PATCH 06/23] Fix bug: accumulate wrong hypotheses across all bad samples instead of only the last one when generating new hypotheses --- hypogenic/algorithm/update/default.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/hypogenic/algorithm/update/default.py b/hypogenic/algorithm/update/default.py index dda6c5c..2f13727 100644 --- a/hypogenic/algorithm/update/default.py +++ b/hypogenic/algorithm/update/default.py @@ -84,6 +84,7 @@ def update( # initialize variables num_train_examples = len(self.train_data) wrong_example_ids = set() + wrong_hypos_accumulated = set() # ---------------------------------------------------------------------- # Figuring out starting samples @@ -138,8 +139,6 @@ def update( **generate_kwargs, ) - wrong_hypos_this_sample = [] # to record which hypos are wrong for this sample - # Comparison of the label and prediction for pred, label, hypothesis in zip(preds, labels, top_k_hypotheses): if pred != label: @@ -149,7 +148,7 @@ def update( ) # let the bank know it got one wrong # record the wrong hypothesis - wrong_hypos_this_sample.append(hypothesis) + wrong_hypos_accumulated.add(hypothesis) else: hypotheses_bank[hypothesis].update_info_if_useful( current_sample, self.alpha @@ -189,7 +188,7 @@ def update( self.alpha, cache_seed=cache_seed, max_concurrent=max_concurrent, - reference_hypotheses=wrong_hypos_this_sample, + reference_hypotheses=list(wrong_hypos_accumulated), **generate_kwargs, ) ) @@ -206,6 +205,8 @@ def update( new_hyp_bank.update(new_hypotheses) # reset wrong examples to be empty wrong_example_ids = set() + # reset accumulated wrong hypotheses to be empty + wrong_hypos_accumulated.clear() # call replace class to update the bank hypotheses_bank = self.replace_class.replace( From 6d2e65014849bbb3509d7209df80aad62ea62541 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Sat, 12 Jul 2025 18:20:33 +0800 Subject: [PATCH 07/23] Include correspondingly wrong hypotheses for each sample in the generation prompt --- hypogenic/algorithm/generation/augmented.py | 13 ++------ hypogenic/algorithm/generation/base.py | 2 +- hypogenic/algorithm/generation/default.py | 2 +- hypogenic/algorithm/update/default.py | 12 ++++--- hypogenic/prompt.py | 37 +++++++++++++++------ 5 files changed, 39 insertions(+), 27 deletions(-) diff --git a/hypogenic/algorithm/generation/augmented.py b/hypogenic/algorithm/generation/augmented.py index e7e5479..ef84b94 100644 --- a/hypogenic/algorithm/generation/augmented.py +++ b/hypogenic/algorithm/generation/augmented.py @@ -27,26 +27,17 @@ def batched_hyp_list_generation( example_indices: the indices of examples being used to generate hypotheses num_hypotheses_generate: the number of hypotheses that we expect our response to generate cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number - reference_hypotheses: the current hypotheses that we have in the bank (if any) + reference_hypotheses: A dictionary that accumulates the set of wrong hypotheses for each sample Returns: hypotheses_list: A list containing all newly generated hypotheses. """ - - # ---------------------------------------------------------------------- - # Gather the examples to use for generation - # ---------------------------------------------------------------------- - # Gather examples based on example_indices - example_bank = ( - self.train_data.loc[list(example_indices)].copy().reset_index(drop=True) - ) - # ---------------------------------------------------------------------- # Prompt LLM to generate hypotheses # ---------------------------------------------------------------------- # Batch generate a bunch of prompts based on yaml file prompt_input = self.prompt_class.batched_error_augmented_generation( - example_bank, num_hypotheses_generate, reference_hypotheses + self.train_data, num_hypotheses_generate, reference_hypotheses ) # Batch generate responses based on the prompts that we just generated diff --git a/hypogenic/algorithm/generation/base.py b/hypogenic/algorithm/generation/base.py index dba9632..56a7a5b 100644 --- a/hypogenic/algorithm/generation/base.py +++ b/hypogenic/algorithm/generation/base.py @@ -77,7 +77,7 @@ def batched_hyp_list_generation( example_indices: the indices of examples being used to generate hypotheses num_hypotheses_generate: the number of hypotheses that we expect our response to generate cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number - reference_hypotheses: the current hypotheses that we have in the bank (if any) + reference_hypotheses: A dictionary that accumulates the set of wrong hypotheses for each sample Returns: hypotheses_list: A list containing all newly generated hypotheses. diff --git a/hypogenic/algorithm/generation/default.py b/hypogenic/algorithm/generation/default.py index d20fd42..5735830 100644 --- a/hypogenic/algorithm/generation/default.py +++ b/hypogenic/algorithm/generation/default.py @@ -136,7 +136,7 @@ def batched_hypothesis_generation( alpha: exploration constant in hypogenic reward function cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number max_concurrent: The maximum number of concurrent requests - reference_hypotheses: the current hypotheses that we have in the bank (if any) + reference_hypotheses: A dictionary that accumulates the set of wrong hypotheses for each sample Returns: hypotheses_bank: A dictionary with keys as hypotheses and the values as the Summary Information class diff --git a/hypogenic/algorithm/update/default.py b/hypogenic/algorithm/update/default.py index 2f13727..263cd51 100644 --- a/hypogenic/algorithm/update/default.py +++ b/hypogenic/algorithm/update/default.py @@ -84,7 +84,7 @@ def update( # initialize variables num_train_examples = len(self.train_data) wrong_example_ids = set() - wrong_hypos_accumulated = set() + accumulated_sample_wrong_hypos = {} # {sample_id: set(wrong_hypothesis)} # ---------------------------------------------------------------------- # Figuring out starting samples @@ -128,6 +128,9 @@ def update( # We need to see how good our hypothesis is, which we do by way of the inference class # ------------------------------------------------------------------ num_wrong_hypotheses = 0 + # record the hypotheses that are wrong for the current sample + current_sample_wrong_hypos = set() + preds, labels = self.inference_class.batched_predict( self.train_data, [ @@ -148,7 +151,7 @@ def update( ) # let the bank know it got one wrong # record the wrong hypothesis - wrong_hypos_accumulated.add(hypothesis) + current_sample_wrong_hypos.add(hypothesis) else: hypotheses_bank[hypothesis].update_info_if_useful( current_sample, self.alpha @@ -170,6 +173,7 @@ def update( # We note it as a bad sample wrong_example_ids.add(i) + accumulated_sample_wrong_hypos[i] = current_sample_wrong_hypos if ( len(wrong_example_ids) == self.update_batch_size * self.num_hypotheses_to_update @@ -188,7 +192,7 @@ def update( self.alpha, cache_seed=cache_seed, max_concurrent=max_concurrent, - reference_hypotheses=list(wrong_hypos_accumulated), + reference_hypotheses=accumulated_sample_wrong_hypos, **generate_kwargs, ) ) @@ -206,7 +210,7 @@ def update( # reset wrong examples to be empty wrong_example_ids = set() # reset accumulated wrong hypotheses to be empty - wrong_hypos_accumulated.clear() + accumulated_sample_wrong_hypos.clear() # call replace class to update the bank hypotheses_bank = self.replace_class.replace( diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index 6f7f5d6..d66e970 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -187,25 +187,42 @@ def batched_generation(self, train_data, num_hypotheses): return prompt - def batched_error_augmented_generation(self, train_data, num_hypotheses, reference_hypotheses: List[Any]): + def batched_error_augmented_generation(self, train_data, num_hypotheses, reference_hypotheses): """ Generate hypotheses that is useful for predicting the color of the shoes given the appearance of the person. + + Parameters: + train_data: Training data + num_hypotheses: Number of hypotheses to generate + reference_hypotheses: A dictionary that accumulates the set of wrong hypotheses for each sample """ substitute_dict = {"num_hypotheses": num_hypotheses} - multi_sub_dicts = { - "observations": [], - "reference_hypotheses": [ - {"hypothesis": h, "idx": i + 1} for i, h in enumerate(reference_hypotheses) - ] - } + multi_sub_dicts = {"error_augmented_observation": []} - for example_idx in range(len(train_data)): - multi_sub_dicts["observations"].append( - self._get_substitute_dict(train_data, example_idx) + for sample_id, wrong_hypos in reference_hypotheses.items(): + sample_data = self._get_substitute_dict(train_data, sample_id) + + wrong_hypotheses_info = [] + for idx, hypothesis in enumerate(wrong_hypos): + wrong_hypotheses_info.append({ + "idx": idx + 1, + "hypothesis_text": hypothesis + }) + + wrong_hypotheses_text = self._fill_multi_content( + ({}, wrong_hypotheses_info), + self._get_prompt_template("wrong_hypotheses") ) + error_info = { + "review_sentence": sample_data["review_sentence"], + "label": sample_data["label"], + "wrong_hypotheses": wrong_hypotheses_text + } + multi_sub_dicts["error_augmented_observation"].append(error_info) + substitute_dict = self._fill_multi_in_sub_dict( substitute_dict, multi_sub_dicts, "batched_error_augmented_generation" ) From 0c0d3da5583f84abec8b56128bdb63c81e8bbd24 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Sat, 12 Jul 2025 20:01:45 +0800 Subject: [PATCH 08/23] Generalize batched_error_augmented_generation to support all datasets --- hypogenic/prompt.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index d66e970..f527115 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -215,13 +215,10 @@ def batched_error_augmented_generation(self, train_data, num_hypotheses, referen ({}, wrong_hypotheses_info), self._get_prompt_template("wrong_hypotheses") ) - - error_info = { - "review_sentence": sample_data["review_sentence"], - "label": sample_data["label"], - "wrong_hypotheses": wrong_hypotheses_text - } - multi_sub_dicts["error_augmented_observation"].append(error_info) + + sample_data["wrong_hypotheses"] = wrong_hypotheses_text + + multi_sub_dicts["error_augmented_observation"].append(sample_data) substitute_dict = self._fill_multi_in_sub_dict( substitute_dict, multi_sub_dicts, "batched_error_augmented_generation" From 7b17dfb1ae43f51870db78847920b76b6b0973ba Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Sat, 12 Jul 2025 20:05:44 +0800 Subject: [PATCH 09/23] Add 'aug' prefix to filenames for augmented hypotheses in get_res --- pipeline.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipeline.py b/pipeline.py index 3a7f41a..64eb0c9 100644 --- a/pipeline.py +++ b/pipeline.py @@ -1110,7 +1110,7 @@ def log_arguments(logger, args): methods_run.append(method_name) logger.info("=-=-=-=-=-=-=-=-=-=-=-=No Update=-=-=-=-=-=-=-=-=-=-=-=") results = get_res( - f"results/{task_name}/{model_name}/hyp_{max_num_hypotheses}/hypotheses_training_sample_10_seed_{seed}_epoch_0.json", + f"results/{task_name}/{model_name}/aug_hyp_{max_num_hypotheses}/hypotheses_training_sample_10_seed_{seed}_epoch_0.json", task_name=task_name, api=api, model_name=model_name, @@ -1123,7 +1123,7 @@ def log_arguments(logger, args): methods_run.append(method_name) logger.info("=-=-=-=-=-=-=-=-=-=-=-=With Update=-=-=-=-=-=-=-=-=-=-=-=") results = get_res( - f"results/{task_name}/{model_name}/hyp_{max_num_hypotheses}/hypotheses_training_sample_final_seed_{seed}_epoch_0.json", + f"results/{task_name}/{model_name}/aug_hyp_{max_num_hypotheses}/hypotheses_training_sample_final_seed_{seed}_epoch_0.json", task_name=task_name, api=api, model_name=model_name, From 6e0661a74e3129b57407f3d5076db67a186fb9f3 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Tue, 15 Jul 2025 22:44:57 +0800 Subject: [PATCH 10/23] list error samples for each incorrect hypo during augmented generation --- hypogenic/algorithm/generation/augmented.py | 6 ++- hypogenic/algorithm/update/default.py | 11 ++++-- hypogenic/prompt.py | 44 +++++++++------------ 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/hypogenic/algorithm/generation/augmented.py b/hypogenic/algorithm/generation/augmented.py index ef84b94..05272b2 100644 --- a/hypogenic/algorithm/generation/augmented.py +++ b/hypogenic/algorithm/generation/augmented.py @@ -27,7 +27,7 @@ def batched_hyp_list_generation( example_indices: the indices of examples being used to generate hypotheses num_hypotheses_generate: the number of hypotheses that we expect our response to generate cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number - reference_hypotheses: A dictionary that accumulates the set of wrong hypotheses for each sample + reference_hypotheses: A dictionary {wrong_hypothesis: set(sample_id)} that accumulates the set of wrong samples for each hypothesis Returns: hypotheses_list: A list containing all newly generated hypotheses. @@ -40,6 +40,10 @@ def batched_hyp_list_generation( self.train_data, num_hypotheses_generate, reference_hypotheses ) + print(f">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>" + f"Prompt is {str(prompt_input)}" + f"<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<") + # Batch generate responses based on the prompts that we just generated response = self.api.generate( prompt_input, cache_seed=cache_seed, **generate_kwargs diff --git a/hypogenic/algorithm/update/default.py b/hypogenic/algorithm/update/default.py index 263cd51..f4e042b 100644 --- a/hypogenic/algorithm/update/default.py +++ b/hypogenic/algorithm/update/default.py @@ -84,7 +84,7 @@ def update( # initialize variables num_train_examples = len(self.train_data) wrong_example_ids = set() - accumulated_sample_wrong_hypos = {} # {sample_id: set(wrong_hypothesis)} + accumulated_wrong_hyp_samples = {} # {wrong_hypothesis: set(sample_id)} # ---------------------------------------------------------------------- # Figuring out starting samples @@ -173,7 +173,10 @@ def update( # We note it as a bad sample wrong_example_ids.add(i) - accumulated_sample_wrong_hypos[i] = current_sample_wrong_hypos + for hypo in current_sample_wrong_hypos: + if hypo not in accumulated_wrong_hyp_samples: + accumulated_wrong_hyp_samples[hypo] = set() + accumulated_wrong_hyp_samples[hypo].add(i) if ( len(wrong_example_ids) == self.update_batch_size * self.num_hypotheses_to_update @@ -192,7 +195,7 @@ def update( self.alpha, cache_seed=cache_seed, max_concurrent=max_concurrent, - reference_hypotheses=accumulated_sample_wrong_hypos, + reference_hypotheses=accumulated_wrong_hyp_samples, **generate_kwargs, ) ) @@ -210,7 +213,7 @@ def update( # reset wrong examples to be empty wrong_example_ids = set() # reset accumulated wrong hypotheses to be empty - accumulated_sample_wrong_hypos.clear() + accumulated_wrong_hyp_samples.clear() # call replace class to update the bank hypotheses_bank = self.replace_class.replace( diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index f527115..f335e5b 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -189,43 +189,35 @@ def batched_generation(self, train_data, num_hypotheses): def batched_error_augmented_generation(self, train_data, num_hypotheses, reference_hypotheses): """ - Generate hypotheses that is useful for predicting the color of the shoes given the appearance of the person. - - Parameters: - train_data: Training data - num_hypotheses: Number of hypotheses to generate - reference_hypotheses: A dictionary that accumulates the set of wrong hypotheses for each sample + reference_hypotheses: {wrong_hypothesis: set(sample_id)} """ - substitute_dict = {"num_hypotheses": num_hypotheses} - multi_sub_dicts = {"error_augmented_observation": []} - for sample_id, wrong_hypos in reference_hypotheses.items(): - sample_data = self._get_substitute_dict(train_data, sample_id) - - wrong_hypotheses_info = [] - for idx, hypothesis in enumerate(wrong_hypos): - wrong_hypotheses_info.append({ - "idx": idx + 1, - "hypothesis_text": hypothesis - }) - - wrong_hypotheses_text = self._fill_multi_content( - ({}, wrong_hypotheses_info), - self._get_prompt_template("wrong_hypotheses") + for hypo_idx, (hypothesis, sample_ids) in enumerate(reference_hypotheses.items()): + # 收集所有 wrong sample 的详细信息 + wrong_samples_info = [] + for idx, sample_id in enumerate(sample_ids): + sample_data = self._get_substitute_dict(train_data, sample_id) + sample_data["idx"] = idx + 1 + wrong_samples_info.append(sample_data) + # 组装 wrong_samples 的文本 + wrong_samples_text = self._fill_multi_content( + ({}, wrong_samples_info), + self._get_prompt_template("wrong_samples") ) - - sample_data["wrong_hypotheses"] = wrong_hypotheses_text - - multi_sub_dicts["error_augmented_observation"].append(sample_data) + # 组装每个 hypothesis 的内容 + hyp_info = { + "hypothesis_text": hypothesis, + "wrong_samples": wrong_samples_text + } + multi_sub_dicts["error_augmented_observation"].append(hyp_info) substitute_dict = self._fill_multi_in_sub_dict( substitute_dict, multi_sub_dicts, "batched_error_augmented_generation" ) prompt = self._information_prompt(substitute_dict, "batched_error_augmented_generation") - return prompt def inference(self, hypotheses_dict, test_data, test_idx): From ffec6a19efd51eda872f74c5b7eae65141a84bb9 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Tue, 29 Jul 2025 16:41:59 -0500 Subject: [PATCH 11/23] Include correspondingly correct and wrong hypotheses for each sample in the generation prompt --- hypogenic/algorithm/generation/augmented.py | 45 +++++++-------------- hypogenic/algorithm/update/default.py | 22 +++++----- hypogenic/prompt.py | 26 ++++++++---- 3 files changed, 47 insertions(+), 46 deletions(-) diff --git a/hypogenic/algorithm/generation/augmented.py b/hypogenic/algorithm/generation/augmented.py index 05272b2..ffbdc75 100644 --- a/hypogenic/algorithm/generation/augmented.py +++ b/hypogenic/algorithm/generation/augmented.py @@ -18,35 +18,20 @@ def batched_hyp_list_generation( example_indices: List[int], num_hypotheses_generate: int, cache_seed=None, - reference_hypotheses=None, + reference_hypotheses=None,# {hypo: {"correct": set(), "wrong": set()}} **generate_kwargs ) -> List[str]: - """Batched hypothesis generation method. Takes multiple examples and creates a hypothesis with them. - - Parameters: - example_indices: the indices of examples being used to generate hypotheses - num_hypotheses_generate: the number of hypotheses that we expect our response to generate - cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number - reference_hypotheses: A dictionary {wrong_hypothesis: set(sample_id)} that accumulates the set of wrong samples for each hypothesis - - Returns: - hypotheses_list: A list containing all newly generated hypotheses. - """ - # ---------------------------------------------------------------------- - # Prompt LLM to generate hypotheses - # ---------------------------------------------------------------------- - # Batch generate a bunch of prompts based on yaml file - prompt_input = self.prompt_class.batched_error_augmented_generation( - self.train_data, num_hypotheses_generate, reference_hypotheses - ) - - print(f">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>" - f"Prompt is {str(prompt_input)}" - f"<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<") - - # Batch generate responses based on the prompts that we just generated - response = self.api.generate( - prompt_input, cache_seed=cache_seed, **generate_kwargs - ) - - return extract_hypotheses(response, num_hypotheses_generate) + batch_size = 1 + all_new_hypos = [] + hypo_items = list(reference_hypotheses.items()) + total = len(hypo_items) + for i in range(0, total, batch_size): + batch = dict(hypo_items[i:i+batch_size]) + prompt_input = self.prompt_class.batched_error_augmented_generation( + self.train_data, len(batch), batch + ) + response = self.api.generate( + prompt_input, cache_seed=cache_seed, **generate_kwargs + ) + all_new_hypos.extend(extract_hypotheses(response, 1)) + return all_new_hypos diff --git a/hypogenic/algorithm/update/default.py b/hypogenic/algorithm/update/default.py index f4e042b..1eb3c46 100644 --- a/hypogenic/algorithm/update/default.py +++ b/hypogenic/algorithm/update/default.py @@ -84,7 +84,7 @@ def update( # initialize variables num_train_examples = len(self.train_data) wrong_example_ids = set() - accumulated_wrong_hyp_samples = {} # {wrong_hypothesis: set(sample_id)} + accumulated_wrong_hyp_samples = {} # {hypo: {"correct": set(), "wrong": set()}} # ---------------------------------------------------------------------- # Figuring out starting samples @@ -129,7 +129,7 @@ def update( # ------------------------------------------------------------------ num_wrong_hypotheses = 0 # record the hypotheses that are wrong for the current sample - current_sample_wrong_hypos = set() + current_hypo_samples = {} # {hypo: {"correct": set(), "wrong": set()}} preds, labels = self.inference_class.batched_predict( self.train_data, @@ -144,14 +144,14 @@ def update( # Comparison of the label and prediction for pred, label, hypothesis in zip(preds, labels, top_k_hypotheses): + if hypothesis not in current_hypo_samples: + current_hypo_samples[hypothesis] = {"correct": set(), "wrong": set()} if pred != label: num_wrong_hypotheses += 1 hypotheses_bank[hypothesis].update_info_if_not_useful( current_sample, self.alpha ) # let the bank know it got one wrong - - # record the wrong hypothesis - current_sample_wrong_hypos.add(hypothesis) + current_hypo_samples[hypothesis]["wrong"].add(i) else: hypotheses_bank[hypothesis].update_info_if_useful( current_sample, self.alpha @@ -159,6 +159,7 @@ def update( # keeping track of good examples as we do in generation hypotheses_bank[hypothesis].update_useful_examples(i, label) + current_hypo_samples[hypothesis]["correct"].add(i) # ------------------------------------------------------------------ # Generating a new hypothesis @@ -173,10 +174,13 @@ def update( # We note it as a bad sample wrong_example_ids.add(i) - for hypo in current_sample_wrong_hypos: - if hypo not in accumulated_wrong_hyp_samples: - accumulated_wrong_hyp_samples[hypo] = set() - accumulated_wrong_hyp_samples[hypo].add(i) + for hypothesis in current_hypo_samples: + if hypothesis not in accumulated_wrong_hyp_samples: + accumulated_wrong_hyp_samples[hypothesis] = {"correct": set(), "wrong": set()} + accumulated_wrong_hyp_samples[hypothesis]["correct"].update( + current_hypo_samples[hypothesis]["correct"]) + accumulated_wrong_hyp_samples[hypothesis]["wrong"].update( + current_hypo_samples[hypothesis]["wrong"]) if ( len(wrong_example_ids) == self.update_batch_size * self.num_hypotheses_to_update diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index f335e5b..77dae5f 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -189,26 +189,38 @@ def batched_generation(self, train_data, num_hypotheses): def batched_error_augmented_generation(self, train_data, num_hypotheses, reference_hypotheses): """ - reference_hypotheses: {wrong_hypothesis: set(sample_id)} + reference_hypotheses: {hypo: {"correct": set(), "wrong": set()}} """ - substitute_dict = {"num_hypotheses": num_hypotheses} + substitute_dict = {} multi_sub_dicts = {"error_augmented_observation": []} - for hypo_idx, (hypothesis, sample_ids) in enumerate(reference_hypotheses.items()): - # 收集所有 wrong sample 的详细信息 + for hypo_idx, (hypothesis, sample_dict) in enumerate(reference_hypotheses.items()): + # 处理 correct samples,只取前3个 + correct_samples_info = [] + correct_ids = list(sample_dict.get("correct", []))[:3] + for idx, sample_id in enumerate(correct_ids): + sample_data = self._get_substitute_dict(train_data, sample_id) + sample_data["idx"] = idx + 1 + correct_samples_info.append(sample_data) + correct_samples_text = self._fill_multi_content( + ({}, correct_samples_info), + self._get_prompt_template("correct_samples") + ) + + # 处理 wrong samples wrong_samples_info = [] - for idx, sample_id in enumerate(sample_ids): + for idx, sample_id in enumerate(sample_dict.get("wrong", [])): sample_data = self._get_substitute_dict(train_data, sample_id) sample_data["idx"] = idx + 1 wrong_samples_info.append(sample_data) - # 组装 wrong_samples 的文本 wrong_samples_text = self._fill_multi_content( ({}, wrong_samples_info), self._get_prompt_template("wrong_samples") ) - # 组装每个 hypothesis 的内容 + hyp_info = { "hypothesis_text": hypothesis, + "correct_samples": correct_samples_text, "wrong_samples": wrong_samples_text } multi_sub_dicts["error_augmented_observation"].append(hyp_info) From 9201ad26b5e1c3f82336f2096da2f4790bb30889 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Fri, 1 Aug 2025 15:44:57 -0500 Subject: [PATCH 12/23] Implement Hierarchical Inference by creating stump from the final hypotheses --- .../data_analysis_agent/inference.py | 80 +++++++++++++++++++ .../data_analysis_agent/prompt.py | 9 +++ pipeline.py | 61 +++++++++++++- 3 files changed, 149 insertions(+), 1 deletion(-) diff --git a/hypothesis_agent/data_analysis_agent/inference.py b/hypothesis_agent/data_analysis_agent/inference.py index f8caeeb..4590760 100644 --- a/hypothesis_agent/data_analysis_agent/inference.py +++ b/hypothesis_agent/data_analysis_agent/inference.py @@ -169,3 +169,83 @@ def batched_predict( actual_labels = [data[self.task.label_name][index] for index, _ in idx_hyp_pair] return predictions, actual_labels, idx_hyp_pair + +class MultiHypHierarchicalInference(DefaultInference): + def __init__( + self, + api, + prompt_class: TestPrompt, + train_data: pd.DataFrame, + task: BaseTask, + ): + super().__init__(api, prompt_class, train_data, task) + + def create_stump_from_hypotheses( + self, + hyp_bank, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + prompt_inputs = [ + self.prompt_class.create_stump_from_hypotheses(hyp_bank) + ] + responses = self.api.batched_generate( + prompt_inputs, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + print(responses) + return {responses[0]:SummaryInformation} + + def multiple_hypotheses_batched_predict( + self, + data: pd.DataFrame, + idx_hyp_pair=List[Tuple[int, Dict[str, SummaryInformation]]], + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + prompt_inputs = [ + self.prompt_class.multiple_hypotheses_inference(hyp_bank, data, index) + for index, hyp_bank in idx_hyp_pair + ] + responses = self.api.batched_generate( + prompt_inputs, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + actual_labels = [data[self.task.label_name][index] for index, _ in idx_hyp_pair] + predictions = [self.task.extract_label(responses[i]) for i in range(len(responses))] + + return predictions, actual_labels + + def run_inference_final( + self, + data, + hyp_bank, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + stump_hyp = self.create_stump_from_hypotheses( + hyp_bank, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + + num_samples = len(data) + + return self.multiple_hypotheses_batched_predict( + data, + [ + (i, stump_hyp) + for i in range(num_samples) + ], + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) \ No newline at end of file diff --git a/hypothesis_agent/data_analysis_agent/prompt.py b/hypothesis_agent/data_analysis_agent/prompt.py index e607a85..a047b56 100644 --- a/hypothesis_agent/data_analysis_agent/prompt.py +++ b/hypothesis_agent/data_analysis_agent/prompt.py @@ -197,6 +197,15 @@ def multiple_hypotheses_inference(self, hypotheses_dict, test_data, test_idx): return prompt + def create_stump_from_hypotheses(self, hypotheses_dict): + hypotheses_list = list(hypotheses_dict.keys()) + + substitute_dict= {"hypotheses": "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)])} + + prompt = self._information_prompt(substitute_dict, "create_stump_from_hypotheses") + + return prompt + def test_autogen( self, train_data: pd.DataFrame, num_hypotheses, paper_infos: List[Dict[str, str]] ): diff --git a/pipeline.py b/pipeline.py index 64eb0c9..32b5a9f 100644 --- a/pipeline.py +++ b/pipeline.py @@ -30,7 +30,7 @@ OnlyPaperGeneration, ZeroShotGeneration, ) -from hypothesis_agent.data_analysis_agent.inference import MultiHypDefaultInference +from hypothesis_agent.data_analysis_agent.inference import MultiHypDefaultInference, MultiHypHierarchicalInference from hypothesis_agent.data_analysis_agent.update import TestUpdate from hypothesis_agent.literature_review_agent import LiteratureAgent from hypothesis_agent.literature_review_agent.literature_processor.extract_info import ( @@ -72,6 +72,7 @@ parser.add_argument("--run_notebooklm", action="store_true", help="Run NotebookLM") parser.add_argument("--run_hypogenic", action="store_true", help="Run original HypoGeniC") parser.add_argument("--run_augmented_hypogenic", action="store_true", help="Run augmented HypoGeniC") +parser.add_argument("--run_hierarchical_inference", action="store_true", help="Run hierarchical inference") parser.add_argument("--run_hyporefine", action="store_true", help="Run HypoRefine") parser.add_argument("--run_union_hypo", action="store_true", help="Run Union HypoGeniC and Paper") parser.add_argument("--run_union_refine", action="store_true", help="Run Union HypoRefine and Paper") @@ -744,6 +745,49 @@ def get_res(filename: str, task_name, api, model_name, use_val=False, multihyp=F } return formatted_results +def get_res_hierarchical(filename: str, task_name, api, model_name, use_val=False, multihyp=False): + logger = LoggerConfig.get_logger("Agent - get_res_hierarchical") + + set_seed(seed) + + task = BaseTask( + config_path=f"./data/{task_name}/config.yaml", + from_register=extract_label_register, + use_ood=use_ood + ) + + train_data, test_data, val_data = task.get_data(num_train, num_test, num_val, seed) + if use_val: + test_data = val_data + + prompt_class = TestPrompt(task) + + with open(filename) as f: + hyp_dict = json.load(f) + hyp_bank = {} + for hypothesis in hyp_dict: + hyp_bank[hypothesis] = SummaryInformation.from_dict(hyp_dict[hypothesis]) + + + inference_class = MultiHypHierarchicalInference(api, prompt_class, train_data, task) + pred_list, label_list = inference_class.run_inference_final( + test_data, + hyp_bank, + cache_seed=cache_seed, + max_concurrent=64, + temperature=temperature, + max_tokens=max_tokens, + ) + + results_dict = get_results(pred_list, label_list) + f1 = results_dict["f1"] + acc = results_dict["accuracy"] + logger_str = "Results:\n" + logger_str += f"Accuracy: {acc}\n" + logger_str += f"F1: {f1}\n\n" + logger.info(logger_str) + return results_dict # Return results dictionary + def baseline(few_shot_k, task_name, api, model_name, seed=42, use_val=False): def few_shot( api: LLMWrapper, @@ -938,6 +982,7 @@ def log_arguments(logger, args): ("Run NotebookLM", args.run_notebooklm), ("Run HypoGeniC", args.run_hypogenic), ("Run Augmented HypoGeniC", args.run_augmented_hypogenic), + ("Run Hierarchical Inference", args.run_hierarchical_inference), ("Run HypoRefine", args.run_hyporefine), ("Run Union HypoGeniC", args.run_union_hypo), ("Run Union HypoRefine", args.run_union_refine), @@ -1132,6 +1177,20 @@ def log_arguments(logger, args): ) save_method_results(method_name, results, task_name, model_name, seed, use_ood=use_ood) + if args.run_hierarchical_inference: + method_name = "hypogenic" + methods_run.append(method_name) + logger.info("=-=-=-=-=-=-=-=-=-=-=-=With Update=-=-=-=-=-=-=-=-=-=-=-=") + results = get_res_hierarchical( + f"results/{task_name}/{model_name}/hierarchical_hyp_{max_num_hypotheses}/hypotheses_training_sample_final_seed_{seed}_epoch_0.json", + task_name=task_name, + api=api, + model_name=model_name, + use_val=use_val, + multihyp=multihyp, + ) + save_method_results(method_name, results, task_name, model_name, seed, use_ood=use_ood) + if args.run_hyporefine: logger.info("=-=-=-=-=-=-=-=-=-=-=-=HypoRefine=-=-=-=-=-=-=-=-=-=-=-=") if DO_TRAIN: From ddc33b276fdab5ff8fdecf32db0be20e8b8d4d57 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Sat, 2 Aug 2025 19:20:47 -0500 Subject: [PATCH 13/23] =?UTF-8?q?use=20regex=20in=20create=5Fstump=5Ffrom?= =?UTF-8?q?=5Fhypotheses=20to=20strip=20the=20reasoning=20model=E2=80=99s?= =?UTF-8?q?=20"think"=20section?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- hypothesis_agent/data_analysis_agent/inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/hypothesis_agent/data_analysis_agent/inference.py b/hypothesis_agent/data_analysis_agent/inference.py index 4590760..5a5f6a6 100644 --- a/hypothesis_agent/data_analysis_agent/inference.py +++ b/hypothesis_agent/data_analysis_agent/inference.py @@ -196,8 +196,9 @@ def create_stump_from_hypotheses( max_concurrent=max_concurrent, **generate_kwargs, ) + responses = re.sub(r'.*?', '', responses[0], flags=re.IGNORECASE | re.DOTALL) print(responses) - return {responses[0]:SummaryInformation} + return {responses:SummaryInformation} def multiple_hypotheses_batched_predict( self, From 43842889c1694bb5f4e7198ad019d06edae34ce9 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Sun, 3 Aug 2025 21:05:49 -0500 Subject: [PATCH 14/23] undo Default class modifications; use Augmented class for isolation --- hypogenic/algorithm/generation/default.py | 15 ++++++--------- hypogenic/algorithm/inference/default.py | 4 ++-- hypogenic/algorithm/update/default.py | 22 ++-------------------- hypogenic/tasks.py | 2 +- 4 files changed, 11 insertions(+), 32 deletions(-) diff --git a/hypogenic/algorithm/generation/default.py b/hypogenic/algorithm/generation/default.py index 5735830..bc33a3c 100644 --- a/hypogenic/algorithm/generation/default.py +++ b/hypogenic/algorithm/generation/default.py @@ -15,7 +15,7 @@ @generation_register.register("default") class DefaultGeneration(Generation): """ - Add on extra functionality to the generation function - we consider this + Add on extra functionality to the generation fucntion - we consider this the "default" task. """ @@ -29,8 +29,8 @@ def __init__( """ Parameters: api: the language model that you're using, which may or may not be local - prompt_class: let us know how the prompt is going to look - inference_class: gives us a way to predict labels for accuracy’s sake + prompt_class: let's us know how the prompt is going to look + inference_class: gives us a way to predict labels for accuracy sake task: determines the goal to accomplish """ super().__init__(api, prompt_class, inference_class, task) @@ -38,7 +38,7 @@ def __init__( # ------------------------------------------------------------------------ # # # # ------------------------------------------------------------------------ # - # BATCH INITIALIZE HYPOTHESES # + # BATCH INITLALIZE HYPOTHESES # # ------------------------------------------------------------------------ # # # # ------------------------------------------------------------------------ # @@ -57,7 +57,7 @@ def batched_initialize_hypotheses( Parameters: num_init: the total amount of examples you want to use for initialize hypotheses init_batch size: the number of examples that will be used to generate these hypotheses - init_hypotheses_per_batch: the amount of hypotheses that you want to generate per batch + init_hypotheses_per_batch: the amount of hypotheses that you want to generate per btach cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number max_concurrent: the maximum amount of concurrent calls to the API @@ -123,7 +123,6 @@ def batched_hypothesis_generation( alpha: float, cache_seed=None, max_concurrent=3, - reference_hypotheses=None, **generate_kwargs, ): """ @@ -133,10 +132,9 @@ def batched_hypothesis_generation( example_ids: The ids of the examples for which hypotheses need to be generated current_sample: the current sample in data which the algorithm is on num_hypotheses_generate: the number of hypotheses that we expect our response to generate - alpha: exploration constant in hypogenic reward function + alpha: eploration constant in hypogenic reward funciton cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number max_concurrent: The maximum number of concurrent requests - reference_hypotheses: A dictionary that accumulates the set of wrong hypotheses for each sample Returns: hypotheses_bank: A dictionary with keys as hypotheses and the values as the Summary Information class @@ -145,7 +143,6 @@ def batched_hypothesis_generation( example_ids, num_hypotheses_generate, cache_seed=cache_seed, - reference_hypotheses=reference_hypotheses, **generate_kwargs, ) diff --git a/hypogenic/algorithm/inference/default.py b/hypogenic/algorithm/inference/default.py index e66884d..84d9c4c 100644 --- a/hypogenic/algorithm/inference/default.py +++ b/hypogenic/algorithm/inference/default.py @@ -35,7 +35,7 @@ def batched_predict( **generate_kwargs, ): """ - Makes a batch of predictions on a hypothesis. + Makes a batch of preductions on a hypothesis. Parameters: data: the data to predict on @@ -78,7 +78,7 @@ def run_inference_final( """ Function for testing the best hypothesis - Parameters: + Prameters: data: the data to predict on hyp_bank: the hypotheses that we want to predict from cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number diff --git a/hypogenic/algorithm/update/default.py b/hypogenic/algorithm/update/default.py index 1eb3c46..7b11386 100644 --- a/hypogenic/algorithm/update/default.py +++ b/hypogenic/algorithm/update/default.py @@ -84,7 +84,6 @@ def update( # initialize variables num_train_examples = len(self.train_data) wrong_example_ids = set() - accumulated_wrong_hyp_samples = {} # {hypo: {"correct": set(), "wrong": set()}} # ---------------------------------------------------------------------- # Figuring out starting samples @@ -128,9 +127,6 @@ def update( # We need to see how good our hypothesis is, which we do by way of the inference class # ------------------------------------------------------------------ num_wrong_hypotheses = 0 - # record the hypotheses that are wrong for the current sample - current_hypo_samples = {} # {hypo: {"correct": set(), "wrong": set()}} - preds, labels = self.inference_class.batched_predict( self.train_data, [ @@ -144,14 +140,11 @@ def update( # Comparison of the label and prediction for pred, label, hypothesis in zip(preds, labels, top_k_hypotheses): - if hypothesis not in current_hypo_samples: - current_hypo_samples[hypothesis] = {"correct": set(), "wrong": set()} if pred != label: num_wrong_hypotheses += 1 hypotheses_bank[hypothesis].update_info_if_not_useful( current_sample, self.alpha ) # let the bank know it got one wrong - current_hypo_samples[hypothesis]["wrong"].add(i) else: hypotheses_bank[hypothesis].update_info_if_useful( current_sample, self.alpha @@ -159,7 +152,6 @@ def update( # keeping track of good examples as we do in generation hypotheses_bank[hypothesis].update_useful_examples(i, label) - current_hypo_samples[hypothesis]["correct"].add(i) # ------------------------------------------------------------------ # Generating a new hypothesis @@ -174,13 +166,6 @@ def update( # We note it as a bad sample wrong_example_ids.add(i) - for hypothesis in current_hypo_samples: - if hypothesis not in accumulated_wrong_hyp_samples: - accumulated_wrong_hyp_samples[hypothesis] = {"correct": set(), "wrong": set()} - accumulated_wrong_hyp_samples[hypothesis]["correct"].update( - current_hypo_samples[hypothesis]["correct"]) - accumulated_wrong_hyp_samples[hypothesis]["wrong"].update( - current_hypo_samples[hypothesis]["wrong"]) if ( len(wrong_example_ids) == self.update_batch_size * self.num_hypotheses_to_update @@ -189,7 +174,7 @@ def update( # generate new hypotheses for j in range(self.num_hypotheses_to_update): - # Go through poorly performing examples and generate hypotheses for them + # Go through poorly performing exmaples and generate hypotheses for them # TODO: batched? new_hypotheses = ( self.generation_class.batched_hypothesis_generation( @@ -199,12 +184,11 @@ def update( self.alpha, cache_seed=cache_seed, max_concurrent=max_concurrent, - reference_hypotheses=accumulated_wrong_hyp_samples, **generate_kwargs, ) ) - # If we only take the best performing hypothesis from the batch + # If we onlt take the best performing hypothesis from the batch if self.only_best_hypothesis: best_hypothesis = max( new_hypotheses, key=lambda x: new_hypotheses[x].reward @@ -216,8 +200,6 @@ def update( new_hyp_bank.update(new_hypotheses) # reset wrong examples to be empty wrong_example_ids = set() - # reset accumulated wrong hypotheses to be empty - accumulated_wrong_hyp_samples.clear() # call replace class to update the bank hypotheses_bank = self.replace_class.replace( diff --git a/hypogenic/tasks.py b/hypogenic/tasks.py index 5285b22..264600f 100644 --- a/hypogenic/tasks.py +++ b/hypogenic/tasks.py @@ -47,7 +47,7 @@ def __init__( self.test_data_path = data["ood_data_path"] self.val_data_path = data["ood_data_path"] - # getting prompt templates from yaml file + # getting omrpt templates from yaml file self.prompt_template = data["prompt_templates"] # task label From ac0b31d050f124c7f779249dd308fe86df8dba5c Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Mon, 4 Aug 2025 17:01:11 -0500 Subject: [PATCH 15/23] Implement redundancy clearance in update process; refactor Generation class; introduce AugmentedUpdate for isolated updates --- hypogenic/algorithm/generation/augmented.py | 81 ++++++- hypogenic/algorithm/generation/base.py | 43 +++- hypogenic/algorithm/generation/utils.py | 11 +- hypogenic/algorithm/update/augmented.py | 248 ++++++++++++++++++++ hypogenic/prompt.py | 9 + pipeline.py | 5 +- 6 files changed, 382 insertions(+), 15 deletions(-) create mode 100644 hypogenic/algorithm/update/augmented.py diff --git a/hypogenic/algorithm/generation/augmented.py b/hypogenic/algorithm/generation/augmented.py index ffbdc75..5dc2a64 100644 --- a/hypogenic/algorithm/generation/augmented.py +++ b/hypogenic/algorithm/generation/augmented.py @@ -4,8 +4,60 @@ from .utils import extract_hypotheses -@generation_register.register("Augmented") +@generation_register.register("augmented") class AugmentedGeneration(DefaultGeneration): + + # ------------------------------------------------------------------------ # + # # + # ------------------------------------------------------------------------ # + # BATCHED_HYPOTHESIS GENERATION # + # ------------------------------------------------------------------------ # + # # + # ------------------------------------------------------------------------ # + def batched_hypothesis_generation( + self, + example_ids, + current_sample, + num_hypotheses_generate: int, + alpha: float, + cache_seed=None, + max_concurrent=3, + reference_hypotheses=None,# {hypo: {"correct": set(), "wrong": set()}} + **generate_kwargs, + ): + """ + Generates new hypotheses for the given examples + + Parameters: + example_ids: The ids of the examples for which hypotheses need to be generated + current_sample: the current sample in data which the algorithm is on + num_hypotheses_generate: the number of hypotheses that we expect our response to generate + alpha: exploration constant in hypogenic reward function + cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number + max_concurrent: The maximum number of concurrent requests + reference_hypotheses: A dictionary of reference hypotheses with their associated correct and wrong sets + + Returns: + hypotheses_bank: A dictionary with keys as hypotheses and the values as the Summary Information class + """ + new_hypotheses = self.batched_hyp_list_generation( + example_ids, + num_hypotheses_generate, + cache_seed=cache_seed, + reference_hypotheses=reference_hypotheses, + **generate_kwargs, + ) + + return self.make_hypotheses_bank( + example_ids, + current_sample, + alpha, + new_hypotheses, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + # ------------------------------------------------------------------------ # # # # ------------------------------------------------------------------------ # @@ -35,3 +87,30 @@ def batched_hyp_list_generation( ) all_new_hypos.extend(extract_hypotheses(response, 1)) return all_new_hypos + + def remove_redundancy( + self, + example_ids, + current_sample, + current_hyp_bank, + alpha: float, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + prompt_input = self.prompt_class.remove_redundancy(current_hyp_bank) + response = self.api.generate( + prompt_input, + cache_seed=cache_seed, + **generate_kwargs, + ) + new_hyp_list = extract_hypotheses(response) + return self.make_hypotheses_bank( + example_ids, + current_sample, + alpha, + new_hyp_list, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) \ No newline at end of file diff --git a/hypogenic/algorithm/generation/base.py b/hypogenic/algorithm/generation/base.py index 56a7a5b..480f0ad 100644 --- a/hypogenic/algorithm/generation/base.py +++ b/hypogenic/algorithm/generation/base.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod import math -import os from typing import List from .utils import extract_hypotheses @@ -23,7 +22,7 @@ def __init__( """Initialize the update class Parameters: - api: The LLM API to call for intialization and batched hypothesis generation + api: The LLM API to call for initialization and batched hypothesis generation It could also be a local LLM. prompt_class: the class containing specific prompts for the task inference_class: The Inference Class to call when checking for accuracy @@ -49,7 +48,11 @@ def batched_initialize_hypotheses( """Initialization method for generating hypotheses. Make sure to only loop till args.num_init Parameters: - args: the parsed arguments + num_init: Total number of hypotheses to generate during initialization phase + init_batch_size: Number of samples to process in each batch + init_hypotheses_per_batch: Number of hypotheses to generate per batch + cache_seed: Cache seed, if None will not use cache, otherwise will use cache with corresponding seed number + max_concurrent: Maximum number of concurrent requests to make to the API Returns: hypotheses_bank: A dictionary with keys as hypotheses and the values as the Summary Information class @@ -68,7 +71,6 @@ def batched_hyp_list_generation( example_indices: List[int], num_hypotheses_generate: int, cache_seed=None, - reference_hypotheses=None, **generate_kwargs ) -> List[str]: """Batched hypothesis generation method. Takes multiple examples and creates a hypothesis with them. @@ -77,7 +79,6 @@ def batched_hyp_list_generation( example_indices: the indices of examples being used to generate hypotheses num_hypotheses_generate: the number of hypotheses that we expect our response to generate cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number - reference_hypotheses: A dictionary that accumulates the set of wrong hypotheses for each sample Returns: hypotheses_list: A list containing all newly generated hypotheses. @@ -130,9 +131,8 @@ def make_hypotheses_bank( Parameters: example_indices: the indices of examples being used to generate hypotheses current_sample: the current sample in data which the algorithm is on - num_hypotheses_generate: the number of hypotheses that we expect our repsonse to generate - hypotheses: a list of hypotheses generated by the LM - alpha: eploration constant in hypogenic reward funciton + alpha: exploration constant in hypogenic reward function + hypotheses_list: the list of hypotheses that we want to make bank for cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number max_concurrent: the maximum number of concurrent requests to make to the API @@ -190,3 +190,30 @@ def make_hypotheses_bank( new_generated_hypotheses[hyp].set_example(ex) return new_generated_hypotheses + + @abstractmethod + def batched_hypothesis_generation( + self, + example_ids, + current_sample, + num_hypotheses_generate: int, + alpha: float, + cache_seed=None, + max_concurrent=3, + reference_hypotheses=None, + **generate_kwargs, + ): + pass + + @abstractmethod + def remove_redundancy( + self, + example_ids, + current_sample, + current_hyp_bank, + alpha: float, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + pass \ No newline at end of file diff --git a/hypogenic/algorithm/generation/utils.py b/hypogenic/algorithm/generation/utils.py index 6558afd..8ccf3c2 100644 --- a/hypogenic/algorithm/generation/utils.py +++ b/hypogenic/algorithm/generation/utils.py @@ -4,7 +4,7 @@ logger_name = "HypoGenic - Generation" -def extract_hypotheses(text: str, num_hypotheses) -> List[str]: +def extract_hypotheses(text: str, num_hypotheses = None) -> List[str]: """ Given a response with hypotheses, we want to take all of them out from the text. This function might need to be adjusted depending on the prompt and the @@ -31,7 +31,10 @@ def extract_hypotheses(text: str, num_hypotheses) -> List[str]: hypotheses = list(set([hypothesis.strip() for hypothesis in hypotheses])) # this is a bit sketchy - if len(hypotheses) != num_hypotheses: - logger.warn(f"Expected {num_hypotheses} hypotheses, but got {len(hypotheses)}.") + if num_hypotheses: + if len(hypotheses) != num_hypotheses: + logger.warning(f"Expected {num_hypotheses} hypotheses, but got {len(hypotheses)}.") - return hypotheses[:num_hypotheses] + return hypotheses[:num_hypotheses] + else: + return hypotheses diff --git a/hypogenic/algorithm/update/augmented.py b/hypogenic/algorithm/update/augmented.py new file mode 100644 index 0000000..e8f5d24 --- /dev/null +++ b/hypogenic/algorithm/update/augmented.py @@ -0,0 +1,248 @@ +from typing import Dict + +from . import update_register +from .base import Update +from ..generation import Generation +from ..inference import Inference +from ..replace import Replace +from ..summary_information import SummaryInformation +from ...logger_config import LoggerConfig + +logger_name = "HypoGenic - Augmented Update" + + +@update_register.register("augmented") +class AugmentedUpdate(Update): + """ + DefaultUpdate uses ONE hypothesis to make a prediction on a new example. + """ + + def __init__( + self, + generation_class: Generation, + inference_class: Inference, + replace_class: Replace, + save_path: str, + file_name_template: str = "hypotheses_training_sample_${sample}_seed_${seed}_epoch_${epoch}.json", + sample_num_to_restart_from=-1, + num_init=25, + epoch_to_start_from=0, + num_wrong_scale=0.8, + k=-1, + alpha=5e-1, + update_batch_size=5, + num_hypotheses_to_update=5, + update_hypotheses_per_batch=5, + only_best_hypothesis=False, + save_every_n_examples=100, + ): + super().__init__( + generation_class, + inference_class, + replace_class, + save_path, + file_name_template, + sample_num_to_restart_from, + num_init, + epoch_to_start_from, + num_wrong_scale, + k, + alpha, + update_batch_size, + num_hypotheses_to_update, + update_hypotheses_per_batch, + only_best_hypothesis, + save_every_n_examples, + ) + + def update( + self, + hypotheses_bank: Dict[str, SummaryInformation], + current_epoch, + current_seed, + cache_seed=None, + max_concurrent=3, + redundancy_threshold = 5, + **generate_kwargs, + ): + """ + We update the hypothesis bank once we reach a certain amount of regret + + Parameters: + hypotheses_bank: The hypothesis bank + current_epoch: The current epoch + current_seed: The current seed + cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number + max_concurrent: The maximum number of concurrent requests + redundancy_threshold: The threshold for removing redundancy in the hypotheses bank + """ + logger = LoggerConfig.get_logger(logger_name) + + # initialize variables + num_train_examples = len(self.train_data) + wrong_example_ids = set() + accumulated_wrong_hyp_samples = {} # {hypo: {"correct": set(), "wrong": set()}} + generate_count = 0 + + # ---------------------------------------------------------------------- + # Figuring out starting samples + # ---------------------------------------------------------------------- + # go through training examples + # When restarting from epoch > 0, no need to start at num_init + # When not restarting, then default sample_num_to_restart_from = -1. start with num_init. + # For multiple epochs restarts, there should always be a non-negative sample_num_to_restart_from + if self.sample_num_to_restart_from >= 0: + start_sample = self.sample_num_to_restart_from + else: + start_sample = self.num_init + + # This is to check if we are running more epochs than the starting epoch, if so, start at sample 0 + # basically, if we've completed the starting epoch, we want to start the next one + if current_epoch > self.epoch_to_start_from: + start_sample = 0 + + # ---------------------------------------------------------------------- + # Creating the new hypotheses + # ---------------------------------------------------------------------- + # from the start to the end + for i in range(start_sample, num_train_examples): + # the 'i' here is the sample we are testing each of the top hypotheses + + current_sample = i + 1 + logger.info(f"Training on example {i}") + + # We need to get the best k for testing the strength of our hypothesis bank + top_k_hypotheses = sorted( + hypotheses_bank, key=lambda x: hypotheses_bank[x].reward, reverse=True + )[: self.k] + + # We are at the regret that we need in order to generate a new hypothesis + if self.num_wrong_scale > 0: + num_wrong_to_add_bank = ( + len(top_k_hypotheses) * i / num_train_examples + ) * self.num_wrong_scale + else: + raise ValueError("num_wrong_scale should be greater than 0.") + # ------------------------------------------------------------------ + # We need to see how good our hypothesis is, which we do by way of the inference class + # ------------------------------------------------------------------ + num_wrong_hypotheses = 0 + # record the hypotheses that are wrong for the current sample + current_hypo_samples = {} # {hypo: {"correct": set(), "wrong": set()}} + + preds, labels = self.inference_class.batched_predict( + self.train_data, + [ + (i, {hypothesis: hypotheses_bank[hypothesis]}) + for hypothesis in top_k_hypotheses + ], + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + + # Comparison of the label and prediction + for pred, label, hypothesis in zip(preds, labels, top_k_hypotheses): + if hypothesis not in current_hypo_samples: + current_hypo_samples[hypothesis] = {"correct": set(), "wrong": set()} + if pred != label: + num_wrong_hypotheses += 1 + hypotheses_bank[hypothesis].update_info_if_not_useful( + current_sample, self.alpha + ) # let the bank know it got one wrong + current_hypo_samples[hypothesis]["wrong"].add(i) + else: + hypotheses_bank[hypothesis].update_info_if_useful( + current_sample, self.alpha + ) # let the bank know it got one right + + # keeping track of good examples as we do in generation + hypotheses_bank[hypothesis].update_useful_examples(i, label) + current_hypo_samples[hypothesis]["correct"].add(i) + + # ------------------------------------------------------------------ + # Generating a new hypothesis + # ------------------------------------------------------------------ + + # if we get enough wrong examples as determined by num_wrong_to_add_bank, + # we need to generate new hypotheses + if ( + num_wrong_hypotheses >= num_wrong_to_add_bank + or len(top_k_hypotheses) == 0 + ): + # We note it as a bad sample + wrong_example_ids.add(i) + for hypothesis in current_hypo_samples: + if hypothesis not in accumulated_wrong_hyp_samples: + accumulated_wrong_hyp_samples[hypothesis] = {"correct": set(), "wrong": set()} + accumulated_wrong_hyp_samples[hypothesis]["correct"].update( + current_hypo_samples[hypothesis]["correct"]) + accumulated_wrong_hyp_samples[hypothesis]["wrong"].update( + current_hypo_samples[hypothesis]["wrong"]) + if ( + len(wrong_example_ids) + == self.update_batch_size * self.num_hypotheses_to_update + ): + generate_count += 1 + new_hyp_bank = {} + + # generate new hypotheses + for j in range(self.num_hypotheses_to_update): + # Go through poorly performing examples and generate hypotheses for them + # TODO: batched? + new_hypotheses = ( + self.generation_class.batched_hypothesis_generation( + wrong_example_ids, + current_sample, + self.update_hypotheses_per_batch, + self.alpha, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + reference_hypotheses=accumulated_wrong_hyp_samples, + **generate_kwargs, + ) + ) + + # If we only take the best performing hypothesis from the batch + if self.only_best_hypothesis: + best_hypothesis = max( + new_hypotheses, key=lambda x: new_hypotheses[x].reward + ) + new_hyp_bank.update( + {best_hypothesis: new_hypotheses[best_hypothesis]} + ) + else: + new_hyp_bank.update(new_hypotheses) + + # call replace class to update the bank + hypotheses_bank = self.replace_class.replace( + hypotheses_bank, new_hyp_bank + ) + + if generate_count >= redundancy_threshold: + hypotheses_bank = self.generation_class.remove_redundancy( + wrong_example_ids, + current_sample, + hypotheses_bank, + self.alpha, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + + # reset wrong examples to be empty + wrong_example_ids = set() + # reset accumulated wrong hypotheses to be empty + accumulated_wrong_hyp_samples.clear() + + # save hypotheses to json + if (i + 1) % self.save_every_n_examples == 0: + self.save_to_json( + hypotheses_bank, + sample=i + 1, + seed=current_seed, + epoch=current_epoch, + ) + + # Our new bank + return hypotheses_bank diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index 77dae5f..37c4415 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -232,6 +232,15 @@ def batched_error_augmented_generation(self, train_data, num_hypotheses, referen prompt = self._information_prompt(substitute_dict, "batched_error_augmented_generation") return prompt + def remove_redundancy(self, hypotheses_dict): + hypotheses_list = list(hypotheses_dict.keys()) + + substitute_dict = {"hypotheses": "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)])} + + prompt = self._information_prompt(substitute_dict, "remove_redundancy") + + return prompt + def inference(self, hypotheses_dict, test_data, test_idx): """ Create inference prompt. diff --git a/pipeline.py b/pipeline.py index 32b5a9f..44d2f6b 100644 --- a/pipeline.py +++ b/pipeline.py @@ -4,6 +4,7 @@ import os from hypogenic.algorithm.generation.augmented import AugmentedGeneration +from hypogenic.algorithm.update.augmented import AugmentedUpdate from hypogenic.utils import set_seed, get_results from hypogenic.tasks import BaseTask from hypogenic.extract_label import extract_label_register @@ -417,7 +418,7 @@ def augmented_hypogenic(task_name, api, model_name): inference_class = DefaultInference(api, prompt_class, train_data, task) generation_class = AugmentedGeneration(api, prompt_class, inference_class, task) - update_class = DefaultUpdate( + update_class = AugmentedUpdate( generation_class=generation_class, inference_class=inference_class, replace_class=DefaultReplace(max_num_hypotheses), @@ -430,7 +431,6 @@ def augmented_hypogenic(task_name, api, model_name): save_every_n_examples=save_every_10_examples, ) - hypotheses_bank = {} hypotheses_bank = update_class.batched_initialize_hypotheses( num_init, init_batch_size=init_batch_size, @@ -454,6 +454,7 @@ def augmented_hypogenic(task_name, api, model_name): cache_seed=cache_seed, temperature=temperature, max_tokens=max_tokens, + redundancy_threshold=5, max_concurrent=64, ) update_class.save_to_json( From 9e5840fb65b89589decb8ac31c3fedec38627a10 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Mon, 4 Aug 2025 17:24:28 -0500 Subject: [PATCH 16/23] Implement redundancy clearance for the final hypotheses; Add control statement in AugmentedUpdate class for redundancy clearance. --- hypogenic/algorithm/generation/augmented.py | 23 +++++++++++++++++++-- hypogenic/algorithm/generation/base.py | 12 ++++++++++- hypogenic/algorithm/update/augmented.py | 16 ++++++++++++-- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/hypogenic/algorithm/generation/augmented.py b/hypogenic/algorithm/generation/augmented.py index 5dc2a64..3a21099 100644 --- a/hypogenic/algorithm/generation/augmented.py +++ b/hypogenic/algorithm/generation/augmented.py @@ -88,7 +88,7 @@ def batched_hyp_list_generation( all_new_hypos.extend(extract_hypotheses(response, 1)) return all_new_hypos - def remove_redundancy( + def clear_redundancy_update( self, example_ids, current_sample, @@ -113,4 +113,23 @@ def remove_redundancy( cache_seed=cache_seed, max_concurrent=max_concurrent, **generate_kwargs, - ) \ No newline at end of file + ) + + def clear_redundancy_final( + self, + hyp_bank, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + prompt_inputs = [ + self.prompt_class.remove_redundancy(hyp_bank) + ] + responses = self.api.batched_generate( + prompt_inputs, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + new_hyp_list = extract_hypotheses(responses[0]) + return new_hyp_list \ No newline at end of file diff --git a/hypogenic/algorithm/generation/base.py b/hypogenic/algorithm/generation/base.py index 480f0ad..210bb7c 100644 --- a/hypogenic/algorithm/generation/base.py +++ b/hypogenic/algorithm/generation/base.py @@ -206,7 +206,7 @@ def batched_hypothesis_generation( pass @abstractmethod - def remove_redundancy( + def clear_redundancy_update( self, example_ids, current_sample, @@ -215,5 +215,15 @@ def remove_redundancy( cache_seed=None, max_concurrent=3, **generate_kwargs, + ): + pass + + @abstractmethod + def clear_redundancy_final( + self, + hyp_bank, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, ): pass \ No newline at end of file diff --git a/hypogenic/algorithm/update/augmented.py b/hypogenic/algorithm/update/augmented.py index e8f5d24..9bde527 100644 --- a/hypogenic/algorithm/update/augmented.py +++ b/hypogenic/algorithm/update/augmented.py @@ -63,6 +63,8 @@ def update( cache_seed=None, max_concurrent=3, redundancy_threshold = 5, + clear_redundancy_update = False, + clear_redundancy_final = True, **generate_kwargs, ): """ @@ -75,6 +77,8 @@ def update( cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number max_concurrent: The maximum number of concurrent requests redundancy_threshold: The threshold for removing redundancy in the hypotheses bank + clear_redundancy_update: Whether to remove redundancy in the hypotheses bank during the update process + clear_redundancy_final: Whether to remove redundancy in the final hypotheses bank """ logger = LoggerConfig.get_logger(logger_name) @@ -219,8 +223,8 @@ def update( hypotheses_bank, new_hyp_bank ) - if generate_count >= redundancy_threshold: - hypotheses_bank = self.generation_class.remove_redundancy( + if clear_redundancy_update and generate_count >= redundancy_threshold: + hypotheses_bank = self.generation_class.clear_redundancy_update( wrong_example_ids, current_sample, hypotheses_bank, @@ -244,5 +248,13 @@ def update( epoch=current_epoch, ) + if clear_redundancy_final: + hypotheses_bank = self.generation_class.clear_redundancy_final( + hyp_bank=hypotheses_bank, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + # Our new bank return hypotheses_bank From a36e7eb894cc4c545339587d0cb681d169ada18a Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Mon, 4 Aug 2025 17:49:44 -0500 Subject: [PATCH 17/23] Use batched_generate in batched_hyp_list_generation; Use generate in clear_redundancy_final; Rename reference_hypotheses to reference_info; Rename batched_error_augmented_generation to error_augmented_generation --- hypogenic/algorithm/generation/augmented.py | 45 ++++++++++++--------- hypogenic/algorithm/generation/base.py | 1 - hypogenic/algorithm/update/augmented.py | 2 +- hypogenic/prompt.py | 10 ++--- 4 files changed, 31 insertions(+), 27 deletions(-) diff --git a/hypogenic/algorithm/generation/augmented.py b/hypogenic/algorithm/generation/augmented.py index 3a21099..3df3a94 100644 --- a/hypogenic/algorithm/generation/augmented.py +++ b/hypogenic/algorithm/generation/augmented.py @@ -22,7 +22,7 @@ def batched_hypothesis_generation( alpha: float, cache_seed=None, max_concurrent=3, - reference_hypotheses=None,# {hypo: {"correct": set(), "wrong": set()}} + reference_info=None,# {hypo: {"correct": set(), "wrong": set()}} **generate_kwargs, ): """ @@ -35,7 +35,7 @@ def batched_hypothesis_generation( alpha: exploration constant in hypogenic reward function cache_seed: If `None`, will not use cache, otherwise will use cache with corresponding seed number max_concurrent: The maximum number of concurrent requests - reference_hypotheses: A dictionary of reference hypotheses with their associated correct and wrong sets + reference_info: A dictionary of reference hypotheses with their associated correct and wrong sets Returns: hypotheses_bank: A dictionary with keys as hypotheses and the values as the Summary Information class @@ -44,7 +44,7 @@ def batched_hypothesis_generation( example_ids, num_hypotheses_generate, cache_seed=cache_seed, - reference_hypotheses=reference_hypotheses, + reference_info=reference_info, **generate_kwargs, ) @@ -70,22 +70,29 @@ def batched_hyp_list_generation( example_indices: List[int], num_hypotheses_generate: int, cache_seed=None, - reference_hypotheses=None,# {hypo: {"correct": set(), "wrong": set()}} + reference_info=None,# {hypo: {"correct": set(), "wrong": set()}} **generate_kwargs ) -> List[str]: - batch_size = 1 all_new_hypos = [] - hypo_items = list(reference_hypotheses.items()) - total = len(hypo_items) - for i in range(0, total, batch_size): - batch = dict(hypo_items[i:i+batch_size]) - prompt_input = self.prompt_class.batched_error_augmented_generation( - self.train_data, len(batch), batch - ) - response = self.api.generate( - prompt_input, cache_seed=cache_seed, **generate_kwargs + reference_items = list(reference_info.items()) + total = len(reference_items) + + prompt_inputs = [] + for i in range(0, total): + prompt_input = self.prompt_class.error_augmented_generation( + self.train_data, dict(reference_items[i]) ) + prompt_inputs.append(prompt_input) + + responses = self.api.batched_generate( + prompt_inputs, + cache_seed=cache_seed, + **generate_kwargs + ) + + for response in responses: all_new_hypos.extend(extract_hypotheses(response, 1)) + return all_new_hypos def clear_redundancy_update( @@ -122,14 +129,12 @@ def clear_redundancy_final( max_concurrent=3, **generate_kwargs, ): - prompt_inputs = [ - self.prompt_class.remove_redundancy(hyp_bank) - ] - responses = self.api.batched_generate( - prompt_inputs, + prompt_input = self.prompt_class.remove_redundancy(hyp_bank) + responses = self.api.generate( + prompt_input, cache_seed=cache_seed, max_concurrent=max_concurrent, **generate_kwargs, ) - new_hyp_list = extract_hypotheses(responses[0]) + new_hyp_list = extract_hypotheses(responses) return new_hyp_list \ No newline at end of file diff --git a/hypogenic/algorithm/generation/base.py b/hypogenic/algorithm/generation/base.py index 210bb7c..cb353c5 100644 --- a/hypogenic/algorithm/generation/base.py +++ b/hypogenic/algorithm/generation/base.py @@ -200,7 +200,6 @@ def batched_hypothesis_generation( alpha: float, cache_seed=None, max_concurrent=3, - reference_hypotheses=None, **generate_kwargs, ): pass diff --git a/hypogenic/algorithm/update/augmented.py b/hypogenic/algorithm/update/augmented.py index 9bde527..38b2e44 100644 --- a/hypogenic/algorithm/update/augmented.py +++ b/hypogenic/algorithm/update/augmented.py @@ -202,7 +202,7 @@ def update( self.alpha, cache_seed=cache_seed, max_concurrent=max_concurrent, - reference_hypotheses=accumulated_wrong_hyp_samples, + reference_info=accumulated_wrong_hyp_samples, **generate_kwargs, ) ) diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index 37c4415..08228d8 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -187,14 +187,14 @@ def batched_generation(self, train_data, num_hypotheses): return prompt - def batched_error_augmented_generation(self, train_data, num_hypotheses, reference_hypotheses): + def error_augmented_generation(self, train_data, reference_info): """ - reference_hypotheses: {hypo: {"correct": set(), "wrong": set()}} + reference_info: {hypo: {"correct": set(), "wrong": set()}} """ substitute_dict = {} multi_sub_dicts = {"error_augmented_observation": []} - for hypo_idx, (hypothesis, sample_dict) in enumerate(reference_hypotheses.items()): + for hypo_idx, (hypothesis, sample_dict) in enumerate(reference_info.items()): # 处理 correct samples,只取前3个 correct_samples_info = [] correct_ids = list(sample_dict.get("correct", []))[:3] @@ -226,10 +226,10 @@ def batched_error_augmented_generation(self, train_data, num_hypotheses, referen multi_sub_dicts["error_augmented_observation"].append(hyp_info) substitute_dict = self._fill_multi_in_sub_dict( - substitute_dict, multi_sub_dicts, "batched_error_augmented_generation" + substitute_dict, multi_sub_dicts, "error_augmented_generation" ) - prompt = self._information_prompt(substitute_dict, "batched_error_augmented_generation") + prompt = self._information_prompt(substitute_dict, "error_augmented_generation") return prompt def remove_redundancy(self, hypotheses_dict): From 902db004092ce6c035b0b27f9d3ab04ae5ec23d2 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Mon, 4 Aug 2025 18:06:05 -0500 Subject: [PATCH 18/23] =?UTF-8?q?Fix=20some=20bug=20in=20batched=5Fhyp=5Fl?= =?UTF-8?q?ist=5Fgeneration=20of=20AugmentedGeneration=20class=EF=BC=9B=20?= =?UTF-8?q?rename=20some=20parameters?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- hypogenic/algorithm/generation/augmented.py | 4 +++- hypogenic/algorithm/update/augmented.py | 7 ++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/hypogenic/algorithm/generation/augmented.py b/hypogenic/algorithm/generation/augmented.py index 3df3a94..62163c3 100644 --- a/hypogenic/algorithm/generation/augmented.py +++ b/hypogenic/algorithm/generation/augmented.py @@ -79,8 +79,10 @@ def batched_hyp_list_generation( prompt_inputs = [] for i in range(0, total): + hypothesis, sample_dict = reference_items[i] + batch = {hypothesis: sample_dict} prompt_input = self.prompt_class.error_augmented_generation( - self.train_data, dict(reference_items[i]) + self.train_data, batch ) prompt_inputs.append(prompt_input) diff --git a/hypogenic/algorithm/update/augmented.py b/hypogenic/algorithm/update/augmented.py index 38b2e44..fcbf091 100644 --- a/hypogenic/algorithm/update/augmented.py +++ b/hypogenic/algorithm/update/augmented.py @@ -86,7 +86,7 @@ def update( num_train_examples = len(self.train_data) wrong_example_ids = set() accumulated_wrong_hyp_samples = {} # {hypo: {"correct": set(), "wrong": set()}} - generate_count = 0 + generation_count = 0 # ---------------------------------------------------------------------- # Figuring out starting samples @@ -187,7 +187,7 @@ def update( len(wrong_example_ids) == self.update_batch_size * self.num_hypotheses_to_update ): - generate_count += 1 + generation_count += 1 new_hyp_bank = {} # generate new hypotheses @@ -223,7 +223,8 @@ def update( hypotheses_bank, new_hyp_bank ) - if clear_redundancy_update and generate_count >= redundancy_threshold: + if clear_redundancy_update and generation_count >= redundancy_threshold: + generation_count = 0 hypotheses_bank = self.generation_class.clear_redundancy_update( wrong_example_ids, current_sample, From dc74cc2475295783323a07279ebf5d222c04074c Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Tue, 5 Aug 2025 13:45:13 -0500 Subject: [PATCH 19/23] Revert changes to inference and prompt in data_analysis_agent to ensure isolation --- .../data_analysis_agent/inference.py | 81 ------------------- .../data_analysis_agent/prompt.py | 9 --- pipeline.py | 3 +- 3 files changed, 2 insertions(+), 91 deletions(-) diff --git a/hypothesis_agent/data_analysis_agent/inference.py b/hypothesis_agent/data_analysis_agent/inference.py index 5a5f6a6..f8caeeb 100644 --- a/hypothesis_agent/data_analysis_agent/inference.py +++ b/hypothesis_agent/data_analysis_agent/inference.py @@ -169,84 +169,3 @@ def batched_predict( actual_labels = [data[self.task.label_name][index] for index, _ in idx_hyp_pair] return predictions, actual_labels, idx_hyp_pair - -class MultiHypHierarchicalInference(DefaultInference): - def __init__( - self, - api, - prompt_class: TestPrompt, - train_data: pd.DataFrame, - task: BaseTask, - ): - super().__init__(api, prompt_class, train_data, task) - - def create_stump_from_hypotheses( - self, - hyp_bank, - cache_seed=None, - max_concurrent=3, - **generate_kwargs, - ): - prompt_inputs = [ - self.prompt_class.create_stump_from_hypotheses(hyp_bank) - ] - responses = self.api.batched_generate( - prompt_inputs, - cache_seed=cache_seed, - max_concurrent=max_concurrent, - **generate_kwargs, - ) - responses = re.sub(r'.*?', '', responses[0], flags=re.IGNORECASE | re.DOTALL) - print(responses) - return {responses:SummaryInformation} - - def multiple_hypotheses_batched_predict( - self, - data: pd.DataFrame, - idx_hyp_pair=List[Tuple[int, Dict[str, SummaryInformation]]], - cache_seed=None, - max_concurrent=3, - **generate_kwargs, - ): - prompt_inputs = [ - self.prompt_class.multiple_hypotheses_inference(hyp_bank, data, index) - for index, hyp_bank in idx_hyp_pair - ] - responses = self.api.batched_generate( - prompt_inputs, - cache_seed=cache_seed, - max_concurrent=max_concurrent, - **generate_kwargs, - ) - actual_labels = [data[self.task.label_name][index] for index, _ in idx_hyp_pair] - predictions = [self.task.extract_label(responses[i]) for i in range(len(responses))] - - return predictions, actual_labels - - def run_inference_final( - self, - data, - hyp_bank, - cache_seed=None, - max_concurrent=3, - **generate_kwargs, - ): - stump_hyp = self.create_stump_from_hypotheses( - hyp_bank, - cache_seed=cache_seed, - max_concurrent=max_concurrent, - **generate_kwargs, - ) - - num_samples = len(data) - - return self.multiple_hypotheses_batched_predict( - data, - [ - (i, stump_hyp) - for i in range(num_samples) - ], - cache_seed=cache_seed, - max_concurrent=max_concurrent, - **generate_kwargs, - ) \ No newline at end of file diff --git a/hypothesis_agent/data_analysis_agent/prompt.py b/hypothesis_agent/data_analysis_agent/prompt.py index a047b56..e607a85 100644 --- a/hypothesis_agent/data_analysis_agent/prompt.py +++ b/hypothesis_agent/data_analysis_agent/prompt.py @@ -197,15 +197,6 @@ def multiple_hypotheses_inference(self, hypotheses_dict, test_data, test_idx): return prompt - def create_stump_from_hypotheses(self, hypotheses_dict): - hypotheses_list = list(hypotheses_dict.keys()) - - substitute_dict= {"hypotheses": "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)])} - - prompt = self._information_prompt(substitute_dict, "create_stump_from_hypotheses") - - return prompt - def test_autogen( self, train_data: pd.DataFrame, num_hypotheses, paper_infos: List[Dict[str, str]] ): diff --git a/pipeline.py b/pipeline.py index 44d2f6b..fc8f2e9 100644 --- a/pipeline.py +++ b/pipeline.py @@ -31,7 +31,8 @@ OnlyPaperGeneration, ZeroShotGeneration, ) -from hypothesis_agent.data_analysis_agent.inference import MultiHypDefaultInference, MultiHypHierarchicalInference +from hypothesis_agent.data_analysis_agent.hierarchical_inference import MultiHypHierarchicalInference +from hypothesis_agent.data_analysis_agent.inference import MultiHypDefaultInference from hypothesis_agent.data_analysis_agent.update import TestUpdate from hypothesis_agent.literature_review_agent import LiteratureAgent from hypothesis_agent.literature_review_agent.literature_processor.extract_info import ( From 26f290189fccc1de9d89c89e6aa2e9816c839e57 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Thu, 7 Aug 2025 13:22:34 -0500 Subject: [PATCH 20/23] Implement stump-based hierarchical inference --- hypogenic/prompt.py | 31 +++ .../hierarchical_inference.py | 213 ++++++++++++++++++ 2 files changed, 244 insertions(+) create mode 100644 hypothesis_agent/data_analysis_agent/hierarchical_inference.py diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index 08228d8..fcb225f 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -347,3 +347,34 @@ def is_relevant(self, hypotheses_dict, test_data, test_idx): prompt = self._information_prompt(substitute_dict, "is_relevant") return prompt + + def create_stump_from_hypotheses(self, hypotheses_dict): + hypotheses_list = list(hypotheses_dict.keys()) + + substitute_dict= {"hypotheses": "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)])} + + prompt = self._information_prompt(substitute_dict, "create_stump_from_hypotheses") + + return prompt + + def determine_group(self, group_conditions, test_data, test_idx): + """ + Create prompt to determine which group a sample belongs to. + """ + substitute_dict = self._get_substitute_dict(test_data, test_idx) + substitute_dict["group_conditions"] = group_conditions + + prompt = self._information_prompt(substitute_dict, "determine_group") + return prompt + + def hierarchical_inference(self, hypotheses_dict, test_data, test_idx, group_condition): + hypotheses_list = list(hypotheses_dict.keys()) + + substitute_dict = self._get_substitute_dict(test_data, test_idx) + substitute_dict["hypotheses"] = "\n".join([f"{idx+1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)]) + + substitute_dict["group_condition"] = group_condition + + prompt = self._information_prompt(substitute_dict, "hierarchical_inference") + + return prompt \ No newline at end of file diff --git a/hypothesis_agent/data_analysis_agent/hierarchical_inference.py b/hypothesis_agent/data_analysis_agent/hierarchical_inference.py new file mode 100644 index 0000000..945eae5 --- /dev/null +++ b/hypothesis_agent/data_analysis_agent/hierarchical_inference.py @@ -0,0 +1,213 @@ +import re + +from hypogenic.algorithm.inference import DefaultInference +from hypogenic.logger_config import LoggerConfig +from hypogenic.tasks import BaseTask +from hypogenic.algorithm.summary_information import ( + SummaryInformation, +) + +from hypothesis_agent.data_analysis_agent.prompt import TestPrompt +import pandas as pd + + +class MultiHypHierarchicalInference(DefaultInference): + def __init__( + self, + api, + prompt_class: TestPrompt, + train_data: pd.DataFrame, + task: BaseTask, + ): + super().__init__(api, prompt_class, train_data, task) + + @staticmethod + def _parse_stump_response(response): + """ + Parse the stump response to extract groups and their hypotheses. + Returns a dictionary mapping group numbers to their hypotheses. + """ + # Remove think tags if present + response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) + + # Parse the response to extract groups and their conditions/hypotheses + groups = {} + current_group = None + + lines = response.strip().split('\n') + for line in lines: + line = line.strip() + if not line: + continue + + # Look for group indicators (e.g., "Group 1:", "1.", etc.) + group_match = re.match(r'^(?:Group\s*)?(\d+)[:.]?\s*(.*)', line, re.IGNORECASE) + if group_match: + current_group = int(group_match.group(1)) + current_condition = group_match.group(2).strip() + if current_group not in groups: + groups[current_group] = { + 'condition': current_condition, + 'hypotheses': [] + } + elif current_group is not None and line.startswith('-') or line.startswith('•'): + # This is a hypothesis under the current group + hypothesis = line.lstrip('- ').lstrip('• ').strip() + if hypothesis: + groups[current_group]['hypotheses'].append(hypothesis) + elif current_group is not None and re.match(r'^\d+\.', line): + # This is a numbered hypothesis under the current group + hypothesis = re.sub(r'^\d+\.\s*', '', line).strip() + if hypothesis: + groups[current_group]['hypotheses'].append(hypothesis) + + return groups + + def _create_stump_from_hypotheses( + self, + hyp_bank, + cache_seed=None, + **generate_kwargs, + ): + logger = LoggerConfig.get_logger("StumpCreation") + prompt_input = self.prompt_class.create_stump_from_hypotheses(hyp_bank) + response = self.api.generate( + prompt_input, + cache_seed=cache_seed, + **generate_kwargs, + ) + logger.info (f"Stump response: {response}") + groups = self._parse_stump_response(response) + + return groups + + def _determine_sample_groups_batched( + self, + data: pd.DataFrame, + groups, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + """ + Determine which group each sample belongs to in batch. + """ + # Format group conditions for the prompt + group_conditions = [] + for group_num, group_info in groups.items(): + condition = group_info['condition'] + group_conditions.append(f"Group {group_num}: {condition}") + group_conditions_text = "\n".join(group_conditions) + + # Create prompts for all samples + prompt_inputs = [] + for idx in range(len(data)): + prompt_input = self.prompt_class.determine_group(group_conditions_text, data, idx) + prompt_inputs.append(prompt_input) + + # Batch generate responses + responses = self.api.batched_generate( + prompt_inputs, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + # Extract group numbers from responses + group_nums = [] + for response in responses: + response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) + response = response.strip() + try: + group_num = int(re.search(r'\d+', response).group()) + group_nums.append(group_num) + except (ValueError, AttributeError): + # Default to group 1 if parsing fails + print(f"Failed to parse group number from response: {response}, defaulting to group 1") + group_nums.append(1) + + return group_nums + + def _stump_batched_predict( + self, + data: pd.DataFrame, + groups, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + """ + Make predictions using group-specific hypotheses with batched processing. + """ + # First, determine groups for all samples in batch + group_nums = self._determine_sample_groups_batched( + data, groups, cache_seed, max_concurrent, **generate_kwargs + ) + + # Create prompts for all samples with their respective group hypotheses + prompt_inputs = [] + sample_indices = [] + + for idx in range(len(data)): + group_num = group_nums[idx] + + if group_num in groups: + group_hypotheses = groups[group_num]['hypotheses'] + group_condition = groups[group_num]['condition'] + # Convert to the format expected by multiple_hypotheses_inference + # Create a dictionary mapping hypothesis text to SummaryInformation + hyp_dict = {} + for i, hypothesis_text in enumerate(group_hypotheses): + hyp_dict[hypothesis_text] = SummaryInformation() + + # Create prompt for this sample with group-specific hypotheses + prompt_input = self.prompt_class.hierarchical_inference( + hyp_dict, data, idx, group_condition + ) + prompt_inputs.append(prompt_input) + sample_indices.append(idx) + else: + raise Exception(f"Group {group_num} not found in groups.") + + # Batch generate predictions + responses = self.api.batched_generate( + prompt_inputs, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + + # Extract predictions and actual labels + predictions = [] + actual_labels = [] + + for i, response in enumerate(responses): + response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) + prediction = self.task.extract_label(response) + predictions.append(prediction) + actual_labels.append(data[self.task.label_name][sample_indices[i]]) + + return predictions, actual_labels + + def run_inference_final( + self, + data, + hyp_bank, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + # Create stump (groups) from hypotheses + groups = self._create_stump_from_hypotheses( + hyp_bank, + cache_seed=cache_seed, + **generate_kwargs, + ) + + # Run inference using group-specific hypotheses + return self._stump_batched_predict( + data, + groups, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) \ No newline at end of file From 31918a4ce1c62b2cc2b9d80bfda1f2be52eac124 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Tue, 12 Aug 2025 21:22:54 -0500 Subject: [PATCH 21/23] Implement tree based inference --- hypogenic/prompt.py | 101 ++- .../hierarchical_inference.py | 805 +++++++++++++++--- 2 files changed, 787 insertions(+), 119 deletions(-) diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index fcb225f..8a39706 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -195,7 +195,6 @@ def error_augmented_generation(self, train_data, reference_info): multi_sub_dicts = {"error_augmented_observation": []} for hypo_idx, (hypothesis, sample_dict) in enumerate(reference_info.items()): - # 处理 correct samples,只取前3个 correct_samples_info = [] correct_ids = list(sample_dict.get("correct", []))[:3] for idx, sample_id in enumerate(correct_ids): @@ -207,7 +206,6 @@ def error_augmented_generation(self, train_data, reference_info): self._get_prompt_template("correct_samples") ) - # 处理 wrong samples wrong_samples_info = [] for idx, sample_id in enumerate(sample_dict.get("wrong", [])): sample_data = self._get_substitute_dict(train_data, sample_id) @@ -348,33 +346,106 @@ def is_relevant(self, hypotheses_dict, test_data, test_idx): return prompt - def create_stump_from_hypotheses(self, hypotheses_dict): - hypotheses_list = list(hypotheses_dict.keys()) + def tree_split(self, hypotheses_dict, current_path): + if isinstance(hypotheses_dict, (list, tuple, set)): + hypotheses_list = list(hypotheses_dict) + else: + hypotheses_list = list(hypotheses_dict.keys()) + + substitute_dict = {"hypotheses": "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)])} + + if current_path: + substitute_dict["current_path"] = self._format_condition_path(current_path) + else: + substitute_dict["current_path"] = "Root" + + prompt = self._information_prompt(substitute_dict, "tree_split") + + return prompt + + def tree_split_decision(self, hypotheses_dict, current_depth, current_path): + """ + Prompt for LLM to decide whether to continue splitting or create a leaf node + """ + if isinstance(hypotheses_dict, (list, tuple, set)): + hypotheses_list = list(hypotheses_dict) + else: + hypotheses_list = list(hypotheses_dict.keys()) + + substitute_dict = { + "hypotheses": "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)]), + "current_depth": current_depth + } - substitute_dict= {"hypotheses": "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)])} + if current_path: + substitute_dict["current_path"] = self._format_condition_path(current_path) + else: + substitute_dict["current_path"] = "Root" - prompt = self._information_prompt(substitute_dict, "create_stump_from_hypotheses") + prompt = self._information_prompt(substitute_dict, "tree_split_decision") return prompt - def determine_group(self, group_conditions, test_data, test_idx): + def internal_inference(self, group_conditions, test_data, test_idx): """ Create prompt to determine which group a sample belongs to. """ substitute_dict = self._get_substitute_dict(test_data, test_idx) substitute_dict["group_conditions"] = group_conditions - prompt = self._information_prompt(substitute_dict, "determine_group") + prompt = self._information_prompt(substitute_dict, "internal_inference") return prompt - def hierarchical_inference(self, hypotheses_dict, test_data, test_idx, group_condition): - hypotheses_list = list(hypotheses_dict.keys()) - - substitute_dict = self._get_substitute_dict(test_data, test_idx) + def multiple_hypotheses_inference_with_path(self, hyp_dict, sample_data, param, tree_path=None, condition_path=None): + """ + Enhanced inference prompt that includes tree path context + """ + hypotheses_list = list(hyp_dict.keys()) + substitute_dict = self._get_substitute_dict(sample_data, 0) substitute_dict["hypotheses"] = "\n".join([f"{idx+1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)]) - substitute_dict["group_condition"] = group_condition + if condition_path: + substitute_dict["current_path"] = self._format_condition_path(condition_path) + else: + substitute_dict["current_path"] = "Root" + + prompt = self._information_prompt(substitute_dict, "multiple_hypotheses_inference_with_path") + return prompt + + def tree_refinement(self, original_hypotheses, tree_structure, validation_analysis): + if isinstance(original_hypotheses, (list, tuple, set)): + hypotheses_list = list(original_hypotheses) + else: + hypotheses_list = list(original_hypotheses.keys()) + + substitute_dict = { + 'original_hypotheses': "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)]), + 'tree_structure': tree_structure, + 'validation_analysis': validation_analysis + } + + return self._information_prompt(substitute_dict, "tree_refinement") - prompt = self._information_prompt(substitute_dict, "hierarchical_inference") + def tree_validation(self, original_hypotheses, tree_structure): + if isinstance(original_hypotheses, (list, tuple, set)): + hypotheses_list = list(original_hypotheses) + else: + hypotheses_list = list(original_hypotheses.keys()) + + substitute_dict = { + 'original_hypotheses': "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)]), + 'tree_structure': tree_structure + } + + return self._information_prompt(substitute_dict, "tree_validation") - return prompt \ No newline at end of file + @staticmethod + def _format_condition_path(condition_path): + if not condition_path: + return "Root" + + path_parts = [] + for i, condition_name in enumerate(condition_path): + path_parts.append(f"Condition {i+1}: {condition_name}") + + return " → ".join(path_parts) \ No newline at end of file diff --git a/hypothesis_agent/data_analysis_agent/hierarchical_inference.py b/hypothesis_agent/data_analysis_agent/hierarchical_inference.py index 945eae5..a8587d3 100644 --- a/hypothesis_agent/data_analysis_agent/hierarchical_inference.py +++ b/hypothesis_agent/data_analysis_agent/hierarchical_inference.py @@ -11,6 +11,15 @@ import pandas as pd +class TreeNode: + def __init__(self, groups=None, hypotheses=None, children=None, is_leaf=False, path=None): + self.groups = groups or {} # {group_num: {condition, hypotheses, examples}} + self.hypotheses = hypotheses or [] + self.children = children or {} # {index} + self.is_leaf = is_leaf + self.path = path or [] # [condition1, condition2, ...] + + class MultiHypHierarchicalInference(DefaultInference): def __init__( self, @@ -21,18 +30,27 @@ def __init__( ): super().__init__(api, prompt_class, train_data, task) + + # -------------------------------- + # Tree create + # -------------------------------- @staticmethod - def _parse_stump_response(response): - """ - Parse the stump response to extract groups and their hypotheses. - Returns a dictionary mapping group numbers to their hypotheses. - """ - # Remove think tags if present + def _parse_split_decision(response): response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) + response = response.strip().upper() + if any(keyword in response for keyword in ["NO_SPLIT"]): + return "NO_SPLIT" + elif any(keyword in response for keyword in ["SPLIT"]): + return "SPLIT" + else: + raise ValueError(f"Could not parse split decision from response: {response}") - # Parse the response to extract groups and their conditions/hypotheses + @staticmethod + def _parse_tree_response(response): + response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) groups = {} current_group = None + current_section = None # 'condition', 'hypotheses', or 'examples' lines = response.strip().split('\n') for line in lines: @@ -40,174 +58,753 @@ def _parse_stump_response(response): if not line: continue - # Look for group indicators (e.g., "Group 1:", "1.", etc.) - group_match = re.match(r'^(?:Group\s*)?(\d+)[:.]?\s*(.*)', line, re.IGNORECASE) + group_match = re.match(r'^Group\s*(\d+)[:.]?\s*(.*)', line, re.IGNORECASE) if group_match: current_group = int(group_match.group(1)) current_condition = group_match.group(2).strip() if current_group not in groups: groups[current_group] = { 'condition': current_condition, - 'hypotheses': [] + 'hypotheses': [], + 'examples': [] } - elif current_group is not None and line.startswith('-') or line.startswith('•'): - # This is a hypothesis under the current group - hypothesis = line.lstrip('- ').lstrip('• ').strip() + current_section = 'condition' + continue + + if re.match(r'^Examples?[:.]?\s*$', line, re.IGNORECASE): + if current_group is not None: + current_section = 'examples' + continue + + if current_group is not None and re.match(r'^\d+\.', line): + current_section = 'hypotheses' + hypothesis = re.sub(r'^\d+\.\s*', '', line).strip() if hypothesis: groups[current_group]['hypotheses'].append(hypothesis) - elif current_group is not None and re.match(r'^\d+\.', line): - # This is a numbered hypothesis under the current group - hypothesis = re.sub(r'^\d+\.\s*', '', line).strip() + continue + + if current_group is not None and current_section == 'examples' and re.match(r'^\d+\.', line): + example = re.sub(r'^\d+\.\s*', '', line).strip() + if example: + groups[current_group]['examples'].append(example) + continue + + if current_group is not None and current_section == 'hypotheses' and (line.startswith('-') or line.startswith('•')): + hypothesis = line.lstrip('- ').lstrip('• ').strip() if hypothesis: groups[current_group]['hypotheses'].append(hypothesis) + elif current_group is not None and current_section == 'examples' and (line.startswith('-') or line.startswith('•')): + example = line.lstrip('- ').lstrip('• ').strip() + if example: + groups[current_group]['examples'].append(example) return groups - def _create_stump_from_hypotheses( + def _tree_split( self, hyp_bank, cache_seed=None, + max_depth=5, + current_depth=0, + current_path=None, **generate_kwargs, ): - logger = LoggerConfig.get_logger("StumpCreation") - prompt_input = self.prompt_class.create_stump_from_hypotheses(hyp_bank) + logger = LoggerConfig.get_logger("TreeSplit") + if current_depth >= max_depth: + logger.info(f"Reached max depth {max_depth}, creating leaf node") + return TreeNode(is_leaf=True, hypotheses=hyp_bank, path=current_path or []) + + logger.info(f"=== Tree Split Decision at Depth {current_depth} ===") + logger.info(f"Current path: {current_path or []}") + prompt_input = self.prompt_class.tree_split_decision(hypotheses_dict=hyp_bank, current_depth=current_depth, current_path=current_path) response = self.api.generate( prompt_input, cache_seed=cache_seed, **generate_kwargs, ) - logger.info (f"Stump response: {response}") - groups = self._parse_stump_response(response) + logger.info(f"--- LLM Output for Split Decision ---") + logger.info(f"Response: {response}") + split_decision = self._parse_split_decision(response) - return groups + if split_decision == "NO_SPLIT": + logger.info(f"LLM decided not to split at depth {current_depth}, creating leaf node") + return TreeNode(is_leaf=True, hypotheses=hyp_bank, path=current_path or []) + elif split_decision == "SPLIT": + logger.info(f"LLM decided to split at depth {current_depth}, generating groups") + split_prompt_input = self.prompt_class.tree_split(hypotheses_dict=hyp_bank, current_path=current_path) + split_response = self.api.generate( + split_prompt_input, + cache_seed=cache_seed, + **generate_kwargs, + ) + logger.info(f"--- LLM Output for Tree Split ---") + logger.info(f"Response: {split_response}") + groups = self._parse_tree_response(split_response) - def _determine_sample_groups_batched( - self, - data: pd.DataFrame, - groups, - cache_seed=None, - max_concurrent=3, + node = TreeNode() + node.groups = groups + node.hypotheses = hyp_bank + node.path = current_path or [] + + for group_num, group_info in groups.items(): + child_path = (current_path or []) + [group_info['condition']] + logger.info(f"--- Recursing to Group {group_num} ---") + logger.info(f"Child condition path: {child_path}") + child_node = self._tree_split( + group_info['hypotheses'], + cache_seed=cache_seed, + max_depth=max_depth, + current_depth=current_depth + 1, + current_path=child_path, + **generate_kwargs + ) + node.children[group_num] = child_node + return node + else: + logger.warning(f"Could not parse split decision at depth {current_depth}, defaulting to leaf node") + return TreeNode(is_leaf=True, hypotheses=hyp_bank, path=current_path or []) + + # -------------------------------- + # Tree update (validation and refinement) + # -------------------------------- + def _visualize_tree(self, node, depth=0, prefix=""): + if node.is_leaf: + # Leaf node - show hypotheses + result = f"{prefix}LEAF: {len(node.hypotheses)} hypotheses\n" + for i, hyp in enumerate(node.hypotheses): + if i == len(node.hypotheses) - 1: + result += f"{prefix} └── {hyp}\n" + else: + result += f"{prefix} ├── {hyp}\n" + return result + else: + # Internal node - show groups and recurse + result = f"{prefix}NODE: {len(node.groups)} groups, {len(node.hypotheses)} hypotheses\n" + + group_items = list(node.groups.items()) + for i, (group_num, group_info) in enumerate(group_items): + is_last_group = (i == len(group_items) - 1) + + # Group header + if is_last_group: + result += f"{prefix} └── Group {group_num}: {group_info['condition']}\n" + else: + result += f"{prefix} ├── Group {group_num}: {group_info['condition']}\n" + + # Group details + if group_info.get('hypotheses'): + hyp_items = group_info['hypotheses'] + for j, hyp in enumerate(hyp_items): + is_last_hyp = (j == len(hyp_items) - 1) + if is_last_group: + if is_last_hyp: + result += f"{prefix} └── {hyp}\n" + else: + result += f"{prefix} ├── {hyp}\n" + else: + if is_last_hyp: + result += f"{prefix} └── {hyp}\n" + else: + result += f"{prefix} ├── {hyp}\n" + + if group_info.get('examples'): + example_items = group_info['examples'] + for j, example in enumerate(example_items): + is_last_example = (j == len(example_items) - 1) + if is_last_group: + if is_last_example: + result += f"{prefix} └── Example: {example}\n" + else: + result += f"{prefix} ├── Example: {example}\n" + else: + if is_last_example: + result += f"{prefix} └── Example: {example}\n" + else: + result += f"{prefix} ├── Example: {example}\n" + + # Recurse to children + if group_num in node.children: + child_prefix = f"{prefix} " if is_last_group else f"{prefix} " + result += self._visualize_tree(node.children[group_num], depth + 1, child_prefix) + + return result + + def _tree_validation(self, tree_root, original_hypotheses, cache_seed=None, **generate_kwargs): + logger = LoggerConfig.get_logger("TreeValidation") + logger.info("=== Tree Validation Phase ===") + tree_structure = self._visualize_tree(tree_root) + prompt_input = self.prompt_class.tree_validation(original_hypotheses=original_hypotheses, + tree_structure=tree_structure) + validation_response = self.api.generate( + prompt_input, + cache_seed=cache_seed, **generate_kwargs, - ): - """ - Determine which group each sample belongs to in batch. - """ - # Format group conditions for the prompt - group_conditions = [] - for group_num, group_info in groups.items(): - condition = group_info['condition'] - group_conditions.append(f"Group {group_num}: {condition}") - group_conditions_text = "\n".join(group_conditions) - - # Create prompts for all samples - prompt_inputs = [] - for idx in range(len(data)): - prompt_input = self.prompt_class.determine_group(group_conditions_text, data, idx) - prompt_inputs.append(prompt_input) + ) + logger.info("--- LLM Output for Tree Validation ---") + logger.info(f"Validation response: {validation_response}") + validation_passed = self._check_validation_result(validation_response) + logger.info(f"Validation passed: {validation_passed}") + if not validation_passed: + logger.warning(f"Validation failed.") + return validation_response, validation_passed - # Batch generate responses - responses = self.api.batched_generate( - prompt_inputs, + @staticmethod + def _check_validation_result(validation_response): + response_lower = validation_response.lower() + pass_keywords = ['VALID'] + fail_keywords = ['INVALID'] + if any(keyword in response_lower for keyword in fail_keywords): + return False + if any(keyword in response_lower for keyword in pass_keywords): + return True + return False + + def _tree_refinement(self, tree_root, original_hypotheses, validation_analysis, cache_seed=None, **generate_kwargs): + logger = LoggerConfig.get_logger("TreeRefinement") + logger.info("=== Tree Refinement Phase ===") + tree_structure = self._visualize_tree(tree_root) + prompt_input = self.prompt_class.tree_refinement(original_hypotheses=original_hypotheses, + tree_structure=tree_structure, + validation_analysis=validation_analysis) + refinement_response = self.api.generate( + prompt_input, cache_seed=cache_seed, - max_concurrent=max_concurrent, **generate_kwargs, ) - # Extract group numbers from responses - group_nums = [] - for response in responses: - response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) - response = response.strip() - try: - group_num = int(re.search(r'\d+', response).group()) - group_nums.append(group_num) - except (ValueError, AttributeError): - # Default to group 1 if parsing fails - print(f"Failed to parse group number from response: {response}, defaulting to group 1") - group_nums.append(1) + logger.info("--- LLM Output for Tree Refinement ---") + logger.info(f"Refinement response: {refinement_response}") + logger.info("--- Parsing Improved Tree Structure ---") + improved_tree_root = self._parse_and_build_improved_tree(refinement_response, original_hypotheses) + return improved_tree_root + + def _parse_and_build_improved_tree(self, refinement_response, original_hypotheses): + logger = LoggerConfig.get_logger("TreeRefinement") + try: + logger.info("--- Parsing Complete Tree Structure ---") + tree_structure = self._parse_complete_tree_structure(refinement_response) + if not tree_structure: + logger.warning("Could not parse complete tree structure from refinement response") + return None + logger.info("--- Building Complete Improved Tree Structure ---") + improved_root = self._build_tree_from_structure(tree_structure, original_hypotheses, []) + if improved_root: + logger.info("=== Successfully Built Complete Improved Tree Structure ===") + return improved_root + else: + logger.error("Failed to build tree from parsed structure") + return None + except Exception as e: + logger.error(f"Error building improved tree: {e}") + return None + + def _parse_complete_tree_structure(self, refinement_response): + response = re.sub(r'.*?', '', refinement_response, flags=re.IGNORECASE | re.DOTALL) + lines = response.strip().split('\n') + # Parse the tree structure recursively + tree_structure = self._parse_nested_groups(lines, 0, len(lines)) + return tree_structure + + def _parse_nested_groups(self, lines, start_idx, end_idx): + groups = {} + i = start_idx + + while i < end_idx: + line = lines[i].strip() + + if not line: + i += 1 + continue + + # Check for group start - handle both formats: + # 1. "├── Group X: condition" (indented format) + # 2. "Group X: condition" (direct format) + group_match = None + + # Try indented format first + indented_match = re.match(r'^[├└]──\s*Group\s*(\d+):\s*(.*)', line, re.IGNORECASE) + if indented_match: + group_match = indented_match + else: + # Try direct format + direct_match = re.match(r'^Group\s*(\d+):\s*(.*)', line, re.IGNORECASE) + if direct_match: + group_match = direct_match + + if group_match: + group_num = int(group_match.group(1)) + condition = group_match.group(2).strip() + + # Initialize new group + groups[group_num] = { + 'condition': condition, + 'hypotheses': [], + 'examples': [], + 'subgroups': {}, + 'is_leaf': True # Default to leaf, will be updated if subgroups found + } + + # Look ahead to see if this group has subgroups + j = i + 1 + nested_start = None + + while j < end_idx: + next_line = lines[j].strip() + if not next_line: + j += 1 + continue + + # Check if next line is a nested group (more indented) + # Look for deeper indentation patterns + if (next_line.startswith('├──') or next_line.startswith('└──')) and 'Group' in next_line: + # This is a nested group + nested_start = j + break + elif re.match(r'^Group\s*\d+:', next_line, re.IGNORECASE): + # This is a sibling group at same level + break + elif re.match(r'^[├└]──\s*Group\s*\d+:', next_line, re.IGNORECASE): + # This is a sibling group at same level (indented format) + break + j += 1 + + # If we found nested groups, parse them + if nested_start is not None: + # Find the end of nested structure + nested_end = self._find_nested_end(lines, nested_start, end_idx) + + # Parse nested subgroups + nested_groups = self._parse_nested_groups(lines, nested_start, nested_end) + groups[group_num]['subgroups'] = nested_groups + groups[group_num]['is_leaf'] = False - return group_nums + # Skip to end of nested structure + i = nested_end + else: + # This is a leaf group, look for hypotheses/examples + j = i + 1 + while j < end_idx: + next_line = lines[j].strip() + if not next_line: + j += 1 + continue - def _stump_batched_predict( + # Check if we've reached the next group at same level + if re.match(r'^Group\s*\d+:', next_line, re.IGNORECASE): + break + if re.match(r'^[├└]──\s*Group\s*\d+:', next_line, re.IGNORECASE): + break + + # Extract content from indented lines + if next_line.startswith('├──') or next_line.startswith('└──'): + # Extract hypothesis from indented line + content = re.sub(r'^[├└]──\s*', '', next_line).strip() + if content: + groups[group_num]['hypotheses'].append(content) + elif next_line.startswith('-') or next_line.startswith('•'): + # Bullet points + content = next_line.lstrip('- ').lstrip('• ').strip() + if content: + groups[group_num]['hypotheses'].append(content) + elif next_line.startswith('"') and next_line.endswith('"'): + # Quoted examples + content = next_line.strip('"') + if content: + groups[group_num]['examples'].append(content) + + j += 1 + + i = j - 1 # Adjust for the loop increment + + i += 1 + + return groups + + @staticmethod + def _find_nested_end(lines, start_idx, end_idx): + i = start_idx + 1 + + while i < end_idx: + line = lines[i].strip() + if not line: + i += 1 + continue + + # Check if we've reached the end of nested structure + # Look for next group at same level (either direct or indented) + if re.match(r'^Group\s*\d+:', line, re.IGNORECASE): + # Found next group at same level, end of nested structure + return i + elif re.match(r'^[├└]──\s*Group\s*\d+:', line, re.IGNORECASE): + # Found next group at same level (indented format), end of nested structure + return i + + # Check for other structural markers that might indicate end + if line.startswith('###') or line.startswith('##'): + # Found section header, end of nested structure + return i + + # Check for deeper nesting - if we see more indentation, continue + if line.startswith('├──') or line.startswith('└──'): + # This is still part of nested structure + pass + + i += 1 + + return end_idx + + @staticmethod + def _build_tree_from_structure(tree_structure, original_hypotheses, current_path): + logger = LoggerConfig.get_logger("TreeRefinement") + + if not tree_structure: + return None + + root = TreeNode() + root.groups = tree_structure + root.hypotheses = original_hypotheses + root.path = current_path + root.children = {} + + for group_num, group_info in tree_structure.items(): + if group_info.get('is_leaf', False): + # Create leaf node + leaf_node = TreeNode( + is_leaf=True, + hypotheses=group_info['hypotheses'], + path=current_path + [group_info['condition']] + ) + root.children[group_num] = leaf_node + logger.info(f"Created leaf node for Group {group_num}") + else: + # Create internal node with subgroups + if group_info.get('subgroups'): + internal_node = TreeNode() + internal_node.groups = group_info['subgroups'] + internal_node.hypotheses = group_info['hypotheses'] + internal_node.path = current_path + [group_info['condition']] + internal_node.children = {} + + # Recursively build subtree for subgroups + for subgroup_num, subgroup_info in group_info['subgroups'].items(): + if subgroup_info.get('is_leaf', False): + subgroup_node = TreeNode( + is_leaf=True, + hypotheses=subgroup_info['hypotheses'], + path=internal_node.path + [subgroup_info['condition']] + ) + internal_node.children[subgroup_num] = subgroup_node + else: + # Handle deeper nesting if needed + logger.warning(f"Deep nesting detected in subgroup {subgroup_num}, treating as leaf") + subgroup_node = TreeNode( + is_leaf=True, + hypotheses=subgroup_info.get('hypotheses', []), + path=internal_node.path + [subgroup_info['condition']] + ) + internal_node.children[subgroup_num] = subgroup_node + + root.children[group_num] = internal_node + logger.info( + f"Created internal node for Group {group_num} with {len(group_info['subgroups'])} subgroups") + else: + # Fallback: treat as leaf node + leaf_node = TreeNode( + is_leaf=True, + hypotheses=group_info['hypotheses'], + path=current_path + [group_info['condition']] + ) + root.children[group_num] = leaf_node + logger.info(f"Created fallback leaf node for Group {group_num}") + + return root + + # -------------------------------- + # Tree inference + # -------------------------------- + def _tree_batched_predict( self, data: pd.DataFrame, - groups, + tree_root: TreeNode, cache_seed=None, max_concurrent=3, **generate_kwargs, ): - """ - Make predictions using group-specific hypotheses with batched processing. - """ - # First, determine groups for all samples in batch - group_nums = self._determine_sample_groups_batched( - data, groups, cache_seed, max_concurrent, **generate_kwargs + logger = LoggerConfig.get_logger("TreeBatchedPredict") + sample_paths = self._determine_paths_batched( + data, tree_root, cache_seed, max_concurrent, **generate_kwargs ) - # Create prompts for all samples with their respective group hypotheses - prompt_inputs = [] - sample_indices = [] + path_groups = {} + for idx, path in enumerate(sample_paths): + path_key = tuple(path) + if path_key not in path_groups: + path_groups[path_key] = [] + path_groups[path_key].append(idx) + + all_predictions = [None] * len(data) + all_actual_labels = [None] * len(data) + + logger.info("=== Leaf Node Analysis ===") + for path, indices in path_groups.items(): + if path: + leaf_node = self._get_leaf_node_by_path(tree_root, path) + if leaf_node and leaf_node.is_leaf: + leaf_data = data.iloc[indices].reset_index(drop=True) + + batch_predictions = self._leaf_batch_predict( + leaf_data, leaf_node, cache_seed, max_concurrent, **generate_kwargs + ) + logger.info(f"Predictions: {batch_predictions}") + + for i, pred in enumerate(batch_predictions): + all_predictions[indices[i]] = pred + all_actual_labels[indices[i]] = data[self.task.label_name][indices[i]] + else: + raise RuntimeError(f"Resolved path {path} did not lead to a valid leaf node.") + else: + raise RuntimeError("Empty path encountered for some samples.") - for idx in range(len(data)): - group_num = group_nums[idx] - - if group_num in groups: - group_hypotheses = groups[group_num]['hypotheses'] - group_condition = groups[group_num]['condition'] - # Convert to the format expected by multiple_hypotheses_inference - # Create a dictionary mapping hypothesis text to SummaryInformation - hyp_dict = {} - for i, hypothesis_text in enumerate(group_hypotheses): - hyp_dict[hypothesis_text] = SummaryInformation() - - # Create prompt for this sample with group-specific hypotheses - prompt_input = self.prompt_class.hierarchical_inference( - hyp_dict, data, idx, group_condition + return all_predictions, all_actual_labels + + @staticmethod + def _extract_group_num(response_text: str, allowed_groups: set[int]): + cleaned = re.sub(r'.*?', '', str(response_text), flags=re.IGNORECASE | re.DOTALL).strip() + candidates = re.findall(r'\d+', cleaned) + for c in candidates: + try: + num = int(c) + if num in allowed_groups: + return num + except ValueError: + continue + return None + + def _determine_paths_batched(self, data, tree_root, cache_seed=None, max_concurrent=3, **generate_kwargs): + logger = LoggerConfig.get_logger("DeterminePathsBatched") + current_nodes = [tree_root] * len(data) + current_paths = [[] for _ in range(len(data))] + + layer_count = 0 + while any(not node.is_leaf for node in current_nodes): + layer_count += 1 + logger.info(f"=== Processing Layer {layer_count} ===") + non_leaf_indices = [i for i, node in enumerate(current_nodes) if not node.is_leaf] + if not non_leaf_indices: + break + node_groups = {} + for idx in non_leaf_indices: + node = current_nodes[idx] + node_key = id(node) + if node_key not in node_groups: + node_groups[node_key] = {'node': node, 'indices': []} + node_groups[node_key]['indices'].append(idx) + for node_key, group_info in node_groups.items(): + node = group_info['node'] + indices = group_info['indices'] + group_conditions = [] + for group_num, group_detail in node.groups.items(): + condition = group_detail['condition'] + group_conditions.append(f"Group {group_num}: {condition}") + allowed_list = sorted(list(node.groups.keys())) + allowed_set = set(allowed_list) + allowed_text = ", ".join(str(x) for x in allowed_list) + group_conditions_text = "\n".join(group_conditions) + f"\nValid group numbers: {allowed_text}. Answer must be ONE of these numbers ONLY." + logger.info(f"Group conditions: {group_conditions}") + + prompt_inputs = [] + for idx in indices: + sample_data = data.iloc[[idx]] + sample_data = sample_data.reset_index(drop=True) + prompt_input = self.prompt_class.internal_inference(group_conditions_text, sample_data, 0) + prompt_inputs.append(prompt_input) + + responses = self.api.batched_generate( + prompt_inputs, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, ) - prompt_inputs.append(prompt_input) - sample_indices.append(idx) + logger.info(f"--- LLM Outputs for Group Assignment ---") + for i, response in enumerate(responses): + sample_idx = indices[i] + logger.info(f"Sample {sample_idx} response: {response}") + + parsed_groups: dict[int, int | None] = {} + invalid_local_positions = [] # position within indices + for i, response in enumerate(responses): + sample_idx = indices[i] + gnum = self._extract_group_num(response, allowed_set) + if gnum is None: + invalid_local_positions.append(i) + parsed_groups[sample_idx] = None + logger.warning(f"Sample {sample_idx}: Failed to parse group from response: '{response}'") + else: + parsed_groups[sample_idx] = gnum + if invalid_local_positions: + logger.info(f"Retrying {len(invalid_local_positions)} samples with stronger constraints") + retry_inputs = [] + for pos in invalid_local_positions: + idx_global = indices[pos] + sample_data = data.iloc[[idx_global]].reset_index(drop=True) + retry_text = group_conditions_text + f"\nChoose strictly one from [{allowed_text}] and output ONLY the numeral." + retry_inputs.append(self.prompt_class.internal_inference(retry_text, sample_data, 0)) + + retry_responses = self.api.batched_generate( + retry_inputs, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + + logger.info(f"--- LLM Retry Outputs ---") + for j, resp in enumerate(retry_responses): + pos = invalid_local_positions[j] + sample_idx = indices[pos] + logger.info(f"Sample {sample_idx} retry response: {resp}") + + for j, resp in enumerate(retry_responses): + pos = invalid_local_positions[j] + sample_idx = indices[pos] + gnum = self._extract_group_num(resp, allowed_set) + parsed_groups[sample_idx] = gnum + if gnum is not None: + logger.info(f"Sample {sample_idx}: Retry successful, assigned to group {gnum}") + else: + logger.warning(f"Sample {sample_idx}: Retry failed, response: '{resp}'") + + group_distribution = {} + for sample_idx, gnum in parsed_groups.items(): + if gnum is None: + raise RuntimeError( + f"Failed to get a valid group among {allowed_list} for sample index {sample_idx}." + ) + + if gnum in node.children: + current_nodes[sample_idx] = node.children[gnum] + current_paths[sample_idx].append(gnum) + group_distribution[gnum] = group_distribution.get(gnum, 0) + 1 + logger.info(f"Sample {sample_idx}: Path updated to {current_paths[sample_idx]}") + else: + raise RuntimeError( + f"Predicted group {gnum} not in node.children. Path so far: {current_paths[sample_idx]}" + ) + + logger.info(f"Layer {layer_count} group distribution: {group_distribution}") + logger.info(f"Updated paths: {[current_paths[i] for i in indices]}") + + logger.info(f"=== Path Determination Complete ===") + logger.info(f"Total layers processed: {layer_count}") + + return current_paths + + @staticmethod + def _get_leaf_node_by_path(tree_root, path): + current_node = tree_root + for group_num in path: + if not current_node.is_leaf and group_num in current_node.children: + current_node = current_node.children[group_num] else: - raise Exception(f"Group {group_num} not found in groups.") + return None + return current_node + + def _leaf_batch_predict(self, leaf_data, leaf_node, cache_seed=None, max_concurrent=3, **generate_kwargs): + logger = LoggerConfig.get_logger("LeafBatchPredict") - # Batch generate predictions + logger.info(f"=== Leaf Node Prediction ===") + hyp_dict = {} + for hypothesis_text in leaf_node.hypotheses: + hyp_dict[hypothesis_text] = SummaryInformation() + prompt_inputs = [] + for idx in range(len(leaf_data)): + sample_data = leaf_data.iloc[[idx]] + sample_data = sample_data.reset_index(drop=True) + prompt_input = self.prompt_class.multiple_hypotheses_inference_with_path( + hyp_dict, sample_data, 0, leaf_node.path, leaf_node.path + ) + prompt_inputs.append(prompt_input) responses = self.api.batched_generate( prompt_inputs, cache_seed=cache_seed, max_concurrent=max_concurrent, **generate_kwargs, ) - - # Extract predictions and actual labels + logger.info(f"--- LLM Outputs for Leaf Prediction ---") + for i, response in enumerate(responses): + logger.info(f"Sample {i} response: {response}") predictions = [] - actual_labels = [] - for i, response in enumerate(responses): response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) prediction = self.task.extract_label(response) predictions.append(prediction) - actual_labels.append(data[self.task.label_name][sample_indices[i]]) + logger.info(f"Sample {i}: Extracted prediction: {prediction}") + return predictions - return predictions, actual_labels + # -------------------------------- + # Interface for running tree inference + # -------------------------------- def run_inference_final( self, data, hyp_bank, cache_seed=None, max_concurrent=3, + max_depth=5, + max_iterations=5, **generate_kwargs, ): - # Create stump (groups) from hypotheses - groups = self._create_stump_from_hypotheses( - hyp_bank, + logger = LoggerConfig.get_logger("RunInferenceFinal") + logger.info("=== Starting Hierarchical Tree Inference ===") + + logger.info("--- Phase 1: Building Decision Tree ---") + tree_root = self._tree_split( + hyp_bank=hyp_bank, cache_seed=cache_seed, + max_depth=max_depth, **generate_kwargs, ) - # Run inference using group-specific hypotheses - return self._stump_batched_predict( + logger.info("--- Phase 2: Tree Validation and Refinement ---") + current_tree_root = tree_root + iteration_count = 0 + validation_passed = False + + # while iteration_count < max_iterations and not validation_passed: + # iteration_count += 1 + # logger.info(f"--- Iteration {iteration_count}/{max_iterations} ---") + # validation_analysis, validation_passed = self._tree_validation( + # tree_root=current_tree_root, original_hypotheses=hyp_bank, cache_seed=cache_seed, **generate_kwargs + # ) + # if validation_passed: + # logger.info(f"=== Tree Validation Passed at Iteration {iteration_count} ===") + # break + # else: + # logger.info(f"=== Tree Validation Failed at Iteration {iteration_count} ===") + # if iteration_count >= max_iterations: + # logger.warning(f"Reached maximum iterations ({max_iterations}), using current tree") + # break + # logger.info(f"--- Proceeding to Tree Refinement for Iteration {iteration_count} ---") + # refined_tree_root = self._tree_refinement( + # tree_root=current_tree_root, original_hypotheses=hyp_bank, validation_analysis=validation_analysis, cache_seed=cache_seed, **generate_kwargs + # ) + # current_tree_root = refined_tree_root + + logger.info("--- Tree Validation and Refinement Complete ---") + logger.info(f"Total iterations: {iteration_count}") + if validation_passed: + logger.info("Tree validation passed successfully") + else: + logger.warning("Tree validation did not pass, but proceeding with current tree") + + logger.info("--- Phase 3: Running Tree Inference ---") + predictions, actual_labels = self._tree_batched_predict( data, - groups, + current_tree_root, cache_seed=cache_seed, max_concurrent=max_concurrent, **generate_kwargs, - ) \ No newline at end of file + ) + + logger.info("=== Hierarchical Tree Inference Complete ===") + logger.info(f"Prediction distribution: {pd.Series(predictions).value_counts().to_dict()}") + return predictions, actual_labels From 1f842a4897ae31c0a422f250398f62479117bd0b Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Tue, 12 Aug 2025 23:21:23 -0500 Subject: [PATCH 22/23] Implement tree based inference with update process --- hypogenic/prompt.py | 2 +- .../hierarchical_inference.py | 424 ++++++++++-------- run_pipeline.sh | 1 + 3 files changed, 238 insertions(+), 189 deletions(-) diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index 8a39706..bf79700 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -396,7 +396,7 @@ def internal_inference(self, group_conditions, test_data, test_idx): prompt = self._information_prompt(substitute_dict, "internal_inference") return prompt - def multiple_hypotheses_inference_with_path(self, hyp_dict, sample_data, param, tree_path=None, condition_path=None): + def multiple_hypotheses_inference_with_path(self, hyp_dict, sample_data, condition_path=None): """ Enhanced inference prompt that includes tree path context """ diff --git a/hypothesis_agent/data_analysis_agent/hierarchical_inference.py b/hypothesis_agent/data_analysis_agent/hierarchical_inference.py index a8587d3..b23b076 100644 --- a/hypothesis_agent/data_analysis_agent/hierarchical_inference.py +++ b/hypothesis_agent/data_analysis_agent/hierarchical_inference.py @@ -150,8 +150,15 @@ def _tree_split( child_path = (current_path or []) + [group_info['condition']] logger.info(f"--- Recursing to Group {group_num} ---") logger.info(f"Child condition path: {child_path}") + + # Create a new hypothesis bank that includes both hypotheses and examples + child_hyp_bank = group_info['hypotheses'].copy() + if group_info.get('examples'): + # Add examples as additional context for child nodes + child_hyp_bank.extend(group_info['examples']) + child_node = self._tree_split( - group_info['hypotheses'], + child_hyp_bank, # Pass both hypotheses and examples cache_seed=cache_seed, max_depth=max_depth, current_depth=current_depth + 1, @@ -167,65 +174,75 @@ def _tree_split( # -------------------------------- # Tree update (validation and refinement) # -------------------------------- - def _visualize_tree(self, node, depth=0, prefix=""): + def _visualize_tree_log(self, node, depth=0): if node.is_leaf: - # Leaf node - show hypotheses - result = f"{prefix}LEAF: {len(node.hypotheses)} hypotheses\n" - for i, hyp in enumerate(node.hypotheses): - if i == len(node.hypotheses) - 1: - result += f"{prefix} └── {hyp}\n" - else: - result += f"{prefix} ├── {hyp}\n" + # Leaf node + result = f"LEAF NODE:\n" + if node.hypotheses: + result += "Hypotheses:\n" + for i, hyp in enumerate(node.hypotheses, 1): + result += f" {i}. {hyp}\n" return result else: - # Internal node - show groups and recurse - result = f"{prefix}NODE: {len(node.groups)} groups, {len(node.hypotheses)} hypotheses\n" + # Internal node + result = f"INTERNAL NODE:\n" + if node.hypotheses: + result += "General Hypotheses:\n" + for i, hyp in enumerate(node.hypotheses, 1): + result += f" {i}. {hyp}\n" - group_items = list(node.groups.items()) - for i, (group_num, group_info) in enumerate(group_items): - is_last_group = (i == len(group_items) - 1) - - # Group header - if is_last_group: - result += f"{prefix} └── Group {group_num}: {group_info['condition']}\n" - else: - result += f"{prefix} ├── Group {group_num}: {group_info['condition']}\n" + result += "Groups:\n" + for group_num, group_info in node.groups.items(): + result += f" Group {group_num}: {group_info['condition']}\n" - # Group details if group_info.get('hypotheses'): - hyp_items = group_info['hypotheses'] - for j, hyp in enumerate(hyp_items): - is_last_hyp = (j == len(hyp_items) - 1) - if is_last_group: - if is_last_hyp: - result += f"{prefix} └── {hyp}\n" - else: - result += f"{prefix} ├── {hyp}\n" - else: - if is_last_hyp: - result += f"{prefix} └── {hyp}\n" - else: - result += f"{prefix} ├── {hyp}\n" + result += " Refined Hypotheses:\n" + for i, hyp in enumerate(group_info['hypotheses'], 1): + result += f" {i}. {hyp}\n" if group_info.get('examples'): - example_items = group_info['examples'] - for j, example in enumerate(example_items): - is_last_example = (j == len(example_items) - 1) - if is_last_group: - if is_last_example: - result += f"{prefix} └── Example: {example}\n" - else: - result += f"{prefix} ├── Example: {example}\n" - else: - if is_last_example: - result += f"{prefix} └── Example: {example}\n" - else: - result += f"{prefix} ├── Example: {example}\n" + result += " Examples:\n" + for i, example in enumerate(group_info['examples'], 1): + result += f" {i}. {example}\n" # Recurse to children if group_num in node.children: - child_prefix = f"{prefix} " if is_last_group else f"{prefix} " - result += self._visualize_tree(node.children[group_num], depth + 1, child_prefix) + child_result = self._visualize_tree(node.children[group_num], depth + 1) + # Indent child result + child_lines = child_result.split('\n') + indented_child = '\n'.join(f" {line}" if line else "" for line in child_lines) + result += f" Subtree:\n{indented_child}\n" + + return result + + def _visualize_tree(self, node, depth=0): + if node.is_leaf: + # Leaf node - show hypotheses + result = f"LEAF NODE:\n" + if node.hypotheses: + result += "Hypotheses:\n" + for i, hyp in enumerate(node.hypotheses, 1): + result += f" {i}. {hyp}\n" + return result + else: + # Internal node + result = f"INTERNAL NODE:\n" + + result += "Groups:\n" + for group_num, group_info in node.groups.items(): + result += f" Group {group_num}: {group_info['condition']}\n" + if group_info.get('examples'): + result += " Examples:\n" + for i, example in enumerate(group_info['examples'], 1): + result += f" {i}. {example}\n" + + # Recurse to children + if group_num in node.children: + child_result = self._visualize_tree(node.children[group_num], depth + 1) + # Indent child result + child_lines = child_result.split('\n') + indented_child = '\n'.join(f" {line}" if line else "" for line in child_lines) + result += f" Subtree:\n{indented_child}\n" return result @@ -243,7 +260,6 @@ def _tree_validation(self, tree_root, original_hypotheses, cache_seed=None, **ge logger.info("--- LLM Output for Tree Validation ---") logger.info(f"Validation response: {validation_response}") validation_passed = self._check_validation_result(validation_response) - logger.info(f"Validation passed: {validation_passed}") if not validation_passed: logger.warning(f"Validation failed.") return validation_response, validation_passed @@ -298,162 +314,152 @@ def _parse_and_build_improved_tree(self, refinement_response, original_hypothese return None def _parse_complete_tree_structure(self, refinement_response): + """Parse the structured tree format from LLM response""" response = re.sub(r'.*?', '', refinement_response, flags=re.IGNORECASE | re.DOTALL) lines = response.strip().split('\n') - # Parse the tree structure recursively - tree_structure = self._parse_nested_groups(lines, 0, len(lines)) + + # Parse the structured tree format + tree_structure = self._parse_structured_tree(lines, 0, len(lines)) return tree_structure - def _parse_nested_groups(self, lines, start_idx, end_idx): + def _parse_structured_tree(self, lines, start_idx, end_idx): + """Parse the new structured tree format""" groups = {} i = start_idx - + while i < end_idx: line = lines[i].strip() - if not line: i += 1 continue - - # Check for group start - handle both formats: - # 1. "├── Group X: condition" (indented format) - # 2. "Group X: condition" (direct format) - group_match = None - - # Try indented format first - indented_match = re.match(r'^[├└]──\s*Group\s*(\d+):\s*(.*)', line, re.IGNORECASE) - if indented_match: - group_match = indented_match - else: - # Try direct format - direct_match = re.match(r'^Group\s*(\d+):\s*(.*)', line, re.IGNORECASE) - if direct_match: - group_match = direct_match - + + # Look for group start + group_match = re.match(r'^Group\s*(\d+):\s*(.*)', line, re.IGNORECASE) if group_match: group_num = int(group_match.group(1)) condition = group_match.group(2).strip() - + # Initialize new group groups[group_num] = { 'condition': condition, 'hypotheses': [], 'examples': [], 'subgroups': {}, - 'is_leaf': True # Default to leaf, will be updated if subgroups found + 'is_leaf': True # Default to leaf } - - # Look ahead to see if this group has subgroups + + # Parse group content j = i + 1 - nested_start = None - while j < end_idx: next_line = lines[j].strip() if not next_line: j += 1 continue - - # Check if next line is a nested group (more indented) - # Look for deeper indentation patterns - if (next_line.startswith('├──') or next_line.startswith('└──')) and 'Group' in next_line: - # This is a nested group - nested_start = j - break - elif re.match(r'^Group\s*\d+:', next_line, re.IGNORECASE): - # This is a sibling group at same level + + # Check if we've reached the next group or section + if re.match(r'^Group\s*\d+:', next_line, re.IGNORECASE): break - elif re.match(r'^[├└]──\s*Group\s*\d+:', next_line, re.IGNORECASE): - # This is a sibling group at same level (indented format) + if next_line.startswith('INTERNAL NODE:') or next_line.startswith('LEAF NODE:'): break - j += 1 - - # If we found nested groups, parse them - if nested_start is not None: - # Find the end of nested structure - nested_end = self._find_nested_end(lines, nested_start, end_idx) - - # Parse nested subgroups - nested_groups = self._parse_nested_groups(lines, nested_start, nested_end) - groups[group_num]['subgroups'] = nested_groups - groups[group_num]['is_leaf'] = False - - # Skip to end of nested structure - i = nested_end - else: - # This is a leaf group, look for hypotheses/examples - j = i + 1 - while j < end_idx: - next_line = lines[j].strip() - if not next_line: + + # Parse refined hypotheses + if next_line.startswith('Refined Hypotheses:'): + j += 1 + while j < end_idx: + hyp_line = lines[j].strip() + if not hyp_line: + j += 1 + continue + + # Check for end of hypotheses section + if (hyp_line.startswith('Examples:') or + hyp_line.startswith('Subtree:') or + re.match(r'^Group\s*\d+:', hyp_line, re.IGNORECASE) or + hyp_line.startswith('INTERNAL NODE:') or + hyp_line.startswith('LEAF NODE:')): + break + + # Extract hypothesis + hyp_match = re.match(r'^\s*\d+\.\s*(.+)', hyp_line) + if hyp_match: + hypothesis = hyp_match.group(1).strip() + groups[group_num]['hypotheses'].append(hypothesis) + j += 1 - continue - - # Check if we've reached the next group at same level - if re.match(r'^Group\s*\d+:', next_line, re.IGNORECASE): - break - if re.match(r'^[├└]──\s*Group\s*\d+:', next_line, re.IGNORECASE): - break - - # Extract content from indented lines - if next_line.startswith('├──') or next_line.startswith('└──'): - # Extract hypothesis from indented line - content = re.sub(r'^[├└]──\s*', '', next_line).strip() - if content: - groups[group_num]['hypotheses'].append(content) - elif next_line.startswith('-') or next_line.startswith('•'): - # Bullet points - content = next_line.lstrip('- ').lstrip('• ').strip() - if content: - groups[group_num]['hypotheses'].append(content) - elif next_line.startswith('"') and next_line.endswith('"'): - # Quoted examples - content = next_line.strip('"') - if content: - groups[group_num]['examples'].append(content) - + continue + + # Parse examples + if next_line.startswith('Examples:'): j += 1 - - i = j - 1 # Adjust for the loop increment - + while j < end_idx: + ex_line = lines[j].strip() + if not ex_line: + j += 1 + continue + + # Check for end of examples section + if (ex_line.startswith('Subtree:') or + re.match(r'^Group\s*\d+:', ex_line, re.IGNORECASE) or + ex_line.startswith('INTERNAL NODE:') or + ex_line.startswith('LEAF NODE:')): + break + + # Extract example + ex_match = re.match(r'^\s*\d+\.\s*(.+)', ex_line) + if ex_match: + example = ex_match.group(1).strip() + groups[group_num]['examples'].append(example) + + j += 1 + continue + + # Parse subtree + if next_line.startswith('Subtree:'): + j += 1 + subtree_start = j + + # Find subtree end + subtree_end = self._find_subtree_end(lines, subtree_start, end_idx) + + # Parse nested subgroups + nested_groups = self._parse_structured_tree(lines, subtree_start, subtree_end) + groups[group_num]['subgroups'] = nested_groups + groups[group_num]['is_leaf'] = False + + j = subtree_end + continue + + j += 1 + + i = j - 1 # Adjust for the loop increment + i += 1 - + return groups @staticmethod - def _find_nested_end(lines, start_idx, end_idx): - i = start_idx + 1 - + def _find_subtree_end(lines, start_idx, end_idx): + """Find the end of a subtree section""" + i = start_idx + while i < end_idx: line = lines[i].strip() if not line: i += 1 continue - - # Check if we've reached the end of nested structure - # Look for next group at same level (either direct or indented) - if re.match(r'^Group\s*\d+:', line, re.IGNORECASE): - # Found next group at same level, end of nested structure - return i - elif re.match(r'^[├└]──\s*Group\s*\d+:', line, re.IGNORECASE): - # Found next group at same level (indented format), end of nested structure - return i - - # Check for other structural markers that might indicate end - if line.startswith('###') or line.startswith('##'): - # Found section header, end of nested structure + + # Check for end of subtree + if (line.startswith('Group') or + line.startswith('INTERNAL NODE:') or + line.startswith('LEAF NODE:') or + (line.startswith('Group') and ':' in line)): return i - # Check for deeper nesting - if we see more indentation, continue - if line.startswith('├──') or line.startswith('└──'): - # This is still part of nested structure - pass - i += 1 return end_idx - @staticmethod - def _build_tree_from_structure(tree_structure, original_hypotheses, current_path): + def _build_tree_from_structure(self, tree_structure, original_hypotheses, current_path): logger = LoggerConfig.get_logger("TreeRefinement") if not tree_structure: @@ -494,14 +500,38 @@ def _build_tree_from_structure(tree_structure, original_hypotheses, current_path ) internal_node.children[subgroup_num] = subgroup_node else: - # Handle deeper nesting if needed - logger.warning(f"Deep nesting detected in subgroup {subgroup_num}, treating as leaf") - subgroup_node = TreeNode( - is_leaf=True, - hypotheses=subgroup_info.get('hypotheses', []), - path=internal_node.path + [subgroup_info['condition']] - ) - internal_node.children[subgroup_num] = subgroup_node + # Handle deeper nesting recursively + if subgroup_info.get('subgroups'): + # This is an internal node with its own subgroups + nested_node = self._build_tree_from_structure( + subgroup_info['subgroups'], + subgroup_info.get('hypotheses', []), + internal_node.path + [subgroup_info['condition']] + ) + if nested_node: + nested_node.groups = subgroup_info['subgroups'] + nested_node.hypotheses = subgroup_info.get('hypotheses', []) + nested_node.path = internal_node.path + [subgroup_info['condition']] + internal_node.children[subgroup_num] = nested_node + logger.info(f"Created nested internal node for Group {group_num} -> Subgroup {subgroup_num}") + else: + # Fallback: treat as leaf node + logger.warning(f"Failed to build nested node for subgroup {subgroup_num}, treating as leaf") + subgroup_node = TreeNode( + is_leaf=True, + hypotheses=subgroup_info.get('hypotheses', []), + path=internal_node.path + [subgroup_info['condition']] + ) + internal_node.children[subgroup_num] = subgroup_node + else: + # Fallback: treat as leaf node + logger.warning(f"Subgroup {subgroup_num} has no subgroups but is not marked as leaf, treating as leaf") + subgroup_node = TreeNode( + is_leaf=True, + hypotheses=subgroup_info.get('hypotheses', []), + path=internal_node.path + [subgroup_info['condition']] + ) + internal_node.children[subgroup_num] = subgroup_node root.children[group_num] = internal_node logger.info( @@ -719,7 +749,7 @@ def _leaf_batch_predict(self, leaf_data, leaf_node, cache_seed=None, max_concurr sample_data = leaf_data.iloc[[idx]] sample_data = sample_data.reset_index(drop=True) prompt_input = self.prompt_class.multiple_hypotheses_inference_with_path( - hyp_dict, sample_data, 0, leaf_node.path, leaf_node.path + hyp_dict=hyp_dict, sample_data=sample_data, condition_path=leaf_node.path ) prompt_inputs.append(prompt_input) responses = self.api.batched_generate( @@ -764,31 +794,49 @@ def run_inference_final( **generate_kwargs, ) + # Display the built tree structure after Phase 1 + logger.info("=== Phase 1 Complete: Built Decision Tree ===") + logger.info("Tree Structure:") + tree_structure = self._visualize_tree(tree_root) + logger.info(tree_structure) + logger.info("=" * 80) + logger.info("--- Phase 2: Tree Validation and Refinement ---") current_tree_root = tree_root iteration_count = 0 validation_passed = False - # while iteration_count < max_iterations and not validation_passed: - # iteration_count += 1 - # logger.info(f"--- Iteration {iteration_count}/{max_iterations} ---") - # validation_analysis, validation_passed = self._tree_validation( - # tree_root=current_tree_root, original_hypotheses=hyp_bank, cache_seed=cache_seed, **generate_kwargs - # ) - # if validation_passed: - # logger.info(f"=== Tree Validation Passed at Iteration {iteration_count} ===") - # break - # else: - # logger.info(f"=== Tree Validation Failed at Iteration {iteration_count} ===") - # if iteration_count >= max_iterations: - # logger.warning(f"Reached maximum iterations ({max_iterations}), using current tree") - # break - # logger.info(f"--- Proceeding to Tree Refinement for Iteration {iteration_count} ---") - # refined_tree_root = self._tree_refinement( - # tree_root=current_tree_root, original_hypotheses=hyp_bank, validation_analysis=validation_analysis, cache_seed=cache_seed, **generate_kwargs - # ) - # current_tree_root = refined_tree_root + while iteration_count < max_iterations and not validation_passed: + iteration_count += 1 + logger.info(f"--- Iteration {iteration_count}/{max_iterations} ---") + + # Show current tree structure before validation + logger.info(f"Current tree structure before validation):") + current_tree_structure = self._visualize_tree(current_tree_root) + logger.info(current_tree_structure) + + validation_analysis, validation_passed = self._tree_validation( + tree_root=current_tree_root, original_hypotheses=hyp_bank, cache_seed=cache_seed, **generate_kwargs + ) + if validation_passed: + logger.info(f"=== Tree Validation Passed at Iteration {iteration_count} ===") + break + else: + logger.info(f"=== Tree Validation Failed at Iteration {iteration_count} ===") + if iteration_count >= max_iterations: + logger.warning(f"Reached maximum iterations ({max_iterations}), using current tree") + break + logger.info(f"--- Proceeding to Tree Refinement for Iteration {iteration_count} ---") + refined_tree_root = self._tree_refinement( + tree_root=current_tree_root, original_hypotheses=hyp_bank, validation_analysis=validation_analysis, cache_seed=cache_seed, **generate_kwargs + ) + current_tree_root = refined_tree_root + # Show refined tree structure after refinement + logger.info(f"Refined tree structure (after iteration {iteration_count}):") + refined_tree_structure = self._visualize_tree(current_tree_root) + logger.info(refined_tree_structure) + logger.info("-" * 60) logger.info("--- Tree Validation and Refinement Complete ---") logger.info(f"Total iterations: {iteration_count}") if validation_passed: diff --git a/run_pipeline.sh b/run_pipeline.sh index c2590b3..d9392ab 100755 --- a/run_pipeline.sh +++ b/run_pipeline.sh @@ -45,6 +45,7 @@ METHODS=( # "only_paper" "hypogenic" # "augmented_hypogenic" + #"hierarchical_inference" # "hyporefine" # "union_hypo" # "union_refine" From 8f6ab3bb795196162e37637702e8d29e3e1192f5 Mon Sep 17 00:00:00 2001 From: yfyfyufeng Date: Wed, 20 Aug 2025 15:33:09 -0500 Subject: [PATCH 23/23] Implement stump based inference --- hypogenic/prompt.py | 40 ++- .../data_analysis_agent/stump_inference.py | 275 ++++++++++++++++++ ...rchical_inference.py => tree_inference.py} | 2 +- pipeline.py | 23 +- 4 files changed, 336 insertions(+), 4 deletions(-) create mode 100644 hypothesis_agent/data_analysis_agent/stump_inference.py rename hypothesis_agent/data_analysis_agent/{hierarchical_inference.py => tree_inference.py} (99%) diff --git a/hypogenic/prompt.py b/hypogenic/prompt.py index bf79700..53e1a39 100644 --- a/hypogenic/prompt.py +++ b/hypogenic/prompt.py @@ -448,4 +448,42 @@ def _format_condition_path(condition_path): for i, condition_name in enumerate(condition_path): path_parts.append(f"Condition {i+1}: {condition_name}") - return " → ".join(path_parts) \ No newline at end of file + return " → ".join(path_parts) + + def create_stump(self, hyp_bank): + """ + Create stump (decision groups) from hypotheses. + """ + if isinstance(hyp_bank, (list, tuple, set)): + hypotheses_list = list(hyp_bank) + else: + hypotheses_list = list(hyp_bank.keys()) + + substitute_dict = { + "hypotheses": "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)]) + } + + prompt = self._information_prompt(substitute_dict, "create_stump") + return prompt + + def determine_group(self, group_conditions, test_data, test_idx): + """ + Determine which group a sample belongs to based on decision stump conditions. + """ + substitute_dict = self._get_substitute_dict(test_data, test_idx) + substitute_dict["group_conditions"] = group_conditions + + prompt = self._information_prompt(substitute_dict, "determine_group") + return prompt + + def stump_predict(self, hyp_dict, test_data, test_idx, group_condition): + """ + Create prediction prompt for decision stump. + """ + hypotheses_list = list(hyp_dict.keys()) + substitute_dict = self._get_substitute_dict(test_data, test_idx) + substitute_dict["hypotheses"] = "\n".join([f"{idx + 1}. {hyp}" for idx, hyp in enumerate(hypotheses_list)]) + substitute_dict["group_condition"] = group_condition + + prompt = self._information_prompt(substitute_dict, "stump_predict") + return prompt \ No newline at end of file diff --git a/hypothesis_agent/data_analysis_agent/stump_inference.py b/hypothesis_agent/data_analysis_agent/stump_inference.py new file mode 100644 index 0000000..139d905 --- /dev/null +++ b/hypothesis_agent/data_analysis_agent/stump_inference.py @@ -0,0 +1,275 @@ +import re + +from hypogenic.algorithm.inference import DefaultInference +from hypogenic.logger_config import LoggerConfig +from hypogenic.tasks import BaseTask +from hypogenic.algorithm.summary_information import ( + SummaryInformation, +) + +from hypothesis_agent.data_analysis_agent.prompt import TestPrompt +import pandas as pd + + +class StumpInference(DefaultInference): + def __init__( + self, + api, + prompt_class: TestPrompt, + train_data: pd.DataFrame, + task: BaseTask, + grouping_api, + ): + super().__init__(api, prompt_class, train_data, task) + self.grouping_api = grouping_api + + @staticmethod + def _parse_stump_response(response): + """ + Parse the stump response to extract groups and their hypotheses. + Returns a dictionary mapping group numbers to their hypotheses. + """ + # Remove think tags if present + response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) + + # Parse the response to extract groups and their hypotheses + groups = {} + current_group = None + + lines = response.strip().split('\n') + for line in lines: + line = line.strip() + if not line: + continue + + # Look for group indicators (e.g., "Group 1:" or "Group 1: ") + # Must start with "Group" followed by a number and colon + group_match = re.match(r'^Group\s*(\d+)\s*:\s*(.*)', line, re.IGNORECASE) + if group_match: + current_group = int(group_match.group(1)) + current_condition = group_match.group(2).strip() + + if current_group not in groups: + groups[current_group] = { + 'condition': current_condition, + 'hypotheses': [] + } + elif current_group is not None and (line.startswith('-') or line.startswith('•')): + # This is a hypothesis under the current group (indicated by '-' or '•') + hypothesis = line.lstrip('- ').lstrip('• ').strip() + if hypothesis: + groups[current_group]['hypotheses'].append(hypothesis) + elif current_group is not None and re.match(r'^\d+\.', line): + # This is a numbered hypothesis under the current group (e.g., "1. Hypothesis text") + hypothesis = re.sub(r'^\d+\.\s*', '', line).strip() + if hypothesis: + groups[current_group]['hypotheses'].append(hypothesis) + + # Log parsed results for debugging + logger = LoggerConfig.get_logger("StumpParsing") + logger.info("=== Parsed Stump Response ===") + log_text = '' + for group_num, group_info in sorted(groups.items()): + log_text += f"Group {group_num}: {group_info['condition']}\n" + for i, hypothesis in enumerate(group_info['hypotheses'], 1): + log_text += f" {i}. {hypothesis}\n" + logger.info(log_text) + logger.info("=============================") + + return groups + + def _create_stump( + self, + hyp_bank, + cache_seed=None, + **generate_kwargs, + ): + logger = LoggerConfig.get_logger("StumpCreation") + prompt_input = self.prompt_class.create_stump(hyp_bank) + response = self.grouping_api.generate( + prompt_input, + cache_seed=cache_seed, + **generate_kwargs, + ) + logger.info(f"Stump response: {response}") + groups = self._parse_stump_response(response) + return groups + + def _determine_groups_batched( + self, + data: pd.DataFrame, + groups, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + """ + Determine which group each sample belongs to in batch. + """ + logger = LoggerConfig.get_logger("SampleGroupDetermination") + # Format group conditions for the prompt + group_conditions = [] + for group_num, group_info in groups.items(): + condition = group_info['condition'] + group_conditions.append(f"Group {group_num}: {condition}") + group_conditions_text = "\n".join(group_conditions) + + # Create prompts for all samples + prompt_inputs = [] + for idx in range(len(data)): + prompt_input = self.prompt_class.determine_group(group_conditions_text, data, idx) + prompt_inputs.append(prompt_input) + + # Batch generate responses + responses = self.api.batched_generate( + prompt_inputs, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + # Extract group numbers from responses + group_nums = [] + for response in responses: + response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) + response = response.strip() + try: + group_num = int(re.search(r'\d+', response).group()) + group_nums.append(group_num) + except (ValueError, AttributeError): + logger.error(f"Failed to parse group number from response: {response}, defaulting to group 1") + + return group_nums + + def _stump_batched_predict( + self, + data: pd.DataFrame, + groups, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + """ + Make predictions using group-specific hypotheses with batched processing. + """ + logger = LoggerConfig.get_logger("StumpBatchedPredict") + # First, determine groups for all samples in batch + group_nums = self._determine_groups_batched( + data, groups, cache_seed, max_concurrent, **generate_kwargs + ) + + # Create prompts for all samples with their respective group condition and hypotheses + prompt_inputs = [] + sample_indices = [] + + for idx in range(len(data)): + group_num = group_nums[idx] + + if group_num in groups: + group_hypotheses = groups[group_num]['hypotheses'] + group_condition = groups[group_num]['condition'] + # Convert to the format expected by multiple_hypotheses_inference + # Create a dictionary mapping hypothesis text to SummaryInformation + hyp_dict = {} + for i, hypothesis_text in enumerate(group_hypotheses): + hyp_dict[hypothesis_text] = SummaryInformation() + + # Create prompt for this sample with group-specific hypotheses + prompt_input = self.prompt_class.stump_predict( + hyp_dict, data, idx, group_condition + ) + prompt_inputs.append(prompt_input) + sample_indices.append(idx) + else: + logger.error(f"Group {group_num} not found in groups.") + + # Batch generate predictions + responses = self.api.batched_generate( + prompt_inputs, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) + + # Extract predictions and actual labels + predictions = [] + actual_labels = [] + + for i, response in enumerate(responses): + response = re.sub(r'.*?', '', response, flags=re.IGNORECASE | re.DOTALL) + prediction = self.task.extract_label(response) + predictions.append(prediction) + actual_labels.append(data[self.task.label_name][sample_indices[i]]) + + # Log group statistics and prediction distribution + self._log_group_statistics(groups, group_nums, predictions, sample_indices) + + return predictions, actual_labels + + @staticmethod + def _log_group_statistics(groups, group_nums, predictions, sample_indices): + """ + Log statistics for each group including sample count and prediction distribution. + """ + logger = LoggerConfig.get_logger("GroupStatistics") + log_text = '' + + # Create a mapping from sample index to group number + sample_to_group = {} + for i, group_num in enumerate(group_nums): + sample_to_group[sample_indices[i]] = group_num + + # Analyze each group + for group_num, group_info in groups.items(): + group_condition = group_info['condition'] + + # Find samples belonging to this group + group_samples = [idx for idx, g_num in sample_to_group.items() if g_num == group_num] + sample_count = len(group_samples) + + # Get predictions for samples in this group + group_predictions = [] + for sample_idx in group_samples: + if sample_idx in sample_indices: + pred_idx = sample_indices.index(sample_idx) + if pred_idx < len(predictions): + group_predictions.append(predictions[pred_idx]) + + # Count prediction distribution + pred_distribution = {} + for pred in group_predictions: + pred_distribution[pred] = pred_distribution.get(pred, 0) + 1 + + # Format prediction distribution string + pred_dist_str = ", ".join([f"{label}: {count}" for label, count in pred_distribution.items()]) + + # Log group statistics + log_text += f"Group {group_num}: {group_condition}\n" + log_text += f"Sample count: {sample_count}\n" + log_text += f"Prediction distribution: {pred_dist_str}\n" + log_text += "-" * 50 +"\n" + + logger.info(log_text) + + def run_inference_final( + self, + data, + hyp_bank, + cache_seed=None, + max_concurrent=3, + **generate_kwargs, + ): + # Create stump from hypotheses + groups = self._create_stump( + hyp_bank, + cache_seed=cache_seed, + **generate_kwargs, + ) + + # Run inference using group-specific hypotheses + return self._stump_batched_predict( + data, + groups, + cache_seed=cache_seed, + max_concurrent=max_concurrent, + **generate_kwargs, + ) \ No newline at end of file diff --git a/hypothesis_agent/data_analysis_agent/hierarchical_inference.py b/hypothesis_agent/data_analysis_agent/tree_inference.py similarity index 99% rename from hypothesis_agent/data_analysis_agent/hierarchical_inference.py rename to hypothesis_agent/data_analysis_agent/tree_inference.py index b23b076..bafc354 100644 --- a/hypothesis_agent/data_analysis_agent/hierarchical_inference.py +++ b/hypothesis_agent/data_analysis_agent/tree_inference.py @@ -20,7 +20,7 @@ def __init__(self, groups=None, hypotheses=None, children=None, is_leaf=False, p self.path = path or [] # [condition1, condition2, ...] -class MultiHypHierarchicalInference(DefaultInference): +class TreeInference(DefaultInference): def __init__( self, api, diff --git a/pipeline.py b/pipeline.py index fc8f2e9..a7bf1ad 100644 --- a/pipeline.py +++ b/pipeline.py @@ -31,7 +31,8 @@ OnlyPaperGeneration, ZeroShotGeneration, ) -from hypothesis_agent.data_analysis_agent.hierarchical_inference import MultiHypHierarchicalInference +from hypothesis_agent.data_analysis_agent.stump_inference import StumpInference +# from hypothesis_agent.data_analysis_agent.tree_inference import TreeInference from hypothesis_agent.data_analysis_agent.inference import MultiHypDefaultInference from hypothesis_agent.data_analysis_agent.update import TestUpdate from hypothesis_agent.literature_review_agent import LiteratureAgent @@ -61,6 +62,11 @@ parser.add_argument("--task_name", type=str, required=True) parser.add_argument("--literature_folder", type=str) +# This is needed for the grouping model +parser.add_argument("--grouping_model_type", type=str) +parser.add_argument("--grouping_model_name", type=str) +parser.add_argument("--grouping_model_path", type=str) + # This is needed for local models parser.add_argument("--model_path", type=str) parser.add_argument("--do_train", action="store_true", default=False) @@ -770,8 +776,21 @@ def get_res_hierarchical(filename: str, task_name, api, model_name, use_val=Fals for hypothesis in hyp_dict: hyp_bank[hypothesis] = SummaryInformation.from_dict(hyp_dict[hypothesis]) + # Initialize the API for the grouping model + grouping_model_type = args.grouping_model_type + grouping_model_name = args.grouping_model_name + groping_model_path = args.grouping_model_path + if grouping_model_type and grouping_model_name: + logger.info(f"Using grouping model: {grouping_model_type} - {grouping_model_name}") + grouping_api = llm_wrapper_register.build(grouping_model_type)(model=grouping_model_name, path_name=groping_model_path) + else: + logger.warning('Grouping model type or name not provided, using the same API as the main model.') + grouping_api = api + + # inference_class = TreeInference(api, prompt_class, train_data, task) + # inference_class = DefaultInference(api, prompt_class, train_data, task) + inference_class = StumpInference(api, prompt_class, train_data, task, grouping_api) - inference_class = MultiHypHierarchicalInference(api, prompt_class, train_data, task) pred_list, label_list = inference_class.run_inference_final( test_data, hyp_bank,