Skip to content

Commit 2d963f6

Browse files
fix(graphgen): standardize types for qa_form
1 parent c443180 commit 2d963f6

3 files changed

Lines changed: 5 additions & 3 deletions

File tree

graphgen/configs/graphgen_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ input_file: resources/examples/raw_demo.jsonl
33
tokenizer: cl100k_base
44
quiz_samples: 2
55
traverse_strategy:
6-
qa_form: open
6+
qa_form: aggregated
77
bidirectional: true
88
edge_sampling: max_loss
99
expand_method: max_width

graphgen/graphgen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,10 +237,12 @@ async def async_traverse(self):
237237
self.traverse_strategy,
238238
self.text_chunks_storage,
239239
self.progress_bar)
240-
else:
240+
elif self.traverse_strategy.qa_form == "aggregated":
241241
results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance,
242242
self.graph_storage, self.traverse_strategy, self.text_chunks_storage,
243243
self.progress_bar)
244+
else:
245+
raise ValueError(f"Unknown qa_form: {self.traverse_strategy.qa_form}")
244246
await self.qa_storage.upsert(results)
245247
await self.qa_storage.index_done_callback()
246248

graphgen/models/strategy/travserse_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
@dataclass
77
class TraverseStrategy(BaseStrategy):
88
# 生成的QA形式:原子、多跳、开放性
9-
qa_form: str = "multi_hop" # "atomic" or "multi_hop" or "open"
9+
qa_form: str = "atomic" # "atomic" or "multi_hop" or "aggregated"
1010
# 最大边数和最大token数方法中选择一个生效
1111
expand_method: str = "max_tokens" # "max_width" or "max_tokens"
1212
# 单向拓展还是双向拓展

0 commit comments

Comments
 (0)