77from dotenv import load_dotenv
88
99from .graphgen import GraphGen
10- from .models import OpenAIModel , Tokenizer , TraverseStrategy
11- from .utils import logger , read_file , set_logger
10+ from .utils import logger , set_logger
1211
1312sys_path = os .path .abspath (os .path .dirname (__file__ ))
1413
@@ -53,10 +52,8 @@ def main():
5352
5453 with open (args .config_file , "r" , encoding = "utf-8" ) as f :
5554 config = yaml .load (f , Loader = yaml .FullLoader )
56- input_file = config ["input_file" ]
57- data = read_file (input_file )
58- output_data_type = config ["output_data_type" ]
5955
56+ output_data_type = config ["output_data_type" ]
6057 unique_id = int (time .time ())
6158 set_logger (
6259 os .path .join (
@@ -72,41 +69,26 @@ def main():
7269 ),
7370 )
7471
75- tokenizer_instance = Tokenizer (model_name = config ["tokenizer" ])
76- synthesizer_llm_client = OpenAIModel (
77- model_name = os .getenv ("SYNTHESIZER_MODEL" ),
78- api_key = os .getenv ("SYNTHESIZER_API_KEY" ),
79- base_url = os .getenv ("SYNTHESIZER_BASE_URL" ),
80- tokenizer_instance = tokenizer_instance ,
81- )
82- trainee_llm_client = OpenAIModel (
83- model_name = os .getenv ("TRAINEE_MODEL" ),
84- api_key = os .getenv ("TRAINEE_API_KEY" ),
85- base_url = os .getenv ("TRAINEE_BASE_URL" ),
86- tokenizer_instance = tokenizer_instance ,
87- )
88-
89- graph_gen = GraphGen (
90- working_dir = working_dir ,
91- unique_id = unique_id ,
92- synthesizer_llm_client = synthesizer_llm_client ,
93- trainee_llm_client = trainee_llm_client ,
94- search_config = config ["search" ],
95- tokenizer_instance = tokenizer_instance ,
96- )
72+ graph_gen = GraphGen (working_dir = working_dir , unique_id = unique_id , config = config )
9773
98- graph_gen .insert (data , config [ "input_data_type" ] )
74+ graph_gen .insert ()
9975
10076 if config ["search" ]["enabled" ]:
10177 graph_gen .search ()
10278
10379 # Use pipeline according to the output data type
10480 if output_data_type in ["atomic" , "aggregated" , "multi_hop" ]:
105- graph_gen .quiz (max_samples = config ["quiz_samples" ])
106- graph_gen .judge (re_judge = config ["re_judge" ])
107- traverse_strategy = TraverseStrategy (** config ["traverse_strategy" ])
108- traverse_strategy .qa_form = output_data_type
109- graph_gen .traverse (traverse_strategy = traverse_strategy )
81+ if "quiz_and_judge_strategy" in config and config [
82+ "quiz_and_judge_strategy"
83+ ].get ("enabled" , False ):
84+ graph_gen .quiz ()
85+ graph_gen .judge ()
86+ else :
87+ logger .warning (
88+ "Quiz and Judge strategy is disabled. Edge sampling falls back to random."
89+ )
90+ graph_gen .traverse_strategy .edge_sampling = "random"
91+ graph_gen .traverse ()
11092 elif output_data_type == "cot" :
11193 graph_gen .generate_reasoning (method_params = config ["method_params" ])
11294 else :
0 commit comments