|
1 | 1 | from tqdm.asyncio import tqdm_asyncio as tqdm_async |
2 | 2 |
|
3 | | -from graphgen.models import NetworkXStorage, OpenAIModel, WikiSearch |
4 | | -from graphgen.templates import SEARCH_JUDGEMENT_PROMPT |
| 3 | +from graphgen.models import WikiSearch |
5 | 4 | from graphgen.utils import logger |
6 | 5 |
|
7 | 6 |
|
8 | 7 | async def _process_single_entity( |
9 | 8 | entity_name: str, |
10 | | - description: str, |
11 | | - llm_client: OpenAIModel, |
12 | 9 | wiki_search_client: WikiSearch, |
13 | | -) -> tuple[str, None] | tuple[str, str]: |
| 10 | +) -> str | None: |
14 | 11 | """ |
15 | | - Process single entity |
16 | | -
|
| 12 | + Process single entity by searching Wikipedia |
| 13 | + :param entity_name |
| 14 | + :param wiki_search_client |
| 15 | + :return: summary of the entity or None if not found |
17 | 16 | """ |
18 | 17 | search_results = await wiki_search_client.search(entity_name) |
19 | 18 | if not search_results: |
20 | | - return entity_name, None |
21 | | - examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"]) |
22 | | - search_results.append("None of the above") |
| 19 | + return None |
23 | 20 |
|
24 | | - search_results_str = "\n".join( |
25 | | - [f"{i + 1}. {sr}" for i, sr in enumerate(search_results)] |
26 | | - ) |
27 | | - prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format( |
28 | | - examples=examples, |
29 | | - entity_name=entity_name, |
30 | | - description=description, |
31 | | - search_results=search_results_str, |
32 | | - ) |
33 | | - response = await llm_client.generate_answer(prompt) |
| 21 | + summary = None |
34 | 22 | try: |
35 | | - response = response.strip() |
36 | | - response = int(response) |
37 | | - if response < 1 or response >= len(search_results): |
38 | | - response = None |
39 | | - else: |
40 | | - response = await wiki_search_client.summary(search_results[response - 1]) |
41 | | - except ValueError: |
42 | | - response = None |
43 | | - |
44 | | - logger.info( |
45 | | - "Entity %s search result: %s response: %s", |
46 | | - entity_name, |
47 | | - str(search_results), |
48 | | - response, |
49 | | - ) |
| 23 | + summary = await wiki_search_client.summary(search_results[-1]) |
| 24 | + logger.info( |
| 25 | + "Entity %s search result: %s summary: %s", |
| 26 | + entity_name, |
| 27 | + str(search_results), |
| 28 | + summary, |
| 29 | + ) |
| 30 | + except Exception as e: # pylint: disable=broad-except |
| 31 | + logger.error("Error processing entity %s: %s", entity_name, str(e)) |
50 | 32 |
|
51 | | - return entity_name, response |
| 33 | + return summary |
52 | 34 |
|
53 | 35 |
|
54 | 36 | async def search_wikipedia( |
55 | | - llm_client: OpenAIModel, |
56 | 37 | wiki_search_client: WikiSearch, |
57 | | - kg_instance: NetworkXStorage, |
| 38 | + entities: set[str], |
58 | 39 | ) -> dict: |
59 | 40 | """ |
60 | 41 | Search wikipedia for entities |
61 | 42 |
|
62 | | - :param llm_client: LLM model |
63 | 43 | :param wiki_search_client: wiki search client |
64 | | - :param kg_instance: knowledge graph instance |
| 44 | + :param entities: list of entities to search |
65 | 45 | :return: nodes with search results |
66 | 46 | """ |
67 | | - nodes = await kg_instance.get_all_nodes() |
68 | | - nodes = list(nodes) |
69 | 47 | wiki_data = {} |
70 | 48 |
|
71 | | - async for node in tqdm_async(nodes, desc="Searching Wikipedia", total=len(nodes)): |
72 | | - entity_name = node[0].strip('"') |
73 | | - description = node[1]["description"] |
| 49 | + async for entity in tqdm_async( |
| 50 | + entities, desc="Searching Wikipedia", total=len(entities) |
| 51 | + ): |
74 | 52 | try: |
75 | | - entity, summary = await _process_single_entity( |
76 | | - entity_name, description, llm_client, wiki_search_client |
77 | | - ) |
78 | | - wiki_data[entity] = summary |
| 53 | + entity, summary = await _process_single_entity(entity, wiki_search_client) |
| 54 | + if summary: |
| 55 | + wiki_data[entity] = summary |
79 | 56 | except Exception as e: # pylint: disable=broad-except |
80 | | - logger.error("Error processing entity %s: %s", entity_name, str(e)) |
| 57 | + logger.error("Error processing entity %s: %s", entity, str(e)) |
81 | 58 | return wiki_data |
0 commit comments