Skip to content

Commit 6878c1f

Browse files
tylerslatonblovemme
committed
feat(agent): integrate CopilotKit with tools, shared state, and generative UI
- Wire CopilotKitMiddleware and LangGraphAGUIAgent into the LangGraph agent - Add query_data tool with sample financial CSV for generative chart UI - Add manage_todos/get_todos tools and AgentState with todos field for shared state - Add Excalidraw MCP middleware and suggestion - Add unit test infrastructure for agent tools Co-Authored-By: Brian Love <brian@liveloveapp.com> Co-Authored-By: Markus Ecker <markus.ecker@gmail.com>
1 parent 4bb6dc7 commit 6878c1f

File tree

12 files changed

+2702
-170
lines changed

12 files changed

+2702
-170
lines changed

patterns/langgraph-single-agent/Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,9 @@ USER bedrock_agentcore
3131

3232
EXPOSE 8080
3333

34-
# Copy only the agent code
34+
# Copy agent code
3535
COPY patterns/langgraph-single-agent/langgraph_agent.py .
36+
COPY patterns/langgraph-single-agent/tools/ tools/
3637
COPY patterns/utils/ utils/
3738

3839
# Healthcheck using Python (no extra dependencies needed)

patterns/langgraph-single-agent/langgraph_agent.py

Lines changed: 151 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -1,213 +1,198 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
from langgraph.prebuilt import create_react_agent
5-
from langchain_aws import ChatBedrock
6-
from langchain_mcp_adapters.client import MultiServerMCPClient
4+
from __future__ import annotations
5+
6+
import base64
7+
import json
78
import os
8-
import boto3
9+
import logging
10+
11+
from ag_ui.core import RunAgentInput, RunErrorEvent, RunFinishedEvent
912
from bedrock_agentcore.identity.auth import requires_access_token
1013
from bedrock_agentcore.runtime import BedrockAgentCoreApp, RequestContext
11-
import traceback
12-
13-
# Use official LangGraph AWS integration for memory
14+
from copilotkit import CopilotKitMiddleware, LangGraphAGUIAgent
15+
from langchain.agents import create_agent
16+
from langchain_aws import ChatBedrock
17+
from langchain_mcp_adapters.client import MultiServerMCPClient
1418
from langgraph_checkpoint_aws import AgentCoreMemorySaver
1519

16-
from utils.auth import extract_user_id_from_context
1720
from utils.ssm import get_ssm_parameter
21+
from tools import query_data, AgentState, todo_tools
1822

1923
app = BedrockAgentCoreApp()
2024

21-
# OAuth2 Credential Provider decorator from AgentCore Identity SDK.
22-
# Automatically retrieves OAuth2 access tokens from the Token Vault (with caching)
23-
# or fetches fresh tokens from the configured OAuth2 provider when expired.
24-
# The provider_name references an OAuth2 Credential Provider registered in AgentCore Identity.
25+
ACTOR_ID_KEYS = ("actor_id", "actorId", "user_id", "userId", "sub")
26+
27+
SYSTEM_PROMPT = """You are a helpful assistant with access to tools via the Gateway and built-in data tools.
28+
29+
When demonstrating charts, always call the query_data tool first to fetch data from the database before calling any chart tool.
30+
When managing todos, use manage_todos to update the list and get_todos to read the current list.
31+
When asked about your tools, list them and explain what they do."""
32+
33+
34+
def decode_jwt_sub(authorization_header: str | None) -> str | None:
35+
if not authorization_header:
36+
return None
37+
38+
parts = authorization_header.strip().split()
39+
if len(parts) != 2 or parts[0].lower() != "bearer":
40+
return None
41+
42+
token_parts = parts[1].split(".")
43+
if len(token_parts) < 2:
44+
return None
45+
46+
try:
47+
payload = token_parts[1]
48+
payload += "=" * ((4 - len(payload) % 4) % 4)
49+
decoded = base64.urlsafe_b64decode(payload.encode("utf-8"))
50+
sub = json.loads(decoded).get("sub")
51+
return sub if isinstance(sub, str) and sub else None
52+
except Exception:
53+
return None
54+
55+
56+
def resolve_actor_id(
57+
input_data: RunAgentInput, authorization_header: str | None
58+
) -> str | None:
59+
forwarded_props = (
60+
input_data.forwarded_props
61+
if isinstance(input_data.forwarded_props, dict)
62+
else {}
63+
)
64+
65+
for key in ACTOR_ID_KEYS:
66+
value = forwarded_props.get(key)
67+
if isinstance(value, str) and value:
68+
return value
69+
70+
return decode_jwt_sub(authorization_header)
71+
72+
2573
@requires_access_token(
2674
provider_name=os.environ["GATEWAY_CREDENTIAL_PROVIDER_NAME"],
2775
auth_flow="M2M",
28-
scopes=[]
76+
scopes=[],
2977
)
3078
async def _fetch_gateway_token(access_token: str) -> str:
31-
"""
32-
Fetch fresh OAuth2 token for AgentCore Gateway authentication.
33-
34-
This is async because it's called with 'await' in create_gateway_mcp_client().
35-
The @requires_access_token decorator handles token retrieval and refresh:
36-
1. Token Retrieval: Calls GetResourceOauth2Token API to fetch token from Token Vault
37-
2. Automatic Refresh: Uses refresh tokens to renew expired access tokens
38-
3. Error Orchestration: Handles missing tokens and OAuth flow management
39-
40-
For M2M (Machine-to-Machine) flows, the decorator uses Client Credentials grant type.
41-
The provider_name must match the Name field in the CDK OAuth2CredentialProvider resource.
42-
"""
4379
return access_token
4480

4581

4682
async def create_gateway_mcp_client() -> MultiServerMCPClient:
47-
"""
48-
Create an MCP client connected to the AgentCore Gateway with OAuth2 authentication.
49-
50-
MCP (Model Context Protocol) is how agents communicate with tool providers.
51-
This creates a client that can talk to the AgentCore Gateway using OAuth2
52-
authentication. The Gateway then provides access to Lambda-based tools.
53-
54-
This implementation avoids the "closure trap" by calling _fetch_gateway_token()
55-
on every invocation of create_gateway_mcp_client(). Since this function is called
56-
per-request in agent_stream(), it ensures fresh tokens for each request.
57-
"""
58-
stack_name = os.environ.get('STACK_NAME')
83+
stack_name = os.environ.get("STACK_NAME")
5984
if not stack_name:
6085
raise ValueError("STACK_NAME environment variable is required")
61-
62-
# Validate stack name format to prevent injection
63-
if not stack_name.replace('-', '').replace('_', '').isalnum():
86+
87+
if not stack_name.replace("-", "").replace("_", "").isalnum():
6488
raise ValueError("Invalid STACK_NAME format")
65-
66-
print(f"[AGENT] Creating Gateway MCP client for stack: {stack_name}")
67-
68-
# Fetch Gateway URL from SSM
69-
gateway_url = get_ssm_parameter(f'/{stack_name}/gateway_url')
70-
print(f"[AGENT] Gateway URL from SSM: {gateway_url}")
71-
72-
# Fetch fresh token on every call to avoid closure trap
89+
90+
gateway_url = get_ssm_parameter(f"/{stack_name}/gateway_url")
7391
fresh_token = await _fetch_gateway_token()
74-
75-
# Create MCP client with Bearer token authentication
76-
gateway_client = MultiServerMCPClient({
77-
"gateway": {
78-
"transport": "streamable_http",
79-
"url": gateway_url,
80-
"headers": {
81-
"Authorization": f"Bearer {fresh_token}"
92+
93+
return MultiServerMCPClient(
94+
{
95+
"gateway": {
96+
"transport": "streamable_http",
97+
"url": gateway_url,
98+
"headers": {
99+
"Authorization": f"Bearer {fresh_token}",
100+
},
82101
}
83102
}
84-
})
85-
86-
print("[AGENT] Gateway MCP client created successfully")
87-
return gateway_client
88-
89-
90-
async def create_langgraph_agent(user_id: str, session_id: str, tools: list):
91-
"""
92-
Create a LangGraph agent with AgentCore Gateway MCP tools and memory integration.
93-
94-
This function sets up a LangGraph StateGraph that can access tools through
95-
the AgentCore Gateway and maintains conversation memory.
96-
"""
97-
system_prompt = """You are a helpful assistant with access to tools via the Gateway.
98-
When asked about your tools, list them and explain what they do."""
99-
100-
# Create Bedrock model
101-
bedrock_model = ChatBedrock(
103+
)
104+
105+
106+
def _build_model(streaming: bool) -> ChatBedrock:
107+
return ChatBedrock(
102108
model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0",
103109
temperature=0.1,
104-
streaming=True
110+
max_tokens=16384,
111+
streaming=streaming,
112+
beta_use_converse_api=True,
105113
)
106114

107-
# Get and validate Memory ID
115+
116+
def _build_checkpointer() -> AgentCoreMemorySaver:
108117
memory_id = os.environ.get("MEMORY_ID")
109118
if not memory_id:
110119
raise ValueError("MEMORY_ID environment variable is required")
111-
112-
# Configure AgentCore Memory using official LangGraph AWS integration
113-
checkpointer = AgentCoreMemorySaver(
120+
121+
return AgentCoreMemorySaver(
114122
memory_id=memory_id,
115-
region_name=os.environ.get("AWS_DEFAULT_REGION", "us-east-1")
123+
region_name=os.environ.get("AWS_DEFAULT_REGION", "us-east-1"),
116124
)
117125

126+
127+
async def create_langgraph_agent(tools: list):
118128
try:
119-
print("[AGENT] Creating LangGraph agent with Gateway tools...")
120-
121-
graph = create_react_agent(
122-
model=bedrock_model,
123-
tools=tools,
124-
checkpointer=checkpointer,
125-
prompt=system_prompt
129+
return create_agent(
130+
model=_build_model(streaming=True),
131+
tools=[*tools, query_data, *todo_tools], # MCP tools + data + todo tools
132+
checkpointer=_build_checkpointer(),
133+
middleware=[CopilotKitMiddleware()],
134+
system_prompt=SYSTEM_PROMPT,
135+
state_schema=AgentState, # extends BaseAgentState with todos: list[Todo]
126136
)
127-
128-
print("[AGENT] Agent created successfully with Gateway tools")
129-
return graph
130-
131-
except Exception as e:
132-
print(f"[AGENT ERROR] Error creating LangGraph agent: {e}")
133-
print(f"[AGENT ERROR] Exception type: {type(e).__name__}")
134-
print(f"[AGENT ERROR] Traceback:")
135-
traceback.print_exc()
137+
except Exception:
138+
logging.exception("Error creating LangGraph agent")
136139
raise
137140

138141

142+
async def create_runtime_graph():
143+
mcp_client = await create_gateway_mcp_client()
144+
tools = await mcp_client.get_tools()
145+
return await create_langgraph_agent(tools)
146+
147+
148+
class ActorAwareLangGraphAgent(LangGraphAGUIAgent):
149+
async def run(self, input: RunAgentInput):
150+
actor_id = (
151+
self.config.get("configurable", {}).get("actor_id") if self.config else None
152+
)
153+
if not actor_id:
154+
raise ValueError(
155+
"Missing actor identity. Provide forwardedProps.actor_id/user_id "
156+
"or include sub claim in the bearer token."
157+
)
158+
159+
self.graph = await create_runtime_graph()
160+
self.config = {"configurable": {"actor_id": actor_id}}
161+
async for event in super().run(input):
162+
yield event
163+
164+
139165
@app.entrypoint
140-
async def agent_stream(payload, context: RequestContext):
141-
"""
142-
Main entrypoint for the LangGraph agent using streaming with Gateway integration.
143-
144-
This is the function that AgentCore Runtime calls when the agent receives a request.
145-
It extracts the user's query from the payload, securely obtains the user ID from
146-
the validated JWT token in the request context, creates a LangGraph agent with Gateway
147-
tools and memory, and streams the response back. This function handles the complete
148-
request lifecycle with token-level streaming. The user ID is extracted from the
149-
JWT token (via RequestContext).
150-
"""
151-
user_query = payload.get("prompt")
152-
session_id = payload.get("runtimeSessionId")
153-
154-
if not all([user_query, session_id]):
155-
yield {
156-
"status": "error",
157-
"error": "Missing required fields: prompt or runtimeSessionId"
158-
}
159-
return
160-
166+
async def invocations(payload: dict, context: RequestContext):
167+
input_data = RunAgentInput.model_validate(payload)
168+
authorization_header = None
169+
if context.request_headers:
170+
authorization_header = context.request_headers.get("Authorization")
171+
172+
actor_id = resolve_actor_id(input_data, authorization_header)
173+
if not actor_id:
174+
raise ValueError(
175+
"Missing actor identity. Provide forwardedProps.actor_id/user_id "
176+
"or include sub claim in the bearer token."
177+
)
178+
179+
request_agent = ActorAwareLangGraphAgent(
180+
name="LangGraphSingleAgent",
181+
description="LangGraph single agent exposed via AG-UI",
182+
graph=None,
183+
config={"configurable": {"actor_id": actor_id}},
184+
)
185+
161186
try:
162-
# Extract user ID securely from the validated JWT token
163-
# instead of trusting the payload body (which could be manipulated)
164-
user_id = extract_user_id_from_context(context)
165-
166-
print(f"[STREAM] Starting streaming invocation for user: {user_id}, session: {session_id}")
167-
print(f"[STREAM] Query: {user_query}")
168-
169-
# Get OAuth2 access token and create Gateway MCP client
170-
# The @requires_access_token decorator handles token fetching automatically
171-
print("[STREAM] Creating Gateway MCP client (decorator handles OAuth2)...")
172-
mcp_client = await create_gateway_mcp_client()
173-
print("[STREAM] Gateway MCP client created successfully")
174-
175-
print("[STREAM] Loading Gateway tools...")
176-
tools = await mcp_client.get_tools()
177-
print(f"[STREAM] Loaded {len(tools)} tools from Gateway")
178-
179-
# Create agent with the loaded tools
180-
graph = await create_langgraph_agent(user_id, session_id, tools)
181-
182-
# Configure streaming with actor_id and thread_id for memory
183-
config = {
184-
"configurable": {
185-
"thread_id": session_id,
186-
"actor_id": user_id
187-
}
188-
}
189-
190-
# Stream messages using LangGraph's astream with stream_mode="messages"
191-
async for event in graph.astream(
192-
{"messages": [("user", user_query)]},
193-
config=config,
194-
stream_mode="messages"
195-
):
196-
# event is a tuple: (message_chunk, metadata)
197-
message_chunk, metadata = event
198-
yield message_chunk.model_dump()
199-
200-
print("[STREAM] Streaming completed successfully")
201-
202-
except Exception as e:
203-
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
204-
print(f"[STREAM ERROR] Error in agent_stream: {error_msg}")
205-
print(f"[STREAM ERROR] Exception type: {type(e).__name__}")
206-
traceback.print_exc()
207-
yield {
208-
"status": "error",
209-
"error": error_msg
210-
}
187+
async for event in request_agent.run(input_data):
188+
if event is not None:
189+
yield event.model_dump(mode="json", by_alias=True, exclude_none=True)
190+
except Exception as exc:
191+
logging.exception("Agent run failed")
192+
yield RunErrorEvent(
193+
message=str(exc) or type(exc).__name__,
194+
code=type(exc).__name__,
195+
).model_dump(mode="json", by_alias=True, exclude_none=True)
211196

212197

213198
if __name__ == "__main__":

patterns/langgraph-single-agent/requirements.txt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
11
# LangGraph agent dependencies with pinned versions
2+
fastapi==0.115.12
3+
uvicorn==0.34.2
4+
ag-ui-protocol>=0.1.10
5+
ag-ui-langgraph==0.0.28
6+
copilotkit==0.1.83
7+
partialjson>=0.0.8
28
langgraph==1.0.10rc1
9+
langchain>=0.3.0
310
langchain-aws==1.0.0
411
langchain-mcp-adapters==0.1.13
5-
langgraph-checkpoint-aws==1.0.1
12+
langgraph-checkpoint-aws==1.0.5
613
mcp==1.23.1
714
bedrock-agentcore==1.0.6
8-
PyJWT[crypto]>=2.10.1

0 commit comments

Comments
 (0)