-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchatbot_with_hitl.py
More file actions
148 lines (120 loc) · 4.23 KB
/
Copy pathchatbot_with_hitl.py
File metadata and controls
148 lines (120 loc) · 4.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# backend.py
from langgraph.graph import StateGraph, START
from typing import TypedDict, Annotated
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.tools import tool
from langgraph.types import interrupt, Command
from dotenv import load_dotenv
import requests
load_dotenv()
# -------------------
# 1. LLM
# -------------------
llm = ChatOpenAI()
# -------------------
# 2. Tools
# -------------------
@tool
def get_stock_price(symbol: str) -> dict:
"""
Fetch latest stock price for a given symbol (e.g. 'AAPL', 'TSLA')
using Alpha Vantage with API key in the URL.
"""
url = (
"https://www.alphavantage.co/query"
f"?function=GLOBAL_QUOTE&symbol={symbol}&apikey=C9PE94QUEW9VWGFM"
)
r = requests.get(url)
return r.json()
@tool
def purchase_stock(symbol: str, quantity: int) -> dict:
"""
Simulate purchasing a given quantity of a stock symbol.
HUMAN-IN-THE-LOOP:
Before confirming the purchase, this tool will interrupt
and wait for a human decision ("yes" / anything else).
"""
# This pauses the graph and returns control to the caller
decision = interrupt(f"Approve buying {quantity} shares of {symbol}? (yes/no)")
if isinstance(decision, str) and decision.lower() == "yes":
return {
"status": "success",
"message": f"Purchase order placed for {quantity} shares of {symbol}.",
"symbol": symbol,
"quantity": quantity,
}
else:
return {
"status": "cancelled",
"message": f"Purchase of {quantity} shares of {symbol} was declined by human.",
"symbol": symbol,
"quantity": quantity,
}
tools = [get_stock_price, purchase_stock]
llm_with_tools = llm.bind_tools(tools)
# -------------------
# 3. State
# -------------------
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
# -------------------
# 4. Nodes
# -------------------
def chat_node(state: ChatState):
"""LLM node that may answer or request a tool call."""
messages = state["messages"]
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
tool_node = ToolNode(tools)
# -------------------
# 5. Checkpointer (in-memory)
# -------------------
memory = MemorySaver()
# -------------------
# 6. Graph
# -------------------
graph = StateGraph(ChatState)
graph.add_node("chat_node", chat_node)
graph.add_node("tools", tool_node)
graph.add_edge(START, "chat_node")
graph.add_conditional_edges("chat_node", tools_condition)
graph.add_edge("tools", "chat_node")
chatbot = graph.compile(checkpointer=memory)
# -------------------
# 7. Simple usage example (CLI with HITL)
# -------------------
if __name__ == "__main__":
# Use a fixed thread_id so the conversation is persisted in memory
thread_id = "demo-thread"
while True:
user_input = input("You: ")
if user_input.lower().strip() in {"exit", "quit"}:
print("Goodbye!")
break
# Build initial state for this turn
state = {"messages": [HumanMessage(content=user_input)]}
# Run the graph (may hit an interrupt)
result = chatbot.invoke(
state,
config={"configurable": {"thread_id": thread_id}},
)
# Check for HITL interrupt from purchase_stock
interrupts = result.get("__interrupt__", [])
if interrupts:
# Our interrupt payload is the string we passed to interrupt(...)
prompt_to_human = interrupts[0].value
print(f"HITL: {prompt_to_human}")
decision = input("Your decision: ").strip().lower()
# Resume graph with the human decision ("yes" / "no" / whatever)
result = chatbot.invoke(
Command(resume=decision),
config={"configurable": {"thread_id": thread_id}},
)
# Get the latest message from the assistant
messages = result["messages"]
last_msg = messages[-1]
print(f"Bot: {last_msg.content}\n")