diff --git a/src/giskard_hub/resources/_helpers_types.py b/src/giskard_hub/resources/_helpers_types.py index abacaa6..df15510 100644 --- a/src/giskard_hub/resources/_helpers_types.py +++ b/src/giskard_hub/resources/_helpers_types.py @@ -7,6 +7,7 @@ from pydantic import TypeAdapter +from .._types import SequenceNotStr from .._models import BaseModel from ..types.chat import ChatMessage from ..types.scan import Scan, ScanProbe @@ -67,6 +68,28 @@ def state(self) -> TaskState: ... # --------------------------------------------------------------------------- +def build_local_scan_body( + project_id: str, + agent_name: str, + agent_description: str, + supported_languages: SequenceNotStr[str], + knowledge_base_id: str | None, + tags: SequenceNotStr[str] | None, +) -> dict[str, Any]: + """Build the request body for POST /v2/scans/create-local.""" + body: dict[str, Any] = { + "project_id": project_id, + "agent_name": agent_name, + "agent_description": agent_description, + "supported_languages": supported_languages, + } + if knowledge_base_id: + body["knowledge_base_id"] = knowledge_base_id + if tags: + body["tags"] = list(tags) + return body + + def normalize_agent_output(value: Any) -> AgentOutput: """Validate and normalize arbitrary agent return values into an `AgentOutput`.""" parsed: AgentReturn = agent_return_adapter.validate_python(value) diff --git a/src/giskard_hub/resources/_poll_scan.py b/src/giskard_hub/resources/_poll_scan.py new file mode 100644 index 0000000..4684c84 --- /dev/null +++ b/src/giskard_hub/resources/_poll_scan.py @@ -0,0 +1,99 @@ +"""Polling client for local-agent scan execution.""" + +import time +import asyncio +import inspect +import logging +from typing import Any, Callable, Awaitable + +import httpx + +from ..types.chat import ChatMessage + +logger = logging.getLogger(__name__) + +_TERMINAL_STATES = frozenset({"finished", "error", "canceled"}) + + +def _normalize_output(value: Any) -> dict[str, object]: + from ._helpers_types import normalize_agent_output + + return normalize_agent_output(value).to_dict() + + +def _parse_messages(raw_messages: list[dict[str, Any]]) -> list[ChatMessage]: + return [ChatMessage(role=m["role"], content=m.get("content", "")) for m in raw_messages] + + +def run_poll_scan( + base_url: str, + api_key: str, + scan_id: str, + agent: Callable[[list[ChatMessage]], Any], + http_client: httpx.Client, + poll_interval: float = 0.5, +) -> None: + """Poll for invocations, call the local agent, submit responses. + + Blocks until the scan reaches a terminal state. + """ + headers = {"X-API-Key": api_key} + url = f"{base_url}/v2/scans/{scan_id}/invocations" + + while True: + resp = http_client.get(url, params={"status": "pending"}, headers=headers) + resp.raise_for_status() + data = resp.json()["data"] + + for inv in data["invocations"]: + messages = _parse_messages(inv["messages"]) + try: + body: dict[str, Any] = {"output": _normalize_output(agent(messages))} + except Exception as exc: + logger.exception("Agent invocation failed") + body = {"error": {"message": str(exc)}} + + http_client.post(f"{url}/{inv['id']}/respond", json=body, headers=headers).raise_for_status() + + if data["scan_status"] in _TERMINAL_STATES: + break + + time.sleep(poll_interval) + + +async def arun_poll_scan( + base_url: str, + api_key: str, + scan_id: str, + agent: Callable[[list[ChatMessage]], Any | Awaitable[Any]], + http_client: httpx.AsyncClient, + poll_interval: float = 0.5, +) -> None: + """Async version of the polling loop.""" + headers = {"X-API-Key": api_key} + url = f"{base_url}/v2/scans/{scan_id}/invocations" + + while True: + resp = await http_client.get(url, params={"status": "pending"}, headers=headers) + resp.raise_for_status() + data = resp.json()["data"] + + async def _process(inv: dict[str, Any]) -> None: + messages = _parse_messages(inv["messages"]) + try: + result = agent(messages) + if inspect.isawaitable(result): + result = await result + body: dict[str, Any] = {"output": _normalize_output(result)} + except Exception as exc: + logger.exception("Agent invocation failed") + body = {"error": {"message": str(exc)}} + + (await http_client.post(f"{url}/{inv['id']}/respond", json=body, headers=headers)).raise_for_status() + + await asyncio.gather(*(_process(inv) for inv in data["invocations"])) + + if data["scan_status"] in _TERMINAL_STATES: + break + + await asyncio.sleep(poll_interval) diff --git a/src/giskard_hub/resources/helpers.py b/src/giskard_hub/resources/helpers.py index bef56e9..8beafc9 100644 --- a/src/giskard_hub/resources/helpers.py +++ b/src/giskard_hub/resources/helpers.py @@ -8,7 +8,7 @@ import time import asyncio import inspect -from typing import Callable, Optional, Awaitable, Collection, cast +from typing import Any, Callable, Optional, Awaitable, Collection, cast from concurrent.futures import ThreadPoolExecutor from .._types import Omit, SequenceNotStr, omit @@ -20,7 +20,7 @@ ) from .._resource import SyncAPIResource, AsyncAPIResource from ..types.chat import ChatMessage -from ..types.scan import ScanProbe, ScanProbeAttempt +from ..types.scan import Scan, ScanProbe, ScanProbeAttempt from ..types.agent import Agent, AgentOutputParam from ..types.dataset import Dataset from ..types.project import Project @@ -31,10 +31,12 @@ AsyncRetriever, PrintMetricsEntity, make_retriever, + build_local_scan_body, normalize_agent_output, ) from ..types.test_case import TestCase from ..types.evaluation import Evaluation, TestCaseEvaluation +from ..types.knowledge_base import KnowledgeBase __all__ = ["HelpersResource", "AsyncHelpersResource"] @@ -157,6 +159,83 @@ def evaluate( return self._evaluate_local(agent=agent, dataset_id=dataset_id, name=name, tags=tags) + def scan( + self, + *, + agent: str | Agent | Callable[[list[ChatMessage]], AgentReturn], + project: str | Project, + knowledge_base: Optional[str | KnowledgeBase] | Omit = omit, + tags: Optional[SequenceNotStr[str]] | Omit = omit, + agent_name: Optional[str] = None, + agent_description: Optional[str] = None, + supported_languages: Optional[list[str]] = None, + poll_interval: float = 5.0, + ) -> Scan: + """Run a vulnerability scan for a given agent. + + Handles both remote agents (referenced by ID or ``Agent``, which must + already be registered in the Hub) and local Python callables. + + When a local callable is provided, the Hub orchestrates LIDAR + server-side and the SDK polls for agent invocation requests via + stateless HTTP endpoints. The scan runs in the Hub's worker + process, same as remote scans. + + Parameters + ---------- + agent : + Either a remote agent identifier (``str`` or ``Agent``) or a + callable with signature ``(messages: list[ChatMessage]) -> AgentReturn``. + project : + Project identifier or ``Project`` instance. + knowledge_base : + Optional knowledge base identifier. + tags : + Optional list of tags to filter which probes to run. + agent_name : + Name for the agent (used only for local callables; defaults to + the function name). + agent_description : + Description of the agent (used only for local callables). + supported_languages : + Languages the agent supports (default ``["en"]``). + poll_interval : + Seconds between status polls when running a remote scan. + + Returns + ------- + Scan + The completed scan result with grade and probe information. + """ + + project_id = project if isinstance(project, str) else project.id + kb_id = ( + knowledge_base + if isinstance(knowledge_base, str) + else knowledge_base.id + if isinstance(knowledge_base, KnowledgeBase) + else None + ) + + if isinstance(agent, (str, Agent)): + return self._scan_remote( + agent=agent, + project_id=project_id, + knowledge_base_id=kb_id, + tags=tags, + poll_interval=poll_interval, + ) + + return self._scan_local( + agent=agent, + project_id=project_id, + knowledge_base_id=kb_id if isinstance(kb_id, str) else None, + tags=None if isinstance(tags, Omit) or tags is None else tags, + agent_name=agent_name if agent_name is not None else getattr(agent, "__name__", "local_agent"), + agent_description=agent_description or getattr(agent, "__doc__", None) or "", + supported_languages=supported_languages or ["en"], + ) + def print_metrics(self, entity: PrintMetricsEntity) -> None: """Print metrics for an evaluation or scan result to the console. @@ -171,6 +250,62 @@ def print_metrics(self, entity: PrintMetricsEntity) -> None: # -- Private helpers ----------------------------------------------------- + def _scan_remote( + self, + *, + agent: str | Agent, + project_id: str, + knowledge_base_id: str | Omit | None, + tags: Optional[SequenceNotStr[str]] | Omit, + poll_interval: float, + ) -> Scan: + agent_id = agent if isinstance(agent, str) else agent.id + + create_kwargs: dict[str, Any] = { + "project_id": project_id, + "agent_id": agent_id, + } + if knowledge_base_id is not omit and knowledge_base_id is not None: + create_kwargs["knowledge_base_id"] = knowledge_base_id + if tags is not omit: + create_kwargs["tags"] = tags + + scan = self._client.scans.create(**create_kwargs) + return self.wait_for_completion(scan, poll_interval=poll_interval) # type: ignore[return-value] + + def _scan_local( + self, + *, + agent: Callable[[list[ChatMessage]], AgentReturn], + project_id: str, + knowledge_base_id: str | None, + tags: SequenceNotStr[str] | None, + agent_name: str, + agent_description: str, + supported_languages: list[str], + ) -> Scan: + from ._poll_scan import run_poll_scan + from ..types.common import APIResponse as _APIResponse + + body = build_local_scan_body( + project_id, + agent_name, + agent_description, + supported_languages, + knowledge_base_id, + tags, + ) + scan = self._unwrap(self._post("/v2/scans/create-local", body=body, cast_to=_APIResponse[Scan])) + + run_poll_scan( + base_url=str(self._client.base_url), + api_key=self._client.api_key, + scan_id=scan.id, + agent=agent, + http_client=self._client._client, + ) + return self._client.scans.retrieve(scan.id) + def _evaluate_remote( self, *, @@ -232,9 +367,7 @@ def _evaluate_local( return evaluation def _print_scan_metrics(self, entity: object) -> None: - from ..types.scan import Scan as _Scan - - scan = cast(_Scan, entity) + scan = cast(Scan, entity) category_map = {cat.id: cat.title for cat in self._client.scans.list_categories()} probe_results = self._client.scans.list_probes(scan_id=scan.id) @@ -367,6 +500,51 @@ async def evaluate( return await self._evaluate_local(agent=agent, dataset_id=dataset_id, name=name, tags=tags) + async def scan( + self, + *, + agent: str | Agent | Callable[[list[ChatMessage]], AgentReturn | Awaitable[AgentReturn]], + project: str | Project, + knowledge_base: Optional[str | KnowledgeBase] | Omit = omit, + tags: Optional[SequenceNotStr[str]] | Omit = omit, + agent_name: Optional[str] = None, + agent_description: Optional[str] = None, + supported_languages: Optional[list[str]] = None, + poll_interval: float = 5.0, + ) -> Scan: + """Asynchronously run a vulnerability scan for a given agent. + + See :meth:`HelpersResource.scan` for full parameter documentation. + """ + + project_id = project if isinstance(project, str) else project.id + kb_id = ( + knowledge_base + if isinstance(knowledge_base, str) + else knowledge_base.id + if isinstance(knowledge_base, KnowledgeBase) + else None + ) + + if isinstance(agent, (str, Agent)): + return await self._scan_remote( + agent=agent, + project_id=project_id, + knowledge_base_id=kb_id, + tags=tags, + poll_interval=poll_interval, + ) + + return await self._scan_local( + agent=agent, + project_id=project_id, + knowledge_base_id=kb_id if isinstance(kb_id, str) else None, + tags=None if isinstance(tags, Omit) or tags is None else tags, + agent_name=agent_name if agent_name is not None else getattr(agent, "__name__", "local_agent"), + agent_description=agent_description or getattr(agent, "__doc__", None) or "", + supported_languages=supported_languages or ["en"], + ) + async def print_metrics(self, entity: PrintMetricsEntity) -> None: """Print metrics for an evaluation or scan result to the console (async). @@ -381,6 +559,62 @@ async def print_metrics(self, entity: PrintMetricsEntity) -> None: # -- Private helpers ----------------------------------------------------- + async def _scan_remote( + self, + *, + agent: str | Agent, + project_id: str, + knowledge_base_id: str | None, + tags: Optional[SequenceNotStr[str]] | Omit, + poll_interval: float, + ) -> Scan: + agent_id = agent if isinstance(agent, str) else agent.id + + create_kwargs: dict[str, Any] = { + "project_id": project_id, + "agent_id": agent_id, + } + if knowledge_base_id is not None: + create_kwargs["knowledge_base_id"] = knowledge_base_id + if tags is not omit: + create_kwargs["tags"] = tags + + scan = await self._client.scans.create(**create_kwargs) + return await self.wait_for_completion(scan, poll_interval=poll_interval) # type: ignore[return-value] + + async def _scan_local( + self, + *, + agent: Callable[[list[ChatMessage]], AgentReturn | Awaitable[AgentReturn]], + project_id: str, + knowledge_base_id: str | None, + tags: SequenceNotStr[str] | None, + agent_name: str, + agent_description: str, + supported_languages: list[str], + ) -> Scan: + from ._poll_scan import arun_poll_scan + from ..types.common import APIResponse as _APIResponse + + body = build_local_scan_body( + project_id, + agent_name, + agent_description, + supported_languages, + knowledge_base_id, + tags, + ) + scan = self._unwrap(await self._post("/v2/scans/create-local", body=body, cast_to=_APIResponse[Scan])) + + await arun_poll_scan( + base_url=str(self._client.base_url), + api_key=self._client.api_key, + scan_id=scan.id, + agent=agent, + http_client=self._client._client, + ) + return await self._client.scans.retrieve(scan.id) + async def _evaluate_remote( self, *, @@ -451,9 +685,7 @@ async def _process_entry(entry: TestCaseEvaluation) -> None: return evaluation async def _print_scan_metrics(self, entity: object) -> None: - from ..types.scan import Scan as _Scan - - scan = cast(_Scan, entity) + scan = cast(Scan, entity) category_map = {cat.id: cat.title for cat in await self._client.scans.list_categories()} probe_results = await self._client.scans.list_probes(scan_id=scan.id) attempts_list = await asyncio.gather(