|
3 | 3 | import asyncio |
4 | 4 | import json |
5 | 5 | import uuid |
| 6 | +from collections.abc import Mapping |
6 | 7 | from datetime import datetime, timezone |
7 | 8 | from typing import Annotated, Any, AsyncIterator, MutableMapping, Optional |
8 | 9 |
|
|
46 | 47 | from models.requests import QueryRequest |
47 | 48 | from utils.mcp_headers import mcp_headers_dependency, McpHeaders |
48 | 49 | from utils.responses import ( |
49 | | - extract_text_from_output_item, |
| 50 | + extract_text_from_response_item, |
50 | 51 | prepare_responses_params, |
51 | 52 | ) |
52 | 53 | from utils.suid import normalize_conversation_id |
@@ -107,7 +108,7 @@ def _convert_responses_content_to_a2a_parts(output: list[Any]) -> list[Part]: |
107 | 108 | parts: list[Part] = [] |
108 | 109 |
|
109 | 110 | for output_item in output: |
110 | | - text = extract_text_from_output_item(output_item) |
| 111 | + text = extract_text_from_response_item(output_item) |
111 | 112 | if text: |
112 | 113 | parts.append(Part(root=TextPart(text=text))) |
113 | 114 |
|
@@ -184,15 +185,22 @@ class A2AAgentExecutor(AgentExecutor): |
184 | 185 | routing queries to the LLM backend using the Responses API. |
185 | 186 | """ |
186 | 187 |
|
187 | | - def __init__(self, auth_token: str, mcp_headers: Optional[McpHeaders] = None): |
| 188 | + def __init__( |
| 189 | + self, |
| 190 | + auth_token: str, |
| 191 | + mcp_headers: Optional[McpHeaders] = None, |
| 192 | + request_headers: Optional[Mapping[str, str]] = None, |
| 193 | + ): |
188 | 194 | """Initialize the A2A agent executor. |
189 | 195 |
|
190 | 196 | Args: |
191 | 197 | auth_token: Authentication token for the request |
192 | 198 | mcp_headers: MCP headers for context propagation |
| 199 | + request_headers: Incoming HTTP request headers for allowlist propagation |
193 | 200 | """ |
194 | 201 | self.auth_token: str = auth_token |
195 | 202 | self.mcp_headers: McpHeaders = mcp_headers or {} |
| 203 | + self.request_headers: Optional[Mapping[str, str]] = request_headers |
196 | 204 |
|
197 | 205 | async def execute( |
198 | 206 | self, |
@@ -326,6 +334,7 @@ async def _process_task_streaming( # pylint: disable=too-many-locals |
326 | 334 | self.mcp_headers, |
327 | 335 | stream=True, |
328 | 336 | store=True, |
| 337 | + request_headers=self.request_headers, |
329 | 338 | ) |
330 | 339 | # Stream response from LLM using the Responses API |
331 | 340 | stream = await client.responses.create(**responses_params.model_dump()) |
@@ -649,17 +658,26 @@ async def get_agent_card( # pylint: disable=unused-argument |
649 | 658 | raise |
650 | 659 |
|
651 | 660 |
|
652 | | -async def _create_a2a_app(auth_token: str, mcp_headers: McpHeaders) -> Any: |
| 661 | +async def _create_a2a_app( |
| 662 | + auth_token: str, |
| 663 | + mcp_headers: McpHeaders, |
| 664 | + request_headers: Optional[Mapping[str, str]] = None, |
| 665 | +) -> Any: |
653 | 666 | """Create an A2A Starlette application instance with auth context. |
654 | 667 |
|
655 | 668 | Args: |
656 | 669 | auth_token: Authentication token for the request |
657 | 670 | mcp_headers: MCP headers for context propagation |
| 671 | + request_headers: Incoming HTTP request headers for allowlist propagation |
658 | 672 |
|
659 | 673 | Returns: |
660 | 674 | A2A Starlette ASGI application |
661 | 675 | """ |
662 | | - agent_executor = A2AAgentExecutor(auth_token=auth_token, mcp_headers=mcp_headers) |
| 676 | + agent_executor = A2AAgentExecutor( |
| 677 | + auth_token=auth_token, |
| 678 | + mcp_headers=mcp_headers, |
| 679 | + request_headers=request_headers, |
| 680 | + ) |
663 | 681 | task_store = await _get_task_store() |
664 | 682 |
|
665 | 683 | request_handler = DefaultRequestHandler( |
@@ -713,7 +731,7 @@ async def handle_a2a_jsonrpc( # pylint: disable=too-many-locals,too-many-statem |
713 | 731 | auth_token = "" |
714 | 732 |
|
715 | 733 | # Create A2A app with auth context |
716 | | - a2a_app = await _create_a2a_app(auth_token, mcp_headers) |
| 734 | + a2a_app = await _create_a2a_app(auth_token, mcp_headers, request.headers) |
717 | 735 |
|
718 | 736 | # Detect if this is a streaming request by checking the JSON-RPC method |
719 | 737 | is_streaming_request = False |
|
0 commit comments