Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/tool_system/tools/web_fetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
PermissionPassthroughResult,
PermissionResult,
)
from src.utils.abortable_net import abortable_read, call_with_abort


# -- HTML to Markdown ----------------------------------------------------------
Expand Down Expand Up @@ -231,9 +232,9 @@ def _charset_from_content_type(content_type: str) -> str | None:
return match.group(1).strip().strip('"\'') if match else None


def _read_response_body(resp) -> str:
def _read_response_body(resp, abort_signal=None) -> str:
"""Read, transparently decompress (gzip/deflate), and decode a response body."""
raw = resp.read(_MAX_FETCH_BYTES)
raw = abortable_read(resp, _MAX_FETCH_BYTES, abort_signal)
encoding = (resp.headers.get("Content-Encoding") or "").lower()
if "gzip" in encoding:
try:
Expand Down Expand Up @@ -264,20 +265,30 @@ def _is_cloudflare_challenge(e: urllib.error.HTTPError) -> bool:


def _fetch_with_redirect_handling(
url: str, timeout: float = 15, fmt: str = "markdown", user_agent: str = _BROWSER_UA
url: str,
timeout: float = 15,
fmt: str = "markdown",
user_agent: str = _BROWSER_UA,
abort_signal=None,
) -> tuple[str, str, int]:
opener = urllib.request.build_opener(_NoRedirectHandler)
current_url = url
for _ in range(_MAX_REDIRECTS):
if abort_signal is not None:
abort_signal.throw_if_aborted()
req = urllib.request.Request(current_url, headers=_request_headers(fmt, user_agent))
try:
resp = opener.open(req, timeout=timeout)
resp = call_with_abort(
lambda: opener.open(req, timeout=timeout), abort_signal
)
content_type = resp.headers.get("Content-Type", "")
return _read_response_body(resp), content_type, resp.status
return _read_response_body(resp, abort_signal), content_type, resp.status
except urllib.error.HTTPError as e:
# Cloudflare challenged the browser UA -> retry once with a bot UA.
if _is_cloudflare_challenge(e) and user_agent != _FALLBACK_UA:
return _fetch_with_redirect_handling(url, timeout, fmt, _FALLBACK_UA)
return _fetch_with_redirect_handling(
url, timeout, fmt, _FALLBACK_UA, abort_signal
)
if e.code in (301, 302, 303, 307, 308):
redirect_url = e.headers.get("Location", "")
if not redirect_url:
Expand Down Expand Up @@ -476,7 +487,9 @@ def _web_fetch_call(tool_input: dict[str, Any], context: ToolContext) -> ToolRes
if cached:
content, content_type, status = cached
else:
raw, content_type, status = _fetch_with_redirect_handling(url, fmt=fmt)
raw, content_type, status = _fetch_with_redirect_handling(
url, fmt=fmt, abort_signal=context.abort_controller.signal
)
content = _convert(raw, content_type, fmt)
_cache_set(cache_key, content, content_type, status)

Expand Down
20 changes: 16 additions & 4 deletions src/tool_system/tools/web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Any
from urllib.parse import urlparse

from src.utils.abortable_net import call_with_abort

from ..build_tool import Tool, ValidationResult, build_tool
from ..context import ToolContext
from ..errors import ToolInputError
Expand Down Expand Up @@ -130,7 +132,9 @@ def is_web_search_configured() -> bool:
return _tavily_api_key() is not None


def _tavily_search(query: str, num: int = 15) -> list[dict[str, str]]:
def _tavily_search(
query: str, num: int = 15, abort_signal=None
) -> list[dict[str, str]]:
"""Search the web via Tavily.

Raises ``ToolInputError`` when ``TAVILY_API_KEY`` is unset (so the model and
Expand All @@ -153,9 +157,15 @@ def _tavily_search(query: str, num: int = 15) -> list[dict[str, str]]:
method="POST",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {key}"},
)
try:

def _request() -> str:
with urllib.request.urlopen(req, timeout=20) as resp:
raw = resp.read(2_000_000).decode("utf-8", errors="replace")
return resp.read(2_000_000).decode("utf-8", errors="replace")

try:
# ESC unblocks the caller immediately; the worker dies at the
# 20s socket timeout (#276).
raw = call_with_abort(_request, abort_signal)
except urllib.error.HTTPError as exc:
detail = ""
try:
Expand Down Expand Up @@ -333,7 +343,9 @@ def _web_search_call(tool_input: dict[str, Any], context: ToolContext) -> ToolRe
start_time = time.monotonic()

# Search via Tavily (requires TAVILY_API_KEY).
results = _tavily_search(query, num=15)
results = _tavily_search(
query, num=15, abort_signal=context.abort_controller.signal
)

# Apply domain filters
results = _apply_domain_filters(
Expand Down
124 changes: 124 additions & 0 deletions src/utils/abortable_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Abort-aware wrappers for blocking ``urllib`` network calls.

ESC-cancel support for WebFetch/WebSearch (#276): ``urllib`` has no
cancellation primitive, so cancellation is built from two mechanisms:

- ``call_with_abort``: run the blocking call on a daemon worker thread and
poll the abort signal from the caller; on abort the CALLER unblocks
immediately (raises ``AbortError``) while the worker dies at its socket
timeout. A late-arriving response is closed so the socket isn't leaked.
- ``abortable_read``: chunked body read with an abort listener that closes
the response — closing the underlying socket unblocks a ``read()`` that
is mid-await between bytes, which polling alone cannot do.
"""

from __future__ import annotations

import socket
import threading
from typing import Any, Callable, TypeVar

from .abort_controller import AbortError, AbortSignal

T = TypeVar("T")

_POLL_INTERVAL_S = 0.05
_READ_CHUNK_BYTES = 65536


def _safe_close(obj: Any) -> None:
# ``close()`` alone does NOT interrupt a ``recv`` blocked on another
# thread — the fd stays referenced by the in-flight read. Shut the
# socket down first (http.client internals, best-effort) so the
# blocked read raises immediately instead of waiting out the timeout.
try:
sock = obj.fp.raw._sock
sock.shutdown(socket.SHUT_RDWR)
except Exception:
pass
try:
obj.close()
except Exception:
pass


def call_with_abort(fn: Callable[[], T], abort_signal: AbortSignal | None) -> T:
"""Run blocking ``fn`` and return its result, raising ``AbortError``
the moment ``abort_signal`` trips.

On abort the worker thread is abandoned (it exits at its socket
timeout, bounded by the caller's ``timeout=`` argument to urllib); if
its result arrives after the abort it is closed and discarded.
"""
if abort_signal is None:
return fn()
abort_signal.throw_if_aborted()

result: list[T] = []
error: list[BaseException] = []
done = threading.Event()

def _worker() -> None:
try:
value = fn()
if abort_signal.aborted:
_safe_close(value)
else:
result.append(value)
except BaseException as exc: # noqa: BLE001 — relayed to the caller
error.append(exc)
finally:
done.set()

thread = threading.Thread(
target=_worker, name="abortable-net-call", daemon=True
)
thread.start()
while not done.wait(_POLL_INTERVAL_S):
if abort_signal.aborted:
raise AbortError(abort_signal.reason or "user_interrupt")
if abort_signal.aborted:
raise AbortError(abort_signal.reason or "user_interrupt")
if error:
raise error[0]
return result[0]


def abortable_read(
resp: Any, max_bytes: int, abort_signal: AbortSignal | None
) -> bytes:
"""Read up to ``max_bytes`` from ``resp`` in chunks, raising
``AbortError`` if ``abort_signal`` trips mid-read.

An abort listener closes ``resp`` so a read blocked between bytes
unblocks immediately instead of waiting out the socket timeout.
"""
if abort_signal is None:
return resp.read(max_bytes)
abort_signal.throw_if_aborted()

def _close_on_abort() -> None:
_safe_close(resp)

registered = abort_signal.add_listener(_close_on_abort, once=True)
chunks: list[bytes] = []
remaining = max_bytes
try:
while remaining > 0:
abort_signal.throw_if_aborted()
try:
chunk = resp.read(min(_READ_CHUNK_BYTES, remaining))
except Exception:
if abort_signal.aborted:
raise AbortError(
abort_signal.reason or "user_interrupt"
) from None
raise
if not chunk:
break
chunks.append(chunk)
remaining -= len(chunk)
abort_signal.throw_if_aborted()
finally:
abort_signal.remove_listener(registered)
return b"".join(chunks)
Loading