Skip to content

Commit 760b27e

Browse files
feat: add google search
1 parent b54632a commit 760b27e

11 files changed

Lines changed: 171 additions & 70 deletions

File tree

graphgen/configs/graphgen_config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,5 @@ traverse_strategy:
1414
loss_strategy: only_edge
1515
search:
1616
enabled: true
17-
search_types: ["wikipedia", "google"]
17+
search_types: ["google"]
1818
re_judge: false

graphgen/graphgen.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,9 +237,8 @@ async def async_search(self):
237237
"[Search] Found %d entities to search", len(new_search_entities)
238238
)
239239
_add_search_data = await search_all(
240-
llm_client=self.synthesizer_llm_client,
241240
search_types=self.search_config["search_types"],
242-
kg_instance=self.graph_storage,
241+
search_entities=new_search_entities,
243242
)
244243
if _add_search_data:
245244
await self.search_storage.upsert(_add_search_data)

graphgen/models/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from graphgen.models.search.kg.wiki_search import WikiSearch
2-
31
from .evaluate.length_evaluator import LengthEvaluator
42
from .evaluate.mtld_evaluator import MTLDEvaluator
53
from .evaluate.reward_evaluator import RewardEvaluator
64
from .evaluate.uni_evaluator import UniEvaluator
75
from .llm.openai_model import OpenAIModel
86
from .llm.tokenizer import Tokenizer
97
from .llm.topk_token_model import Token, TopkTokenModel
8+
from .search.kg.wiki_search import WikiSearch
9+
from .search.web.google_search import GoogleSearch
1010
from .storage.json_storage import JsonKVStorage
1111
from .storage.networkx_storage import NetworkXStorage
1212
from .strategy.travserse_strategy import TraverseStrategy
@@ -25,6 +25,7 @@
2525
"JsonKVStorage",
2626
# search models
2727
"WikiSearch",
28+
"GoogleSearch",
2829
# evaluate models
2930
"TextPair",
3031
"LengthEvaluator",

graphgen/models/search/kg/wiki_search.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ def set_language(language: str):
1414
assert language in ["en", "zh"], "Only support English and Chinese"
1515
set_lang(language)
1616

17-
async def search(self, query: str) -> Union[List[str], None]:
17+
async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]:
1818
self.set_language(detect_main_language(query))
19-
return wikipedia.search(query)
19+
return wikipedia.search(query, results=num_results, suggestion=False)
2020

2121
async def summary(self, query: str) -> Union[str, None]:
2222
self.set_language(detect_main_language(query))

graphgen/models/search/web/__init__.py

Whitespace-only changes.
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from dataclasses import dataclass
2+
3+
import requests
4+
from fastapi import HTTPException
5+
6+
from graphgen.utils import logger
7+
8+
GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1"
9+
10+
11+
@dataclass
12+
class GoogleSearch:
13+
def __init__(self, subscription_key: str, cx: str):
14+
"""
15+
Initialize the Google Search client with the subscription key and custom search engine ID.
16+
:param subscription_key: Your Google API subscription key.
17+
:param cx: Your custom search engine ID.
18+
"""
19+
self.subscription_key = subscription_key
20+
self.cx = cx
21+
22+
def search(self, query: str, num_results: int = 1):
23+
"""
24+
Search with Google and return the contexts.
25+
:param query: The search query.
26+
:param num_results: The number of results to return.
27+
:return: A list of search results.
28+
"""
29+
params = {
30+
"key": self.subscription_key,
31+
"cx": self.cx,
32+
"q": query,
33+
"num": num_results,
34+
}
35+
response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10)
36+
if not response.ok:
37+
logger.error("Search engine error: %s", response.text)
38+
raise HTTPException(response.status_code, "Search engine error.")
39+
json_content = response.json()
40+
try:
41+
contexts = json_content["items"][:num_results]
42+
except KeyError:
43+
logger.error("Error encountered: %s", json_content)
44+
return []
45+
return contexts
Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,58 @@
11
from tqdm.asyncio import tqdm_asyncio as tqdm_async
22

3-
from graphgen.models import NetworkXStorage, OpenAIModel, WikiSearch
4-
from graphgen.templates import SEARCH_JUDGEMENT_PROMPT
3+
from graphgen.models import WikiSearch
54
from graphgen.utils import logger
65

76

87
async def _process_single_entity(
98
entity_name: str,
10-
description: str,
11-
llm_client: OpenAIModel,
129
wiki_search_client: WikiSearch,
13-
) -> tuple[str, None] | tuple[str, str]:
10+
) -> str | None:
1411
"""
15-
Process single entity
16-
12+
Process single entity by searching Wikipedia
13+
:param entity_name
14+
:param wiki_search_client
15+
:return: summary of the entity or None if not found
1716
"""
1817
search_results = await wiki_search_client.search(entity_name)
1918
if not search_results:
20-
return entity_name, None
21-
examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"])
22-
search_results.append("None of the above")
19+
return None
2320

24-
search_results_str = "\n".join(
25-
[f"{i + 1}. {sr}" for i, sr in enumerate(search_results)]
26-
)
27-
prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format(
28-
examples=examples,
29-
entity_name=entity_name,
30-
description=description,
31-
search_results=search_results_str,
32-
)
33-
response = await llm_client.generate_answer(prompt)
21+
summary = None
3422
try:
35-
response = response.strip()
36-
response = int(response)
37-
if response < 1 or response >= len(search_results):
38-
response = None
39-
else:
40-
response = await wiki_search_client.summary(search_results[response - 1])
41-
except ValueError:
42-
response = None
43-
44-
logger.info(
45-
"Entity %s search result: %s response: %s",
46-
entity_name,
47-
str(search_results),
48-
response,
49-
)
23+
summary = await wiki_search_client.summary(search_results[-1])
24+
logger.info(
25+
"Entity %s search result: %s summary: %s",
26+
entity_name,
27+
str(search_results),
28+
summary,
29+
)
30+
except Exception as e: # pylint: disable=broad-except
31+
logger.error("Error processing entity %s: %s", entity_name, str(e))
5032

51-
return entity_name, response
33+
return summary
5234

5335

5436
async def search_wikipedia(
55-
llm_client: OpenAIModel,
5637
wiki_search_client: WikiSearch,
57-
kg_instance: NetworkXStorage,
38+
entities: set[str],
5839
) -> dict:
5940
"""
6041
Search wikipedia for entities
6142
62-
:param llm_client: LLM model
6343
:param wiki_search_client: wiki search client
64-
:param kg_instance: knowledge graph instance
44+
:param entities: list of entities to search
6545
:return: nodes with search results
6646
"""
67-
nodes = await kg_instance.get_all_nodes()
68-
nodes = list(nodes)
6947
wiki_data = {}
7048

71-
async for node in tqdm_async(nodes, desc="Searching Wikipedia", total=len(nodes)):
72-
entity_name = node[0].strip('"')
73-
description = node[1]["description"]
49+
async for entity in tqdm_async(
50+
entities, desc="Searching Wikipedia", total=len(entities)
51+
):
7452
try:
75-
entity, summary = await _process_single_entity(
76-
entity_name, description, llm_client, wiki_search_client
77-
)
78-
wiki_data[entity] = summary
53+
entity, summary = await _process_single_entity(entity, wiki_search_client)
54+
if summary:
55+
wiki_data[entity] = summary
7956
except Exception as e: # pylint: disable=broad-except
80-
logger.error("Error processing entity %s: %s", entity_name, str(e))
57+
logger.error("Error processing entity %s: %s", entity, str(e))
8158
return wiki_data

graphgen/operators/search/search_all.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
1-
from graphgen.models import NetworkXStorage, OpenAIModel
1+
"""
2+
To use Google Web Search API,
3+
follow the instructions [here](https://developers.google.com/custom-search/v1/overview)
4+
to get your Google search api key.
5+
6+
To use Bing Web Search API,
7+
follow the instructions [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api)
8+
and obtain your Bing subscription key.
9+
"""
10+
11+
import os
12+
213
from graphgen.utils import logger
314

415

516
async def search_all(
6-
llm_client: OpenAIModel, search_types: dict, kg_instance: NetworkXStorage
17+
search_types: dict, search_entities: set[str]
718
) -> dict[str, dict[str, str]]:
819
"""
9-
:param llm_client
1020
:param search_types
11-
:param kg_instance
21+
:param search_entities: list of entities to search
1222
:return: nodes with search results
1323
"""
1424

15-
# 增量建图时,只需要搜索新增实体
16-
1725
results = {}
1826

1927
for search_type in search_types:
@@ -23,16 +31,25 @@ async def search_all(
2331

2432
wiki_search_client = WikiSearch()
2533

26-
wiki_results = await search_wikipedia(
27-
llm_client, wiki_search_client, kg_instance
28-
)
34+
wiki_results = await search_wikipedia(wiki_search_client, search_entities)
2935
for entity_name, description in wiki_results.items():
3036
if description:
3137
results[entity_name] = {"wikipedia": description}
32-
# elif search_type == "google":
33-
# from graphgen.operators.search.web.search_google import search_google
34-
# return await search_google(llm_client, kg_instance)
35-
#
38+
elif search_type == "google":
39+
from graphgen.models import GoogleSearch
40+
from graphgen.operators.search.web.search_google import search_google
41+
42+
google_search_client = GoogleSearch(
43+
subscription_key=os.environ["GOOGLE_SEARCH_API_KEY"],
44+
cx=os.environ["GOOGLE_SEARCH_CX"],
45+
)
46+
47+
google_results = await search_google(google_search_client, search_entities)
48+
for entity_name, description in google_results.items():
49+
if description:
50+
results[entity_name] = results.get(entity_name, {})
51+
results[entity_name]["google"] = description
52+
3653
# elif search_type == "bing":
3754
# from graphgen.operators.search.web.search_bing import search_bing
3855
# return await search_bing(llm_client, kg_instance)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
2+
BING_MKT = "en-US"
3+
4+
5+
async def search_bing():
6+
"""
7+
Search with Bing and return the contexts.
8+
:return:
9+
"""
10+
raise NotImplementedError("Bing search is not implemented yet.")
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import trafilatura
2+
from tqdm.asyncio import tqdm_asyncio as tqdm_async
3+
4+
from graphgen.models import GoogleSearch
5+
from graphgen.utils import logger
6+
7+
8+
async def _process_single_entity(
9+
entity_name: str, google_search_client: GoogleSearch
10+
) -> str | None:
11+
search_results = google_search_client.search(entity_name)
12+
if not search_results:
13+
return None
14+
15+
# Get more details from the first search result
16+
first_result = search_results[0]
17+
content = trafilatura.fetch_url(first_result["link"])
18+
summary = trafilatura.extract(content, include_comments=False, include_links=False)
19+
summary = summary.strip()
20+
logger.info(
21+
"Entity %s search result: %s",
22+
entity_name,
23+
summary,
24+
)
25+
return summary
26+
27+
28+
async def search_google(
29+
google_search_client: GoogleSearch,
30+
entities: set[str],
31+
) -> dict:
32+
"""
33+
Search with Google and return the contexts.
34+
:param google_search_client: Google search client
35+
:param entities: list of entities to search
36+
:return:
37+
"""
38+
google_data = {}
39+
40+
async for entity in tqdm_async(
41+
entities, desc="Searching Google", total=len(entities)
42+
):
43+
try:
44+
summary = await _process_single_entity(entity, google_search_client)
45+
if summary:
46+
google_data[entity] = summary
47+
except Exception as e: # pylint: disable=broad-except
48+
logger.error("Error processing entity %s: %s", entity, str(e))
49+
return google_data

0 commit comments

Comments
 (0)