|
1 | | -"""AG-UI LangGraph agent with Gateway MCP tools, Memory, and Code Interpreter. |
| 1 | +"""AG-UI LangGraph agent with Gateway MCP tools, Memory, and Code Interpreter.""" |
2 | 2 |
|
3 | | -Uses copilotkit's LangGraphAGUIAgent to produce native AG-UI SSE events. |
4 | | -AgentCore proxies these unchanged when deployed with --protocol AGUI. |
5 | | -""" |
| 3 | +from __future__ import annotations |
6 | 4 |
|
7 | 5 | import logging |
8 | 6 | import os |
|
12 | 10 | from copilotkit import CopilotKitMiddleware, LangGraphAGUIAgent |
13 | 11 | from langchain.agents import create_agent |
14 | 12 | from langchain_aws import ChatBedrock |
15 | | -from langgraph.graph import END, START, StateGraph |
16 | 13 | from langgraph_checkpoint_aws import AgentCoreMemorySaver |
17 | 14 | from tools.gateway import create_gateway_mcp_client |
18 | 15 | from utils.auth import extract_user_id_from_context |
|
28 | 25 | "When asked about your tools, list them and explain what they do." |
29 | 26 | ) |
30 | 27 |
|
| 28 | +REGION = os.environ.get("AWS_REGION", "us-east-1") |
| 29 | +MEMORY_ID = os.environ.get("MEMORY_ID") |
| 30 | +MODEL = ChatBedrock( |
| 31 | + model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0", |
| 32 | + temperature=0.1, |
| 33 | + streaming=True, |
| 34 | + beta_use_converse_api=True, |
| 35 | +) |
| 36 | +CODE_INTERPRETER = LangGraphCodeInterpreterTools(REGION).execute_python_securely |
31 | 37 |
|
32 | | -def _build_model() -> ChatBedrock: |
33 | | - return ChatBedrock( |
34 | | - model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0", |
35 | | - temperature=0.1, |
36 | | - streaming=True, |
37 | | - beta_use_converse_api=True, |
38 | | - ) |
39 | 38 |
|
| 39 | +def get_memory_saver() -> AgentCoreMemorySaver | None: |
| 40 | + """Return an AgentCore Memory checkpointer, or None when MEMORY_ID is unset.""" |
| 41 | + if not MEMORY_ID: |
| 42 | + return None |
| 43 | + return AgentCoreMemorySaver(memory_id=MEMORY_ID, region_name=REGION) |
40 | 44 |
|
41 | | -def _create_checkpointer() -> AgentCoreMemorySaver: |
42 | | - memory_id = os.environ.get("MEMORY_ID") |
43 | | - if not memory_id: |
44 | | - raise ValueError("MEMORY_ID environment variable is required") |
45 | | - return AgentCoreMemorySaver( |
46 | | - memory_id=memory_id, |
47 | | - region_name=os.environ.get("AWS_DEFAULT_REGION", "us-east-1"), |
48 | | - ) |
49 | 45 |
|
50 | | - |
51 | | -async def create_langgraph_agent(user_id: str): |
52 | | - """Create a LangGraph agent with Gateway tools, Memory, and Code Interpreter.""" |
53 | | - mcp_client = await create_gateway_mcp_client(user_id) |
| 46 | +async def build_graph(actor_id: str): |
| 47 | + """Build a LangGraph compiled graph with Gateway tools and Memory.""" |
| 48 | + mcp_client = await create_gateway_mcp_client(actor_id) |
54 | 49 | tools = await mcp_client.get_tools() |
55 | | - |
56 | | - region = os.environ.get("AWS_DEFAULT_REGION", "us-east-1") |
57 | | - code_tools = LangGraphCodeInterpreterTools(region) |
58 | | - tools.append(code_tools.execute_python_securely) |
| 50 | + tools.append(CODE_INTERPRETER) |
59 | 51 |
|
60 | 52 | return create_agent( |
61 | | - model=_build_model(), |
| 53 | + model=MODEL, |
62 | 54 | tools=tools, |
63 | | - checkpointer=_create_checkpointer(), |
| 55 | + checkpointer=get_memory_saver(), |
64 | 56 | middleware=[CopilotKitMiddleware()], |
65 | 57 | system_prompt=SYSTEM_PROMPT, |
66 | 58 | ) |
67 | 59 |
|
68 | 60 |
|
69 | | -class ActorAwareLangGraphAgent(LangGraphAGUIAgent): |
70 | | - """LangGraphAGUIAgent that creates the graph per-request with fresh tokens.""" |
71 | | - |
72 | | - def __init__(self, *, user_id: str, **kwargs): |
73 | | - self._user_id = user_id |
74 | | - # Create a minimal placeholder graph to satisfy validation in newer |
75 | | - # copilotkit/ag_ui_langgraph versions that inspect self.graph.nodes |
76 | | - # during __init__. The placeholder is overwritten in run(). |
77 | | - if kwargs.get("graph") is None: |
78 | | - builder = StateGraph(dict) |
79 | | - builder.add_node("placeholder", lambda x: x) |
80 | | - builder.add_edge(START, "placeholder") |
81 | | - builder.add_edge("placeholder", END) |
82 | | - kwargs["graph"] = builder.compile() |
83 | | - super().__init__(**kwargs) |
84 | | - |
85 | | - async def run(self, input: RunAgentInput): |
86 | | - self.graph = await create_langgraph_agent(self._user_id) |
87 | | - async for event in super().run(input): |
88 | | - yield event |
89 | | - |
90 | | - |
91 | 61 | @app.entrypoint |
92 | 62 | async def invocations(payload: dict, context: RequestContext): |
93 | 63 | input_data = RunAgentInput.model_validate(payload) |
| 64 | + actor_id = extract_user_id_from_context(context) |
94 | 65 |
|
95 | | - user_id = extract_user_id_from_context(context) |
96 | | - |
97 | | - agent = ActorAwareLangGraphAgent( |
| 66 | + graph = await build_graph(actor_id) |
| 67 | + agui_agent = LangGraphAGUIAgent( |
98 | 68 | name="agui_langgraph_agent", |
99 | 69 | description="AG-UI LangGraph agent with Gateway MCP tools and Memory", |
100 | | - graph=None, |
101 | | - config={"configurable": {"actor_id": user_id}}, |
102 | | - user_id=user_id, |
| 70 | + graph=graph, |
| 71 | + config={"configurable": {"actor_id": actor_id}}, |
103 | 72 | ) |
104 | 73 |
|
105 | 74 | try: |
106 | | - async for event in agent.run(input_data): |
| 75 | + async for event in agui_agent.run(input_data): |
107 | 76 | if event is not None: |
108 | 77 | yield event.model_dump(mode="json", by_alias=True, exclude_none=True) |
109 | 78 | except Exception as exc: |
|
0 commit comments