|
1 | 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 |
|
4 | | -from langgraph.prebuilt import create_react_agent |
5 | | -from langchain_aws import ChatBedrock |
6 | | -from langchain_mcp_adapters.client import MultiServerMCPClient |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +import base64 |
| 7 | +import json |
7 | 8 | import os |
8 | | -import boto3 |
| 9 | +import logging |
| 10 | + |
| 11 | +from ag_ui.core import RunAgentInput, RunErrorEvent, RunFinishedEvent |
9 | 12 | from bedrock_agentcore.identity.auth import requires_access_token |
10 | 13 | from bedrock_agentcore.runtime import BedrockAgentCoreApp, RequestContext |
11 | | -import traceback |
12 | | - |
13 | | -# Use official LangGraph AWS integration for memory |
| 14 | +from copilotkit import CopilotKitMiddleware, LangGraphAGUIAgent |
| 15 | +from langchain.agents import create_agent |
| 16 | +from langchain_aws import ChatBedrock |
| 17 | +from langchain_mcp_adapters.client import MultiServerMCPClient |
14 | 18 | from langgraph_checkpoint_aws import AgentCoreMemorySaver |
15 | 19 |
|
16 | | -from utils.auth import extract_user_id_from_context |
17 | 20 | from utils.ssm import get_ssm_parameter |
| 21 | +from tools import query_data, AgentState, todo_tools |
18 | 22 |
|
19 | 23 | app = BedrockAgentCoreApp() |
20 | 24 |
|
21 | | -# OAuth2 Credential Provider decorator from AgentCore Identity SDK. |
22 | | -# Automatically retrieves OAuth2 access tokens from the Token Vault (with caching) |
23 | | -# or fetches fresh tokens from the configured OAuth2 provider when expired. |
24 | | -# The provider_name references an OAuth2 Credential Provider registered in AgentCore Identity. |
| 25 | +ACTOR_ID_KEYS = ("actor_id", "actorId", "user_id", "userId", "sub") |
| 26 | + |
| 27 | +SYSTEM_PROMPT = """You are a helpful assistant with access to tools via the Gateway and built-in data tools. |
| 28 | +
|
| 29 | +When demonstrating charts, always call the query_data tool first to fetch data from the database before calling any chart tool. |
| 30 | +When managing todos, use manage_todos to update the list and get_todos to read the current list. |
| 31 | +When asked about your tools, list them and explain what they do.""" |
| 32 | + |
| 33 | + |
| 34 | +def decode_jwt_sub(authorization_header: str | None) -> str | None: |
| 35 | + if not authorization_header: |
| 36 | + return None |
| 37 | + |
| 38 | + parts = authorization_header.strip().split() |
| 39 | + if len(parts) != 2 or parts[0].lower() != "bearer": |
| 40 | + return None |
| 41 | + |
| 42 | + token_parts = parts[1].split(".") |
| 43 | + if len(token_parts) < 2: |
| 44 | + return None |
| 45 | + |
| 46 | + try: |
| 47 | + payload = token_parts[1] |
| 48 | + payload += "=" * ((4 - len(payload) % 4) % 4) |
| 49 | + decoded = base64.urlsafe_b64decode(payload.encode("utf-8")) |
| 50 | + sub = json.loads(decoded).get("sub") |
| 51 | + return sub if isinstance(sub, str) and sub else None |
| 52 | + except Exception: |
| 53 | + return None |
| 54 | + |
| 55 | + |
| 56 | +def resolve_actor_id( |
| 57 | + input_data: RunAgentInput, authorization_header: str | None |
| 58 | +) -> str | None: |
| 59 | + forwarded_props = ( |
| 60 | + input_data.forwarded_props |
| 61 | + if isinstance(input_data.forwarded_props, dict) |
| 62 | + else {} |
| 63 | + ) |
| 64 | + |
| 65 | + for key in ACTOR_ID_KEYS: |
| 66 | + value = forwarded_props.get(key) |
| 67 | + if isinstance(value, str) and value: |
| 68 | + return value |
| 69 | + |
| 70 | + return decode_jwt_sub(authorization_header) |
| 71 | + |
| 72 | + |
25 | 73 | @requires_access_token( |
26 | 74 | provider_name=os.environ["GATEWAY_CREDENTIAL_PROVIDER_NAME"], |
27 | 75 | auth_flow="M2M", |
28 | | - scopes=[] |
| 76 | + scopes=[], |
29 | 77 | ) |
30 | 78 | async def _fetch_gateway_token(access_token: str) -> str: |
31 | | - """ |
32 | | - Fetch fresh OAuth2 token for AgentCore Gateway authentication. |
33 | | - |
34 | | - This is async because it's called with 'await' in create_gateway_mcp_client(). |
35 | | - The @requires_access_token decorator handles token retrieval and refresh: |
36 | | - 1. Token Retrieval: Calls GetResourceOauth2Token API to fetch token from Token Vault |
37 | | - 2. Automatic Refresh: Uses refresh tokens to renew expired access tokens |
38 | | - 3. Error Orchestration: Handles missing tokens and OAuth flow management |
39 | | - |
40 | | - For M2M (Machine-to-Machine) flows, the decorator uses Client Credentials grant type. |
41 | | - The provider_name must match the Name field in the CDK OAuth2CredentialProvider resource. |
42 | | - """ |
43 | 79 | return access_token |
44 | 80 |
|
45 | 81 |
|
46 | 82 | async def create_gateway_mcp_client() -> MultiServerMCPClient: |
47 | | - """ |
48 | | - Create an MCP client connected to the AgentCore Gateway with OAuth2 authentication. |
49 | | -
|
50 | | - MCP (Model Context Protocol) is how agents communicate with tool providers. |
51 | | - This creates a client that can talk to the AgentCore Gateway using OAuth2 |
52 | | - authentication. The Gateway then provides access to Lambda-based tools. |
53 | | - |
54 | | - This implementation avoids the "closure trap" by calling _fetch_gateway_token() |
55 | | - on every invocation of create_gateway_mcp_client(). Since this function is called |
56 | | - per-request in agent_stream(), it ensures fresh tokens for each request. |
57 | | - """ |
58 | | - stack_name = os.environ.get('STACK_NAME') |
| 83 | + stack_name = os.environ.get("STACK_NAME") |
59 | 84 | if not stack_name: |
60 | 85 | raise ValueError("STACK_NAME environment variable is required") |
61 | | - |
62 | | - # Validate stack name format to prevent injection |
63 | | - if not stack_name.replace('-', '').replace('_', '').isalnum(): |
| 86 | + |
| 87 | + if not stack_name.replace("-", "").replace("_", "").isalnum(): |
64 | 88 | raise ValueError("Invalid STACK_NAME format") |
65 | | - |
66 | | - print(f"[AGENT] Creating Gateway MCP client for stack: {stack_name}") |
67 | | - |
68 | | - # Fetch Gateway URL from SSM |
69 | | - gateway_url = get_ssm_parameter(f'/{stack_name}/gateway_url') |
70 | | - print(f"[AGENT] Gateway URL from SSM: {gateway_url}") |
71 | | - |
72 | | - # Fetch fresh token on every call to avoid closure trap |
| 89 | + |
| 90 | + gateway_url = get_ssm_parameter(f"/{stack_name}/gateway_url") |
73 | 91 | fresh_token = await _fetch_gateway_token() |
74 | | - |
75 | | - # Create MCP client with Bearer token authentication |
76 | | - gateway_client = MultiServerMCPClient({ |
77 | | - "gateway": { |
78 | | - "transport": "streamable_http", |
79 | | - "url": gateway_url, |
80 | | - "headers": { |
81 | | - "Authorization": f"Bearer {fresh_token}" |
| 92 | + |
| 93 | + return MultiServerMCPClient( |
| 94 | + { |
| 95 | + "gateway": { |
| 96 | + "transport": "streamable_http", |
| 97 | + "url": gateway_url, |
| 98 | + "headers": { |
| 99 | + "Authorization": f"Bearer {fresh_token}", |
| 100 | + }, |
82 | 101 | } |
83 | 102 | } |
84 | | - }) |
85 | | - |
86 | | - print("[AGENT] Gateway MCP client created successfully") |
87 | | - return gateway_client |
88 | | - |
89 | | - |
90 | | -async def create_langgraph_agent(user_id: str, session_id: str, tools: list): |
91 | | - """ |
92 | | - Create a LangGraph agent with AgentCore Gateway MCP tools and memory integration. |
93 | | - |
94 | | - This function sets up a LangGraph StateGraph that can access tools through |
95 | | - the AgentCore Gateway and maintains conversation memory. |
96 | | - """ |
97 | | - system_prompt = """You are a helpful assistant with access to tools via the Gateway. |
98 | | - When asked about your tools, list them and explain what they do.""" |
99 | | - |
100 | | - # Create Bedrock model |
101 | | - bedrock_model = ChatBedrock( |
| 103 | + ) |
| 104 | + |
| 105 | + |
| 106 | +def _build_model(streaming: bool) -> ChatBedrock: |
| 107 | + return ChatBedrock( |
102 | 108 | model_id="us.anthropic.claude-sonnet-4-5-20250929-v1:0", |
103 | 109 | temperature=0.1, |
104 | | - streaming=True |
| 110 | + max_tokens=16384, |
| 111 | + streaming=streaming, |
| 112 | + beta_use_converse_api=True, |
105 | 113 | ) |
106 | 114 |
|
107 | | - # Get and validate Memory ID |
| 115 | + |
| 116 | +def _build_checkpointer() -> AgentCoreMemorySaver: |
108 | 117 | memory_id = os.environ.get("MEMORY_ID") |
109 | 118 | if not memory_id: |
110 | 119 | raise ValueError("MEMORY_ID environment variable is required") |
111 | | - |
112 | | - # Configure AgentCore Memory using official LangGraph AWS integration |
113 | | - checkpointer = AgentCoreMemorySaver( |
| 120 | + |
| 121 | + return AgentCoreMemorySaver( |
114 | 122 | memory_id=memory_id, |
115 | | - region_name=os.environ.get("AWS_DEFAULT_REGION", "us-east-1") |
| 123 | + region_name=os.environ.get("AWS_DEFAULT_REGION", "us-east-1"), |
116 | 124 | ) |
117 | 125 |
|
| 126 | + |
| 127 | +async def create_langgraph_agent(tools: list): |
118 | 128 | try: |
119 | | - print("[AGENT] Creating LangGraph agent with Gateway tools...") |
120 | | - |
121 | | - graph = create_react_agent( |
122 | | - model=bedrock_model, |
123 | | - tools=tools, |
124 | | - checkpointer=checkpointer, |
125 | | - prompt=system_prompt |
| 129 | + return create_agent( |
| 130 | + model=_build_model(streaming=True), |
| 131 | + tools=[*tools, query_data, *todo_tools], # MCP tools + data + todo tools |
| 132 | + checkpointer=_build_checkpointer(), |
| 133 | + middleware=[CopilotKitMiddleware()], |
| 134 | + system_prompt=SYSTEM_PROMPT, |
| 135 | + state_schema=AgentState, # extends BaseAgentState with todos: list[Todo] |
126 | 136 | ) |
127 | | - |
128 | | - print("[AGENT] Agent created successfully with Gateway tools") |
129 | | - return graph |
130 | | - |
131 | | - except Exception as e: |
132 | | - print(f"[AGENT ERROR] Error creating LangGraph agent: {e}") |
133 | | - print(f"[AGENT ERROR] Exception type: {type(e).__name__}") |
134 | | - print(f"[AGENT ERROR] Traceback:") |
135 | | - traceback.print_exc() |
| 137 | + except Exception: |
| 138 | + logging.exception("Error creating LangGraph agent") |
136 | 139 | raise |
137 | 140 |
|
138 | 141 |
|
| 142 | +async def create_runtime_graph(): |
| 143 | + mcp_client = await create_gateway_mcp_client() |
| 144 | + tools = await mcp_client.get_tools() |
| 145 | + return await create_langgraph_agent(tools) |
| 146 | + |
| 147 | + |
| 148 | +class ActorAwareLangGraphAgent(LangGraphAGUIAgent): |
| 149 | + async def run(self, input: RunAgentInput): |
| 150 | + actor_id = ( |
| 151 | + self.config.get("configurable", {}).get("actor_id") if self.config else None |
| 152 | + ) |
| 153 | + if not actor_id: |
| 154 | + raise ValueError( |
| 155 | + "Missing actor identity. Provide forwardedProps.actor_id/user_id " |
| 156 | + "or include sub claim in the bearer token." |
| 157 | + ) |
| 158 | + |
| 159 | + self.graph = await create_runtime_graph() |
| 160 | + self.config = {"configurable": {"actor_id": actor_id}} |
| 161 | + async for event in super().run(input): |
| 162 | + yield event |
| 163 | + |
| 164 | + |
139 | 165 | @app.entrypoint |
140 | | -async def agent_stream(payload, context: RequestContext): |
141 | | - """ |
142 | | - Main entrypoint for the LangGraph agent using streaming with Gateway integration. |
143 | | - |
144 | | - This is the function that AgentCore Runtime calls when the agent receives a request. |
145 | | - It extracts the user's query from the payload, securely obtains the user ID from |
146 | | - the validated JWT token in the request context, creates a LangGraph agent with Gateway |
147 | | - tools and memory, and streams the response back. This function handles the complete |
148 | | - request lifecycle with token-level streaming. The user ID is extracted from the |
149 | | - JWT token (via RequestContext). |
150 | | - """ |
151 | | - user_query = payload.get("prompt") |
152 | | - session_id = payload.get("runtimeSessionId") |
153 | | - |
154 | | - if not all([user_query, session_id]): |
155 | | - yield { |
156 | | - "status": "error", |
157 | | - "error": "Missing required fields: prompt or runtimeSessionId" |
158 | | - } |
159 | | - return |
160 | | - |
| 166 | +async def invocations(payload: dict, context: RequestContext): |
| 167 | + input_data = RunAgentInput.model_validate(payload) |
| 168 | + authorization_header = None |
| 169 | + if context.request_headers: |
| 170 | + authorization_header = context.request_headers.get("Authorization") |
| 171 | + |
| 172 | + actor_id = resolve_actor_id(input_data, authorization_header) |
| 173 | + if not actor_id: |
| 174 | + raise ValueError( |
| 175 | + "Missing actor identity. Provide forwardedProps.actor_id/user_id " |
| 176 | + "or include sub claim in the bearer token." |
| 177 | + ) |
| 178 | + |
| 179 | + request_agent = ActorAwareLangGraphAgent( |
| 180 | + name="LangGraphSingleAgent", |
| 181 | + description="LangGraph single agent exposed via AG-UI", |
| 182 | + graph=None, |
| 183 | + config={"configurable": {"actor_id": actor_id}}, |
| 184 | + ) |
| 185 | + |
161 | 186 | try: |
162 | | - # Extract user ID securely from the validated JWT token |
163 | | - # instead of trusting the payload body (which could be manipulated) |
164 | | - user_id = extract_user_id_from_context(context) |
165 | | - |
166 | | - print(f"[STREAM] Starting streaming invocation for user: {user_id}, session: {session_id}") |
167 | | - print(f"[STREAM] Query: {user_query}") |
168 | | - |
169 | | - # Get OAuth2 access token and create Gateway MCP client |
170 | | - # The @requires_access_token decorator handles token fetching automatically |
171 | | - print("[STREAM] Creating Gateway MCP client (decorator handles OAuth2)...") |
172 | | - mcp_client = await create_gateway_mcp_client() |
173 | | - print("[STREAM] Gateway MCP client created successfully") |
174 | | - |
175 | | - print("[STREAM] Loading Gateway tools...") |
176 | | - tools = await mcp_client.get_tools() |
177 | | - print(f"[STREAM] Loaded {len(tools)} tools from Gateway") |
178 | | - |
179 | | - # Create agent with the loaded tools |
180 | | - graph = await create_langgraph_agent(user_id, session_id, tools) |
181 | | - |
182 | | - # Configure streaming with actor_id and thread_id for memory |
183 | | - config = { |
184 | | - "configurable": { |
185 | | - "thread_id": session_id, |
186 | | - "actor_id": user_id |
187 | | - } |
188 | | - } |
189 | | - |
190 | | - # Stream messages using LangGraph's astream with stream_mode="messages" |
191 | | - async for event in graph.astream( |
192 | | - {"messages": [("user", user_query)]}, |
193 | | - config=config, |
194 | | - stream_mode="messages" |
195 | | - ): |
196 | | - # event is a tuple: (message_chunk, metadata) |
197 | | - message_chunk, metadata = event |
198 | | - yield message_chunk.model_dump() |
199 | | - |
200 | | - print("[STREAM] Streaming completed successfully") |
201 | | - |
202 | | - except Exception as e: |
203 | | - error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}" |
204 | | - print(f"[STREAM ERROR] Error in agent_stream: {error_msg}") |
205 | | - print(f"[STREAM ERROR] Exception type: {type(e).__name__}") |
206 | | - traceback.print_exc() |
207 | | - yield { |
208 | | - "status": "error", |
209 | | - "error": error_msg |
210 | | - } |
| 187 | + async for event in request_agent.run(input_data): |
| 188 | + if event is not None: |
| 189 | + yield event.model_dump(mode="json", by_alias=True, exclude_none=True) |
| 190 | + except Exception as exc: |
| 191 | + logging.exception("Agent run failed") |
| 192 | + yield RunErrorEvent( |
| 193 | + message=str(exc) or type(exc).__name__, |
| 194 | + code=type(exc).__name__, |
| 195 | + ).model_dump(mode="json", by_alias=True, exclude_none=True) |
211 | 196 |
|
212 | 197 |
|
213 | 198 | if __name__ == "__main__": |
|
0 commit comments