Skip to content

Commit bf34543

Browse files
author
Adam Wright
authored
Merge pull request #78 from reactome/integrate-uniprot
Integrate uniprot
2 parents 4bdc9df + 8c28256 commit bf34543

33 files changed

Lines changed: 1175 additions & 176 deletions

.config.schema.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ properties:
5555
type: array
5656
items:
5757
type: string
58-
enum: ["React-to-Me"]
58+
enum: ["React-to-Me", "Cross-Database Prototype"]
5959
usage_limits:
6060
type: object
6161
properties:

bin/chat-chainlit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
from typing import Any
32

43
import chainlit as cl
54
from chainlit.data.base import BaseDataLayer
@@ -10,6 +9,7 @@
109

1110
from agent.graph import AgentGraph
1211
from agent.profiles import ProfileName, get_chat_profiles
12+
from agent.profiles.base import OutputState
1313
from util.chainlit_helpers import (is_feature_enabled, message_rate_limited,
1414
save_openai_metrics, static_messages,
1515
update_search_results)
@@ -93,7 +93,7 @@ async def main(message: cl.Message) -> None:
9393
openai_cb = OpenAICallbackHandler()
9494

9595
enable_postprocess: bool = is_feature_enabled(config, "postprocessing")
96-
result: dict[str, Any] = await llm_graph.ainvoke(
96+
result: OutputState = await llm_graph.ainvoke(
9797
message.content,
9898
chat_profile.lower(),
9999
callbacks=[chainlit_cb, openai_cb],

bin/embeddings_manager

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ from botocore.client import Config
1414

1515
from data_generation.alliance import generate_alliance_embeddings
1616
from data_generation.reactome import generate_reactome_embeddings
17+
from data_generation.uniprot import generate_uniprot_embeddings
1718
from util.embedding_environment import EM_ARCHIVE, EmbeddingEnvironment
1819

1920
S3_BUCKET = "download.reactome.org"
@@ -86,6 +87,8 @@ def make(
8687
os.environ["HUGGINGFACEHUB_API_TOKEN"] = hf_key
8788
if embedding.db == "reactome":
8889
generate_reactome_embeddings(str(embedding_path), hf_model=embedding.model, **kwargs)
90+
elif embedding.db == "uniprot":
91+
generate_uniprot_embeddings(embedding_path, hf_model=embedding.model, **kwargs)
8992
elif embedding.db == "alliance":
9093
generate_alliance_embeddings(str(embedding_path), hf_model=embedding.model, **kwargs)
9194
else:

mypy.ini

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ ignore_missing_imports = True
33
allow_untyped_calls = True
44
allow_untyped_defs = True
55
allow_untyped_globals = True
6+
explicit_package_bases = True
67
exclude = data/
8+
files = bin/,src/
79

810
[mypy.plugins.pandas.*]
911
init_forbid_dynamic = False

poetry.lock

Lines changed: 26 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ psycopg = {extras = ["binary"], version = "^3.2.3"}
4545
pydantic = "^2.10.5"
4646
pyyaml = "^6.0.2"
4747
tavily-python = "^0.5.0"
48+
openpyxl = "^3.1.5"
4849

4950
[tool.poetry.group.dev.dependencies]
5051
ruff = "^0.7.1"

src/agent/graph.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from agent.models import get_embedding, get_llm
1717
from agent.profiles import ProfileName, create_profile_graphs
18+
from agent.profiles.base import InputState, OutputState
1819
from util.logging import logging
1920

2021
LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable"
@@ -81,13 +82,13 @@ async def ainvoke(
8182
callbacks: Callbacks,
8283
thread_id: str,
8384
enable_postprocess: bool = True,
84-
) -> dict[str, Any]:
85+
) -> OutputState:
8586
if self.graph is None:
8687
self.graph = await self.initialize()
8788
if profile not in self.graph:
88-
return {}
89-
result: dict[str, Any] = await self.graph[profile].ainvoke(
90-
{"user_input": user_input},
89+
return OutputState()
90+
result: OutputState = await self.graph[profile].ainvoke(
91+
InputState(user_input=user_input),
9192
config=RunnableConfig(
9293
callbacks=callbacks,
9394
configurable={

src/agent/profiles/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
from langchain_core.language_models.chat_models import BaseChatModel
66
from langgraph.graph.state import StateGraph
77

8-
from agent.profiles.react_to_me import create_reacttome_graph
8+
from agent.profiles.cross_database import create_cross_database_graph
9+
from agent.profiles.react_to_me import create_reactome_graph
910

1011

1112
class ProfileName(StrEnum):
1213
# These should exactly match names in .config.schema.yaml
1314
React_to_Me = "React-to-Me"
15+
Cross_Database_Prototype = "Cross-Database Prototype"
1416

1517

1618
class Profile(NamedTuple):
@@ -23,7 +25,12 @@ class Profile(NamedTuple):
2325
ProfileName.React_to_Me.lower(): Profile(
2426
name=ProfileName.React_to_Me,
2527
description="An AI assistant specialized in exploring **Reactome** biological pathways and processes.",
26-
graph_builder=create_reacttome_graph,
28+
graph_builder=create_reactome_graph,
29+
),
30+
ProfileName.Cross_Database_Prototype.lower(): Profile(
31+
name=ProfileName.Cross_Database_Prototype,
32+
description="Early version of an AI assistant with knowledge from multiple bio-databases (**Reactome** + **Uniprot**).",
33+
graph_builder=create_cross_database_graph,
2734
),
2835
}
2936

src/agent/profiles/base.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,49 @@
11
from typing import Annotated, TypedDict
22

3-
from langchain_core.documents import Document
3+
from langchain_core.embeddings import Embeddings
4+
from langchain_core.language_models.chat_models import BaseChatModel
45
from langchain_core.messages import BaseMessage
6+
from langchain_core.runnables import Runnable, RunnableConfig
57
from langgraph.graph.message import add_messages
68

9+
from agent.tasks.rephrase import create_rephrase_chain
710
from tools.external_search.state import WebSearchResult
811

912

10-
class AdditionalContent(TypedDict):
13+
class AdditionalContent(TypedDict, total=False):
1114
search_results: list[WebSearchResult]
1215

1316

14-
class BaseState(TypedDict):
15-
# (Everything the Chainlit layer uses should be included here)
16-
17+
class InputState(TypedDict, total=False):
1718
user_input: str # User input text
18-
chat_history: Annotated[list[BaseMessage], add_messages]
19-
context: list[Document]
19+
20+
21+
class OutputState(TypedDict, total=False):
2022
answer: str # primary LLM response that is streamed to the user
2123
additional_content: AdditionalContent # sends on graph completion
2224

2325

26+
class BaseState(InputState, OutputState, total=False):
27+
rephrased_input: str # LLM-generated query from user input
28+
chat_history: Annotated[list[BaseMessage], add_messages]
29+
30+
2431
class BaseGraphBuilder:
25-
pass # NOTE: Anything that is common to all graph builders goes here
32+
# NOTE: Anything that is common to all graph builders goes here
33+
34+
def __init__(
35+
self,
36+
llm: BaseChatModel,
37+
embedding: Embeddings,
38+
) -> None:
39+
self.rephrase_chain: Runnable = create_rephrase_chain(llm)
40+
41+
async def preprocess(self, state: BaseState, config: RunnableConfig) -> BaseState:
42+
rephrased_input: str = await self.rephrase_chain.ainvoke(
43+
{
44+
"user_input": state["user_input"],
45+
"chat_history": state["chat_history"],
46+
},
47+
config,
48+
)
49+
return BaseState(rephrased_input=rephrased_input)

0 commit comments

Comments
 (0)