Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/giskard_hub/resources/_helpers_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from pydantic import TypeAdapter

from .._types import SequenceNotStr
from .._models import BaseModel
from ..types.chat import ChatMessage
from ..types.scan import Scan, ScanProbe
Expand Down Expand Up @@ -67,6 +68,28 @@ def state(self) -> TaskState: ...
# ---------------------------------------------------------------------------


def build_local_scan_body(
project_id: str,
agent_name: str,
agent_description: str,
supported_languages: SequenceNotStr[str],
knowledge_base_id: str | None,
tags: SequenceNotStr[str] | None,
) -> dict[str, Any]:
"""Build the request body for POST /v2/scans/create-local."""
body: dict[str, Any] = {
"project_id": project_id,
"agent_name": agent_name,
"agent_description": agent_description,
"supported_languages": supported_languages,
}
if knowledge_base_id:
body["knowledge_base_id"] = knowledge_base_id
if tags:
body["tags"] = list(tags)
return body


def normalize_agent_output(value: Any) -> AgentOutput:
"""Validate and normalize arbitrary agent return values into an `AgentOutput`."""
parsed: AgentReturn = agent_return_adapter.validate_python(value)
Expand Down
99 changes: 99 additions & 0 deletions src/giskard_hub/resources/_poll_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""Polling client for local-agent scan execution."""

import time
import asyncio
import inspect
import logging
from typing import Any, Callable, Awaitable

import httpx

from ..types.chat import ChatMessage

logger = logging.getLogger(__name__)

_TERMINAL_STATES = frozenset({"finished", "error", "canceled"})


def _normalize_output(value: Any) -> dict[str, object]:
from ._helpers_types import normalize_agent_output

return normalize_agent_output(value).to_dict()


def _parse_messages(raw_messages: list[dict[str, Any]]) -> list[ChatMessage]:
return [ChatMessage(role=m["role"], content=m.get("content", "")) for m in raw_messages]


def run_poll_scan(
base_url: str,
api_key: str,
scan_id: str,
agent: Callable[[list[ChatMessage]], Any],
http_client: httpx.Client,
poll_interval: float = 0.5,
) -> None:
"""Poll for invocations, call the local agent, submit responses.

Blocks until the scan reaches a terminal state.
"""
headers = {"X-API-Key": api_key}
url = f"{base_url}/v2/scans/{scan_id}/invocations"

while True:
resp = http_client.get(url, params={"status": "pending"}, headers=headers)
resp.raise_for_status()
data = resp.json()["data"]

for inv in data["invocations"]:
messages = _parse_messages(inv["messages"])
try:
body: dict[str, Any] = {"output": _normalize_output(agent(messages))}
except Exception as exc:
logger.exception("Agent invocation failed")
body = {"error": {"message": str(exc)}}
Comment thread
henchaves marked this conversation as resolved.

http_client.post(f"{url}/{inv['id']}/respond", json=body, headers=headers).raise_for_status()

if data["scan_status"] in _TERMINAL_STATES:
break

time.sleep(poll_interval)


async def arun_poll_scan(
base_url: str,
api_key: str,
scan_id: str,
agent: Callable[[list[ChatMessage]], Any | Awaitable[Any]],
http_client: httpx.AsyncClient,
poll_interval: float = 0.5,
) -> None:
"""Async version of the polling loop."""
headers = {"X-API-Key": api_key}
url = f"{base_url}/v2/scans/{scan_id}/invocations"

while True:
resp = await http_client.get(url, params={"status": "pending"}, headers=headers)
resp.raise_for_status()
data = resp.json()["data"]

async def _process(inv: dict[str, Any]) -> None:
messages = _parse_messages(inv["messages"])
try:
result = agent(messages)
if inspect.isawaitable(result):
result = await result
body: dict[str, Any] = {"output": _normalize_output(result)}
except Exception as exc:
logger.exception("Agent invocation failed")
body = {"error": {"message": str(exc)}}
Comment thread
henchaves marked this conversation as resolved.

(await http_client.post(f"{url}/{inv['id']}/respond", json=body, headers=headers)).raise_for_status()

await asyncio.gather(*(_process(inv) for inv in data["invocations"]))

if data["scan_status"] in _TERMINAL_STATES:
break

await asyncio.sleep(poll_interval)
Loading