-
Notifications
You must be signed in to change notification settings - Fork 95
Expand file tree
/
Copy pathbanking_agents.py
More file actions
221 lines (179 loc) · 8.25 KB
/
Copy pathbanking_agents.py
File metadata and controls
221 lines (179 loc) · 8.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import logging
import os
import uuid
from langchain.schema import AIMessage
from typing import Literal
from langgraph.graph import StateGraph, START, MessagesState
from langgraph.prebuilt import create_react_agent
from langgraph.types import Command, interrupt
from langgraph_checkpoint_cosmosdb import CosmosDBSaver
from langsmith import traceable
from src.app.services.azure_open_ai import model
# from src.app.services.local_model import model # Use local model
from src.app.services.azure_cosmos_db import DATABASE_NAME, checkpoint_container, chat_container, \
update_chat_container, patch_active_agent
from src.app.tools.sales import get_offer_information, calculate_monthly_payment, create_account
from src.app.tools.transactions import bank_balance, bank_transfer, get_transaction_history
from src.app.tools.support import service_request, get_branch_location
from src.app.tools.coordinator import create_agent_transfer
local_interactive_mode = False
logging.basicConfig(level=logging.DEBUG)
PROMPT_DIR = os.path.join(os.path.dirname(__file__), 'prompts')
def load_prompt(agent_name):
"""Loads the prompt for a given agent from a file."""
file_path = os.path.join(PROMPT_DIR, f"{agent_name}.prompty")
print(f"Loading prompt for {agent_name} from {file_path}")
try:
with open(file_path, "r", encoding="utf-8") as file:
return file.read().strip()
except FileNotFoundError:
print(f"Prompt file not found for {agent_name}, using default placeholder.")
return "You are an AI banking assistant." # Fallback default prompt
coordinator_agent_tools = [
create_agent_transfer(agent_name="customer_support_agent"),
create_agent_transfer(agent_name="sales_agent"),
]
coordinator_agent = create_react_agent(
model,
coordinator_agent_tools,
state_modifier=load_prompt("coordinator_agent"),
)
customer_support_agent_tools = [
get_branch_location,
service_request,
create_agent_transfer(agent_name="sales_agent"),
create_agent_transfer(agent_name="transactions_agent"),
]
customer_support_agent = create_react_agent(
model,
customer_support_agent_tools,
state_modifier=load_prompt("customer_support_agent"),
)
transactions_agent_tools = [
bank_balance,
bank_transfer,
get_transaction_history,
create_agent_transfer(agent_name="customer_support_agent"),
]
transactions_agent = create_react_agent(
model,
transactions_agent_tools,
state_modifier=load_prompt("transactions_agent"),
)
sales_agent_tools = [
get_offer_information,
calculate_monthly_payment,
create_account,
create_agent_transfer(agent_name="customer_support_agent"),
create_agent_transfer(agent_name="transactions_agent"),
]
sales_agent = create_react_agent(
model,
sales_agent_tools,
state_modifier=load_prompt("sales_agent"),
)
@traceable(run_type="llm")
def call_coordinator_agent(state: MessagesState, config) -> Command[Literal["coordinator_agent", "human"]]:
thread_id = config["configurable"].get("thread_id", "UNKNOWN_THREAD_ID")
userId = config["configurable"].get("userId", "UNKNOWN_USER_ID")
tenantId = config["configurable"].get("tenantId", "UNKNOWN_TENANT_ID")
logging.debug(f"Calling coordinator agent with Thread ID: {thread_id}")
# Get the active agent from Cosmos DB with a point lookup
partition_key = [tenantId, userId, thread_id]
activeAgent = None
try:
activeAgent = chat_container.read_item(item=thread_id, partition_key=partition_key).get('activeAgent',
'unknown')
except Exception as e:
logging.debug(f"No active agent found: {e}")
if activeAgent is None:
if local_interactive_mode:
update_chat_container({
"id": thread_id,
"tenantId": "cli-test",
"userId": "cli-test",
"sessionId": thread_id,
"name": "cli-test",
"age": "cli-test",
"address": "cli-test",
"activeAgent": "unknown",
"ChatName": "cli-test",
"messages": []
})
logging.debug(f"Active agent from point lookup: {activeAgent}")
# If active agent is something other than unknown or coordinator_agent, transfer directly to that agent
if activeAgent is not None and activeAgent not in ["unknown", "coordinator_agent"]:
logging.debug(f"Routing straight to last active agent: {activeAgent}")
return Command(update=state, goto=activeAgent)
else:
response = coordinator_agent.invoke(state)
return Command(update=response, goto="human")
@traceable(run_type="llm")
def call_customer_support_agent(state: MessagesState, config) -> Command[Literal["customer_support_agent", "human"]]:
thread_id = config["configurable"].get("thread_id", "UNKNOWN_THREAD_ID")
if local_interactive_mode:
patch_active_agent(tenantId="cli-test", userId="cli-test", sessionId=thread_id,
activeAgent="customer_support_agent")
response = customer_support_agent.invoke(state)
return Command(update=response, goto="human")
@traceable(run_type="llm")
def call_sales_agent(state: MessagesState, config) -> Command[Literal["sales_agent", "human"]]:
thread_id = config["configurable"].get("thread_id", "UNKNOWN_THREAD_ID")
if local_interactive_mode:
patch_active_agent(tenantId="cli-test", userId="cli-test", sessionId=thread_id,
activeAgent="sales_agent")
response = sales_agent.invoke(state, config) # Invoke sales agent with state
return Command(update=response, goto="human")
@traceable(run_type="llm")
def call_transactions_agent(state: MessagesState, config) -> Command[Literal["transactions_agent", "human"]]:
thread_id = config["configurable"].get("thread_id", "UNKNOWN_THREAD_ID")
if local_interactive_mode:
patch_active_agent(tenantId="cli-test", userId="cli-test", sessionId=thread_id,
activeAgent="transactions_agent")
response = transactions_agent.invoke(state)
return Command(update=response, goto="human")
# The human_node with interrupt function serves as a mechanism to stop
# the graph and collect user input for multi-turn conversations.
@traceable
def human_node(state: MessagesState, config) -> None:
"""A node for collecting user input."""
interrupt(value="Ready for user input.")
return None
builder = StateGraph(MessagesState)
builder.add_node("coordinator_agent", call_coordinator_agent)
builder.add_node("customer_support_agent", call_customer_support_agent)
builder.add_node("sales_agent", call_sales_agent)
builder.add_node("transactions_agent", call_transactions_agent)
builder.add_node("human", human_node)
builder.add_edge(START, "coordinator_agent")
checkpointer = CosmosDBSaver(database_name=DATABASE_NAME, container_name=checkpoint_container)
graph = builder.compile(checkpointer=checkpointer)
def interactive_chat():
thread_config = {"configurable": {"thread_id": str(uuid.uuid4()), "userId": "Mark", "tenantId": "Contoso"}}
global local_interactive_mode
local_interactive_mode = True
print("Welcome to the interactive multi-agent shopping assistant.")
print("Type 'exit' to end the conversation.\n")
user_input = input("You: ")
conversation_turn = 1
while user_input.lower() != "exit":
input_message = {"messages": [{"role": "user", "content": user_input}]}
response_found = False # Track if we received an AI response
for update in graph.stream(
input_message,
config=thread_config,
stream_mode="updates",
):
for node_id, value in update.items():
if isinstance(value, dict) and value.get("messages"):
last_message = value["messages"][-1] # Get last message
if isinstance(last_message, AIMessage):
print(f"{node_id}: {last_message.content}\n")
response_found = True
if not response_found:
print("DEBUG: No AI response received.")
# Get user input for the next round
user_input = input("You: ")
conversation_turn += 1
if __name__ == "__main__":
interactive_chat()