Skip to content

Commit e63968b

Browse files
feat: uniprot search
1 parent 760b27e commit e63968b

7 files changed

Lines changed: 182 additions & 7 deletions

File tree

graphgen/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from .llm.openai_model import OpenAIModel
66
from .llm.tokenizer import Tokenizer
77
from .llm.topk_token_model import Token, TopkTokenModel
8+
from .search.db.uniprot_search import UniProtSearch
89
from .search.kg.wiki_search import WikiSearch
10+
from .search.web.bing_search import BingSearch
911
from .search.web.google_search import GoogleSearch
1012
from .storage.json_storage import JsonKVStorage
1113
from .storage.networkx_storage import NetworkXStorage
@@ -26,6 +28,8 @@
2628
# search models
2729
"WikiSearch",
2830
"GoogleSearch",
31+
"BingSearch",
32+
"UniProtSearch",
2933
# evaluate models
3034
"TextPair",
3135
"LengthEvaluator",
File renamed without changes.
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from dataclasses import dataclass
2+
3+
import requests
4+
from fastapi import HTTPException
5+
6+
from graphgen.utils import logger
7+
8+
UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search"
9+
10+
11+
@dataclass
12+
class UniProtSearch:
13+
"""
14+
UniProt Search client to search with UniProt.
15+
1) Get the protein by accession number.
16+
2) Search with keywords or protein names.
17+
"""
18+
19+
def get_entry(self, accession: str) -> dict:
20+
"""
21+
Get the UniProt entry by accession number(e.g., P04637).
22+
"""
23+
url = f"{UNIPROT_BASE}/{accession}.json"
24+
return self._safe_get(url).json()
25+
26+
def search(
27+
self,
28+
query: str,
29+
*,
30+
size: int = 10,
31+
cursor: str = None,
32+
fields: list[str] = None,
33+
) -> dict:
34+
"""
35+
Search UniProt with a query string.
36+
:param query: The search query.
37+
:param size: The number of results to return.
38+
:param cursor: The cursor for pagination.
39+
:param fields: The fields to return in the response.
40+
:return: A dictionary containing the search results.
41+
"""
42+
params = {
43+
"query": query,
44+
"size": size,
45+
}
46+
if cursor:
47+
params["cursor"] = cursor
48+
if fields:
49+
params["fields"] = ",".join(fields)
50+
url = UNIPROT_BASE
51+
return self._safe_get(url, params=params).json()
52+
53+
@staticmethod
54+
def _safe_get(url: str, params: dict = None) -> requests.Response:
55+
r = requests.get(
56+
url,
57+
params=params,
58+
headers={"Accept": "application/json"},
59+
timeout=10,
60+
)
61+
if not r.ok:
62+
logger.error("Search engine error: %s", r.text)
63+
raise HTTPException(r.status_code, "Search engine error.")
64+
return r
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from dataclasses import dataclass
2+
3+
import requests
4+
from fastapi import HTTPException
5+
6+
from graphgen.utils import logger
7+
8+
BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
9+
BING_MKT = "en-US"
10+
11+
12+
@dataclass
13+
class BingSearch:
14+
"""
15+
Bing Search client to search with Bing.
16+
"""
17+
18+
subscription_key: str
19+
20+
def search(self, query: str, num_results: int = 1):
21+
"""
22+
Search with Bing and return the contexts.
23+
:param query: The search query.
24+
:param num_results: The number of results to return.
25+
:return: A list of search results.
26+
"""
27+
params = {"q": query, "mkt": BING_MKT, "count": num_results}
28+
response = requests.get(
29+
BING_SEARCH_V7_ENDPOINT,
30+
headers={"Ocp-Apim-Subscription-Key": self.subscription_key},
31+
params=params,
32+
timeout=10,
33+
)
34+
if not response.ok:
35+
logger.error("Search engine error: %s", response.text)
36+
raise HTTPException(response.status_code, "Search engine error.")
37+
json_content = response.json()
38+
try:
39+
contexts = json_content["webPages"]["value"][:num_results]
40+
except KeyError:
41+
logger.error("Error encountered: %s", json_content)
42+
return []
43+
return contexts

graphgen/operators/search/kg/search_google_kg.py

Whitespace-only changes.

graphgen/operators/search/search_all.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,31 @@ async def search_all(
4949
if description:
5050
results[entity_name] = results.get(entity_name, {})
5151
results[entity_name]["google"] = description
52+
elif search_type == "bing":
53+
from graphgen.models import BingSearch
54+
from graphgen.operators.search.web.search_bing import search_bing
5255

53-
# elif search_type == "bing":
54-
# from graphgen.operators.search.web.search_bing import search_bing
55-
# return await search_bing(llm_client, kg_instance)
56+
bing_search_client = BingSearch(
57+
subscription_key=os.environ["BING_SEARCH_API_KEY"]
58+
)
59+
60+
bing_results = await search_bing(bing_search_client, search_entities)
61+
for entity_name, description in bing_results.items():
62+
if description:
63+
results[entity_name] = results.get(entity_name, {})
64+
results[entity_name]["bing"] = description
65+
elif search_type == "uniprot":
66+
# from graphgen.models import UniProtSearch
67+
# from graphgen.operators.search.db.search_uniprot import search_uniprot
68+
#
69+
# uniprot_search_client = UniProtSearch()
70+
#
71+
# uniprot_results = await search_uniprot(
72+
# uniprot_search_client, search_entities
73+
# )
74+
raise NotImplementedError(
75+
"Processing of UniProt search results is not implemented yet."
76+
)
5677

5778
else:
5879
logger.error("Search type %s is not supported yet.", search_type)
Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,53 @@
1-
BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search"
2-
BING_MKT = "en-US"
1+
import trafilatura
2+
from tqdm.asyncio import tqdm_asyncio as tqdm_async
33

4+
from graphgen.models import BingSearch
5+
from graphgen.utils import logger
46

5-
async def search_bing():
7+
8+
async def _process_single_entity(
9+
entity_name: str, bing_search_client: BingSearch
10+
) -> str | None:
11+
"""
12+
Process single entity by searching Bing.
13+
:param entity_name: The name of the entity to search.
14+
:param bing_search_client: The Bing search client.
15+
:return: Summary of the entity or None if not found.
16+
"""
17+
search_results = bing_search_client.search(entity_name)
18+
if not search_results:
19+
return None
20+
21+
# Get more details from the first search result
22+
first_result = search_results[0]
23+
content = trafilatura.fetch_url(first_result["url"])
24+
summary = trafilatura.extract(content, include_comments=False, include_links=False)
25+
summary = summary.strip()
26+
logger.info(
27+
"Entity %s search result: %s",
28+
entity_name,
29+
summary,
30+
)
31+
return summary
32+
33+
34+
async def search_bing(
35+
bing_search_client: BingSearch,
36+
entities: set[str],
37+
) -> dict[str, str]:
638
"""
739
Search with Bing and return the contexts.
840
:return:
941
"""
10-
raise NotImplementedError("Bing search is not implemented yet.")
42+
bing_data = {}
43+
44+
async for entity in tqdm_async(
45+
entities, desc="Searching Bing", total=len(entities)
46+
):
47+
try:
48+
summary = await _process_single_entity(entity, bing_search_client)
49+
if summary:
50+
bing_data[entity] = summary
51+
except Exception as e: # pylint: disable=broad-except
52+
logger.error("Error processing entity %s: %s", entity, str(e))
53+
return bing_data

0 commit comments

Comments
 (0)