Skip to content

Commit ed27cd1

Browse files
refactor: adjust templates
1 parent 7d4f1e5 commit ed27cd1

7 files changed

Lines changed: 38 additions & 125 deletions

File tree

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,37 @@
1-
from typing import Any, List
1+
from typing import List
22

3-
from graphgen.bases import BaseGraphStorage, BasePartitioner
3+
from graphgen.bases import BaseGraphStorage
44
from graphgen.bases.datatypes import Community
5+
from graphgen.models import BFSPartitioner
56

67

7-
class ECEPartitioner(BasePartitioner):
8-
def partition(
9-
self,
10-
g: BaseGraphStorage,
11-
bidirectional: bool = False,
12-
**kwargs: Any,
13-
) -> List[Community]:
14-
pass
8+
class ECEPartitioner(BFSPartitioner):
9+
"""
10+
ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE).
11+
We calculate ECE for edges in KG(represented as 'comprehension loss') and group edges with similar ECE values into the same community.
12+
1. Select a sampling strategy.
13+
2. Choose a unit based on the sampling strategy.
14+
2. Expand the community using BFS.
15+
3. When expending, prefer to add units with the sampling strategy.
16+
4. Stop when the max unit size is reached or the max input length is reached.
17+
(A unit is a node or an edge.)
18+
"""
1519

16-
def split_communities(self, communities: List[Community]) -> List[Community]:
17-
pass
20+
# async def partition(
21+
# self,
22+
# g: BaseGraphStorage,
23+
# *,
24+
# ):
25+
# pass
26+
27+
28+
# 修改
29+
# max_depth 取消
30+
# expand_method 改名为 xxx
31+
# edge_sampling
32+
# loss_strategy取消,因为node和edge可以看作同一种unit
33+
# bidirectional 取消
34+
# max_extra_edges 改名为 max_units_per_community
35+
# max_tokens 改名为 max_tokens_per_community
36+
37+
# 可以退化成BFS

graphgen/operators/partition/traverse_graph.py

Lines changed: 0 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -299,117 +299,6 @@ async def _process_single_batch(
299299
return results
300300

301301

302-
# pylint: disable=too-many-branches, too-many-statements
303-
async def traverse_graph_for_atomic(
304-
llm_client: OpenAIClient,
305-
tokenizer: Tokenizer,
306-
graph_storage: NetworkXStorage,
307-
traverse_strategy: Dict,
308-
text_chunks_storage: JsonKVStorage,
309-
progress_bar: gr.Progress = None,
310-
max_concurrent: int = 1000,
311-
) -> dict:
312-
"""
313-
Traverse the graph atomicly
314-
315-
:param llm_client
316-
:param tokenizer
317-
:param graph_storage
318-
:param traverse_strategy
319-
:param text_chunks_storage
320-
:param progress_bar
321-
:param max_concurrent
322-
:return: question and answer
323-
"""
324-
325-
semaphore = asyncio.Semaphore(max_concurrent)
326-
327-
def _parse_qa(qa: str) -> tuple:
328-
if "Question:" in qa and "Answer:" in qa:
329-
question = qa.split("Question:")[1].split("Answer:")[0].strip()
330-
answer = qa.split("Answer:")[1].strip()
331-
elif "问题:" in qa and "答案:" in qa:
332-
question = qa.split("问题:")[1].split("答案:")[0].strip()
333-
answer = qa.split("答案:")[1].strip()
334-
else:
335-
return None, None
336-
return question.strip('"'), answer.strip('"')
337-
338-
async def _generate_question(node_or_edge: tuple):
339-
if len(node_or_edge) == 2:
340-
des = node_or_edge[0] + ": " + node_or_edge[1]["description"]
341-
loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0
342-
else:
343-
des = node_or_edge[2]["description"]
344-
loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0
345-
346-
async with semaphore:
347-
try:
348-
language = "Chinese" if detect_main_language(des) == "zh" else "English"
349-
350-
qa = await llm_client.generate_answer(
351-
QUESTION_GENERATION_PROMPT[language]["SINGLE_QA_TEMPLATE"].format(
352-
doc=des
353-
)
354-
)
355-
356-
question, answer = _parse_qa(qa)
357-
if question is None or answer is None:
358-
return {}
359-
360-
question = question.strip('"')
361-
answer = answer.strip('"')
362-
363-
logger.info("Question: %s", question)
364-
logger.info("Answer: %s", answer)
365-
return {
366-
compute_content_hash(question): {
367-
"question": question,
368-
"answer": answer,
369-
"loss": loss,
370-
}
371-
}
372-
except Exception as e: # pylint: disable=broad-except
373-
logger.error("Error occurred while generating question: %s", e)
374-
return {}
375-
376-
results = {}
377-
edges = list(await graph_storage.get_all_edges())
378-
nodes = list(await graph_storage.get_all_nodes())
379-
380-
edges, nodes = await _pre_tokenize(graph_storage, tokenizer, edges, nodes)
381-
382-
tasks = []
383-
for node in nodes:
384-
if "<SEP>" in node[1]["description"]:
385-
description_list = node[1]["description"].split("<SEP>")
386-
for item in description_list:
387-
tasks.append((node[0], {"description": item}))
388-
if "loss" in node[1]:
389-
tasks[-1][1]["loss"] = node[1]["loss"]
390-
else:
391-
tasks.append((node[0], node[1]))
392-
for edge in edges:
393-
if "<SEP>" in edge[2]["description"]:
394-
description_list = edge[2]["description"].split("<SEP>")
395-
for item in description_list:
396-
tasks.append((edge[0], edge[1], {"description": item}))
397-
if "loss" in edge[2]:
398-
tasks[-1][2]["loss"] = edge[2]["loss"]
399-
else:
400-
tasks.append((edge[0], edge[1], edge[2]))
401-
402-
results_list = await run_concurrent(
403-
_generate_question,
404-
tasks,
405-
progress_bar=progress_bar,
406-
desc="[4/4]Generating QAs",
407-
)
408-
for res in results_list:
409-
results.update(res)
410-
return results
411-
412-
413302
async def traverse_graph_for_multi_hop(
414303
llm_client: OpenAIClient,
415304
tokenizer: Tokenizer,

graphgen/templates/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from .answer_rephrasing import ANSWER_REPHRASING_PROMPT
2-
from .atomic_generation import ATOMIC_GENERATION_PROMPT
3-
from .community import COT_GENERATION_PROMPT, COT_TEMPLATE_DESIGN_PROMPT
42
from .coreference_resolution import COREFERENCE_RESOLUTION_PROMPT
53
from .description_rephrasing import DESCRIPTION_REPHRASING_PROMPT
4+
from .generation import (
5+
ATOMIC_GENERATION_PROMPT,
6+
COT_GENERATION_PROMPT,
7+
COT_TEMPLATE_DESIGN_PROMPT,
8+
)
69
from .kg_extraction import KG_EXTRACTION_PROMPT
710
from .kg_summarization import KG_SUMMARIZATION_PROMPT
811
from .multi_hop_generation import MULTI_HOP_GENERATION_PROMPT
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
from .atomic_generation import ATOMIC_GENERATION_PROMPT
12
from .cot_generation import COT_GENERATION_PROMPT
23
from .cot_template_design import COT_TEMPLATE_DESIGN_PROMPT
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)