Skip to content

Commit 4bdc9df

Browse files
author
Adam Wright
authored
Merge pull request #77 from reactome/refactor-codebase
Reorganize the codebase
2 parents 7a99e2b + 3c2d6ed commit 4bdc9df

32 files changed

Lines changed: 486 additions & 458 deletions

.config.schema.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ properties:
5151
- required: ["event"]
5252
- required: ["after_messages"]
5353
required: ["message", "trigger"]
54+
profiles:
55+
type: array
56+
items:
57+
type: string
58+
enum: ["React-to-Me"]
5459
usage_limits:
5560
type: object
5661
properties:
@@ -73,4 +78,4 @@ properties:
7378
pattern: "^[0-9]+[smhdw]$"
7479
required: ["users", "max_messages", "interval"]
7580
required: ["message_rates"]
76-
required: ["features", "messages", "usage_limits"]
81+
required: ["features", "messages", "profiles", "usage_limits"]

.github/actions/verify_imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"chat-chainlit.py",
1212
"chat-fastapi.py",
1313
"embeddings_manager",
14+
"export_nologin_usage.py",
1415
"export_records.py",
1516
],
1617
)

bin/chat-chainlit.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
#!/usr/bin/env python
2-
31
import os
42
from typing import Any
53

@@ -10,28 +8,19 @@
108
from dotenv import load_dotenv
119
from langchain_community.callbacks import OpenAICallbackHandler
1210

13-
from conversational_chain.graph import RAGGraphWithMemory
14-
from retreival_chain import create_retrieval_chain
11+
from agent.graph import AgentGraph
12+
from agent.profiles import ProfileName, get_chat_profiles
1513
from util.chainlit_helpers import (is_feature_enabled, message_rate_limited,
1614
save_openai_metrics, static_messages,
1715
update_search_results)
1816
from util.config_yml import Config, TriggerEvent
19-
from util.embedding_environment import EmbeddingEnvironment
2017
from util.logging import logging
2118

2219
load_dotenv()
2320
config: Config | None = Config.from_yaml()
2421

25-
26-
ENV = os.getenv("CHAT_ENV", "reactome")
27-
logging.info(f"Selected environment: {ENV}")
28-
29-
llm_graph: RAGGraphWithMemory = create_retrieval_chain(
30-
ENV,
31-
EmbeddingEnvironment.get_dir(ENV),
32-
hf_model=EmbeddingEnvironment.get_model(ENV),
33-
oai_model=os.getenv("OPENAI_MODEL", "gpt-4o-mini"),
34-
)
22+
profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me]
23+
llm_graph = AgentGraph(profiles)
3524

3625
if os.getenv("POSTGRES_CHAINLIT_DB"):
3726
CHAINLIT_DB_URI = f"postgresql+psycopg://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_CHAINLIT_DB')}?sslmode=disable"
@@ -56,12 +45,13 @@ def oauth_callback(
5645

5746

5847
@cl.set_chat_profiles
59-
async def chat_profile() -> list[cl.ChatProfile]:
48+
async def chat_profiles() -> list[cl.ChatProfile]:
6049
return [
6150
cl.ChatProfile(
62-
name="React-to-me",
63-
markdown_description="An AI assistant specialized in exploring **Reactome** biological pathways and processes.",
51+
name=profile.name,
52+
markdown_description=profile.description,
6453
)
54+
for profile in get_chat_profiles(profiles)
6555
]
6656

6757

@@ -92,6 +82,8 @@ async def main(message: cl.Message) -> None:
9282
message_count: int = cl.user_session.get("message_count", 0) + 1
9383
cl.user_session.set("message_count", message_count)
9484

85+
chat_profile: str = cl.user_session.get("chat_profile")
86+
9587
thread_id: str = cl.user_session.get("thread_id")
9688

9789
chainlit_cb = cl.AsyncLangchainCallbackHandler(
@@ -103,6 +95,7 @@ async def main(message: cl.Message) -> None:
10395
enable_postprocess: bool = is_feature_enabled(config, "postprocessing")
10496
result: dict[str, Any] = await llm_graph.ainvoke(
10597
message.content,
98+
chat_profile.lower(),
10699
callbacks=[chainlit_cb, openai_cb],
107100
thread_id=thread_id,
108101
enable_postprocess=enable_postprocess,

bin/chat-repl

Lines changed: 0 additions & 91 deletions
This file was deleted.

bin/embeddings_manager

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ import boto3
1212
from botocore import UNSIGNED
1313
from botocore.client import Config
1414

15-
from embeddings.alliance_generator import generate_alliance_embeddings
16-
from embeddings.reactome_generator import generate_reactome_embeddings
15+
from data_generation.alliance import generate_alliance_embeddings
16+
from data_generation.reactome import generate_reactome_embeddings
1717
from util.embedding_environment import EM_ARCHIVE, EmbeddingEnvironment
1818

1919
S3_BUCKET = "download.reactome.org"

config_default.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# yaml-language-server: $schema=./.config.schema.yaml
22

3+
profiles:
4+
- React-to-Me
5+
36
features:
47
postprocessing: # external web search feature
58
enabled: true

src/__init__.py

Whitespace-only changes.

src/agent/graph.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import asyncio
2+
import os
3+
from typing import Any
4+
5+
from langchain_core.callbacks.base import Callbacks
6+
from langchain_core.embeddings import Embeddings
7+
from langchain_core.language_models.chat_models import BaseChatModel
8+
from langchain_core.runnables import RunnableConfig
9+
from langgraph.checkpoint.base import BaseCheckpointSaver
10+
from langgraph.checkpoint.memory import MemorySaver
11+
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
12+
from langgraph.graph.state import CompiledStateGraph, StateGraph
13+
from psycopg import AsyncConnection
14+
from psycopg_pool import AsyncConnectionPool
15+
16+
from agent.models import get_embedding, get_llm
17+
from agent.profiles import ProfileName, create_profile_graphs
18+
from util.logging import logging
19+
20+
LANGGRAPH_DB_URI = f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_LANGGRAPH_DB')}?sslmode=disable"
21+
22+
if not os.getenv("POSTGRES_LANGGRAPH_DB"):
23+
logging.warning("POSTGRES_LANGGRAPH_DB undefined; falling back to MemorySaver.")
24+
25+
26+
class AgentGraph:
27+
def __init__(
28+
self,
29+
profiles: list[ProfileName],
30+
) -> None:
31+
# Get base models
32+
llm: BaseChatModel = get_llm("openai", "gpt-4o-mini")
33+
embedding: Embeddings = get_embedding("openai", "text-embedding-3-large")
34+
35+
self.uncompiled_graph: dict[str, StateGraph] = create_profile_graphs(
36+
profiles, llm, embedding
37+
)
38+
39+
# The following are set asynchronously by calling initialize()
40+
self.graph: dict[str, CompiledStateGraph] | None = None
41+
self.pool: AsyncConnectionPool[AsyncConnection[dict[str, Any]]] | None = None
42+
43+
def __del__(self) -> None:
44+
if self.pool:
45+
asyncio.run(self.close_pool())
46+
47+
async def initialize(self) -> dict[str, CompiledStateGraph]:
48+
checkpointer: BaseCheckpointSaver[str] = await self.create_checkpointer()
49+
return {
50+
profile: graph.compile(checkpointer=checkpointer)
51+
for profile, graph in self.uncompiled_graph.items()
52+
}
53+
54+
async def create_checkpointer(self) -> BaseCheckpointSaver[str]:
55+
if not os.getenv("POSTGRES_LANGGRAPH_DB"):
56+
return MemorySaver()
57+
self.pool = AsyncConnectionPool(
58+
conninfo=LANGGRAPH_DB_URI,
59+
max_size=20,
60+
open=False,
61+
timeout=30,
62+
kwargs={
63+
"autocommit": True,
64+
"prepare_threshold": 0,
65+
},
66+
)
67+
await self.pool.open()
68+
checkpointer = AsyncPostgresSaver(self.pool)
69+
await checkpointer.setup()
70+
return checkpointer
71+
72+
async def close_pool(self) -> None:
73+
if self.pool:
74+
await self.pool.close()
75+
76+
async def ainvoke(
77+
self,
78+
user_input: str,
79+
profile: str,
80+
*,
81+
callbacks: Callbacks,
82+
thread_id: str,
83+
enable_postprocess: bool = True,
84+
) -> dict[str, Any]:
85+
if self.graph is None:
86+
self.graph = await self.initialize()
87+
if profile not in self.graph:
88+
return {}
89+
result: dict[str, Any] = await self.graph[profile].ainvoke(
90+
{"user_input": user_input},
91+
config=RunnableConfig(
92+
callbacks=callbacks,
93+
configurable={
94+
"thread_id": thread_id,
95+
"enable_postprocess": enable_postprocess,
96+
},
97+
),
98+
)
99+
return result

src/agent/models.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from typing import Literal
2+
3+
from langchain_core.embeddings import Embeddings
4+
from langchain_core.language_models.chat_models import BaseChatModel
5+
from langchain_huggingface import (HuggingFaceEmbeddings,
6+
HuggingFaceEndpointEmbeddings)
7+
from langchain_ollama.chat_models import ChatOllama
8+
from langchain_openai.chat_models.base import ChatOpenAI
9+
from langchain_openai.embeddings import OpenAIEmbeddings
10+
11+
12+
def get_embedding(
13+
provider: (
14+
Literal[
15+
"openai",
16+
"huggingfacehub",
17+
"huggingfacelocal",
18+
]
19+
| str
20+
),
21+
model: str | None = None,
22+
*,
23+
device: str | None = "cpu",
24+
) -> Embeddings:
25+
if model is None:
26+
provider, model = provider.split("/", 1)
27+
if provider == "openai":
28+
return OpenAIEmbeddings(model=model)
29+
elif provider == "huggingfacehub":
30+
return HuggingFaceEndpointEmbeddings(model=model)
31+
elif provider == "huggingfacelocal":
32+
return HuggingFaceEmbeddings(
33+
model_name=model,
34+
model_kwargs={"device": device, "trust_remote_code": True},
35+
encode_kwargs={"batch_size": 12, "normalize_embeddings": False},
36+
)
37+
else:
38+
raise ValueError(f"Unknown provider: {provider}")
39+
40+
41+
def get_llm(
42+
provider: (
43+
Literal[
44+
"openai",
45+
"ollama",
46+
]
47+
| str
48+
),
49+
model: str | None = None,
50+
*,
51+
base_url: str | None = None,
52+
) -> BaseChatModel:
53+
if model is None:
54+
provider, model = provider.split("/", 1)
55+
if provider == "openai":
56+
return ChatOpenAI(
57+
model=model,
58+
temperature=0.0,
59+
base_url=base_url,
60+
)
61+
elif provider == "ollama":
62+
return ChatOllama(
63+
model=model,
64+
temperature=0.0,
65+
base_url=base_url,
66+
)
67+
else:
68+
raise ValueError(f"Unknown provider: {provider}")

0 commit comments

Comments
 (0)