-
Notifications
You must be signed in to change notification settings - Fork 95
Expand file tree
/
Copy pathbanking_agents_api.py
More file actions
634 lines (512 loc) · 24.7 KB
/
Copy pathbanking_agents_api.py
File metadata and controls
634 lines (512 loc) · 24.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
import os
import uuid
import fastapi
from dotenv import load_dotenv
from datetime import datetime
from fastapi import BackgroundTasks
#from azure.monitor.opentelemetry import configure_azure_monitor
from azure.cosmos.exceptions import CosmosHttpResponseError
from fastapi import Depends, HTTPException, Body
from langchain_core.messages import HumanMessage, ToolMessage
from pydantic import BaseModel
from typing import List, Dict
from src.app.services.azure_open_ai import model
# from src.app.services.local_model import model # Use local model
from langgraph_checkpoint_cosmosdb import CosmosDBSaver
from langgraph.graph.state import CompiledStateGraph
from starlette.middleware.cors import CORSMiddleware
from src.app.banking_agents import graph, checkpointer
from src.app.services.azure_cosmos_db import update_chat_container, patch_active_agent, \
fetch_chat_container_by_tenant_and_user, \
fetch_chat_container_by_session, delete_userdata_item, debug_container, update_users_container, \
update_account_container, update_offers_container, store_chat_history, update_active_agent_in_latest_message, \
chat_container, fetch_chat_history_by_session, delete_chat_history_by_session
import logging
# Setup logging
logging.basicConfig(level=logging.ERROR)
load_dotenv(override=False)
#configure_azure_monitor()
endpointTitle = "ChatEndpoints"
dataLoadTitle = "DataLoadEndpoints"
# Mapping for agent function names to standardized names
agent_mapping = {
"coordinator_agent": "Coordinator",
"customer_support_agent": "CustomerSupport",
"transactions_agent": "Transactions",
"sales_agent": "Sales"
}
def get_compiled_graph():
return graph
app = fastapi.FastAPI(title="Cosmos DB Multi-Agent Banking API", openapi_url="/cosmos-multi-agent-api.json")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class DebugLog(BaseModel):
id: str
sessionId: str
tenantId: str
userId: str
details: str
class Session(BaseModel):
id: str
type: str = "session"
sessionId: str
tenantId: str
userId: str
tokensUsed: int = 0
name: str
messages: List
class MessageModel(BaseModel):
id: str
type: str
sessionId: str
tenantId: str
userId: str
timeStamp: str
sender: str
senderRole: str
text: str
debugLogId: str
tokensUsed: int
rating: bool
completionPromptId: str
class DebugLog(BaseModel):
id: str
messageId: str
type: str
sessionId: str
tenantId: str
userId: str
timeStamp: str
propertyBag: list
def store_debug_log(sessionId, tenantId, userId, response_data):
"""Stores detailed debug log information in Cosmos DB."""
debug_log_id = str(uuid.uuid4())
message_id = str(uuid.uuid4())
timestamp = datetime.utcnow().isoformat()
# Extract relevant debug details
agent_selected = "Unknown"
previous_agent = "Unknown"
finish_reason = "Unknown"
model_name = "Unknown"
system_fingerprint = "Unknown"
input_tokens = 0
output_tokens = 0
total_tokens = 0
cached_tokens = 0
transfer_success = False
tool_calls = []
logprobs = None
content_filter_results = {}
for entry in response_data:
for agent, details in entry.items():
if "messages" in details:
for msg in details["messages"]:
if hasattr(msg, "response_metadata"):
metadata = getattr(msg, "response_metadata", None)
if metadata:
finish_reason = metadata.get("finish_reason", finish_reason)
model_name = metadata.get("model_name", model_name)
system_fingerprint = metadata.get("system_fingerprint", system_fingerprint)
token_usage = metadata.get("token_usage", {}) or {}
input_tokens = token_usage.get("prompt_tokens", input_tokens)
output_tokens = token_usage.get("completion_tokens", output_tokens)
total_tokens = token_usage.get("total_tokens", total_tokens)
prompt_details = token_usage.get("prompt_tokens_details", {}) or {}
cached_tokens = prompt_details.get("cached_tokens", cached_tokens)
logprobs = metadata.get("logprobs", logprobs)
content_filter_results = metadata.get("content_filter_results", content_filter_results)
if "tool_calls" in msg.additional_kwargs:
tool_calls.extend(msg.additional_kwargs["tool_calls"])
transfer_success = any(
call.get("name", "").startswith("transfer_to_") for call in tool_calls)
previous_agent = agent_selected
agent_selected = tool_calls[-1].get("name", "").replace("transfer_to_", "") if tool_calls else agent_selected
property_bag = [
{"key": "agent_selected", "value": agent_selected, "timeStamp": timestamp},
{"key": "previous_agent", "value": previous_agent, "timeStamp": timestamp},
{"key": "finish_reason", "value": finish_reason, "timeStamp": timestamp},
{"key": "model_name", "value": model_name, "timeStamp": timestamp},
{"key": "system_fingerprint", "value": system_fingerprint, "timeStamp": timestamp},
{"key": "input_tokens", "value": input_tokens, "timeStamp": timestamp},
{"key": "output_tokens", "value": output_tokens, "timeStamp": timestamp},
{"key": "total_tokens", "value": total_tokens, "timeStamp": timestamp},
{"key": "cached_tokens", "value": cached_tokens, "timeStamp": timestamp},
{"key": "transfer_success", "value": transfer_success, "timeStamp": timestamp},
{"key": "tool_calls", "value": str(tool_calls), "timeStamp": timestamp},
{"key": "logprobs", "value": str(logprobs), "timeStamp": timestamp},
{"key": "content_filter_results", "value": str(content_filter_results), "timeStamp": timestamp}
]
debug_entry = {
"id": debug_log_id,
"messageId": message_id,
"type": "debug_log",
"sessionId": sessionId,
"tenantId": tenantId,
"userId": userId,
"timeStamp": timestamp,
"propertyBag": property_bag
}
debug_container.create_item(debug_entry)
return debug_log_id
def create_thread(tenantId: str, userId: str):
sessionId = str(uuid.uuid4())
name = userId
age = 30
address = "123 Main St"
activeAgent = "unknown"
ChatName = "New Chat"
messages = []
update_chat_container({
"id": sessionId,
"tenantId": tenantId,
"userId": userId,
"sessionId": sessionId,
"name": name,
"age": age,
"address": address,
"activeAgent": activeAgent,
"ChatName": ChatName,
"messages": messages
})
return Session(id=sessionId, sessionId=sessionId, tenantId=tenantId, userId=userId, name=name, age=age,
address=address, activeAgent=activeAgent, ChatName=ChatName, messages=messages)
@app.get("/status", tags=[endpointTitle], description="Gets the service status", operation_id="GetServiceStatus",
response_description="Success",
response_model=str)
def get_service_status():
return "CosmosDBService: initializing"
# Note: cosmos db checkpointer store is used internally by LangGraph for "memory": to maintain end-to-end state of each
# conversation thread as contextual input to the OpenAI model.
# However, this function is dead code, as we no longer retrieve chat history from the cosmos db checkpointer store to return in the API.
# Abandoned this approach as the checkpointer store does not natively keep a record of which agent responded to the last message.
# Also, retrieving messages from the checkpointer store is not efficient as it requires scanning more records than necessary for chat history.
# Instead, we are now storing chat history in a separate custom cosmos db session container. Keeping this code for reference.
def _fetch_messages_for_session(sessionId: str, tenantId: str, userId: str) -> List[MessageModel]:
messages = []
config = {
"configurable": {
"thread_id": sessionId,
"checkpoint_ns": ""
}
}
logging.debug(f"Fetching messages for sessionId: {sessionId} with config: {config}")
checkpoints = list(checkpointer.list(config))
logging.debug(f"Number of checkpoints retrieved: {len(checkpoints)}")
if checkpoints:
last_checkpoint = checkpoints[-1]
for key, value in last_checkpoint.checkpoint.items():
if key == "channel_values" and "messages" in value:
messages.extend(value["messages"])
selected_human_index = None
for i in range(len(messages) - 1):
if isinstance(messages[i], HumanMessage) and not isinstance(messages[i + 1], HumanMessage):
selected_human_index = i
break
messages = messages[selected_human_index:] if selected_human_index is not None else []
return [
MessageModel(
id=str(uuid.uuid4()),
type="ai_response",
sessionId=sessionId,
tenantId=tenantId,
userId=userId,
timeStamp=msg.response_metadata.get("timestamp", "") if hasattr(msg, "response_metadata") else "",
sender="User" if isinstance(msg, HumanMessage) else "Coordinator",
senderRole="User" if isinstance(msg, HumanMessage) else "Assistant",
text=msg.content if hasattr(msg, "content") else msg.get("content", ""),
debugLogId=str(uuid.uuid4()),
tokensUsed=msg.response_metadata.get("token_usage", {}).get("total_tokens", 0) if hasattr(msg,
"response_metadata") else 0,
rating=True,
completionPromptId=""
)
for msg in messages
if msg.content
]
@app.get("/tenant/{tenantId}/user/{userId}/sessions",
description="Retrieves sessions from the given tenantId and userId", tags=[endpointTitle],
response_model=List[Session])
def get_chat_sessions(tenantId: str, userId: str):
items = fetch_chat_container_by_tenant_and_user(tenantId, userId)
sessions = []
for item in items:
sessionId = item["sessionId"]
messages = fetch_chat_history_by_session(sessionId)
session = {
"id": sessionId,
"type": "Session",
"sessionId": sessionId,
"tenantId": item["tenantId"],
"userId": item["userId"],
"tokensUsed": item.get("tokensUsed", 0),
"name": item.get("ChatName", "New Chat"),
"messages": messages
}
sessions.append(session)
return sessions
@app.get("/tenant/{tenantId}/user/{userId}/sessions/{sessionId}/messages",
description="Retrieves messages from the sessionId", tags=[endpointTitle], response_model=List[MessageModel])
def get_chat_session(tenantId: str, userId: str, sessionId: str):
return fetch_chat_history_by_session(sessionId)
# to be implemented
@app.post("/tenant/{tenantId}/user/{userId}/sessions/{sessionId}/message/{messageId}/rate",
description="Not yet implemented", tags=[endpointTitle],
operation_id="RateMessage", response_description="Success", response_model=MessageModel)
def rate_message(tenantId: str, userId: str, sessionId: str, messageId: str, rating: bool):
return {
"id": messageId,
"type": "ai_response",
"sessionId": sessionId,
"tenantId": tenantId,
"userId": userId,
"timeStamp": "2023-01-01T00:00:00Z",
"sender": "assistant",
"senderRole": "agent",
"text": "This is a rated message",
"debugLogId": str(uuid.uuid4()),
"tokensUsed": 0,
"rating": rating,
"completionPromptId": ""
}
@app.get("/tenant/{tenantId}/user/{userId}/sessions/{sessionId}/completiondetails/{debuglogId}",
description="Retrieves debug information for chat completions", tags=[endpointTitle],
operation_id="GetChatCompletionDetails", response_model=DebugLog)
def get_chat_completion_details(tenantId: str, userId: str, sessionId: str, debuglogId: str):
try:
debug_log = debug_container.read_item(item=debuglogId, partition_key=sessionId)
return debug_log
except Exception:
raise HTTPException(status_code=404, detail="Debug log not found")
# create a post function that renames the ChatName in the user data container
@app.post("/tenant/{tenantId}/user/{userId}/sessions/{sessionId}/rename", description="Renames the chat session",
tags=[endpointTitle], response_model=Session)
def rename_chat_session(tenantId: str, userId: str, sessionId: str, newChatSessionName: str):
items = fetch_chat_container_by_session(tenantId, userId, sessionId)
if not items:
raise HTTPException(status_code=404, detail="Session not found")
item = items[0]
item["ChatName"] = newChatSessionName
update_chat_container(item)
return Session(id=item["sessionId"], sessionId=item["sessionId"], tenantId=item["tenantId"], userId=item["userId"],
name=item["ChatName"], age=item["age"],
address=item["address"], activeAgent=item["activeAgent"], ChatName=newChatSessionName,
messages=item["messages"])
def delete_all_thread_records(cosmos_saver: CosmosDBSaver, thread_id: str) -> None:
"""
Deletes all records related to a given thread in CosmosDB by first identifying all partition keys
and then deleting every record under each partition key.
"""
# Step 1: Identify all partition keys related to the thread
query = "SELECT DISTINCT c.partition_key FROM c WHERE CONTAINS(c.partition_key, @thread_id)"
parameters = [{"name": "@thread_id", "value": thread_id}]
partition_keys = list(cosmos_saver.container.query_items(
query=query, parameters=parameters, enable_cross_partition_query=True
))
if not partition_keys:
print(f"No records found for thread: {thread_id}")
return
print(f"Found {len(partition_keys)} partition keys related to the thread.")
# Step 2: Delete all records under each partition key
for partition in partition_keys:
partition_key = partition["partition_key"]
# Query all records under the current partition
record_query = "SELECT c.id FROM c WHERE c.partition_key=@partition_key"
record_parameters = [{"name": "@partition_key", "value": partition_key}]
records = list(cosmos_saver.container.query_items(
query=record_query, parameters=record_parameters, enable_cross_partition_query=True
))
for record in records:
record_id = record["id"]
try:
cosmos_saver.container.delete_item(record_id, partition_key=partition_key)
print(f"Deleted record: {record_id} from partition: {partition_key}")
except CosmosHttpResponseError as e:
print(f"Error deleting record {record_id} (HTTP {e.status_code}): {e.message}")
print(f"Successfully deleted all records for thread: {thread_id}")
# deletes the session user data container and all messages in the checkpointer store
@app.delete("/tenant/{tenantId}/user/{userId}/sessions/{sessionId}", tags=[endpointTitle], )
def delete_chat_session(tenantId: str, userId: str, sessionId: str, background_tasks: BackgroundTasks):
delete_userdata_item(tenantId, userId, sessionId)
# Delete all messages in the checkpointer store
config = {
"configurable": {
"thread_id": sessionId,
"checkpoint_ns": "" # Ensure this matches the stored data
}
}
delete_chat_history_by_session(sessionId)
# Schedule the delete_all_thread_records function as a background task
background_tasks.add_task(delete_all_thread_records, checkpointer, sessionId)
return {"message": "Session deleted successfully"}
@app.post("/tenant/{tenantId}/user/{userId}/sessions", tags=[endpointTitle], response_model=Session)
def create_chat_session(tenantId: str, userId: str):
return create_thread(tenantId, userId)
def extract_relevant_messages(debug_lod_id, last_active_agent, response_data, tenantId, userId, sessionId):
# Convert last_active_agent to its mapped value
last_active_agent = agent_mapping.get(last_active_agent, last_active_agent)
debug_lod_id = debug_lod_id
if not response_data:
return []
last_agent_node = None
last_agent_name = "unknown"
for i in range(len(response_data) - 1, -1, -1):
if "__interrupt__" in response_data[i]:
if i > 0:
last_agent_node = response_data[i - 1]
last_agent_name = list(last_agent_node.keys())[0]
break
# storing the last active agent in the session container so that we can retrieve it later
# and deterministically route the incoming message directly to the agent that asked the question.
patch_active_agent(tenantId, userId, sessionId, last_agent_name)
if not last_agent_node:
return []
messages = []
for key, value in last_agent_node.items():
if isinstance(value, dict) and "messages" in value:
messages.extend(value["messages"])
last_user_index = -1
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], HumanMessage):
last_user_index = i
break
if last_user_index == -1:
return []
filtered_messages = [msg for msg in messages[last_user_index:] if not isinstance(msg, ToolMessage)]
return [
MessageModel(
id=str(uuid.uuid4()),
type="ai_response",
sessionId=sessionId,
tenantId=tenantId,
userId=userId,
timeStamp=msg.response_metadata.get("timestamp", "") if hasattr(msg, "response_metadata") else "",
sender="User" if isinstance(msg, HumanMessage) else last_active_agent,
senderRole="User" if isinstance(msg, HumanMessage) else "Assistant",
text=msg.content if hasattr(msg, "content") else msg.get("content", ""),
debugLogId=debug_lod_id,
tokensUsed=msg.response_metadata.get("token_usage", {}).get("total_tokens", 0) if hasattr(msg,
"response_metadata") else 0,
rating=True,
completionPromptId=""
)
for msg in filtered_messages
if msg.content
]
def process_messages(messages, userId, tenantId, sessionId):
for message in messages:
item = {
"id": message.id,
"type": message.type,
"sessionId": message.sessionId,
"tenantId": message.tenantId,
"userId": message.userId,
"timeStamp": message.timeStamp,
"sender": message.sender,
"senderRole": message.senderRole,
"text": message.text,
"debugLogId": message.debugLogId,
"tokensUsed": message.tokensUsed,
"rating": message.rating,
"completionPromptId": message.completionPromptId
}
store_chat_history(item)
partition_key = [tenantId, userId, sessionId]
# Get the active agent from Cosmos DB with a point lookup
activeAgent = chat_container.read_item(item=sessionId, partition_key=partition_key).get('activeAgent', 'unknown')
last_active_agent = agent_mapping.get(activeAgent, activeAgent)
update_active_agent_in_latest_message(sessionId, last_active_agent)
@app.post("/tenant/{tenantId}/user/{userId}/sessions/{sessionId}/completion", tags=[endpointTitle],
response_model=List[MessageModel])
async def get_chat_completion(
tenantId: str,
userId: str,
sessionId: str,
background_tasks: BackgroundTasks,
request_body: str = Body(..., media_type="application/json"),
workflow: CompiledStateGraph = Depends(get_compiled_graph),
):
if not request_body.strip():
raise HTTPException(status_code=400, detail="Request body cannot be empty")
# Retrieve last checkpoint
config = {"configurable": {"thread_id": sessionId, "checkpoint_ns": "", "userId": userId, "tenantId": tenantId}}
checkpoints = list(checkpointer.list(config))
last_active_agent = "coordinator_agent" # Default fallback
if not checkpoints:
# No previous state, start fresh
new_state = {"messages": [{"role": "user", "content": request_body}]}
response_data = workflow.invoke(new_state, config, stream_mode="updates")
else:
# Resume from last checkpoint
last_checkpoint = checkpoints[-1]
last_state = last_checkpoint.checkpoint
if "messages" not in last_state:
last_state["messages"] = []
last_state["messages"].append({"role": "user", "content": request_body})
if "channel_versions" in last_state:
for key in reversed(last_state["channel_versions"].keys()):
if "agent" in key:
last_active_agent = key.split(":")[1]
break
last_state["langgraph_triggers"] = [f"resume:{last_active_agent}"]
response_data = workflow.invoke(last_state, config, stream_mode="updates")
debug_log_id = store_debug_log(sessionId, tenantId, userId, response_data)
messages = extract_relevant_messages(debug_log_id, last_active_agent, response_data, tenantId, userId, sessionId)
partition_key = [tenantId, userId, sessionId]
# Get the active agent from Cosmos DB with a point lookup
activeAgent = chat_container.read_item(item=sessionId, partition_key=partition_key).get('activeAgent', 'unknown')
# update last sender in messages to the active agent
messages[-1].sender = agent_mapping.get(activeAgent, activeAgent)
# Schedule storing chat history and updating correct agent in last message as a background task
# to avoid blocking the API response as this is not needed unless retrieving the message history later.
background_tasks.add_task(process_messages, messages, userId, tenantId, sessionId)
return messages
@app.post("/tenant/{tenantId}/user/{userId}/sessions/{sessionId}/summarize-name", tags=[endpointTitle],
operation_id="SummarizeChatSessionName", response_description="Success", response_model=str)
def summarize_chat_session_name(tenantId: str, userId: str, sessionId: str,
request_body: str = Body(..., media_type="application/json")):
"""
Generates a summarized name for a chat session based on the chat text provided.
"""
try:
prompt = (
"Given the following chat transcript, generate a short, meaningful name for the conversation.\n\n"
f"Chat Transcript:\n{request_body}\n\n"
"Summary Name:"
)
response = model.invoke(prompt)
summarized_name = response.content.strip()
return summarized_name
except Exception as e:
return {"error": f"Failed to generate chat session name: {str(e)}"}
@app.post("/tenant/{tenantId}/user/{userId}/semanticcache/reset", tags=[endpointTitle],
operation_id="ResetSemanticCache", response_description="Success",
description="Semantic cache reset - not yet implemented", )
def reset_semantic_cache(tenantId: str, userId: str):
return {"message": "Semantic cache reset not yet implemented"}
@app.put("/userdata", tags=[dataLoadTitle], description="Inserts or updates a single user data record in Cosmos DB")
async def put_userdata(data: Dict):
try:
update_users_container(data)
return {"message": "Inserted user record successfully", "id": data.get("id")}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to insert user data: {str(e)}")
@app.put("/accountdata", tags=[dataLoadTitle],
description="Inserts or updates a single account data record in Cosmos DB")
async def put_accountdata(data: Dict):
try:
update_account_container(data)
return {"message": "Inserted account record successfully", "id": data.get("id")}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to insert account data: {str(e)}")
@app.put("/offerdata", tags=[dataLoadTitle], description="Inserts or updates a single offer data record in Cosmos DB")
async def put_offerdata(data: Dict):
try:
update_offers_container(data)
return {"message": "Inserted offer record successfully", "id": data.get("id")}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to insert offer data: {str(e)}")