-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy pathgraph.py
More file actions
64 lines (45 loc) · 1.93 KB
/
graph.py
File metadata and controls
64 lines (45 loc) · 1.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from typing import Literal
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import START
from langgraph.graph.state import CompiledStateGraph, StateGraph
from pydantic import ValidationError
from examples.ex006.state import State
from examples.ex006.tools import TOOLS, TOOLS_BY_NAME
from examples.ex006.utils import load_llm
def call_llm(state: State) -> State:
print("> call llm")
llm_with_tools = load_llm().bind_tools(TOOLS)
result = llm_with_tools.invoke(state["messages"])
return {"messages": [result]}
def tool_node(state: State) -> State:
print("> tool node")
llm_response = state["messages"][-1]
if not isinstance(llm_response, AIMessage) or not getattr(
llm_response, "tool_calls", None
):
return state
call = llm_response.tool_calls[-1]
name, args, id_ = call["name"], call["args"], call["id"]
try:
content = TOOLS_BY_NAME[name].invoke(args)
status = "success"
except (KeyError, IndexError, TypeError, ValidationError, ValueError) as error:
content = f"Please, fix your mistakes: {error}"
status = "error"
tool_message = ToolMessage(content=content, tool_call_id=id_, status=status)
return {"messages": [tool_message]}
def router(state: State) -> Literal["tool_node", "__end__"]:
print("> router")
llm_response = state["messages"][-1]
if getattr(llm_response, "tool_calls", None):
return "tool_node"
return "__end__"
def build_graph() -> CompiledStateGraph[State, None, State, State]:
builder = StateGraph(State)
builder.add_node("call_llm", call_llm)
builder.add_node("tool_node", tool_node)
builder.add_edge(START, "call_llm")
builder.add_conditional_edges("call_llm", router, ["tool_node", "__end__"])
builder.add_edge("tool_node", "call_llm")
return builder.compile(checkpointer=InMemorySaver())