Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions src/giskard_hub/resources/_check_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Request-side helpers for converting `CheckConfigParam` into the wire format."""

from typing import Any, Dict, Iterable, Optional, cast

from .._models import BaseModel
from ..types.check import CheckConfigParam

# Maps a built-in check identifier to its `kind` discriminator.
IDENTIFIER_TO_KIND: Dict[str, str] = {
"correctness": "hub_correctness",
"conformity": "hub_conformity",
"groundedness": "hub_groundedness",
"string_match": "string_matching",
"metadata": "hub_metadata",
"semantic_similarity": "semantic_similarity",
}


def check_param_to_spec(identifier: Optional[str], params: Any) -> Dict[str, Any]:
"""Build a `spec` dict, deriving `kind` from `params["type"]` then `identifier`."""
if isinstance(params, BaseModel):
params_dict: Dict[str, Any] = params.model_dump(exclude_none=True)
elif isinstance(params, dict):
params_dict = dict(cast(Dict[str, Any], params))
else:
params_dict = {}
type_from_params = params_dict.pop("type", None)
type_str = type_from_params or identifier or ""
if not type_str:
raise ValueError(
"Cannot derive check kind: provide 'identifier' or include 'type' in 'params', "
"or pass 'spec' directly with an explicit 'kind'."
)
kind = IDENTIFIER_TO_KIND.get(type_str, type_str)
return {"kind": kind, **params_dict}


def check_params_to_specs(
checks: Iterable[CheckConfigParam],
*,
flat: bool = False,
) -> list[Dict[str, Any]]:
"""Convert checks to the wire format.

`flat=False` (default) wraps params under a `spec` key:
`{identifier, enabled, spec: {kind, ...params}}`.

`flat=True` spreads params alongside `identifier`:
`{identifier, ...params}` (with the redundant `type` key stripped).
"""
result: list[Dict[str, Any]] = []
for check in checks:
identifier = check["identifier"]
params = check.get("params") or {}
if flat:
result.append({"identifier": identifier, **{k: v for k, v in params.items() if k != "type"}})
else:
entry: Dict[str, Any] = {"identifier": identifier, "enabled": check.get("enabled", True)}
if params:
entry["spec"] = check_param_to_spec(identifier, params)
result.append(entry)
Comment on lines +56 to +61
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The flat=True branch currently assumes params is a dictionary and calls .items(). However, params can also be a Pydantic BaseModel (as handled in check_param_to_spec). This will cause an AttributeError at runtime if a model is passed. Reusing check_param_to_spec ensures consistent handling of both dictionaries and models while also correctly stripping the type key.

Suggested change
result.append({"identifier": identifier, **{k: v for k, v in params.items() if k != "type"}})
else:
entry: Dict[str, Any] = {"identifier": identifier, "enabled": check.get("enabled", True)}
if params:
entry["spec"] = check_param_to_spec(identifier, params)
result.append(entry)
if flat:
spec = check_param_to_spec(identifier, params)
spec.pop("kind", None)
result.append({"identifier": identifier, **spec})
else:
entry: Dict[str, Any] = {"identifier": identifier, "enabled": check.get("enabled", True)}
if params:
entry["spec"] = check_param_to_spec(identifier, params)
result.append(entry)

return result
10 changes: 5 additions & 5 deletions src/giskard_hub/resources/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
CheckCreateParams,
CheckUpdateParams,
CheckBulkDeleteParams,
_check_param_to_spec,
)
from .._base_client import make_request_options
from ..types.common import APIResponse
from ._check_helpers import check_param_to_spec

__all__ = ["ChecksResource", "AsyncChecksResource"]

Expand Down Expand Up @@ -113,7 +113,7 @@ def create(
if not params_provided and not spec_provided:
raise ValueError("Must provide either 'params' or 'spec'.")

api_spec = spec if spec_provided else _check_param_to_spec(identifier, params)
api_spec = spec if spec_provided else check_param_to_spec(identifier, params)
response = self._post(
"/v2/checks",
body=maybe_transform(
Expand Down Expand Up @@ -256,7 +256,7 @@ def update(
api_spec = None
else:
type_or_id = identifier if isinstance(identifier, str) else None
api_spec = _check_param_to_spec(type_or_id, params)
api_spec = check_param_to_spec(type_or_id, params)
else:
api_spec = omit
response = self._patch(
Expand Down Expand Up @@ -509,7 +509,7 @@ async def create(
if not params_provided and not spec_provided:
raise ValueError("Must provide either 'params' or 'spec'.")

api_spec = spec if spec_provided else _check_param_to_spec(identifier, params)
api_spec = spec if spec_provided else check_param_to_spec(identifier, params)
response = await self._post(
"/v2/checks",
body=await async_maybe_transform(
Expand Down Expand Up @@ -652,7 +652,7 @@ async def update(
api_spec = None
else:
type_or_id = identifier if isinstance(identifier, str) else None
api_spec = _check_param_to_spec(type_or_id, params)
api_spec = check_param_to_spec(type_or_id, params)
else:
api_spec = omit
response = await self._patch(
Expand Down
19 changes: 3 additions & 16 deletions src/giskard_hub/resources/evaluations/evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ..._base_client import make_request_options
from ...types.common import APIResponse, APIResponseWithIncluded
from ...types.dataset import DatasetSubsetParam
from .._check_helpers import check_params_to_specs
from ...types.evaluation import (
Evaluation,
EvaluationListParams,
Expand Down Expand Up @@ -70,20 +71,6 @@ def _validate_dataset_or_old_evaluation(
raise ValueError("Exactly one of `dataset_id` or `old_evaluation_id` must be provided")


def _check_params_to_api(
checks: Iterable[CheckConfigParam],
) -> Iterable[dict[str, object]]:
# Flat shape for /v2/evaluations/run-single (FlatCheckSpec). `type` is stripped because it
# duplicates `identifier` and would leak into the spec extras.
return [
{
"identifier": check["identifier"],
**{k: v for k, v in check.get("params", {}).items() if k != "type"},
}
for check in checks
]


def _normalize_agent_output(
agent_output: AgentOutputParam | str,
) -> AgentOutputParam:
Expand Down Expand Up @@ -721,7 +708,7 @@ def run_single(
# Use input_data if provided, otherwise fall back to messages
final_input_data = input_data if input_data_provided else messages

api_checks: Iterable[dict[str, object]] = _check_params_to_api(checks)
api_checks: Iterable[dict[str, object]] = check_params_to_specs(checks, flat=True)

response = self._post(
"/v2/evaluations/run-single",
Expand Down Expand Up @@ -1373,7 +1360,7 @@ async def run_single(
# Use input_data if provided, otherwise fall back to messages
final_input_data = input_data if input_data_provided else messages

api_checks: Iterable[dict[str, object]] = _check_params_to_api(checks)
api_checks: Iterable[dict[str, object]] = check_params_to_specs(checks, flat=True)

response = await self._post(
"/v2/evaluations/run-single",
Expand Down
11 changes: 6 additions & 5 deletions src/giskard_hub/resources/test_cases/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@
async_to_streamed_response_wrapper,
)
from ...types.chat import ChatMessageParam, ChatMessageWithMetadataParam
from ...types.check import CheckConfigParam, _check_params_to_api
from ...types.check import CheckConfigParam
from ..._base_client import make_request_options
from ...types.common import APIResponse
from .._check_helpers import check_params_to_specs
from ...types.test_case import (
TestCase,
TestCaseCreateParams,
Expand Down Expand Up @@ -159,7 +160,7 @@ def create(
# Use input_data if provided, otherwise fall back to messages
final_input_data = input_data if input_data_provided else messages

api_checks: Iterable[object] | Omit = _check_params_to_api(checks) if not isinstance(checks, Omit) else omit
api_checks: Iterable[object] | Omit = check_params_to_specs(checks) if not isinstance(checks, Omit) else omit
api_demo_output = _normalize_demo_output(demo_output)
response = self._post(
"/v2/test-cases",
Expand Down Expand Up @@ -321,7 +322,7 @@ def update(
if checks is None or isinstance(checks, Omit):
api_checks = checks # type: ignore[assignment]
else:
api_checks = _check_params_to_api(checks)
api_checks = check_params_to_specs(checks)
api_demo_output = _normalize_demo_output(demo_output)
response = self._patch(
f"/v2/test-cases/{test_case_id}",
Expand Down Expand Up @@ -683,7 +684,7 @@ async def create(
# Use input_data if provided, otherwise fall back to messages
final_input_data = input_data if input_data_provided else messages

api_checks: Iterable[object] | Omit = _check_params_to_api(checks) if not isinstance(checks, Omit) else omit
api_checks: Iterable[object] | Omit = check_params_to_specs(checks) if not isinstance(checks, Omit) else omit
api_demo_output = _normalize_demo_output(demo_output)
response = await self._post(
"/v2/test-cases",
Expand Down Expand Up @@ -845,7 +846,7 @@ async def update(
if checks is None or isinstance(checks, Omit):
api_checks = checks # type: ignore[assignment]
else:
api_checks = _check_params_to_api(checks)
api_checks = check_params_to_specs(checks)
api_demo_output = _normalize_demo_output(demo_output)
response = await self._patch(
f"/v2/test-cases/{test_case_id}",
Expand Down
48 changes: 1 addition & 47 deletions src/giskard_hub/types/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,18 +240,8 @@ class TestCaseCheckConfigParam(TypedDict, total=False):
# ---------------------------------------------------------------------------


# Mirrors the backend
_IDENTIFIER_TO_KIND: Dict[str, str] = {
"correctness": "hub_correctness",
"conformity": "hub_conformity",
"groundedness": "hub_groundedness",
"string_match": "string_matching",
"metadata": "hub_metadata",
"semantic_similarity": "semantic_similarity",
}


def _extract_check_params(check: Dict[str, Any]) -> Dict[str, Any]:
"""Strip `kind` from `spec` to derive the user-facing `params` dict."""
spec: Any = check.get("spec") or {}
if isinstance(spec, BaseModel):
return spec.model_dump(exclude={"kind"}, exclude_none=True)
Expand All @@ -260,25 +250,6 @@ def _extract_check_params(check: Dict[str, Any]) -> Dict[str, Any]:
return {}


def _check_param_to_spec(identifier: Optional[str], params: Any) -> Dict[str, Any]:
"""Build a `spec` dict, deriving `kind` from `params["type"]` then `identifier`."""
if isinstance(params, BaseModel):
params_dict: Dict[str, Any] = params.model_dump(exclude_none=True)
elif isinstance(params, dict):
params_dict = dict(cast(Dict[str, Any], params))
else:
params_dict = {}
type_from_params = params_dict.pop("type", None)
type_str = type_from_params or identifier or ""
if not type_str:
raise ValueError(
"Cannot derive check kind: provide 'identifier' or include 'type' in 'params', "
"or pass 'spec' directly with an explicit 'kind'."
)
kind = _IDENTIFIER_TO_KIND.get(type_str, type_str)
return {"kind": kind, **params_dict}


class CheckConfig(BaseModel):
identifier: str
enabled: Optional[bool] = None
Expand All @@ -303,23 +274,6 @@ class CheckConfigParam(TypedDict, total=False):
params: Dict[str, Any]


def _check_params_to_api( # pyright: ignore[reportUnusedFunction]
checks: Iterable[CheckConfigParam],
) -> list[Dict[str, Any]]:
return [
{
"identifier": check["identifier"],
"enabled": check.get("enabled", True),
**(
{"spec": _check_param_to_spec(check["identifier"], check.get("params", {}))}
if check.get("params")
else {}
),
}
for check in checks
]


# ---------------------------------------------------------------------------
# Check params
# ---------------------------------------------------------------------------
Expand Down
34 changes: 17 additions & 17 deletions tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
CorrectnessParams,
JsonPathRuleParam,
)
from giskard_hub.types.check import (
_IDENTIFIER_TO_KIND,
_check_param_to_spec,
_check_params_to_api,
_extract_check_params,
from giskard_hub.types.check import _extract_check_params
from giskard_hub.resources._check_helpers import (
IDENTIFIER_TO_KIND,
check_param_to_spec,
check_params_to_specs,
)

# ---------------------------------------------------------------------------
Expand All @@ -28,7 +28,7 @@


def test_identifier_to_kind_mapping() -> None:
assert _IDENTIFIER_TO_KIND == {
assert IDENTIFIER_TO_KIND == {
"correctness": "hub_correctness",
"conformity": "hub_conformity",
"groundedness": "hub_groundedness",
Expand All @@ -39,27 +39,27 @@ def test_identifier_to_kind_mapping() -> None:


def test_check_param_to_spec_prefers_params_type_over_identifier() -> None:
spec = _check_param_to_spec("custom_name", {"type": "conformity", "rules": ["r"]})
spec = check_param_to_spec("custom_name", {"type": "conformity", "rules": ["r"]})
assert spec == {"kind": "hub_conformity", "rules": ["r"]}


def test_check_param_to_spec_falls_back_to_identifier() -> None:
spec = _check_param_to_spec("correctness", {"reference": "x"})
spec = check_param_to_spec("correctness", {"reference": "x"})
assert spec == {"kind": "hub_correctness", "reference": "x"}


def test_check_param_to_spec_passes_through_unknown_kind() -> None:
spec = _check_param_to_spec("future_kind", {"foo": 1})
spec = check_param_to_spec("future_kind", {"foo": 1})
assert spec == {"kind": "future_kind", "foo": 1}


def test_check_param_to_spec_raises_when_no_kind_derivable() -> None:
with pytest.raises(ValueError, match="Cannot derive check kind"):
_check_param_to_spec(None, {"reference": "x"})
check_param_to_spec(None, {"reference": "x"})


def test_check_param_to_spec_accepts_basemodel() -> None:
spec = _check_param_to_spec("correctness", CorrectnessParams(reference="x"))
spec = check_param_to_spec("correctness", CorrectnessParams(reference="x"))
assert spec == {"kind": "hub_correctness", "reference": "x"}


Expand All @@ -73,8 +73,8 @@ def test_extract_check_params_empty_when_no_spec() -> None:
assert _extract_check_params({"spec": None}) == {}


def test_check_params_to_api_emits_spec_with_kind() -> None:
api = _check_params_to_api([{"identifier": "correctness", "params": {"reference": "x"}}])
def test_check_params_to_specs_emits_nested_with_kind() -> None:
api = check_params_to_specs([{"identifier": "correctness", "params": {"reference": "x"}}])
assert api == [
{
"identifier": "correctness",
Expand All @@ -84,13 +84,13 @@ def test_check_params_to_api_emits_spec_with_kind() -> None:
]


def test_check_params_to_api_omits_spec_when_no_params() -> None:
api = _check_params_to_api([{"identifier": "tone_pro_xyz"}])
def test_check_params_to_specs_omits_spec_when_no_params() -> None:
api = check_params_to_specs([{"identifier": "tone_pro_xyz"}])
assert api == [{"identifier": "tone_pro_xyz", "enabled": True}]


def test_check_params_to_api_strips_redundant_type() -> None:
api = _check_params_to_api([{"identifier": "string_match", "params": {"type": "string_match", "keyword": "k"}}])
def test_check_params_to_specs_strips_redundant_type() -> None:
api = check_params_to_specs([{"identifier": "string_match", "params": {"type": "string_match", "keyword": "k"}}])
assert api == [
{
"identifier": "string_match",
Expand Down
17 changes: 8 additions & 9 deletions tests/test_evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,23 @@
import pytest

from giskard_hub import HubClient, AsyncHubClient
from giskard_hub.resources.evaluations.evaluations import (
_check_params_to_api,
_normalize_agent_output,
)
from giskard_hub.resources._check_helpers import check_params_to_specs
from giskard_hub.resources.evaluations.evaluations import _normalize_agent_output

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def test_check_params_to_api_emits_flat_shape() -> None:
api = list(_check_params_to_api([{"identifier": "correctness", "params": {"reference": "x"}}]))
def test_check_params_to_specs_emits_flat_shape() -> None:
api = check_params_to_specs([{"identifier": "correctness", "params": {"reference": "x"}}], flat=True)
assert api == [{"identifier": "correctness", "reference": "x"}]


def test_check_params_to_api_strips_redundant_type() -> None:
api = list(
_check_params_to_api([{"identifier": "string_match", "params": {"type": "string_match", "keyword": "k"}}])
def test_check_params_to_specs_strips_redundant_type_when_flat() -> None:
api = check_params_to_specs(
[{"identifier": "string_match", "params": {"type": "string_match", "keyword": "k"}}],
flat=True,
)
assert api == [{"identifier": "string_match", "keyword": "k"}]

Expand Down