Skip to content

Commit 3bcf88a

Browse files
committed
Add S3StorageClient for elements persistence
1 parent 95bba31 commit 3bcf88a

3 files changed

Lines changed: 52 additions & 9 deletions

File tree

bin/chat-chainlit.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from agent.graph import AgentGraph
1111
from agent.profiles import ProfileName, get_chat_profiles
1212
from agent.profiles.base import OutputState
13-
from util.chainlit_helpers import (is_feature_enabled, message_rate_limited,
14-
save_openai_metrics, static_messages,
15-
update_search_results)
13+
from util.chainlit_helpers import (PrefixedS3StorageClient, is_feature_enabled,
14+
message_rate_limited, save_openai_metrics,
15+
static_messages, update_search_results)
1616
from util.config_yml import Config, TriggerEvent
1717
from util.logging import logging
1818

@@ -22,12 +22,26 @@
2222
profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me]
2323
llm_graph = AgentGraph(profiles)
2424

25-
if os.getenv("POSTGRES_CHAINLIT_DB"):
26-
CHAINLIT_DB_URI = f"postgresql+psycopg://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_CHAINLIT_DB')}?sslmode=disable"
25+
POSTGRES_CHAINLIT_DB = os.getenv("POSTGRES_CHAINLIT_DB")
26+
POSTGRES_USER = os.getenv("POSTGRES_USER")
27+
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD")
28+
S3_BUCKET = os.getenv("S3_BUCKET")
29+
S3_CHAINLIT_PREFIX = os.getenv("S3_CHAINLIT_PREFIX")
30+
31+
if POSTGRES_CHAINLIT_DB and POSTGRES_USER and POSTGRES_PASSWORD:
32+
CHAINLIT_DB_URI = f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@postgres:5432/{POSTGRES_CHAINLIT_DB}?sslmode=disable"
33+
34+
if S3_BUCKET and S3_CHAINLIT_PREFIX:
35+
storage_client = PrefixedS3StorageClient(S3_BUCKET, S3_CHAINLIT_PREFIX)
36+
else:
37+
storage_client = None
2738

2839
@cl.data_layer
2940
def get_data_layer() -> BaseDataLayer:
30-
return SQLAlchemyDataLayer(conninfo=CHAINLIT_DB_URI)
41+
return SQLAlchemyDataLayer(
42+
conninfo=CHAINLIT_DB_URI,
43+
storage_provider=storage_client,
44+
)
3145

3246
else:
3347
logging.warning("POSTGRES_CHAINLIT_DB undefined; Chainlit persistence disabled.")
@@ -100,15 +114,16 @@ async def main(message: cl.Message) -> None:
100114
thread_id=thread_id,
101115
enable_postprocess=enable_postprocess,
102116
)
117+
assistant_message: cl.Message | None = chainlit_cb.final_stream
103118

104119
if (
105120
enable_postprocess
106-
and chainlit_cb.final_stream
121+
and assistant_message
107122
and len(result["additional_content"]["search_results"]) > 0
108123
):
109124
await update_search_results(
110125
result["additional_content"]["search_results"],
111-
chainlit_cb.final_stream,
126+
assistant_message,
112127
)
113128

114129
await static_messages(config, after_messages=message_count)

docker-compose.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ services:
88
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD}
99
- POSTGRES_CHAINLIT_DB=${POSTGRES_CHAINLIT_DB}
1010
- POSTGRES_LANGGRAPH_DB=${POSTGRES_LANGGRAPH_DB}
11+
- S3_BUCKET=${S3_BUCKET}
12+
- S3_CHAINLIT_PREFIX=${S3_CHAINLIT_PREFIX}
1113
- LOG_LEVEL=${LOG_LEVEL}
1214
- UVICORN_LOG_LEVEL=${LOG_LEVEL}
1315
- CHAT_ENV=${CHAT_ENV}

src/util/chainlit_helpers.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
22
from datetime import datetime
3+
from pathlib import PurePosixPath
34
from typing import Any, Iterable
45

56
import chainlit as cl
67
from chainlit.data import get_data_layer
8+
from chainlit.data.storage_clients.s3 import S3StorageClient
79
from langchain_community.callbacks import OpenAICallbackHandler
810

911
from util.config_yml import Config, TriggerEvent
@@ -12,6 +14,30 @@
1214
guest_user_metadata: dict[str, Any] = {}
1315

1416

17+
class PrefixedS3StorageClient(S3StorageClient):
18+
def __init__(self, bucket: str, prefix: str, **kwargs: Any) -> None:
19+
super().__init__(bucket, **kwargs)
20+
self._prefix = PurePosixPath(prefix)
21+
22+
async def upload_file(
23+
self,
24+
object_key: str,
25+
data: bytes | str,
26+
mime: str = "application/octet-stream",
27+
overwrite: bool = True,
28+
) -> dict[str, Any]:
29+
object_key = str(self._prefix / object_key)
30+
return await super().upload_file(object_key, data, mime, overwrite)
31+
32+
async def delete_file(self, object_key: str) -> bool:
33+
object_key = str(self._prefix / object_key)
34+
return await super().delete_file(object_key)
35+
36+
async def get_read_url(self, object_key: str) -> str:
37+
object_key = str(self._prefix / object_key)
38+
return await super().get_read_url(object_key)
39+
40+
1541
def get_user_id() -> str | None:
1642
user: cl.User | None = cl.user_session.get("user")
1743
return user.identifier if user else None
@@ -136,7 +162,7 @@ async def update_search_results(
136162
name="SearchResults",
137163
props={"results": search_results},
138164
)
139-
message.elements = [search_results_element] # type: ignore
165+
message.elements.append(search_results_element) # type: ignore[arg-type]
140166
await message.update()
141167

142168

0 commit comments

Comments
 (0)