From 6bc639727181d18a27184aa69e9aa8952230e827 Mon Sep 17 00:00:00 2001 From: Henrique Chaves Date: Wed, 18 Mar 2026 18:02:44 +0100 Subject: [PATCH 1/7] feat: add local scan with WebSocket support --- pyproject.toml | 1 + src/giskard_hub/resources/_ws_scan.py | 216 ++++++++++++++++++++ src/giskard_hub/resources/helpers.py | 271 ++++++++++++++++++++++++++ uv.lock | 61 ++++++ 4 files changed, 549 insertions(+) create mode 100644 src/giskard_hub/resources/_ws_scan.py diff --git a/pyproject.toml b/pyproject.toml index 72fd3da..cd62b5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "distro>=1.7.0, <2", "rich", "sniffio", + "websockets>=13.0, <15", ] requires-python = ">=3.10, <3.15" classifiers = [ diff --git a/src/giskard_hub/resources/_ws_scan.py b/src/giskard_hub/resources/_ws_scan.py new file mode 100644 index 0000000..c7a28ba --- /dev/null +++ b/src/giskard_hub/resources/_ws_scan.py @@ -0,0 +1,216 @@ +"""WebSocket client for local-agent scan execution. + +This module provides both sync and async functions that connect to the Hub's +``/v2/scans/{scan_id}/ws`` WebSocket endpoint, receive agent invocation +requests from LIDAR (running on the Hub), execute the local agent callable, +and send the response back. +""" + +import asyncio +import inspect +import json +import logging +import os +import ssl +from typing import Any, Awaitable, Callable +from urllib.parse import urlencode, urlparse, urlunparse + +import httpx + +from ..types.chat import ChatMessage + +logger = logging.getLogger(__name__) + + +def _ssl_context_from_httpx(http_client: httpx.Client | httpx.AsyncClient | None) -> ssl.SSLContext | None: + """Derive an ``ssl.SSLContext`` that mirrors the httpx client's TLS config. + + Walks the internal transport chain + (``httpx.Client._transport._pool._ssl_context``) to extract the exact + ``ssl.SSLContext`` that httpx uses. This means ``verify=False`` on the + httpx client automatically disables verification for the WebSocket too, + and custom CA bundles are preserved. + + Returns ``None`` (use system defaults) when extraction fails. + """ + if http_client is None: + return None + + # httpx.Client._transport → HTTPTransport._pool → httpcore.ConnectionPool._ssl_context + try: + ctx = http_client._transport._pool._ssl_context # type: ignore[union-attr] + if isinstance(ctx, ssl.SSLContext): + return ctx + except AttributeError: + pass + + # Fallback: check common CA-bundle environment variables. + for env_var in ("SSL_CERT_FILE", "REQUESTS_CA_BUNDLE", "CURL_CA_BUNDLE"): + ca_path = os.environ.get(env_var) + if ca_path: + return ssl.create_default_context(cafile=ca_path) + + return None + + +def _http_to_ws_url(http_url: str) -> str: + """Convert an HTTP(S) URL to a WS(S) URL.""" + parsed = urlparse(http_url) + if parsed.scheme == "https": + scheme = "wss" + elif parsed.scheme == "http": + scheme = "ws" + else: + scheme = parsed.scheme + return urlunparse(parsed._replace(scheme=scheme)) + + +def _build_ws_url(base_url: str, scan_id: str, api_key: str) -> str: + """Build the full WebSocket URL for a local scan session.""" + ws_base = _http_to_ws_url(base_url.rstrip("/")) + query = urlencode({"api_key": api_key}) + return f"{ws_base}/v2/scans/{scan_id}/ws?{query}" + + +AgentCallable = Callable[[list[ChatMessage]], Any] +AsyncAgentCallable = Callable[[list[ChatMessage]], Any | Awaitable[Any]] + + +def _normalize_output(value: Any) -> dict: + """Turn the agent return value into a dict matching the Hub protocol.""" + from ._helpers_types import normalize_agent_output + + output: AgentOutput = normalize_agent_output(value) + return output.to_dict() + + +async def _arun_ws_scan( + base_url: str, + api_key: str, + scan_id: str, + agent: AsyncAgentCallable, + on_progress: Callable[[dict], Any] | None = None, + ssl_context: ssl.SSLContext | bool | None = None, +) -> dict | None: + """Async implementation of the WebSocket scan loop. + + Returns the ``complete`` message payload, or ``None`` if the connection + closed before completion. + """ + try: + from websockets.asyncio.client import connect + except ImportError as exc: + raise ImportError( + "The 'websockets' package is required for local scan execution. " + "Install it with: pip install 'giskard-hub[websockets]' or pip install websockets" + ) from exc + + url = _build_ws_url(base_url, scan_id, api_key) + logger.info("Connecting to scan WebSocket: %s", url.split("?")[0]) + + connect_kwargs: dict[str, Any] = {} + if ssl_context is not None: + connect_kwargs["ssl"] = ssl_context + + async with connect(url, **connect_kwargs) as ws: + async for raw_msg in ws: + try: + msg = json.loads(raw_msg) + except json.JSONDecodeError: + logger.warning("Non-JSON WebSocket message received, ignoring") + continue + + msg_type = msg.get("type") + + if msg_type == "invoke": + request_id = msg["request_id"] + messages = [ + ChatMessage( + role=m.get("role", "user"), + content=m.get("content", ""), + ) + for m in msg.get("messages", []) + ] + + try: + result = agent(messages) + if inspect.isawaitable(result): + result = await result + output = _normalize_output(result) + await ws.send( + json.dumps( + { + "type": "response", + "request_id": request_id, + "output": output, + } + ) + ) + except Exception as exc: + logger.error("Agent invocation failed: %s", exc) + await ws.send( + json.dumps( + { + "type": "error", + "request_id": request_id, + "error": {"message": str(exc)}, + } + ) + ) + + elif msg_type == "progress": + if on_progress: + status = msg.get("status", {}) + cb_result = on_progress(status) + if inspect.isawaitable(cb_result): + await cb_result + + elif msg_type == "complete": + logger.info( + "Scan %s completed with grade: %s", + scan_id, + msg.get("grade"), + ) + return msg + + elif msg_type == "error": + error_msg = msg.get("message", "Unknown server error") + raise RuntimeError(f"Scan error from Hub: {error_msg}") + + else: + logger.warning("Unknown WebSocket message type: %s", msg_type) + + return None + + +def run_ws_scan( + base_url: str, + api_key: str, + scan_id: str, + agent: AgentCallable, + on_progress: Callable[[dict], Any] | None = None, + ssl_context: ssl.SSLContext | bool | None = None, +) -> dict | None: + """Synchronous wrapper around the async WebSocket scan loop. + + Works correctly even when called from an environment that already has + a running event loop (e.g. Jupyter notebooks) by executing the async + code in a dedicated thread with its own event loop. + """ + import concurrent.futures + + def _run_in_thread() -> dict | None: + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete( + _arun_ws_scan( + base_url, api_key, scan_id, agent, on_progress, + ssl_context=ssl_context, + ) + ) + finally: + loop.close() + + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: + future = pool.submit(_run_in_thread) + return future.result() diff --git a/src/giskard_hub/resources/helpers.py b/src/giskard_hub/resources/helpers.py index 1a9acc9..ea3fcd1 100644 --- a/src/giskard_hub/resources/helpers.py +++ b/src/giskard_hub/resources/helpers.py @@ -157,6 +157,81 @@ def evaluate( return self._evaluate_local(agent=agent, dataset_id=dataset_id, name=name, tags=tags) + def scan( + self, + *, + agent: str | Agent | Callable[[list[ChatMessage]], AgentReturn], + project: str | Project, + knowledge_base: Optional[str] | Omit = omit, + tags: Optional[SequenceNotStr[str]] | Omit = omit, + agent_name: Optional[str] = None, + agent_description: Optional[str] = None, + supported_languages: Optional[list[str]] = None, + poll_interval: float = 5.0, + ) -> "Scan": + """Run a vulnerability scan for a given agent. + + Handles both remote agents (referenced by ID or ``Agent``, which must + already be registered in the Hub) and local Python callables. + + When a local callable is provided, the scan is executed via WebSocket: + the Hub orchestrates LIDAR server-side and sends agent invocation + requests to this process through the WebSocket connection. The TLS + configuration (certificate verification) is automatically inherited + from the ``httpx.Client`` passed to ``HubClient``. + + Parameters + ---------- + agent : + Either a remote agent identifier (``str`` or ``Agent``) or a + callable with signature ``(messages: list[ChatMessage]) -> AgentReturn``. + project : + Project identifier or ``Project`` instance. + knowledge_base : + Optional knowledge base identifier. + tags : + Optional list of tags to filter which probes to run. + agent_name : + Name for the agent (used only for local callables; defaults to + the function name). + agent_description : + Description of the agent (used only for local callables). + supported_languages : + Languages the agent supports (default ``["en"]``). + poll_interval : + Seconds between status polls when running a remote scan. + + Returns + ------- + Scan + The completed scan result with grade and probe information. + """ + from ..types.scan import Scan as _Scan + + project_id = project if isinstance(project, str) else project.id + kb_id = knowledge_base if isinstance(knowledge_base, str) else ( + knowledge_base if knowledge_base is omit else knowledge_base + ) + + if isinstance(agent, (str, Agent)): + return self._scan_remote( + agent=agent, + project_id=project_id, + knowledge_base_id=kb_id, + tags=tags, + poll_interval=poll_interval, + ) + + return self._scan_local( + agent=agent, + project_id=project_id, + knowledge_base_id=kb_id if kb_id is not omit else None, + tags=tags if tags is not omit else None, + agent_name=agent_name or getattr(agent, "__name__", "local_agent"), + agent_description=agent_description or getattr(agent, "__doc__", None) or "", + supported_languages=supported_languages or ["en"], + ) + def print_metrics(self, entity: PrintMetricsEntity) -> None: """Print metrics for an evaluation or scan result to the console. @@ -171,6 +246,83 @@ def print_metrics(self, entity: PrintMetricsEntity) -> None: # -- Private helpers ----------------------------------------------------- + def _scan_remote( + self, + *, + agent: str | Agent, + project_id: str, + knowledge_base_id: "str | Omit | None", + tags: "Optional[SequenceNotStr[str]] | Omit", + poll_interval: float, + ) -> "Scan": + from ..types.scan import Scan as _Scan + + agent_id = agent if isinstance(agent, str) else agent.id + + create_kwargs: dict = { + "project_id": project_id, + "agent_id": agent_id, + } + if knowledge_base_id is not omit and knowledge_base_id is not None: + create_kwargs["knowledge_base_id"] = knowledge_base_id + if tags is not omit: + create_kwargs["tags"] = tags + + scan = self._client.scans.create(**create_kwargs) + return cast(_Scan, self.wait_for_completion(scan, poll_interval=poll_interval)) + + def _scan_local( + self, + *, + agent: Callable[[list[ChatMessage]], AgentReturn], + project_id: str, + knowledge_base_id: str | None, + tags: "list[str] | SequenceNotStr[str] | None", + agent_name: str, + agent_description: str, + supported_languages: list[str], + ) -> "Scan": + from ..types.scan import Scan as _Scan + from ..types.common import APIResponse as _APIResponse + from ._ws_scan import run_ws_scan, _ssl_context_from_httpx + + # 1. Create the scan record on the Hub (no worker enqueue). + body: dict = { + "project_id": project_id, + "agent_name": agent_name, + "agent_description": agent_description, + "supported_languages": supported_languages, + } + if knowledge_base_id: + body["knowledge_base_id"] = knowledge_base_id + if tags: + body["tags"] = list(tags) + + response = self._post( + "/v2/scans/create-local", + body=body, + cast_to=_APIResponse[_Scan], + ) + scan = self._unwrap(response) + scan_id = scan.id + + # 2. Connect via WebSocket and drive the scan. + # Inherit TLS config (verify/no-verify) from the httpx client. + base_url = str(self._client.base_url) + api_key = self._client.api_key + ssl_ctx = _ssl_context_from_httpx(self._client._client) + + run_ws_scan( + base_url=base_url, + api_key=api_key, + scan_id=scan_id, + agent=agent, + ssl_context=ssl_ctx, + ) + + # 3. Retrieve the final scan result. + return self._client.scans.retrieve(scan_id) + def _evaluate_remote( self, *, @@ -367,6 +519,48 @@ async def evaluate( return await self._evaluate_local(agent=agent, dataset_id=dataset_id, name=name, tags=tags) + async def scan( + self, + *, + agent: str | Agent | Callable[[list[ChatMessage]], AgentReturn | Awaitable[AgentReturn]], + project: str | Project, + knowledge_base: Optional[str] | Omit = omit, + tags: Optional[SequenceNotStr[str]] | Omit = omit, + agent_name: Optional[str] = None, + agent_description: Optional[str] = None, + supported_languages: Optional[list[str]] = None, + poll_interval: float = 5.0, + ) -> "Scan": + """Asynchronously run a vulnerability scan for a given agent. + + See :meth:`HelpersResource.scan` for full parameter documentation. + """ + from ..types.scan import Scan as _Scan + + project_id = project if isinstance(project, str) else project.id + kb_id = knowledge_base if isinstance(knowledge_base, str) else ( + knowledge_base if knowledge_base is omit else knowledge_base + ) + + if isinstance(agent, (str, Agent)): + return await self._scan_remote( + agent=agent, + project_id=project_id, + knowledge_base_id=kb_id, + tags=tags, + poll_interval=poll_interval, + ) + + return await self._scan_local( + agent=agent, + project_id=project_id, + knowledge_base_id=kb_id if kb_id is not omit else None, + tags=tags if tags is not omit else None, + agent_name=agent_name or getattr(agent, "__name__", "local_agent"), + agent_description=agent_description or getattr(agent, "__doc__", None) or "", + supported_languages=supported_languages or ["en"], + ) + async def print_metrics(self, entity: PrintMetricsEntity) -> None: """Print metrics for an evaluation or scan result to the console (async). @@ -381,6 +575,83 @@ async def print_metrics(self, entity: PrintMetricsEntity) -> None: # -- Private helpers ----------------------------------------------------- + async def _scan_remote( + self, + *, + agent: str | Agent, + project_id: str, + knowledge_base_id: "str | Omit | None", + tags: "Optional[SequenceNotStr[str]] | Omit", + poll_interval: float, + ) -> "Scan": + from ..types.scan import Scan as _Scan + + agent_id = agent if isinstance(agent, str) else agent.id + + create_kwargs: dict = { + "project_id": project_id, + "agent_id": agent_id, + } + if knowledge_base_id is not omit and knowledge_base_id is not None: + create_kwargs["knowledge_base_id"] = knowledge_base_id + if tags is not omit: + create_kwargs["tags"] = tags + + scan = await self._client.scans.create(**create_kwargs) + return cast(_Scan, await self.wait_for_completion(scan, poll_interval=poll_interval)) + + async def _scan_local( + self, + *, + agent: Callable[[list[ChatMessage]], AgentReturn | Awaitable[AgentReturn]], + project_id: str, + knowledge_base_id: str | None, + tags: "list[str] | SequenceNotStr[str] | None", + agent_name: str, + agent_description: str, + supported_languages: list[str], + ) -> "Scan": + from ..types.scan import Scan as _Scan + from ..types.common import APIResponse as _APIResponse + from ._ws_scan import _arun_ws_scan, _ssl_context_from_httpx + + # 1. Create the scan record on the Hub (no worker enqueue). + body: dict = { + "project_id": project_id, + "agent_name": agent_name, + "agent_description": agent_description, + "supported_languages": supported_languages, + } + if knowledge_base_id: + body["knowledge_base_id"] = knowledge_base_id + if tags: + body["tags"] = list(tags) + + response = await self._post( + "/v2/scans/create-local", + body=body, + cast_to=_APIResponse[_Scan], + ) + scan = self._unwrap(response) + scan_id = scan.id + + # 2. Connect via WebSocket and drive the scan. + # Inherit TLS config (verify/no-verify) from the httpx client. + base_url = str(self._client.base_url) + api_key = self._client.api_key + ssl_ctx = _ssl_context_from_httpx(self._client._client) + + await _arun_ws_scan( + base_url=base_url, + api_key=api_key, + scan_id=scan_id, + agent=agent, + ssl_context=ssl_ctx, + ) + + # 3. Retrieve the final scan result. + return await self._client.scans.retrieve(scan_id) + async def _evaluate_remote( self, *, diff --git a/uv.lock b/uv.lock index 39f2596..b4af34e 100644 --- a/uv.lock +++ b/uv.lock @@ -384,6 +384,7 @@ dependencies = [ { name = "rich" }, { name = "sniffio" }, { name = "typing-extensions" }, + { name = "websockets" }, ] [package.optional-dependencies] @@ -421,6 +422,7 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "sniffio" }, { name = "typing-extensions", specifier = ">=4.10,<5" }, + { name = "websockets", specifier = ">=13.0,<15" }, ] provides-extras = ["aiohttp", "dev"] @@ -1264,6 +1266,65 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, ] +[[package]] +name = "websockets" +version = "14.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/54/8359678c726243d19fae38ca14a334e740782336c9f19700858c4eb64a1e/websockets-14.2.tar.gz", hash = "sha256:5059ed9c54945efb321f097084b4c7e52c246f2c869815876a69d1efc4ad6eb5", size = 164394, upload-time = "2025-01-19T21:00:56.431Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/28/fa/76607eb7dcec27b2d18d63f60a32e60e2b8629780f343bb83a4dbb9f4350/websockets-14.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e8179f95323b9ab1c11723e5d91a89403903f7b001828161b480a7810b334885", size = 163089, upload-time = "2025-01-19T20:58:43.399Z" }, + { url = "https://files.pythonhosted.org/packages/9e/00/ad2246b5030575b79e7af0721810fdaecaf94c4b2625842ef7a756fa06dd/websockets-14.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0d8c3e2cdb38f31d8bd7d9d28908005f6fa9def3324edb9bf336d7e4266fd397", size = 160741, upload-time = "2025-01-19T20:58:45.309Z" }, + { url = "https://files.pythonhosted.org/packages/72/f7/60f10924d333a28a1ff3fcdec85acf226281331bdabe9ad74947e1b7fc0a/websockets-14.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:714a9b682deb4339d39ffa674f7b674230227d981a37d5d174a4a83e3978a610", size = 160996, upload-time = "2025-01-19T20:58:47.563Z" }, + { url = "https://files.pythonhosted.org/packages/63/7c/c655789cf78648c01ac6ecbe2d6c18f91b75bdc263ffee4d08ce628d12f0/websockets-14.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2e53c72052f2596fb792a7acd9704cbc549bf70fcde8a99e899311455974ca3", size = 169974, upload-time = "2025-01-19T20:58:51.023Z" }, + { url = "https://files.pythonhosted.org/packages/fb/5b/013ed8b4611857ac92ac631079c08d9715b388bd1d88ec62e245f87a39df/websockets-14.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3fbd68850c837e57373d95c8fe352203a512b6e49eaae4c2f4088ef8cf21980", size = 168985, upload-time = "2025-01-19T20:58:52.698Z" }, + { url = "https://files.pythonhosted.org/packages/cd/33/aa3e32fd0df213a5a442310754fe3f89dd87a0b8e5b4e11e0991dd3bcc50/websockets-14.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b27ece32f63150c268593d5fdb82819584831a83a3f5809b7521df0685cd5d8", size = 169297, upload-time = "2025-01-19T20:58:54.898Z" }, + { url = "https://files.pythonhosted.org/packages/93/17/dae0174883d6399f57853ac44abf5f228eaba86d98d160f390ffabc19b6e/websockets-14.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4daa0faea5424d8713142b33825fff03c736f781690d90652d2c8b053345b0e7", size = 169677, upload-time = "2025-01-19T20:58:56.36Z" }, + { url = "https://files.pythonhosted.org/packages/42/e2/0375af7ac00169b98647c804651c515054b34977b6c1354f1458e4116c1e/websockets-14.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:bc63cee8596a6ec84d9753fd0fcfa0452ee12f317afe4beae6b157f0070c6c7f", size = 169089, upload-time = "2025-01-19T20:58:58.824Z" }, + { url = "https://files.pythonhosted.org/packages/73/8d/80f71d2a351a44b602859af65261d3dde3a0ce4e76cf9383738a949e0cc3/websockets-14.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a570862c325af2111343cc9b0257b7119b904823c675b22d4ac547163088d0d", size = 169026, upload-time = "2025-01-19T20:59:01.089Z" }, + { url = "https://files.pythonhosted.org/packages/48/97/173b1fa6052223e52bb4054a141433ad74931d94c575e04b654200b98ca4/websockets-14.2-cp310-cp310-win32.whl", hash = "sha256:75862126b3d2d505e895893e3deac0a9339ce750bd27b4ba515f008b5acf832d", size = 163967, upload-time = "2025-01-19T20:59:02.662Z" }, + { url = "https://files.pythonhosted.org/packages/c0/5b/2fcf60f38252a4562b28b66077e0d2b48f91fef645d5f78874cd1dec807b/websockets-14.2-cp310-cp310-win_amd64.whl", hash = "sha256:cc45afb9c9b2dc0852d5c8b5321759cf825f82a31bfaf506b65bf4668c96f8b2", size = 164413, upload-time = "2025-01-19T20:59:05.071Z" }, + { url = "https://files.pythonhosted.org/packages/15/b6/504695fb9a33df0ca56d157f5985660b5fc5b4bf8c78f121578d2d653392/websockets-14.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3bdc8c692c866ce5fefcaf07d2b55c91d6922ac397e031ef9b774e5b9ea42166", size = 163088, upload-time = "2025-01-19T20:59:06.435Z" }, + { url = "https://files.pythonhosted.org/packages/81/26/ebfb8f6abe963c795122439c6433c4ae1e061aaedfc7eff32d09394afbae/websockets-14.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c93215fac5dadc63e51bcc6dceca72e72267c11def401d6668622b47675b097f", size = 160745, upload-time = "2025-01-19T20:59:09.109Z" }, + { url = "https://files.pythonhosted.org/packages/a1/c6/1435ad6f6dcbff80bb95e8986704c3174da8866ddb751184046f5c139ef6/websockets-14.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c9b6535c0e2cf8a6bf938064fb754aaceb1e6a4a51a80d884cd5db569886910", size = 160995, upload-time = "2025-01-19T20:59:12.816Z" }, + { url = "https://files.pythonhosted.org/packages/96/63/900c27cfe8be1a1f2433fc77cd46771cf26ba57e6bdc7cf9e63644a61863/websockets-14.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a52a6d7cf6938e04e9dceb949d35fbdf58ac14deea26e685ab6368e73744e4c", size = 170543, upload-time = "2025-01-19T20:59:15.026Z" }, + { url = "https://files.pythonhosted.org/packages/00/8b/bec2bdba92af0762d42d4410593c1d7d28e9bfd952c97a3729df603dc6ea/websockets-14.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f05702e93203a6ff5226e21d9b40c037761b2cfb637187c9802c10f58e40473", size = 169546, upload-time = "2025-01-19T20:59:17.156Z" }, + { url = "https://files.pythonhosted.org/packages/6b/a9/37531cb5b994f12a57dec3da2200ef7aadffef82d888a4c29a0d781568e4/websockets-14.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22441c81a6748a53bfcb98951d58d1af0661ab47a536af08920d129b4d1c3473", size = 169911, upload-time = "2025-01-19T20:59:18.623Z" }, + { url = "https://files.pythonhosted.org/packages/60/d5/a6eadba2ed9f7e65d677fec539ab14a9b83de2b484ab5fe15d3d6d208c28/websockets-14.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:efd9b868d78b194790e6236d9cbc46d68aba4b75b22497eb4ab64fa640c3af56", size = 170183, upload-time = "2025-01-19T20:59:20.743Z" }, + { url = "https://files.pythonhosted.org/packages/76/57/a338ccb00d1df881c1d1ee1f2a20c9c1b5b29b51e9e0191ee515d254fea6/websockets-14.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1a5a20d5843886d34ff8c57424cc65a1deda4375729cbca4cb6b3353f3ce4142", size = 169623, upload-time = "2025-01-19T20:59:22.286Z" }, + { url = "https://files.pythonhosted.org/packages/64/22/e5f7c33db0cb2c1d03b79fd60d189a1da044e2661f5fd01d629451e1db89/websockets-14.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:34277a29f5303d54ec6468fb525d99c99938607bc96b8d72d675dee2b9f5bf1d", size = 169583, upload-time = "2025-01-19T20:59:23.656Z" }, + { url = "https://files.pythonhosted.org/packages/aa/2e/2b4662237060063a22e5fc40d46300a07142afe30302b634b4eebd717c07/websockets-14.2-cp311-cp311-win32.whl", hash = "sha256:02687db35dbc7d25fd541a602b5f8e451a238ffa033030b172ff86a93cb5dc2a", size = 163969, upload-time = "2025-01-19T20:59:26.004Z" }, + { url = "https://files.pythonhosted.org/packages/94/a5/0cda64e1851e73fc1ecdae6f42487babb06e55cb2f0dc8904b81d8ef6857/websockets-14.2-cp311-cp311-win_amd64.whl", hash = "sha256:862e9967b46c07d4dcd2532e9e8e3c2825e004ffbf91a5ef9dde519ee2effb0b", size = 164408, upload-time = "2025-01-19T20:59:28.105Z" }, + { url = "https://files.pythonhosted.org/packages/c1/81/04f7a397653dc8bec94ddc071f34833e8b99b13ef1a3804c149d59f92c18/websockets-14.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1f20522e624d7ffbdbe259c6b6a65d73c895045f76a93719aa10cd93b3de100c", size = 163096, upload-time = "2025-01-19T20:59:29.763Z" }, + { url = "https://files.pythonhosted.org/packages/ec/c5/de30e88557e4d70988ed4d2eabd73fd3e1e52456b9f3a4e9564d86353b6d/websockets-14.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:647b573f7d3ada919fd60e64d533409a79dcf1ea21daeb4542d1d996519ca967", size = 160758, upload-time = "2025-01-19T20:59:32.095Z" }, + { url = "https://files.pythonhosted.org/packages/e5/8c/d130d668781f2c77d106c007b6c6c1d9db68239107c41ba109f09e6c218a/websockets-14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6af99a38e49f66be5a64b1e890208ad026cda49355661549c507152113049990", size = 160995, upload-time = "2025-01-19T20:59:33.527Z" }, + { url = "https://files.pythonhosted.org/packages/a6/bc/f6678a0ff17246df4f06765e22fc9d98d1b11a258cc50c5968b33d6742a1/websockets-14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:091ab63dfc8cea748cc22c1db2814eadb77ccbf82829bac6b2fbe3401d548eda", size = 170815, upload-time = "2025-01-19T20:59:35.837Z" }, + { url = "https://files.pythonhosted.org/packages/d8/b2/8070cb970c2e4122a6ef38bc5b203415fd46460e025652e1ee3f2f43a9a3/websockets-14.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b374e8953ad477d17e4851cdc66d83fdc2db88d9e73abf755c94510ebddceb95", size = 169759, upload-time = "2025-01-19T20:59:38.216Z" }, + { url = "https://files.pythonhosted.org/packages/81/da/72f7caabd94652e6eb7e92ed2d3da818626e70b4f2b15a854ef60bf501ec/websockets-14.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a39d7eceeea35db85b85e1169011bb4321c32e673920ae9c1b6e0978590012a3", size = 170178, upload-time = "2025-01-19T20:59:40.423Z" }, + { url = "https://files.pythonhosted.org/packages/31/e0/812725b6deca8afd3a08a2e81b3c4c120c17f68c9b84522a520b816cda58/websockets-14.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0a6f3efd47ffd0d12080594f434faf1cd2549b31e54870b8470b28cc1d3817d9", size = 170453, upload-time = "2025-01-19T20:59:41.996Z" }, + { url = "https://files.pythonhosted.org/packages/66/d3/8275dbc231e5ba9bb0c4f93144394b4194402a7a0c8ffaca5307a58ab5e3/websockets-14.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:065ce275e7c4ffb42cb738dd6b20726ac26ac9ad0a2a48e33ca632351a737267", size = 169830, upload-time = "2025-01-19T20:59:44.669Z" }, + { url = "https://files.pythonhosted.org/packages/a3/ae/e7d1a56755ae15ad5a94e80dd490ad09e345365199600b2629b18ee37bc7/websockets-14.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e9d0e53530ba7b8b5e389c02282f9d2aa47581514bd6049d3a7cffe1385cf5fe", size = 169824, upload-time = "2025-01-19T20:59:46.932Z" }, + { url = "https://files.pythonhosted.org/packages/b6/32/88ccdd63cb261e77b882e706108d072e4f1c839ed723bf91a3e1f216bf60/websockets-14.2-cp312-cp312-win32.whl", hash = "sha256:20e6dd0984d7ca3037afcb4494e48c74ffb51e8013cac71cf607fffe11df7205", size = 163981, upload-time = "2025-01-19T20:59:49.228Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7d/32cdb77990b3bdc34a306e0a0f73a1275221e9a66d869f6ff833c95b56ef/websockets-14.2-cp312-cp312-win_amd64.whl", hash = "sha256:44bba1a956c2c9d268bdcdf234d5e5ff4c9b6dc3e300545cbe99af59dda9dcce", size = 164421, upload-time = "2025-01-19T20:59:50.674Z" }, + { url = "https://files.pythonhosted.org/packages/82/94/4f9b55099a4603ac53c2912e1f043d6c49d23e94dd82a9ce1eb554a90215/websockets-14.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6f1372e511c7409a542291bce92d6c83320e02c9cf392223272287ce55bc224e", size = 163102, upload-time = "2025-01-19T20:59:52.177Z" }, + { url = "https://files.pythonhosted.org/packages/8e/b7/7484905215627909d9a79ae07070057afe477433fdacb59bf608ce86365a/websockets-14.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4da98b72009836179bb596a92297b1a61bb5a830c0e483a7d0766d45070a08ad", size = 160766, upload-time = "2025-01-19T20:59:54.368Z" }, + { url = "https://files.pythonhosted.org/packages/a3/a4/edb62efc84adb61883c7d2c6ad65181cb087c64252138e12d655989eec05/websockets-14.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8a86a269759026d2bde227652b87be79f8a734e582debf64c9d302faa1e9f03", size = 160998, upload-time = "2025-01-19T20:59:56.671Z" }, + { url = "https://files.pythonhosted.org/packages/f5/79/036d320dc894b96af14eac2529967a6fc8b74f03b83c487e7a0e9043d842/websockets-14.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86cf1aaeca909bf6815ea714d5c5736c8d6dd3a13770e885aafe062ecbd04f1f", size = 170780, upload-time = "2025-01-19T20:59:58.085Z" }, + { url = "https://files.pythonhosted.org/packages/63/75/5737d21ee4dd7e4b9d487ee044af24a935e36a9ff1e1419d684feedcba71/websockets-14.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9b0f6c3ba3b1240f602ebb3971d45b02cc12bd1845466dd783496b3b05783a5", size = 169717, upload-time = "2025-01-19T20:59:59.545Z" }, + { url = "https://files.pythonhosted.org/packages/2c/3c/bf9b2c396ed86a0b4a92ff4cdaee09753d3ee389be738e92b9bbd0330b64/websockets-14.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669c3e101c246aa85bc8534e495952e2ca208bd87994650b90a23d745902db9a", size = 170155, upload-time = "2025-01-19T21:00:01.887Z" }, + { url = "https://files.pythonhosted.org/packages/75/2d/83a5aca7247a655b1da5eb0ee73413abd5c3a57fc8b92915805e6033359d/websockets-14.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eabdb28b972f3729348e632ab08f2a7b616c7e53d5414c12108c29972e655b20", size = 170495, upload-time = "2025-01-19T21:00:04.064Z" }, + { url = "https://files.pythonhosted.org/packages/79/dd/699238a92761e2f943885e091486378813ac8f43e3c84990bc394c2be93e/websockets-14.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2066dc4cbcc19f32c12a5a0e8cc1b7ac734e5b64ac0a325ff8353451c4b15ef2", size = 169880, upload-time = "2025-01-19T21:00:05.695Z" }, + { url = "https://files.pythonhosted.org/packages/c8/c9/67a8f08923cf55ce61aadda72089e3ed4353a95a3a4bc8bf42082810e580/websockets-14.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ab95d357cd471df61873dadf66dd05dd4709cae001dd6342edafc8dc6382f307", size = 169856, upload-time = "2025-01-19T21:00:07.192Z" }, + { url = "https://files.pythonhosted.org/packages/17/b1/1ffdb2680c64e9c3921d99db460546194c40d4acbef999a18c37aa4d58a3/websockets-14.2-cp313-cp313-win32.whl", hash = "sha256:a9e72fb63e5f3feacdcf5b4ff53199ec8c18d66e325c34ee4c551ca748623bbc", size = 163974, upload-time = "2025-01-19T21:00:08.698Z" }, + { url = "https://files.pythonhosted.org/packages/14/13/8b7fc4cb551b9cfd9890f0fd66e53c18a06240319915533b033a56a3d520/websockets-14.2-cp313-cp313-win_amd64.whl", hash = "sha256:b439ea828c4ba99bb3176dc8d9b933392a2413c0f6b149fdcba48393f573377f", size = 164420, upload-time = "2025-01-19T21:00:10.182Z" }, + { url = "https://files.pythonhosted.org/packages/10/3d/91d3d2bb1325cd83e8e2c02d0262c7d4426dc8fa0831ef1aa4d6bf2041af/websockets-14.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d7d9cafbccba46e768be8a8ad4635fa3eae1ffac4c6e7cb4eb276ba41297ed29", size = 160773, upload-time = "2025-01-19T21:00:32.225Z" }, + { url = "https://files.pythonhosted.org/packages/33/7c/cdedadfef7381939577858b1b5718a4ab073adbb584e429dd9d9dc9bfe16/websockets-14.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c76193c1c044bd1e9b3316dcc34b174bbf9664598791e6fb606d8d29000e070c", size = 161007, upload-time = "2025-01-19T21:00:33.784Z" }, + { url = "https://files.pythonhosted.org/packages/ca/35/7a20a3c450b27c04e50fbbfc3dfb161ed8e827b2a26ae31c4b59b018b8c6/websockets-14.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd475a974d5352390baf865309fe37dec6831aafc3014ffac1eea99e84e83fc2", size = 162264, upload-time = "2025-01-19T21:00:35.255Z" }, + { url = "https://files.pythonhosted.org/packages/e8/9c/e3f9600564b0c813f2448375cf28b47dc42c514344faed3a05d71fb527f9/websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c6c0097a41968b2e2b54ed3424739aab0b762ca92af2379f152c1aef0187e1c", size = 161873, upload-time = "2025-01-19T21:00:37.377Z" }, + { url = "https://files.pythonhosted.org/packages/3f/37/260f189b16b2b8290d6ae80c9f96d8b34692cf1bb3475df54c38d3deb57d/websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d7ff794c8b36bc402f2e07c0b2ceb4a2424147ed4785ff03e2a7af03711d60a", size = 161818, upload-time = "2025-01-19T21:00:38.952Z" }, + { url = "https://files.pythonhosted.org/packages/ff/1e/e47dedac8bf7140e59aa6a679e850c4df9610ae844d71b6015263ddea37b/websockets-14.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dec254fcabc7bd488dab64846f588fc5b6fe0d78f641180030f8ea27b76d72c3", size = 164465, upload-time = "2025-01-19T21:00:40.456Z" }, + { url = "https://files.pythonhosted.org/packages/7b/c8/d529f8a32ce40d98309f4470780631e971a5a842b60aec864833b3615786/websockets-14.2-py3-none-any.whl", hash = "sha256:7a6ceec4ea84469f15cf15807a747e9efe57e369c384fa86e022b3bea679b79b", size = 157416, upload-time = "2025-01-19T21:00:54.843Z" }, +] + [[package]] name = "yarl" version = "1.22.0" From eb4bbb8a2250115e1c3301a73342b09f9547f755 Mon Sep 17 00:00:00 2001 From: Henrique Chaves Date: Wed, 18 Mar 2026 19:48:06 +0100 Subject: [PATCH 2/7] lint files --- src/giskard_hub/resources/_ws_scan.py | 31 ++++++----- src/giskard_hub/resources/helpers.py | 76 ++++++++++++--------------- 2 files changed, 53 insertions(+), 54 deletions(-) diff --git a/src/giskard_hub/resources/_ws_scan.py b/src/giskard_hub/resources/_ws_scan.py index c7a28ba..600a366 100644 --- a/src/giskard_hub/resources/_ws_scan.py +++ b/src/giskard_hub/resources/_ws_scan.py @@ -6,21 +6,24 @@ and send the response back. """ +import os +import ssl +import json import asyncio import inspect -import json import logging -import os -import ssl -from typing import Any, Awaitable, Callable -from urllib.parse import urlencode, urlparse, urlunparse +from typing import Any, Callable, Awaitable +from urllib.parse import urlparse, urlencode, urlunparse import httpx from ..types.chat import ChatMessage +from ..types.agent import AgentOutput logger = logging.getLogger(__name__) +__all__ = ["run_ws_scan", "_arun_ws_scan", "_ssl_context_from_httpx"] + def _ssl_context_from_httpx(http_client: httpx.Client | httpx.AsyncClient | None) -> ssl.SSLContext | None: """Derive an ``ssl.SSLContext`` that mirrors the httpx client's TLS config. @@ -76,7 +79,7 @@ def _build_ws_url(base_url: str, scan_id: str, api_key: str) -> str: AsyncAgentCallable = Callable[[list[ChatMessage]], Any | Awaitable[Any]] -def _normalize_output(value: Any) -> dict: +def _normalize_output(value: Any) -> dict[str, object]: """Turn the agent return value into a dict matching the Hub protocol.""" from ._helpers_types import normalize_agent_output @@ -89,9 +92,9 @@ async def _arun_ws_scan( api_key: str, scan_id: str, agent: AsyncAgentCallable, - on_progress: Callable[[dict], Any] | None = None, + on_progress: Callable[[dict[str, Any]], Any] | None = None, ssl_context: ssl.SSLContext | bool | None = None, -) -> dict | None: +) -> dict[str, Any] | None: """Async implementation of the WebSocket scan loop. Returns the ``complete`` message payload, or ``None`` if the connection @@ -188,9 +191,9 @@ def run_ws_scan( api_key: str, scan_id: str, agent: AgentCallable, - on_progress: Callable[[dict], Any] | None = None, + on_progress: Callable[[dict[str, Any]], Any] | None = None, ssl_context: ssl.SSLContext | bool | None = None, -) -> dict | None: +) -> dict[str, Any] | None: """Synchronous wrapper around the async WebSocket scan loop. Works correctly even when called from an environment that already has @@ -199,12 +202,16 @@ def run_ws_scan( """ import concurrent.futures - def _run_in_thread() -> dict | None: + def _run_in_thread() -> dict[str, Any] | None: loop = asyncio.new_event_loop() try: return loop.run_until_complete( _arun_ws_scan( - base_url, api_key, scan_id, agent, on_progress, + base_url, + api_key, + scan_id, + agent, + on_progress, ssl_context=ssl_context, ) ) diff --git a/src/giskard_hub/resources/helpers.py b/src/giskard_hub/resources/helpers.py index ea3fcd1..54a530b 100644 --- a/src/giskard_hub/resources/helpers.py +++ b/src/giskard_hub/resources/helpers.py @@ -8,7 +8,7 @@ import time import asyncio import inspect -from typing import Callable, Optional, Awaitable, Collection, cast +from typing import Any, Callable, Optional, Awaitable, Collection, cast from concurrent.futures import ThreadPoolExecutor from .._types import Omit, SequenceNotStr, omit @@ -20,7 +20,7 @@ ) from .._resource import SyncAPIResource, AsyncAPIResource from ..types.chat import ChatMessage -from ..types.scan import ScanProbe, ScanProbeAttempt +from ..types.scan import Scan, ScanProbe, ScanProbeAttempt from ..types.agent import Agent, AgentOutputParam from ..types.dataset import Dataset from ..types.project import Project @@ -168,7 +168,7 @@ def scan( agent_description: Optional[str] = None, supported_languages: Optional[list[str]] = None, poll_interval: float = 5.0, - ) -> "Scan": + ) -> Scan: """Run a vulnerability scan for a given agent. Handles both remote agents (referenced by ID or ``Agent``, which must @@ -206,11 +206,12 @@ def scan( Scan The completed scan result with grade and probe information. """ - from ..types.scan import Scan as _Scan project_id = project if isinstance(project, str) else project.id - kb_id = knowledge_base if isinstance(knowledge_base, str) else ( - knowledge_base if knowledge_base is omit else knowledge_base + kb_id = ( + knowledge_base + if isinstance(knowledge_base, str) + else (knowledge_base if knowledge_base is omit else knowledge_base) ) if isinstance(agent, (str, Agent)): @@ -225,9 +226,9 @@ def scan( return self._scan_local( agent=agent, project_id=project_id, - knowledge_base_id=kb_id if kb_id is not omit else None, - tags=tags if tags is not omit else None, - agent_name=agent_name or getattr(agent, "__name__", "local_agent"), + knowledge_base_id=kb_id if isinstance(kb_id, str) else None, + tags=None if isinstance(tags, Omit) or tags is None else tags, + agent_name=agent_name if agent_name is not None else getattr(agent, "__name__", "local_agent"), agent_description=agent_description or getattr(agent, "__doc__", None) or "", supported_languages=supported_languages or ["en"], ) @@ -254,12 +255,10 @@ def _scan_remote( knowledge_base_id: "str | Omit | None", tags: "Optional[SequenceNotStr[str]] | Omit", poll_interval: float, - ) -> "Scan": - from ..types.scan import Scan as _Scan - + ) -> Scan: agent_id = agent if isinstance(agent, str) else agent.id - create_kwargs: dict = { + create_kwargs: dict[str, Any] = { "project_id": project_id, "agent_id": agent_id, } @@ -269,7 +268,7 @@ def _scan_remote( create_kwargs["tags"] = tags scan = self._client.scans.create(**create_kwargs) - return cast(_Scan, self.wait_for_completion(scan, poll_interval=poll_interval)) + return self.wait_for_completion(scan, poll_interval=poll_interval) # type: ignore[return-value] def _scan_local( self, @@ -281,13 +280,12 @@ def _scan_local( agent_name: str, agent_description: str, supported_languages: list[str], - ) -> "Scan": - from ..types.scan import Scan as _Scan - from ..types.common import APIResponse as _APIResponse + ) -> Scan: from ._ws_scan import run_ws_scan, _ssl_context_from_httpx + from ..types.common import APIResponse as _APIResponse # 1. Create the scan record on the Hub (no worker enqueue). - body: dict = { + body: dict[str, Any] = { "project_id": project_id, "agent_name": agent_name, "agent_description": agent_description, @@ -301,7 +299,7 @@ def _scan_local( response = self._post( "/v2/scans/create-local", body=body, - cast_to=_APIResponse[_Scan], + cast_to=_APIResponse[Scan], ) scan = self._unwrap(response) scan_id = scan.id @@ -384,9 +382,7 @@ def _evaluate_local( return evaluation def _print_scan_metrics(self, entity: object) -> None: - from ..types.scan import Scan as _Scan - - scan = cast(_Scan, entity) + scan = cast(Scan, entity) category_map = {cat.id: cat.title for cat in self._client.scans.list_categories()} probe_results = self._client.scans.list_probes(scan_id=scan.id) @@ -530,16 +526,17 @@ async def scan( agent_description: Optional[str] = None, supported_languages: Optional[list[str]] = None, poll_interval: float = 5.0, - ) -> "Scan": + ) -> Scan: """Asynchronously run a vulnerability scan for a given agent. See :meth:`HelpersResource.scan` for full parameter documentation. """ - from ..types.scan import Scan as _Scan project_id = project if isinstance(project, str) else project.id - kb_id = knowledge_base if isinstance(knowledge_base, str) else ( - knowledge_base if knowledge_base is omit else knowledge_base + kb_id = ( + knowledge_base + if isinstance(knowledge_base, str) + else (knowledge_base if knowledge_base is omit else knowledge_base) ) if isinstance(agent, (str, Agent)): @@ -554,9 +551,9 @@ async def scan( return await self._scan_local( agent=agent, project_id=project_id, - knowledge_base_id=kb_id if kb_id is not omit else None, - tags=tags if tags is not omit else None, - agent_name=agent_name or getattr(agent, "__name__", "local_agent"), + knowledge_base_id=kb_id if isinstance(kb_id, str) else None, + tags=None if isinstance(tags, Omit) or tags is None else tags, + agent_name=agent_name if agent_name is not None else getattr(agent, "__name__", "local_agent"), agent_description=agent_description or getattr(agent, "__doc__", None) or "", supported_languages=supported_languages or ["en"], ) @@ -583,12 +580,10 @@ async def _scan_remote( knowledge_base_id: "str | Omit | None", tags: "Optional[SequenceNotStr[str]] | Omit", poll_interval: float, - ) -> "Scan": - from ..types.scan import Scan as _Scan - + ) -> Scan: agent_id = agent if isinstance(agent, str) else agent.id - create_kwargs: dict = { + create_kwargs: dict[str, Any] = { "project_id": project_id, "agent_id": agent_id, } @@ -598,7 +593,7 @@ async def _scan_remote( create_kwargs["tags"] = tags scan = await self._client.scans.create(**create_kwargs) - return cast(_Scan, await self.wait_for_completion(scan, poll_interval=poll_interval)) + return await self.wait_for_completion(scan, poll_interval=poll_interval) # type: ignore[return-value] async def _scan_local( self, @@ -610,13 +605,12 @@ async def _scan_local( agent_name: str, agent_description: str, supported_languages: list[str], - ) -> "Scan": - from ..types.scan import Scan as _Scan - from ..types.common import APIResponse as _APIResponse + ) -> Scan: from ._ws_scan import _arun_ws_scan, _ssl_context_from_httpx + from ..types.common import APIResponse as _APIResponse # 1. Create the scan record on the Hub (no worker enqueue). - body: dict = { + body: dict[str, Any] = { "project_id": project_id, "agent_name": agent_name, "agent_description": agent_description, @@ -630,7 +624,7 @@ async def _scan_local( response = await self._post( "/v2/scans/create-local", body=body, - cast_to=_APIResponse[_Scan], + cast_to=_APIResponse[Scan], ) scan = self._unwrap(response) scan_id = scan.id @@ -722,9 +716,7 @@ async def _process_entry(entry: TestCaseEvaluation) -> None: return evaluation async def _print_scan_metrics(self, entity: object) -> None: - from ..types.scan import Scan as _Scan - - scan = cast(_Scan, entity) + scan = cast(Scan, entity) category_map = {cat.id: cat.title for cat in await self._client.scans.list_categories()} probe_results = await self._client.scans.list_probes(scan_id=scan.id) attempts_list = await asyncio.gather( From a39c61b26d9bcbbdcb8434ab3fc768bbf4cb895c Mon Sep 17 00:00:00 2001 From: Henrique Chaves Date: Thu, 19 Mar 2026 13:28:03 +0100 Subject: [PATCH 3/7] replace websocket with polling mechanism --- pyproject.toml | 1 - src/giskard_hub/resources/_poll_scan.py | 168 ++++++++++++++++++ src/giskard_hub/resources/_ws_scan.py | 223 ------------------------ src/giskard_hub/resources/helpers.py | 33 ++-- uv.lock | 61 ------- 5 files changed, 182 insertions(+), 304 deletions(-) create mode 100644 src/giskard_hub/resources/_poll_scan.py delete mode 100644 src/giskard_hub/resources/_ws_scan.py diff --git a/pyproject.toml b/pyproject.toml index cd62b5b..72fd3da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "distro>=1.7.0, <2", "rich", "sniffio", - "websockets>=13.0, <15", ] requires-python = ">=3.10, <3.15" classifiers = [ diff --git a/src/giskard_hub/resources/_poll_scan.py b/src/giskard_hub/resources/_poll_scan.py new file mode 100644 index 0000000..7d8dfbb --- /dev/null +++ b/src/giskard_hub/resources/_poll_scan.py @@ -0,0 +1,168 @@ +"""Polling client for local-agent scan execution. + +The SDK polls ``GET /v2/scans/{scan_id}/invocations?status=pending`` for +invocation requests from LIDAR, executes the local agent callable, and +POSTs the response back via +``POST /v2/scans/{scan_id}/invocations/{id}/respond``. + +All communication is plain HTTP — no WebSocket, no long-lived connections. +""" + +import asyncio +import inspect +import logging +import time +from typing import Any, Awaitable, Callable + +import httpx + +from ..types.chat import ChatMessage + +logger = logging.getLogger(__name__) + +AgentCallable = Callable[[list[ChatMessage]], Any] +AsyncAgentCallable = Callable[[list[ChatMessage]], Any | Awaitable[Any]] + +_TERMINAL_STATES = frozenset({"finished", "error", "canceled"}) + + +def _normalize_output(value: Any) -> dict[str, object]: + """Turn the agent return value into a dict matching the Hub protocol.""" + from ._helpers_types import normalize_agent_output + + output = normalize_agent_output(value) + return output.to_dict() + + +def run_poll_scan( + base_url: str, + api_key: str, + scan_id: str, + agent: AgentCallable, + http_client: httpx.Client, + poll_interval: float = 0.5, +) -> None: + """Synchronous polling loop for local scan execution. + + Blocks until the scan finishes, is cancelled, or errors. + Works in Jupyter notebooks without event-loop conflicts. + """ + headers = {"X-API-Key": api_key} + invocations_url = f"{base_url}/v2/scans/{scan_id}/invocations" + + while True: + resp = http_client.get( + invocations_url, + params={"status": "pending"}, + headers=headers, + ) + resp.raise_for_status() + payload = resp.json()["data"] + + for inv in payload["invocations"]: + _process_invocation_sync( + inv, agent, base_url, scan_id, headers, http_client + ) + + if payload["scan_status"] in _TERMINAL_STATES: + break + + time.sleep(poll_interval) + + +def _process_invocation_sync( + inv: dict[str, Any], + agent: AgentCallable, + base_url: str, + scan_id: str, + headers: dict[str, str], + http_client: httpx.Client, +) -> None: + invocation_id = inv["id"] + messages = [ + ChatMessage(role=m.get("role", "user"), content=m.get("content", "")) + for m in inv.get("messages", []) + ] + + try: + result = agent(messages) + output = _normalize_output(result) + body: dict[str, Any] = {"output": output} + except Exception as exc: + logger.error("Agent invocation failed: %s", exc) + body = {"error": {"message": str(exc)}} + + resp = http_client.post( + f"{base_url}/v2/scans/{scan_id}/invocations/{invocation_id}/respond", + json=body, + headers=headers, + ) + resp.raise_for_status() + + +async def arun_poll_scan( + base_url: str, + api_key: str, + scan_id: str, + agent: AsyncAgentCallable, + http_client: httpx.AsyncClient, + poll_interval: float = 0.5, +) -> None: + """Async polling loop for local scan execution.""" + headers = {"X-API-Key": api_key} + invocations_url = f"{base_url}/v2/scans/{scan_id}/invocations" + + while True: + resp = await http_client.get( + invocations_url, + params={"status": "pending"}, + headers=headers, + ) + resp.raise_for_status() + payload = resp.json()["data"] + + await asyncio.gather( + *( + _process_invocation_async( + inv, agent, base_url, scan_id, headers, http_client + ) + for inv in payload["invocations"] + ) + ) + + if payload["scan_status"] in _TERMINAL_STATES: + break + + await asyncio.sleep(poll_interval) + + +async def _process_invocation_async( + inv: dict[str, Any], + agent: AsyncAgentCallable, + base_url: str, + scan_id: str, + headers: dict[str, str], + http_client: httpx.AsyncClient, +) -> None: + invocation_id = inv["id"] + messages = [ + ChatMessage(role=m.get("role", "user"), content=m.get("content", "")) + for m in inv.get("messages", []) + ] + + try: + result = agent(messages) + if inspect.isawaitable(result): + result = await result + output = _normalize_output(result) + body: dict[str, Any] = {"output": output} + except Exception as exc: + logger.error("Agent invocation failed: %s", exc) + body = {"error": {"message": str(exc)}} + + resp = await http_client.post( + f"{base_url}/v2/scans/{scan_id}/invocations/{invocation_id}/respond", + json=body, + headers=headers, + ) + resp.raise_for_status() diff --git a/src/giskard_hub/resources/_ws_scan.py b/src/giskard_hub/resources/_ws_scan.py deleted file mode 100644 index 600a366..0000000 --- a/src/giskard_hub/resources/_ws_scan.py +++ /dev/null @@ -1,223 +0,0 @@ -"""WebSocket client for local-agent scan execution. - -This module provides both sync and async functions that connect to the Hub's -``/v2/scans/{scan_id}/ws`` WebSocket endpoint, receive agent invocation -requests from LIDAR (running on the Hub), execute the local agent callable, -and send the response back. -""" - -import os -import ssl -import json -import asyncio -import inspect -import logging -from typing import Any, Callable, Awaitable -from urllib.parse import urlparse, urlencode, urlunparse - -import httpx - -from ..types.chat import ChatMessage -from ..types.agent import AgentOutput - -logger = logging.getLogger(__name__) - -__all__ = ["run_ws_scan", "_arun_ws_scan", "_ssl_context_from_httpx"] - - -def _ssl_context_from_httpx(http_client: httpx.Client | httpx.AsyncClient | None) -> ssl.SSLContext | None: - """Derive an ``ssl.SSLContext`` that mirrors the httpx client's TLS config. - - Walks the internal transport chain - (``httpx.Client._transport._pool._ssl_context``) to extract the exact - ``ssl.SSLContext`` that httpx uses. This means ``verify=False`` on the - httpx client automatically disables verification for the WebSocket too, - and custom CA bundles are preserved. - - Returns ``None`` (use system defaults) when extraction fails. - """ - if http_client is None: - return None - - # httpx.Client._transport → HTTPTransport._pool → httpcore.ConnectionPool._ssl_context - try: - ctx = http_client._transport._pool._ssl_context # type: ignore[union-attr] - if isinstance(ctx, ssl.SSLContext): - return ctx - except AttributeError: - pass - - # Fallback: check common CA-bundle environment variables. - for env_var in ("SSL_CERT_FILE", "REQUESTS_CA_BUNDLE", "CURL_CA_BUNDLE"): - ca_path = os.environ.get(env_var) - if ca_path: - return ssl.create_default_context(cafile=ca_path) - - return None - - -def _http_to_ws_url(http_url: str) -> str: - """Convert an HTTP(S) URL to a WS(S) URL.""" - parsed = urlparse(http_url) - if parsed.scheme == "https": - scheme = "wss" - elif parsed.scheme == "http": - scheme = "ws" - else: - scheme = parsed.scheme - return urlunparse(parsed._replace(scheme=scheme)) - - -def _build_ws_url(base_url: str, scan_id: str, api_key: str) -> str: - """Build the full WebSocket URL for a local scan session.""" - ws_base = _http_to_ws_url(base_url.rstrip("/")) - query = urlencode({"api_key": api_key}) - return f"{ws_base}/v2/scans/{scan_id}/ws?{query}" - - -AgentCallable = Callable[[list[ChatMessage]], Any] -AsyncAgentCallable = Callable[[list[ChatMessage]], Any | Awaitable[Any]] - - -def _normalize_output(value: Any) -> dict[str, object]: - """Turn the agent return value into a dict matching the Hub protocol.""" - from ._helpers_types import normalize_agent_output - - output: AgentOutput = normalize_agent_output(value) - return output.to_dict() - - -async def _arun_ws_scan( - base_url: str, - api_key: str, - scan_id: str, - agent: AsyncAgentCallable, - on_progress: Callable[[dict[str, Any]], Any] | None = None, - ssl_context: ssl.SSLContext | bool | None = None, -) -> dict[str, Any] | None: - """Async implementation of the WebSocket scan loop. - - Returns the ``complete`` message payload, or ``None`` if the connection - closed before completion. - """ - try: - from websockets.asyncio.client import connect - except ImportError as exc: - raise ImportError( - "The 'websockets' package is required for local scan execution. " - "Install it with: pip install 'giskard-hub[websockets]' or pip install websockets" - ) from exc - - url = _build_ws_url(base_url, scan_id, api_key) - logger.info("Connecting to scan WebSocket: %s", url.split("?")[0]) - - connect_kwargs: dict[str, Any] = {} - if ssl_context is not None: - connect_kwargs["ssl"] = ssl_context - - async with connect(url, **connect_kwargs) as ws: - async for raw_msg in ws: - try: - msg = json.loads(raw_msg) - except json.JSONDecodeError: - logger.warning("Non-JSON WebSocket message received, ignoring") - continue - - msg_type = msg.get("type") - - if msg_type == "invoke": - request_id = msg["request_id"] - messages = [ - ChatMessage( - role=m.get("role", "user"), - content=m.get("content", ""), - ) - for m in msg.get("messages", []) - ] - - try: - result = agent(messages) - if inspect.isawaitable(result): - result = await result - output = _normalize_output(result) - await ws.send( - json.dumps( - { - "type": "response", - "request_id": request_id, - "output": output, - } - ) - ) - except Exception as exc: - logger.error("Agent invocation failed: %s", exc) - await ws.send( - json.dumps( - { - "type": "error", - "request_id": request_id, - "error": {"message": str(exc)}, - } - ) - ) - - elif msg_type == "progress": - if on_progress: - status = msg.get("status", {}) - cb_result = on_progress(status) - if inspect.isawaitable(cb_result): - await cb_result - - elif msg_type == "complete": - logger.info( - "Scan %s completed with grade: %s", - scan_id, - msg.get("grade"), - ) - return msg - - elif msg_type == "error": - error_msg = msg.get("message", "Unknown server error") - raise RuntimeError(f"Scan error from Hub: {error_msg}") - - else: - logger.warning("Unknown WebSocket message type: %s", msg_type) - - return None - - -def run_ws_scan( - base_url: str, - api_key: str, - scan_id: str, - agent: AgentCallable, - on_progress: Callable[[dict[str, Any]], Any] | None = None, - ssl_context: ssl.SSLContext | bool | None = None, -) -> dict[str, Any] | None: - """Synchronous wrapper around the async WebSocket scan loop. - - Works correctly even when called from an environment that already has - a running event loop (e.g. Jupyter notebooks) by executing the async - code in a dedicated thread with its own event loop. - """ - import concurrent.futures - - def _run_in_thread() -> dict[str, Any] | None: - loop = asyncio.new_event_loop() - try: - return loop.run_until_complete( - _arun_ws_scan( - base_url, - api_key, - scan_id, - agent, - on_progress, - ssl_context=ssl_context, - ) - ) - finally: - loop.close() - - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool: - future = pool.submit(_run_in_thread) - return future.result() diff --git a/src/giskard_hub/resources/helpers.py b/src/giskard_hub/resources/helpers.py index 54a530b..3a31240 100644 --- a/src/giskard_hub/resources/helpers.py +++ b/src/giskard_hub/resources/helpers.py @@ -174,11 +174,10 @@ def scan( Handles both remote agents (referenced by ID or ``Agent``, which must already be registered in the Hub) and local Python callables. - When a local callable is provided, the scan is executed via WebSocket: - the Hub orchestrates LIDAR server-side and sends agent invocation - requests to this process through the WebSocket connection. The TLS - configuration (certificate verification) is automatically inherited - from the ``httpx.Client`` passed to ``HubClient``. + When a local callable is provided, the Hub orchestrates LIDAR + server-side and the SDK polls for agent invocation requests via + stateless HTTP endpoints. The scan runs in the Hub's worker + process, same as remote scans. Parameters ---------- @@ -281,10 +280,10 @@ def _scan_local( agent_description: str, supported_languages: list[str], ) -> Scan: - from ._ws_scan import run_ws_scan, _ssl_context_from_httpx + from ._poll_scan import run_poll_scan from ..types.common import APIResponse as _APIResponse - # 1. Create the scan record on the Hub (no worker enqueue). + # 1. Create the scan record (enqueued to worker automatically). body: dict[str, Any] = { "project_id": project_id, "agent_name": agent_name, @@ -304,18 +303,16 @@ def _scan_local( scan = self._unwrap(response) scan_id = scan.id - # 2. Connect via WebSocket and drive the scan. - # Inherit TLS config (verify/no-verify) from the httpx client. + # 2. Poll for invocations and execute the local agent. base_url = str(self._client.base_url) api_key = self._client.api_key - ssl_ctx = _ssl_context_from_httpx(self._client._client) - run_ws_scan( + run_poll_scan( base_url=base_url, api_key=api_key, scan_id=scan_id, agent=agent, - ssl_context=ssl_ctx, + http_client=self._client._client, ) # 3. Retrieve the final scan result. @@ -606,10 +603,10 @@ async def _scan_local( agent_description: str, supported_languages: list[str], ) -> Scan: - from ._ws_scan import _arun_ws_scan, _ssl_context_from_httpx + from ._poll_scan import arun_poll_scan from ..types.common import APIResponse as _APIResponse - # 1. Create the scan record on the Hub (no worker enqueue). + # 1. Create the scan record (enqueued to worker automatically). body: dict[str, Any] = { "project_id": project_id, "agent_name": agent_name, @@ -629,18 +626,16 @@ async def _scan_local( scan = self._unwrap(response) scan_id = scan.id - # 2. Connect via WebSocket and drive the scan. - # Inherit TLS config (verify/no-verify) from the httpx client. + # 2. Poll for invocations and execute the local agent. base_url = str(self._client.base_url) api_key = self._client.api_key - ssl_ctx = _ssl_context_from_httpx(self._client._client) - await _arun_ws_scan( + await arun_poll_scan( base_url=base_url, api_key=api_key, scan_id=scan_id, agent=agent, - ssl_context=ssl_ctx, + http_client=self._client._client, ) # 3. Retrieve the final scan result. diff --git a/uv.lock b/uv.lock index b4af34e..39f2596 100644 --- a/uv.lock +++ b/uv.lock @@ -384,7 +384,6 @@ dependencies = [ { name = "rich" }, { name = "sniffio" }, { name = "typing-extensions" }, - { name = "websockets" }, ] [package.optional-dependencies] @@ -422,7 +421,6 @@ requires-dist = [ { name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" }, { name = "sniffio" }, { name = "typing-extensions", specifier = ">=4.10,<5" }, - { name = "websockets", specifier = ">=13.0,<15" }, ] provides-extras = ["aiohttp", "dev"] @@ -1266,65 +1264,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, ] -[[package]] -name = "websockets" -version = "14.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/94/54/8359678c726243d19fae38ca14a334e740782336c9f19700858c4eb64a1e/websockets-14.2.tar.gz", hash = "sha256:5059ed9c54945efb321f097084b4c7e52c246f2c869815876a69d1efc4ad6eb5", size = 164394, upload-time = "2025-01-19T21:00:56.431Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/28/fa/76607eb7dcec27b2d18d63f60a32e60e2b8629780f343bb83a4dbb9f4350/websockets-14.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e8179f95323b9ab1c11723e5d91a89403903f7b001828161b480a7810b334885", size = 163089, upload-time = "2025-01-19T20:58:43.399Z" }, - { url = "https://files.pythonhosted.org/packages/9e/00/ad2246b5030575b79e7af0721810fdaecaf94c4b2625842ef7a756fa06dd/websockets-14.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0d8c3e2cdb38f31d8bd7d9d28908005f6fa9def3324edb9bf336d7e4266fd397", size = 160741, upload-time = "2025-01-19T20:58:45.309Z" }, - { url = "https://files.pythonhosted.org/packages/72/f7/60f10924d333a28a1ff3fcdec85acf226281331bdabe9ad74947e1b7fc0a/websockets-14.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:714a9b682deb4339d39ffa674f7b674230227d981a37d5d174a4a83e3978a610", size = 160996, upload-time = "2025-01-19T20:58:47.563Z" }, - { url = "https://files.pythonhosted.org/packages/63/7c/c655789cf78648c01ac6ecbe2d6c18f91b75bdc263ffee4d08ce628d12f0/websockets-14.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2e53c72052f2596fb792a7acd9704cbc549bf70fcde8a99e899311455974ca3", size = 169974, upload-time = "2025-01-19T20:58:51.023Z" }, - { url = "https://files.pythonhosted.org/packages/fb/5b/013ed8b4611857ac92ac631079c08d9715b388bd1d88ec62e245f87a39df/websockets-14.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3fbd68850c837e57373d95c8fe352203a512b6e49eaae4c2f4088ef8cf21980", size = 168985, upload-time = "2025-01-19T20:58:52.698Z" }, - { url = "https://files.pythonhosted.org/packages/cd/33/aa3e32fd0df213a5a442310754fe3f89dd87a0b8e5b4e11e0991dd3bcc50/websockets-14.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b27ece32f63150c268593d5fdb82819584831a83a3f5809b7521df0685cd5d8", size = 169297, upload-time = "2025-01-19T20:58:54.898Z" }, - { url = "https://files.pythonhosted.org/packages/93/17/dae0174883d6399f57853ac44abf5f228eaba86d98d160f390ffabc19b6e/websockets-14.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4daa0faea5424d8713142b33825fff03c736f781690d90652d2c8b053345b0e7", size = 169677, upload-time = "2025-01-19T20:58:56.36Z" }, - { url = "https://files.pythonhosted.org/packages/42/e2/0375af7ac00169b98647c804651c515054b34977b6c1354f1458e4116c1e/websockets-14.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:bc63cee8596a6ec84d9753fd0fcfa0452ee12f317afe4beae6b157f0070c6c7f", size = 169089, upload-time = "2025-01-19T20:58:58.824Z" }, - { url = "https://files.pythonhosted.org/packages/73/8d/80f71d2a351a44b602859af65261d3dde3a0ce4e76cf9383738a949e0cc3/websockets-14.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a570862c325af2111343cc9b0257b7119b904823c675b22d4ac547163088d0d", size = 169026, upload-time = "2025-01-19T20:59:01.089Z" }, - { url = "https://files.pythonhosted.org/packages/48/97/173b1fa6052223e52bb4054a141433ad74931d94c575e04b654200b98ca4/websockets-14.2-cp310-cp310-win32.whl", hash = "sha256:75862126b3d2d505e895893e3deac0a9339ce750bd27b4ba515f008b5acf832d", size = 163967, upload-time = "2025-01-19T20:59:02.662Z" }, - { url = "https://files.pythonhosted.org/packages/c0/5b/2fcf60f38252a4562b28b66077e0d2b48f91fef645d5f78874cd1dec807b/websockets-14.2-cp310-cp310-win_amd64.whl", hash = "sha256:cc45afb9c9b2dc0852d5c8b5321759cf825f82a31bfaf506b65bf4668c96f8b2", size = 164413, upload-time = "2025-01-19T20:59:05.071Z" }, - { url = "https://files.pythonhosted.org/packages/15/b6/504695fb9a33df0ca56d157f5985660b5fc5b4bf8c78f121578d2d653392/websockets-14.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3bdc8c692c866ce5fefcaf07d2b55c91d6922ac397e031ef9b774e5b9ea42166", size = 163088, upload-time = "2025-01-19T20:59:06.435Z" }, - { url = "https://files.pythonhosted.org/packages/81/26/ebfb8f6abe963c795122439c6433c4ae1e061aaedfc7eff32d09394afbae/websockets-14.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c93215fac5dadc63e51bcc6dceca72e72267c11def401d6668622b47675b097f", size = 160745, upload-time = "2025-01-19T20:59:09.109Z" }, - { url = "https://files.pythonhosted.org/packages/a1/c6/1435ad6f6dcbff80bb95e8986704c3174da8866ddb751184046f5c139ef6/websockets-14.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c9b6535c0e2cf8a6bf938064fb754aaceb1e6a4a51a80d884cd5db569886910", size = 160995, upload-time = "2025-01-19T20:59:12.816Z" }, - { url = "https://files.pythonhosted.org/packages/96/63/900c27cfe8be1a1f2433fc77cd46771cf26ba57e6bdc7cf9e63644a61863/websockets-14.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a52a6d7cf6938e04e9dceb949d35fbdf58ac14deea26e685ab6368e73744e4c", size = 170543, upload-time = "2025-01-19T20:59:15.026Z" }, - { url = "https://files.pythonhosted.org/packages/00/8b/bec2bdba92af0762d42d4410593c1d7d28e9bfd952c97a3729df603dc6ea/websockets-14.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f05702e93203a6ff5226e21d9b40c037761b2cfb637187c9802c10f58e40473", size = 169546, upload-time = "2025-01-19T20:59:17.156Z" }, - { url = "https://files.pythonhosted.org/packages/6b/a9/37531cb5b994f12a57dec3da2200ef7aadffef82d888a4c29a0d781568e4/websockets-14.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22441c81a6748a53bfcb98951d58d1af0661ab47a536af08920d129b4d1c3473", size = 169911, upload-time = "2025-01-19T20:59:18.623Z" }, - { url = "https://files.pythonhosted.org/packages/60/d5/a6eadba2ed9f7e65d677fec539ab14a9b83de2b484ab5fe15d3d6d208c28/websockets-14.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:efd9b868d78b194790e6236d9cbc46d68aba4b75b22497eb4ab64fa640c3af56", size = 170183, upload-time = "2025-01-19T20:59:20.743Z" }, - { url = "https://files.pythonhosted.org/packages/76/57/a338ccb00d1df881c1d1ee1f2a20c9c1b5b29b51e9e0191ee515d254fea6/websockets-14.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1a5a20d5843886d34ff8c57424cc65a1deda4375729cbca4cb6b3353f3ce4142", size = 169623, upload-time = "2025-01-19T20:59:22.286Z" }, - { url = "https://files.pythonhosted.org/packages/64/22/e5f7c33db0cb2c1d03b79fd60d189a1da044e2661f5fd01d629451e1db89/websockets-14.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:34277a29f5303d54ec6468fb525d99c99938607bc96b8d72d675dee2b9f5bf1d", size = 169583, upload-time = "2025-01-19T20:59:23.656Z" }, - { url = "https://files.pythonhosted.org/packages/aa/2e/2b4662237060063a22e5fc40d46300a07142afe30302b634b4eebd717c07/websockets-14.2-cp311-cp311-win32.whl", hash = "sha256:02687db35dbc7d25fd541a602b5f8e451a238ffa033030b172ff86a93cb5dc2a", size = 163969, upload-time = "2025-01-19T20:59:26.004Z" }, - { url = "https://files.pythonhosted.org/packages/94/a5/0cda64e1851e73fc1ecdae6f42487babb06e55cb2f0dc8904b81d8ef6857/websockets-14.2-cp311-cp311-win_amd64.whl", hash = "sha256:862e9967b46c07d4dcd2532e9e8e3c2825e004ffbf91a5ef9dde519ee2effb0b", size = 164408, upload-time = "2025-01-19T20:59:28.105Z" }, - { url = "https://files.pythonhosted.org/packages/c1/81/04f7a397653dc8bec94ddc071f34833e8b99b13ef1a3804c149d59f92c18/websockets-14.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1f20522e624d7ffbdbe259c6b6a65d73c895045f76a93719aa10cd93b3de100c", size = 163096, upload-time = "2025-01-19T20:59:29.763Z" }, - { url = "https://files.pythonhosted.org/packages/ec/c5/de30e88557e4d70988ed4d2eabd73fd3e1e52456b9f3a4e9564d86353b6d/websockets-14.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:647b573f7d3ada919fd60e64d533409a79dcf1ea21daeb4542d1d996519ca967", size = 160758, upload-time = "2025-01-19T20:59:32.095Z" }, - { url = "https://files.pythonhosted.org/packages/e5/8c/d130d668781f2c77d106c007b6c6c1d9db68239107c41ba109f09e6c218a/websockets-14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6af99a38e49f66be5a64b1e890208ad026cda49355661549c507152113049990", size = 160995, upload-time = "2025-01-19T20:59:33.527Z" }, - { url = "https://files.pythonhosted.org/packages/a6/bc/f6678a0ff17246df4f06765e22fc9d98d1b11a258cc50c5968b33d6742a1/websockets-14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:091ab63dfc8cea748cc22c1db2814eadb77ccbf82829bac6b2fbe3401d548eda", size = 170815, upload-time = "2025-01-19T20:59:35.837Z" }, - { url = "https://files.pythonhosted.org/packages/d8/b2/8070cb970c2e4122a6ef38bc5b203415fd46460e025652e1ee3f2f43a9a3/websockets-14.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b374e8953ad477d17e4851cdc66d83fdc2db88d9e73abf755c94510ebddceb95", size = 169759, upload-time = "2025-01-19T20:59:38.216Z" }, - { url = "https://files.pythonhosted.org/packages/81/da/72f7caabd94652e6eb7e92ed2d3da818626e70b4f2b15a854ef60bf501ec/websockets-14.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a39d7eceeea35db85b85e1169011bb4321c32e673920ae9c1b6e0978590012a3", size = 170178, upload-time = "2025-01-19T20:59:40.423Z" }, - { url = "https://files.pythonhosted.org/packages/31/e0/812725b6deca8afd3a08a2e81b3c4c120c17f68c9b84522a520b816cda58/websockets-14.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0a6f3efd47ffd0d12080594f434faf1cd2549b31e54870b8470b28cc1d3817d9", size = 170453, upload-time = "2025-01-19T20:59:41.996Z" }, - { url = "https://files.pythonhosted.org/packages/66/d3/8275dbc231e5ba9bb0c4f93144394b4194402a7a0c8ffaca5307a58ab5e3/websockets-14.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:065ce275e7c4ffb42cb738dd6b20726ac26ac9ad0a2a48e33ca632351a737267", size = 169830, upload-time = "2025-01-19T20:59:44.669Z" }, - { url = "https://files.pythonhosted.org/packages/a3/ae/e7d1a56755ae15ad5a94e80dd490ad09e345365199600b2629b18ee37bc7/websockets-14.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e9d0e53530ba7b8b5e389c02282f9d2aa47581514bd6049d3a7cffe1385cf5fe", size = 169824, upload-time = "2025-01-19T20:59:46.932Z" }, - { url = "https://files.pythonhosted.org/packages/b6/32/88ccdd63cb261e77b882e706108d072e4f1c839ed723bf91a3e1f216bf60/websockets-14.2-cp312-cp312-win32.whl", hash = "sha256:20e6dd0984d7ca3037afcb4494e48c74ffb51e8013cac71cf607fffe11df7205", size = 163981, upload-time = "2025-01-19T20:59:49.228Z" }, - { url = "https://files.pythonhosted.org/packages/b3/7d/32cdb77990b3bdc34a306e0a0f73a1275221e9a66d869f6ff833c95b56ef/websockets-14.2-cp312-cp312-win_amd64.whl", hash = "sha256:44bba1a956c2c9d268bdcdf234d5e5ff4c9b6dc3e300545cbe99af59dda9dcce", size = 164421, upload-time = "2025-01-19T20:59:50.674Z" }, - { url = "https://files.pythonhosted.org/packages/82/94/4f9b55099a4603ac53c2912e1f043d6c49d23e94dd82a9ce1eb554a90215/websockets-14.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6f1372e511c7409a542291bce92d6c83320e02c9cf392223272287ce55bc224e", size = 163102, upload-time = "2025-01-19T20:59:52.177Z" }, - { url = "https://files.pythonhosted.org/packages/8e/b7/7484905215627909d9a79ae07070057afe477433fdacb59bf608ce86365a/websockets-14.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4da98b72009836179bb596a92297b1a61bb5a830c0e483a7d0766d45070a08ad", size = 160766, upload-time = "2025-01-19T20:59:54.368Z" }, - { url = "https://files.pythonhosted.org/packages/a3/a4/edb62efc84adb61883c7d2c6ad65181cb087c64252138e12d655989eec05/websockets-14.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8a86a269759026d2bde227652b87be79f8a734e582debf64c9d302faa1e9f03", size = 160998, upload-time = "2025-01-19T20:59:56.671Z" }, - { url = "https://files.pythonhosted.org/packages/f5/79/036d320dc894b96af14eac2529967a6fc8b74f03b83c487e7a0e9043d842/websockets-14.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86cf1aaeca909bf6815ea714d5c5736c8d6dd3a13770e885aafe062ecbd04f1f", size = 170780, upload-time = "2025-01-19T20:59:58.085Z" }, - { url = "https://files.pythonhosted.org/packages/63/75/5737d21ee4dd7e4b9d487ee044af24a935e36a9ff1e1419d684feedcba71/websockets-14.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9b0f6c3ba3b1240f602ebb3971d45b02cc12bd1845466dd783496b3b05783a5", size = 169717, upload-time = "2025-01-19T20:59:59.545Z" }, - { url = "https://files.pythonhosted.org/packages/2c/3c/bf9b2c396ed86a0b4a92ff4cdaee09753d3ee389be738e92b9bbd0330b64/websockets-14.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669c3e101c246aa85bc8534e495952e2ca208bd87994650b90a23d745902db9a", size = 170155, upload-time = "2025-01-19T21:00:01.887Z" }, - { url = "https://files.pythonhosted.org/packages/75/2d/83a5aca7247a655b1da5eb0ee73413abd5c3a57fc8b92915805e6033359d/websockets-14.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eabdb28b972f3729348e632ab08f2a7b616c7e53d5414c12108c29972e655b20", size = 170495, upload-time = "2025-01-19T21:00:04.064Z" }, - { url = "https://files.pythonhosted.org/packages/79/dd/699238a92761e2f943885e091486378813ac8f43e3c84990bc394c2be93e/websockets-14.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2066dc4cbcc19f32c12a5a0e8cc1b7ac734e5b64ac0a325ff8353451c4b15ef2", size = 169880, upload-time = "2025-01-19T21:00:05.695Z" }, - { url = "https://files.pythonhosted.org/packages/c8/c9/67a8f08923cf55ce61aadda72089e3ed4353a95a3a4bc8bf42082810e580/websockets-14.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ab95d357cd471df61873dadf66dd05dd4709cae001dd6342edafc8dc6382f307", size = 169856, upload-time = "2025-01-19T21:00:07.192Z" }, - { url = "https://files.pythonhosted.org/packages/17/b1/1ffdb2680c64e9c3921d99db460546194c40d4acbef999a18c37aa4d58a3/websockets-14.2-cp313-cp313-win32.whl", hash = "sha256:a9e72fb63e5f3feacdcf5b4ff53199ec8c18d66e325c34ee4c551ca748623bbc", size = 163974, upload-time = "2025-01-19T21:00:08.698Z" }, - { url = "https://files.pythonhosted.org/packages/14/13/8b7fc4cb551b9cfd9890f0fd66e53c18a06240319915533b033a56a3d520/websockets-14.2-cp313-cp313-win_amd64.whl", hash = "sha256:b439ea828c4ba99bb3176dc8d9b933392a2413c0f6b149fdcba48393f573377f", size = 164420, upload-time = "2025-01-19T21:00:10.182Z" }, - { url = "https://files.pythonhosted.org/packages/10/3d/91d3d2bb1325cd83e8e2c02d0262c7d4426dc8fa0831ef1aa4d6bf2041af/websockets-14.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d7d9cafbccba46e768be8a8ad4635fa3eae1ffac4c6e7cb4eb276ba41297ed29", size = 160773, upload-time = "2025-01-19T21:00:32.225Z" }, - { url = "https://files.pythonhosted.org/packages/33/7c/cdedadfef7381939577858b1b5718a4ab073adbb584e429dd9d9dc9bfe16/websockets-14.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c76193c1c044bd1e9b3316dcc34b174bbf9664598791e6fb606d8d29000e070c", size = 161007, upload-time = "2025-01-19T21:00:33.784Z" }, - { url = "https://files.pythonhosted.org/packages/ca/35/7a20a3c450b27c04e50fbbfc3dfb161ed8e827b2a26ae31c4b59b018b8c6/websockets-14.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd475a974d5352390baf865309fe37dec6831aafc3014ffac1eea99e84e83fc2", size = 162264, upload-time = "2025-01-19T21:00:35.255Z" }, - { url = "https://files.pythonhosted.org/packages/e8/9c/e3f9600564b0c813f2448375cf28b47dc42c514344faed3a05d71fb527f9/websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c6c0097a41968b2e2b54ed3424739aab0b762ca92af2379f152c1aef0187e1c", size = 161873, upload-time = "2025-01-19T21:00:37.377Z" }, - { url = "https://files.pythonhosted.org/packages/3f/37/260f189b16b2b8290d6ae80c9f96d8b34692cf1bb3475df54c38d3deb57d/websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d7ff794c8b36bc402f2e07c0b2ceb4a2424147ed4785ff03e2a7af03711d60a", size = 161818, upload-time = "2025-01-19T21:00:38.952Z" }, - { url = "https://files.pythonhosted.org/packages/ff/1e/e47dedac8bf7140e59aa6a679e850c4df9610ae844d71b6015263ddea37b/websockets-14.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dec254fcabc7bd488dab64846f588fc5b6fe0d78f641180030f8ea27b76d72c3", size = 164465, upload-time = "2025-01-19T21:00:40.456Z" }, - { url = "https://files.pythonhosted.org/packages/7b/c8/d529f8a32ce40d98309f4470780631e971a5a842b60aec864833b3615786/websockets-14.2-py3-none-any.whl", hash = "sha256:7a6ceec4ea84469f15cf15807a747e9efe57e369c384fa86e022b3bea679b79b", size = 157416, upload-time = "2025-01-19T21:00:54.843Z" }, -] - [[package]] name = "yarl" version = "1.22.0" From 138a2744535f82ffed15dc4b96e4a3519df9c485 Mon Sep 17 00:00:00 2001 From: Henrique Chaves Date: Thu, 19 Mar 2026 14:12:27 +0100 Subject: [PATCH 4/7] clean up poll mechanism --- src/giskard_hub/resources/_poll_scan.py | 155 +++++++----------------- 1 file changed, 43 insertions(+), 112 deletions(-) diff --git a/src/giskard_hub/resources/_poll_scan.py b/src/giskard_hub/resources/_poll_scan.py index 7d8dfbb..8dd88bb 100644 --- a/src/giskard_hub/resources/_poll_scan.py +++ b/src/giskard_hub/resources/_poll_scan.py @@ -1,18 +1,14 @@ """Polling client for local-agent scan execution. -The SDK polls ``GET /v2/scans/{scan_id}/invocations?status=pending`` for -invocation requests from LIDAR, executes the local agent callable, and -POSTs the response back via -``POST /v2/scans/{scan_id}/invocations/{id}/respond``. - -All communication is plain HTTP — no WebSocket, no long-lived connections. +The SDK polls for pending invocation requests, executes the +local agent callable, and POSTs the response back. """ +import time import asyncio import inspect import logging -import time -from typing import Any, Awaitable, Callable +from typing import Any, Callable, Awaitable import httpx @@ -20,149 +16,84 @@ logger = logging.getLogger(__name__) -AgentCallable = Callable[[list[ChatMessage]], Any] -AsyncAgentCallable = Callable[[list[ChatMessage]], Any | Awaitable[Any]] - _TERMINAL_STATES = frozenset({"finished", "error", "canceled"}) def _normalize_output(value: Any) -> dict[str, object]: - """Turn the agent return value into a dict matching the Hub protocol.""" from ._helpers_types import normalize_agent_output - output = normalize_agent_output(value) - return output.to_dict() + return normalize_agent_output(value).to_dict() def run_poll_scan( base_url: str, api_key: str, scan_id: str, - agent: AgentCallable, + agent: Callable[[list[ChatMessage]], Any], http_client: httpx.Client, poll_interval: float = 0.5, ) -> None: - """Synchronous polling loop for local scan execution. + """Poll for invocations, call the local agent, submit responses. - Blocks until the scan finishes, is cancelled, or errors. - Works in Jupyter notebooks without event-loop conflicts. + Blocks until the scan reaches a terminal state. """ headers = {"X-API-Key": api_key} - invocations_url = f"{base_url}/v2/scans/{scan_id}/invocations" + url = f"{base_url}/v2/scans/{scan_id}/invocations" while True: - resp = http_client.get( - invocations_url, - params={"status": "pending"}, - headers=headers, - ) + resp = http_client.get(url, params={"status": "pending"}, headers=headers) resp.raise_for_status() - payload = resp.json()["data"] + data = resp.json()["data"] + + for inv in data["invocations"]: + messages = [ChatMessage(role=m["role"], content=m.get("content", "")) for m in inv["messages"]] + try: + body: dict[str, Any] = {"output": _normalize_output(agent(messages))} + except Exception as exc: + logger.error("Agent invocation failed: %s", exc) + body = {"error": {"message": str(exc)}} - for inv in payload["invocations"]: - _process_invocation_sync( - inv, agent, base_url, scan_id, headers, http_client - ) + http_client.post(f"{url}/{inv['id']}/respond", json=body, headers=headers).raise_for_status() - if payload["scan_status"] in _TERMINAL_STATES: + if data["scan_status"] in _TERMINAL_STATES: break time.sleep(poll_interval) -def _process_invocation_sync( - inv: dict[str, Any], - agent: AgentCallable, - base_url: str, - scan_id: str, - headers: dict[str, str], - http_client: httpx.Client, -) -> None: - invocation_id = inv["id"] - messages = [ - ChatMessage(role=m.get("role", "user"), content=m.get("content", "")) - for m in inv.get("messages", []) - ] - - try: - result = agent(messages) - output = _normalize_output(result) - body: dict[str, Any] = {"output": output} - except Exception as exc: - logger.error("Agent invocation failed: %s", exc) - body = {"error": {"message": str(exc)}} - - resp = http_client.post( - f"{base_url}/v2/scans/{scan_id}/invocations/{invocation_id}/respond", - json=body, - headers=headers, - ) - resp.raise_for_status() - - async def arun_poll_scan( base_url: str, api_key: str, scan_id: str, - agent: AsyncAgentCallable, + agent: Callable[[list[ChatMessage]], Any | Awaitable[Any]], http_client: httpx.AsyncClient, poll_interval: float = 0.5, ) -> None: - """Async polling loop for local scan execution.""" + """Async version of the polling loop.""" headers = {"X-API-Key": api_key} - invocations_url = f"{base_url}/v2/scans/{scan_id}/invocations" + url = f"{base_url}/v2/scans/{scan_id}/invocations" while True: - resp = await http_client.get( - invocations_url, - params={"status": "pending"}, - headers=headers, - ) + resp = await http_client.get(url, params={"status": "pending"}, headers=headers) resp.raise_for_status() - payload = resp.json()["data"] - - await asyncio.gather( - *( - _process_invocation_async( - inv, agent, base_url, scan_id, headers, http_client - ) - for inv in payload["invocations"] - ) - ) - - if payload["scan_status"] in _TERMINAL_STATES: - break + data = resp.json()["data"] - await asyncio.sleep(poll_interval) + async def _process(inv: dict[str, Any]) -> None: + messages = [ChatMessage(role=m["role"], content=m.get("content", "")) for m in inv["messages"]] + try: + result = agent(messages) + if inspect.isawaitable(result): + result = await result + body: dict[str, Any] = {"output": _normalize_output(result)} + except Exception as exc: + logger.error("Agent invocation failed: %s", exc) + body = {"error": {"message": str(exc)}} + (await http_client.post(f"{url}/{inv['id']}/respond", json=body, headers=headers)).raise_for_status() -async def _process_invocation_async( - inv: dict[str, Any], - agent: AsyncAgentCallable, - base_url: str, - scan_id: str, - headers: dict[str, str], - http_client: httpx.AsyncClient, -) -> None: - invocation_id = inv["id"] - messages = [ - ChatMessage(role=m.get("role", "user"), content=m.get("content", "")) - for m in inv.get("messages", []) - ] - - try: - result = agent(messages) - if inspect.isawaitable(result): - result = await result - output = _normalize_output(result) - body: dict[str, Any] = {"output": output} - except Exception as exc: - logger.error("Agent invocation failed: %s", exc) - body = {"error": {"message": str(exc)}} - - resp = await http_client.post( - f"{base_url}/v2/scans/{scan_id}/invocations/{invocation_id}/respond", - json=body, - headers=headers, - ) - resp.raise_for_status() + await asyncio.gather(*(_process(inv) for inv in data["invocations"])) + + if data["scan_status"] in _TERMINAL_STATES: + break + + await asyncio.sleep(poll_interval) From 8ca73f76763b03f21ef73e29e4fdf4e15068d7ec Mon Sep 17 00:00:00 2001 From: Henrique Chaves Date: Thu, 19 Mar 2026 14:23:29 +0100 Subject: [PATCH 5/7] update docstrings and comments --- src/giskard_hub/resources/_poll_scan.py | 6 +----- src/giskard_hub/resources/helpers.py | 12 ++++++------ 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/giskard_hub/resources/_poll_scan.py b/src/giskard_hub/resources/_poll_scan.py index 8dd88bb..5ff063f 100644 --- a/src/giskard_hub/resources/_poll_scan.py +++ b/src/giskard_hub/resources/_poll_scan.py @@ -1,8 +1,4 @@ -"""Polling client for local-agent scan execution. - -The SDK polls for pending invocation requests, executes the -local agent callable, and POSTs the response back. -""" +"""Polling client for local-agent scan execution.""" import time import asyncio diff --git a/src/giskard_hub/resources/helpers.py b/src/giskard_hub/resources/helpers.py index 3a31240..770ab48 100644 --- a/src/giskard_hub/resources/helpers.py +++ b/src/giskard_hub/resources/helpers.py @@ -283,7 +283,7 @@ def _scan_local( from ._poll_scan import run_poll_scan from ..types.common import APIResponse as _APIResponse - # 1. Create the scan record (enqueued to worker automatically). + # Create the scan record (enqueued to worker). body: dict[str, Any] = { "project_id": project_id, "agent_name": agent_name, @@ -303,7 +303,7 @@ def _scan_local( scan = self._unwrap(response) scan_id = scan.id - # 2. Poll for invocations and execute the local agent. + # Poll for invocations and execute the local agent. base_url = str(self._client.base_url) api_key = self._client.api_key @@ -315,7 +315,7 @@ def _scan_local( http_client=self._client._client, ) - # 3. Retrieve the final scan result. + # Retrieve the final scan result. return self._client.scans.retrieve(scan_id) def _evaluate_remote( @@ -606,7 +606,7 @@ async def _scan_local( from ._poll_scan import arun_poll_scan from ..types.common import APIResponse as _APIResponse - # 1. Create the scan record (enqueued to worker automatically). + # Create the scan record (enqueued to worker). body: dict[str, Any] = { "project_id": project_id, "agent_name": agent_name, @@ -626,7 +626,7 @@ async def _scan_local( scan = self._unwrap(response) scan_id = scan.id - # 2. Poll for invocations and execute the local agent. + # Poll for invocations and execute the local agent. base_url = str(self._client.base_url) api_key = self._client.api_key @@ -638,7 +638,7 @@ async def _scan_local( http_client=self._client._client, ) - # 3. Retrieve the final scan result. + # Retrieve the final scan result. return await self._client.scans.retrieve(scan_id) async def _evaluate_remote( From b50da72022460e8f6b79c5bb88991d9912bc5f05 Mon Sep 17 00:00:00 2001 From: Henrique Chaves Date: Thu, 19 Mar 2026 14:41:01 +0100 Subject: [PATCH 6/7] reduce code duplication --- src/giskard_hub/resources/_helpers_types.py | 25 +++++- src/giskard_hub/resources/_poll_scan.py | 8 +- src/giskard_hub/resources/helpers.py | 93 +++++++-------------- 3 files changed, 60 insertions(+), 66 deletions(-) diff --git a/src/giskard_hub/resources/_helpers_types.py b/src/giskard_hub/resources/_helpers_types.py index 481f682..854f9eb 100644 --- a/src/giskard_hub/resources/_helpers_types.py +++ b/src/giskard_hub/resources/_helpers_types.py @@ -1,11 +1,10 @@ """Protocols, type aliases, and utility functions for the helpers resource.""" -from __future__ import annotations - from typing import TYPE_CHECKING, Any, TypeVar, Protocol, runtime_checkable from pydantic import TypeAdapter +from .._types import SequenceNotStr from .._models import BaseModel from .._resource import SyncAPIResource, AsyncAPIResource from ..types.chat import ChatMessage @@ -72,6 +71,28 @@ async def retrieve(self, id: str) -> StatefulEntity: ... # --------------------------------------------------------------------------- +def build_local_scan_body( + project_id: str, + agent_name: str, + agent_description: str, + supported_languages: list[str], + knowledge_base_id: str | None, + tags: SequenceNotStr[str] | None, +) -> dict[str, Any]: + """Build the request body for POST /v2/scans/create-local.""" + body: dict[str, Any] = { + "project_id": project_id, + "agent_name": agent_name, + "agent_description": agent_description, + "supported_languages": supported_languages, + } + if knowledge_base_id: + body["knowledge_base_id"] = knowledge_base_id + if tags: + body["tags"] = list(tags) + return body + + def normalize_agent_output(value: Any) -> AgentOutput: """Validate and normalize arbitrary agent return values into an ``AgentOutput``.""" parsed: AgentReturn = agent_return_adapter.validate_python(value) diff --git a/src/giskard_hub/resources/_poll_scan.py b/src/giskard_hub/resources/_poll_scan.py index 5ff063f..fa7f03f 100644 --- a/src/giskard_hub/resources/_poll_scan.py +++ b/src/giskard_hub/resources/_poll_scan.py @@ -21,6 +21,10 @@ def _normalize_output(value: Any) -> dict[str, object]: return normalize_agent_output(value).to_dict() +def _parse_messages(raw_messages: list[dict[str, Any]]) -> list[ChatMessage]: + return [ChatMessage(role=m["role"], content=m.get("content", "")) for m in raw_messages] + + def run_poll_scan( base_url: str, api_key: str, @@ -42,7 +46,7 @@ def run_poll_scan( data = resp.json()["data"] for inv in data["invocations"]: - messages = [ChatMessage(role=m["role"], content=m.get("content", "")) for m in inv["messages"]] + messages = _parse_messages(inv["messages"]) try: body: dict[str, Any] = {"output": _normalize_output(agent(messages))} except Exception as exc: @@ -75,7 +79,7 @@ async def arun_poll_scan( data = resp.json()["data"] async def _process(inv: dict[str, Any]) -> None: - messages = [ChatMessage(role=m["role"], content=m.get("content", "")) for m in inv["messages"]] + messages = _parse_messages(inv["messages"]) try: result = agent(messages) if inspect.isawaitable(result): diff --git a/src/giskard_hub/resources/helpers.py b/src/giskard_hub/resources/helpers.py index 770ab48..3e5741e 100644 --- a/src/giskard_hub/resources/helpers.py +++ b/src/giskard_hub/resources/helpers.py @@ -30,6 +30,7 @@ PrintMetricsEntity, RetrievableResource, AsyncRetrievableResource, + build_local_scan_body, map_entity_to_resource, normalize_agent_output, ) @@ -251,8 +252,8 @@ def _scan_remote( *, agent: str | Agent, project_id: str, - knowledge_base_id: "str | Omit | None", - tags: "Optional[SequenceNotStr[str]] | Omit", + knowledge_base_id: str | Omit | None, + tags: Optional[SequenceNotStr[str]] | Omit, poll_interval: float, ) -> Scan: agent_id = agent if isinstance(agent, str) else agent.id @@ -275,7 +276,7 @@ def _scan_local( agent: Callable[[list[ChatMessage]], AgentReturn], project_id: str, knowledge_base_id: str | None, - tags: "list[str] | SequenceNotStr[str] | None", + tags: SequenceNotStr[str] | None, agent_name: str, agent_description: str, supported_languages: list[str], @@ -283,40 +284,24 @@ def _scan_local( from ._poll_scan import run_poll_scan from ..types.common import APIResponse as _APIResponse - # Create the scan record (enqueued to worker). - body: dict[str, Any] = { - "project_id": project_id, - "agent_name": agent_name, - "agent_description": agent_description, - "supported_languages": supported_languages, - } - if knowledge_base_id: - body["knowledge_base_id"] = knowledge_base_id - if tags: - body["tags"] = list(tags) - - response = self._post( - "/v2/scans/create-local", - body=body, - cast_to=_APIResponse[Scan], + body = build_local_scan_body( + project_id, + agent_name, + agent_description, + supported_languages, + knowledge_base_id, + tags, ) - scan = self._unwrap(response) - scan_id = scan.id - - # Poll for invocations and execute the local agent. - base_url = str(self._client.base_url) - api_key = self._client.api_key + scan = self._unwrap(self._post("/v2/scans/create-local", body=body, cast_to=_APIResponse[Scan])) run_poll_scan( - base_url=base_url, - api_key=api_key, - scan_id=scan_id, + base_url=str(self._client.base_url), + api_key=self._client.api_key, + scan_id=scan.id, agent=agent, http_client=self._client._client, ) - - # Retrieve the final scan result. - return self._client.scans.retrieve(scan_id) + return self._client.scans.retrieve(scan.id) def _evaluate_remote( self, @@ -574,8 +559,8 @@ async def _scan_remote( *, agent: str | Agent, project_id: str, - knowledge_base_id: "str | Omit | None", - tags: "Optional[SequenceNotStr[str]] | Omit", + knowledge_base_id: str | Omit | None, + tags: Optional[SequenceNotStr[str]] | Omit, poll_interval: float, ) -> Scan: agent_id = agent if isinstance(agent, str) else agent.id @@ -598,7 +583,7 @@ async def _scan_local( agent: Callable[[list[ChatMessage]], AgentReturn | Awaitable[AgentReturn]], project_id: str, knowledge_base_id: str | None, - tags: "list[str] | SequenceNotStr[str] | None", + tags: SequenceNotStr[str] | None, agent_name: str, agent_description: str, supported_languages: list[str], @@ -606,40 +591,24 @@ async def _scan_local( from ._poll_scan import arun_poll_scan from ..types.common import APIResponse as _APIResponse - # Create the scan record (enqueued to worker). - body: dict[str, Any] = { - "project_id": project_id, - "agent_name": agent_name, - "agent_description": agent_description, - "supported_languages": supported_languages, - } - if knowledge_base_id: - body["knowledge_base_id"] = knowledge_base_id - if tags: - body["tags"] = list(tags) - - response = await self._post( - "/v2/scans/create-local", - body=body, - cast_to=_APIResponse[Scan], + body = build_local_scan_body( + project_id, + agent_name, + agent_description, + supported_languages, + knowledge_base_id, + tags, ) - scan = self._unwrap(response) - scan_id = scan.id - - # Poll for invocations and execute the local agent. - base_url = str(self._client.base_url) - api_key = self._client.api_key + scan = self._unwrap(await self._post("/v2/scans/create-local", body=body, cast_to=_APIResponse[Scan])) await arun_poll_scan( - base_url=base_url, - api_key=api_key, - scan_id=scan_id, + base_url=str(self._client.base_url), + api_key=self._client.api_key, + scan_id=scan.id, agent=agent, http_client=self._client._client, ) - - # Retrieve the final scan result. - return await self._client.scans.retrieve(scan_id) + return await self._client.scans.retrieve(scan.id) async def _evaluate_remote( self, From 77599c80e8a8172df6c21c15fe7818cd4f7166a7 Mon Sep 17 00:00:00 2001 From: Henrique Chaves Date: Thu, 19 Mar 2026 14:59:51 +0100 Subject: [PATCH 7/7] fix issues raised by gemini --- src/giskard_hub/resources/_helpers_types.py | 2 +- src/giskard_hub/resources/_poll_scan.py | 4 ++-- src/giskard_hub/resources/helpers.py | 17 +++++++++++------ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/giskard_hub/resources/_helpers_types.py b/src/giskard_hub/resources/_helpers_types.py index 854f9eb..a1ca080 100644 --- a/src/giskard_hub/resources/_helpers_types.py +++ b/src/giskard_hub/resources/_helpers_types.py @@ -75,7 +75,7 @@ def build_local_scan_body( project_id: str, agent_name: str, agent_description: str, - supported_languages: list[str], + supported_languages: SequenceNotStr[str], knowledge_base_id: str | None, tags: SequenceNotStr[str] | None, ) -> dict[str, Any]: diff --git a/src/giskard_hub/resources/_poll_scan.py b/src/giskard_hub/resources/_poll_scan.py index fa7f03f..4684c84 100644 --- a/src/giskard_hub/resources/_poll_scan.py +++ b/src/giskard_hub/resources/_poll_scan.py @@ -50,7 +50,7 @@ def run_poll_scan( try: body: dict[str, Any] = {"output": _normalize_output(agent(messages))} except Exception as exc: - logger.error("Agent invocation failed: %s", exc) + logger.exception("Agent invocation failed") body = {"error": {"message": str(exc)}} http_client.post(f"{url}/{inv['id']}/respond", json=body, headers=headers).raise_for_status() @@ -86,7 +86,7 @@ async def _process(inv: dict[str, Any]) -> None: result = await result body: dict[str, Any] = {"output": _normalize_output(result)} except Exception as exc: - logger.error("Agent invocation failed: %s", exc) + logger.exception("Agent invocation failed") body = {"error": {"message": str(exc)}} (await http_client.post(f"{url}/{inv['id']}/respond", json=body, headers=headers)).raise_for_status() diff --git a/src/giskard_hub/resources/helpers.py b/src/giskard_hub/resources/helpers.py index 3e5741e..6af7100 100644 --- a/src/giskard_hub/resources/helpers.py +++ b/src/giskard_hub/resources/helpers.py @@ -36,6 +36,7 @@ ) from ..types.test_case import TestCase from ..types.evaluation import Evaluation, TestCaseEvaluation +from ..types.knowledge_base import KnowledgeBase __all__ = ["HelpersResource", "AsyncHelpersResource"] @@ -163,7 +164,7 @@ def scan( *, agent: str | Agent | Callable[[list[ChatMessage]], AgentReturn], project: str | Project, - knowledge_base: Optional[str] | Omit = omit, + knowledge_base: Optional[str | KnowledgeBase] | Omit = omit, tags: Optional[SequenceNotStr[str]] | Omit = omit, agent_name: Optional[str] = None, agent_description: Optional[str] = None, @@ -211,7 +212,9 @@ def scan( kb_id = ( knowledge_base if isinstance(knowledge_base, str) - else (knowledge_base if knowledge_base is omit else knowledge_base) + else knowledge_base.id + if isinstance(knowledge_base, KnowledgeBase) + else None ) if isinstance(agent, (str, Agent)): @@ -502,7 +505,7 @@ async def scan( *, agent: str | Agent | Callable[[list[ChatMessage]], AgentReturn | Awaitable[AgentReturn]], project: str | Project, - knowledge_base: Optional[str] | Omit = omit, + knowledge_base: Optional[str | KnowledgeBase] | Omit = omit, tags: Optional[SequenceNotStr[str]] | Omit = omit, agent_name: Optional[str] = None, agent_description: Optional[str] = None, @@ -518,7 +521,9 @@ async def scan( kb_id = ( knowledge_base if isinstance(knowledge_base, str) - else (knowledge_base if knowledge_base is omit else knowledge_base) + else knowledge_base.id + if isinstance(knowledge_base, KnowledgeBase) + else None ) if isinstance(agent, (str, Agent)): @@ -559,7 +564,7 @@ async def _scan_remote( *, agent: str | Agent, project_id: str, - knowledge_base_id: str | Omit | None, + knowledge_base_id: str | None, tags: Optional[SequenceNotStr[str]] | Omit, poll_interval: float, ) -> Scan: @@ -569,7 +574,7 @@ async def _scan_remote( "project_id": project_id, "agent_id": agent_id, } - if knowledge_base_id is not omit and knowledge_base_id is not None: + if knowledge_base_id is not None: create_kwargs["knowledge_base_id"] = knowledge_base_id if tags is not omit: create_kwargs["tags"] = tags