-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathhttp.py
More file actions
234 lines (197 loc) · 8.55 KB
/
Copy pathhttp.py
File metadata and controls
234 lines (197 loc) · 8.55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
"""HTTPDriver: execute capabilities against HTTP APIs using httpx."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any, Literal
import httpx
from ..errors import DriverError
from ..models import RawResult
from .base import ExecutionContext
_DEFAULT_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20)
@dataclass
class HTTPEndpoint:
"""Describes an HTTP endpoint for a capability operation."""
url: str
method: str = "GET"
headers: dict[str, str] = field(default_factory=dict)
timeout: float | None = None
"""Per-endpoint timeout in seconds. Falls back to the driver's ``default_timeout``."""
response_format: Literal["json", "text"] = "json"
"""How to read a successful body: parse as JSON (default) or keep it as text."""
class HTTPDriver:
"""A driver that invokes capabilities via HTTP using :mod:`httpx`.
Each operation must be registered with an :class:`HTTPEndpoint`. The driver
holds a single long-lived :class:`httpx.AsyncClient` so requests reuse the
connection pool and keep-alive instead of paying a fresh TLS handshake on
every call (#194); call :meth:`aclose` on shutdown to release it. Bodies are
size-bounded (``max_response_bytes``) and parsed defensively — a non-JSON
body from a JSON endpoint raises :class:`DriverError` rather than leaking a
raw decode error (#197).
"""
def __init__(
self,
driver_id: str = "http",
*,
base_headers: dict[str, str] | None = None,
default_timeout: float = 30.0,
limits: httpx.Limits | None = None,
max_response_bytes: int | None = None,
) -> None:
self._driver_id = driver_id
self._endpoints: dict[str, HTTPEndpoint] = {}
self._base_headers = base_headers or {}
self._default_timeout = default_timeout
self._limits = limits or _DEFAULT_LIMITS
self._max_response_bytes = max_response_bytes
self._client: httpx.AsyncClient | None = None
@property
def driver_id(self) -> str:
"""Unique identifier for this driver."""
return self._driver_id
def register_endpoint(self, operation: str, endpoint: HTTPEndpoint) -> None:
"""Register an HTTP endpoint for an operation.
Args:
operation: The operation name to handle.
endpoint: The :class:`HTTPEndpoint` configuration.
"""
self._endpoints[operation] = endpoint
def _get_client(self) -> httpx.AsyncClient:
"""Return the shared client, creating it on first use.
Built lazily so the connection pool, default headers, and limits are
established once and reused across invocations (#194).
"""
if self._client is None:
self._client = httpx.AsyncClient(
headers=self._base_headers,
timeout=self._default_timeout,
limits=self._limits,
)
return self._client
async def aclose(self) -> None:
"""Close the shared client and release its connection pool.
Idempotent — safe to call more than once. Callers that construct an
:class:`HTTPDriver` own its lifecycle and should call this on shutdown
(e.g. in a ``finally`` block or async-context teardown).
"""
if self._client is not None:
await self._client.aclose()
self._client = None
async def execute(self, ctx: ExecutionContext) -> RawResult:
"""Execute an HTTP request for the given context.
The operation is resolved from ``ctx.args.get("operation")`` first,
then falls back to ``ctx.capability_id``.
Args:
ctx: The execution context.
Returns:
:class:`RawResult` containing the parsed JSON response, or the raw
text when the endpoint's ``response_format`` is ``"text"``.
Raises:
DriverError: If the endpoint is not registered, the request fails,
the response exceeds ``max_response_bytes``, or a JSON endpoint
returns a body that is not valid JSON.
"""
operation = str(ctx.args.get("operation", ctx.capability_id))
endpoint = self._endpoints.get(operation)
if endpoint is None:
raise DriverError(
f"HTTPDriver '{self._driver_id}' has no endpoint for operation='{operation}'."
)
method = endpoint.method.upper()
params: dict[str, Any] = {}
json_body: dict[str, Any] | None = None
if method in ("GET", "DELETE"):
params = {k: v for k, v in ctx.args.items() if k != "operation"}
else:
json_body = {k: v for k, v in ctx.args.items() if k != "operation"}
effective_timeout = (
endpoint.timeout if endpoint.timeout is not None else self._default_timeout
)
client = self._get_client()
try:
async with client.stream(
method,
endpoint.url,
params=params,
json=json_body,
headers=endpoint.headers,
timeout=effective_timeout,
) as response:
if response.is_error:
await response.aread()
raise DriverError(
f"HTTPDriver '{self._driver_id}': HTTP {response.status_code} "
f"from {endpoint.url}: {response.text[:200]}"
)
body = await self._read_bounded(response, url=endpoint.url)
status_code = response.status_code
content_type = response.headers.get("content-type", "")
except httpx.RequestError as exc:
raise DriverError(
f"HTTPDriver '{self._driver_id}': Request to {endpoint.url} failed: {exc}"
) from exc
data = self._decode_body(
body,
response_format=endpoint.response_format,
url=endpoint.url,
content_type=content_type,
)
return RawResult(
capability_id=ctx.capability_id,
data=data,
metadata={"status_code": status_code, "url": endpoint.url},
)
async def _read_bounded(self, response: httpx.Response, *, url: str) -> bytes:
"""Read the response body, aborting if it exceeds ``max_response_bytes``.
Streams chunks so an oversized upstream body is rejected before it is
fully buffered — the firewall's budget only applies *after* a
:class:`RawResult` exists, so the guard has to live here (#194).
Args:
response: The open streaming response.
url: The request URL, used in the error message.
Returns:
The full response body as bytes.
Raises:
DriverError: If the accumulated body exceeds ``max_response_bytes``.
"""
limit = self._max_response_bytes
if limit is None:
return await response.aread()
body = bytearray()
async for chunk in response.aiter_bytes():
body.extend(chunk)
if len(body) > limit:
raise DriverError(
f"HTTPDriver '{self._driver_id}': response from {url} exceeded "
f"max_response_bytes ({limit})."
)
return bytes(body)
def _decode_body(
self,
body: bytes,
*,
response_format: Literal["json", "text"],
url: str,
content_type: str,
) -> Any:
"""Decode a response body per the endpoint's ``response_format``.
Args:
body: The raw response bytes.
response_format: ``"json"`` to parse, ``"text"`` to decode as a string.
url: The request URL, used in the error message.
content_type: The response ``Content-Type``, used in the error message.
Returns:
The parsed JSON value (``None`` for an empty body), or the decoded text.
Raises:
DriverError: If ``response_format`` is ``"json"`` and the body is not
valid JSON (#197).
"""
if response_format == "text":
return body.decode("utf-8", "replace")
try:
return json.loads(body) if body else None
except (json.JSONDecodeError, ValueError) as exc:
snippet = body[:200].decode("utf-8", "replace")
raise DriverError(
f"HTTPDriver '{self._driver_id}': non-JSON response from {url} "
f"(content-type: {content_type or 'unknown'}): {snippet}"
) from exc