From a40d8037fa94d1bfec19f566c157d31feeaf06a5 Mon Sep 17 00:00:00 2001 From: hinotoi-agent Date: Sun, 10 May 2026 16:41:50 +0800 Subject: [PATCH] fix: block unsafe OpenAPI action URLs --- app/schemas/tool/action.py | 8 +- app/services/tool/openapi_call.py | 42 +++++++--- app/services/tool/url_security.py | 76 +++++++++++++++++++ tests/unit/test_action_url_security.py | 101 +++++++++++++++++++++++++ 4 files changed, 214 insertions(+), 13 deletions(-) create mode 100644 app/services/tool/url_security.py create mode 100644 tests/unit/test_action_url_security.py diff --git a/app/schemas/tool/action.py b/app/schemas/tool/action.py index ee7f46b..a30e66b 100644 --- a/app/schemas/tool/action.py +++ b/app/schemas/tool/action.py @@ -8,6 +8,7 @@ from app.exceptions.exception import ValidateFailedError from app.schemas.tool.authentication import Authentication, AuthenticationType +from app.services.tool.url_security import validate_openapi_server_url # This function code from the Open Source Project TaskingAI. @@ -31,6 +32,11 @@ def validate_openapi_schema(schema: Dict): if len(schema["servers"]) != 1: raise ValidateFailedError("Exactly one server is allowed in action schema") + server_url = schema["servers"][0].get("url") if isinstance(schema["servers"][0], dict) else None + if not server_url or not isinstance(server_url, str): + raise ValidateFailedError("Action schema server URL is required") + validate_openapi_server_url(server_url) + # check each path method has a valid description and operationId for path, methods in schema["paths"].items(): for method, details in methods.items(): @@ -56,7 +62,7 @@ def validate_openapi_schema(schema: Dict): if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", details["operationId"]): raise ValidateFailedError( - f'Invalid operationId {details["operationId"]} in {method} {path} in action schema' + f"Invalid operationId {details['operationId']} in {method} {path} in action schema" ) return schema diff --git a/app/services/tool/openapi_call.py b/app/services/tool/openapi_call.py index 85229dd..f8d9472 100644 --- a/app/services/tool/openapi_call.py +++ b/app/services/tool/openapi_call.py @@ -7,6 +7,7 @@ from app.schemas.tool.authentication import Authentication, AuthenticationType from app.schemas.tool.action import ActionMethod, ActionBodyType, ActionParam +from app.services.tool.url_security import UnsafeActionURLError, validate_action_url # This function code from the Open Source Project TaskingAI. @@ -139,18 +140,35 @@ def call_action_api( logging.info(f"call_action_api url={url} request kwargs: {request_kwargs}") - with requests.request(method.value, url, **request_kwargs) as response: - response_content_type = response.headers.get("Content-Type", "").lower() - if "application/json" in response_content_type: - data = response.json() - else: - data = response.text - if response.status_code == 500: - error_message = f"API call failed with status {response.status_code}" - if data: - error_message += f": {data}" - return {"status": response.status_code, "error": error_message} - return {"status": response.status_code, "data": data} + timeout = float(os.environ.get("ACTION_HTTP_TIMEOUT", "10")) + request_kwargs["timeout"] = timeout + request_kwargs["allow_redirects"] = False + + current_url = url + for _ in range(5): + validate_action_url(current_url) + with requests.request(method.value, current_url, **request_kwargs) as response: + if response.is_redirect: + redirect_url = response.headers.get("Location") + if not redirect_url: + return {"status": response.status_code, "error": "Redirect response missing Location header"} + current_url = urllib.parse.urljoin(current_url, redirect_url) + continue + + response_content_type = response.headers.get("Content-Type", "").lower() + if "application/json" in response_content_type: + data = response.json() + else: + data = response.text + if response.status_code == 500: + error_message = f"API call failed with status {response.status_code}" + if data: + error_message += f": {data}" + return {"status": response.status_code, "error": error_message} + return {"status": response.status_code, "data": data} + return {"status": 500, "error": "Too many redirects while making the API call"} + except UnsafeActionURLError as e: + return {"status": 400, "error": f"Blocked unsafe action URL: {e}"} except requests.exceptions.RequestException as e: return {"status": 500, "error": f"Failed to make the API call: {e}"} except Exception: diff --git a/app/services/tool/url_security.py b/app/services/tool/url_security.py new file mode 100644 index 0000000..c6fc8a9 --- /dev/null +++ b/app/services/tool/url_security.py @@ -0,0 +1,76 @@ +import ipaddress +import socket +import urllib.parse +from typing import Iterable + +from app.exceptions.exception import ValidateFailedError + + +_ALLOWED_ACTION_SCHEMES = {"http", "https"} +_BLOCKED_HOSTNAMES = {"localhost"} + + +class UnsafeActionURLError(ValueError): + pass + + +def _is_blocked_ip(address: str) -> bool: + ip = ipaddress.ip_address(address) + return any( + [ + ip.is_private, + ip.is_loopback, + ip.is_link_local, + ip.is_multicast, + ip.is_reserved, + ip.is_unspecified, + ] + ) + + +def _resolved_addresses(hostname: str, port: int | None) -> Iterable[str]: + try: + addrinfos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM) + except socket.gaierror as exc: + raise UnsafeActionURLError(f"Unable to resolve action URL host: {hostname}") from exc + + addresses = set() + for addrinfo in addrinfos: + sockaddr = addrinfo[4] + if sockaddr: + addresses.add(sockaddr[0]) + return addresses + + +def validate_action_url(url: str) -> None: + """Validate that an action URL does not target local or private network addresses.""" + parsed = urllib.parse.urlparse(url) + if parsed.scheme.lower() not in _ALLOWED_ACTION_SCHEMES: + raise UnsafeActionURLError("Action URL scheme must be http or https") + + if not parsed.hostname: + raise UnsafeActionURLError("Action URL must include a hostname") + + if parsed.username or parsed.password: + raise UnsafeActionURLError("Action URL must not include user credentials") + + hostname = parsed.hostname.rstrip(".").lower() + if hostname in _BLOCKED_HOSTNAMES or hostname.endswith(".localhost"): + raise UnsafeActionURLError("Action URL host is not allowed") + + try: + ip = ipaddress.ip_address(hostname) + except ValueError: + for address in _resolved_addresses(hostname, parsed.port): + if _is_blocked_ip(address): + raise UnsafeActionURLError("Action URL host resolves to a disallowed address") + else: + if _is_blocked_ip(str(ip)): + raise UnsafeActionURLError("Action URL host is not allowed") + + +def validate_openapi_server_url(url: str) -> None: + try: + validate_action_url(url) + except UnsafeActionURLError as exc: + raise ValidateFailedError(f"Invalid action server URL: {exc}") from exc diff --git a/tests/unit/test_action_url_security.py b/tests/unit/test_action_url_security.py new file mode 100644 index 0000000..efa30d0 --- /dev/null +++ b/tests/unit/test_action_url_security.py @@ -0,0 +1,101 @@ +import socket + +from app.exceptions.exception import ValidateFailedError +from app.schemas.tool.action import ActionBodyType, ActionMethod, validate_openapi_schema +from app.schemas.tool.authentication import Authentication, AuthenticationType +from app.services.tool.openapi_call import call_action_api +from app.services.tool.url_security import UnsafeActionURLError, validate_action_url + + +def _addrinfo(address): + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (address, 80))] + + +def _schema(server_url): + return { + "openapi": "3.0.0", + "info": {"title": "test", "version": "1.0"}, + "servers": [{"url": server_url}], + "paths": { + "/status": { + "get": { + "operationId": "get_status", + "description": "read status", + "responses": {"200": {"description": "ok"}}, + } + } + }, + } + + +class FakeResponse: + def __init__(self, status_code=200, text="ok", headers=None): + self.status_code = status_code + self.text = text + self.headers = headers or {"Content-Type": "text/plain"} + self.is_redirect = status_code in {301, 302, 303, 307, 308} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def json(self): + return {"ok": True} + + +def test_validate_action_url_blocks_loopback_ip(): + try: + validate_action_url("http://127.0.0.1:8080/secret") + except UnsafeActionURLError: + return + raise AssertionError("loopback action URL was not blocked") + + +def test_validate_action_url_blocks_dns_to_private_address(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda *args, **kwargs: _addrinfo("10.0.0.5")) + + try: + validate_action_url("https://api.example.test/status") + except UnsafeActionURLError: + return + raise AssertionError("private DNS target was not blocked") + + +def test_openapi_schema_rejects_private_server_url(): + try: + validate_openapi_schema(_schema("http://169.254.169.254/latest")) + except ValidateFailedError: + return + raise AssertionError("OpenAPI server URL pointing at metadata service was not rejected") + + +def test_call_action_api_blocks_private_redirect(monkeypatch): + monkeypatch.setattr(socket, "getaddrinfo", lambda host, *args, **kwargs: _addrinfo("93.184.216.34")) + + calls = [] + + def fake_request(method, url, **kwargs): + calls.append((method, url, kwargs)) + return FakeResponse(302, headers={"Location": "http://127.0.0.1:8080/admin"}) + + monkeypatch.setattr("app.services.tool.openapi_call.requests.request", fake_request) + + result = call_action_api( + url="https://api.example.test/status", + method=ActionMethod.GET, + path_param_schema={}, + query_param_schema={}, + body_type=ActionBodyType.NONE, + body_param_schema={}, + parameters={}, + headers={}, + authentication=Authentication(type=AuthenticationType.none), + ) + + assert calls == [("GET", "https://api.example.test/status", calls[0][2])] + assert result["status"] == 400 + assert "Blocked unsafe action URL" in result["error"] + assert calls[0][2]["allow_redirects"] is False + assert calls[0][2]["timeout"] == 10