|
10 | 10 | from agent.graph import AgentGraph |
11 | 11 | from agent.profiles import ProfileName, get_chat_profiles |
12 | 12 | from agent.profiles.base import OutputState |
13 | | -from util.chainlit_helpers import (is_feature_enabled, message_rate_limited, |
14 | | - save_openai_metrics, static_messages, |
15 | | - update_search_results) |
| 13 | +from util.chainlit_helpers import (PrefixedS3StorageClient, is_feature_enabled, |
| 14 | + message_rate_limited, save_openai_metrics, |
| 15 | + static_messages, update_search_results) |
16 | 16 | from util.config_yml import Config, TriggerEvent |
17 | 17 | from util.logging import logging |
18 | 18 |
|
|
22 | 22 | profiles: list[ProfileName] = config.profiles if config else [ProfileName.React_to_Me] |
23 | 23 | llm_graph = AgentGraph(profiles) |
24 | 24 |
|
25 | | -if os.getenv("POSTGRES_CHAINLIT_DB"): |
26 | | - CHAINLIT_DB_URI = f"postgresql+psycopg://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}@postgres:5432/{os.getenv('POSTGRES_CHAINLIT_DB')}?sslmode=disable" |
| 25 | +POSTGRES_CHAINLIT_DB = os.getenv("POSTGRES_CHAINLIT_DB") |
| 26 | +POSTGRES_USER = os.getenv("POSTGRES_USER") |
| 27 | +POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") |
| 28 | +S3_BUCKET = os.getenv("S3_BUCKET") |
| 29 | +S3_CHAINLIT_PREFIX = os.getenv("S3_CHAINLIT_PREFIX") |
| 30 | + |
| 31 | +if POSTGRES_CHAINLIT_DB and POSTGRES_USER and POSTGRES_PASSWORD: |
| 32 | + CHAINLIT_DB_URI = f"postgresql+psycopg://{POSTGRES_USER}:{POSTGRES_PASSWORD}@postgres:5432/{POSTGRES_CHAINLIT_DB}?sslmode=disable" |
| 33 | + |
| 34 | + if S3_BUCKET and S3_CHAINLIT_PREFIX: |
| 35 | + storage_client = PrefixedS3StorageClient(S3_BUCKET, S3_CHAINLIT_PREFIX) |
| 36 | + else: |
| 37 | + storage_client = None |
27 | 38 |
|
28 | 39 | @cl.data_layer |
29 | 40 | def get_data_layer() -> BaseDataLayer: |
30 | | - return SQLAlchemyDataLayer(conninfo=CHAINLIT_DB_URI) |
| 41 | + return SQLAlchemyDataLayer( |
| 42 | + conninfo=CHAINLIT_DB_URI, |
| 43 | + storage_provider=storage_client, |
| 44 | + ) |
31 | 45 |
|
32 | 46 | else: |
33 | 47 | logging.warning("POSTGRES_CHAINLIT_DB undefined; Chainlit persistence disabled.") |
@@ -100,15 +114,16 @@ async def main(message: cl.Message) -> None: |
100 | 114 | thread_id=thread_id, |
101 | 115 | enable_postprocess=enable_postprocess, |
102 | 116 | ) |
| 117 | + assistant_message: cl.Message | None = chainlit_cb.final_stream |
103 | 118 |
|
104 | 119 | if ( |
105 | 120 | enable_postprocess |
106 | | - and chainlit_cb.final_stream |
| 121 | + and assistant_message |
107 | 122 | and len(result["additional_content"]["search_results"]) > 0 |
108 | 123 | ): |
109 | 124 | await update_search_results( |
110 | 125 | result["additional_content"]["search_results"], |
111 | | - chainlit_cb.final_stream, |
| 126 | + assistant_message, |
112 | 127 | ) |
113 | 128 |
|
114 | 129 | await static_messages(config, after_messages=message_count) |
|
0 commit comments