Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sift_client._internal.low_level_wrappers.runs import RunsLowLevelClient
from sift_client._internal.low_level_wrappers.tags import TagsLowLevelClient
from sift_client._internal.low_level_wrappers.test_results import TestResultsLowLevelClient
from sift_client._internal.low_level_wrappers.units import UnitsLowLevelClient
from sift_client._internal.low_level_wrappers.upload import UploadLowLevelClient

__all__ = [
Expand All @@ -27,5 +28,6 @@
"RunsLowLevelClient",
"TagsLowLevelClient",
"TestResultsLowLevelClient",
"UnitsLowLevelClient",
"UploadLowLevelClient",
]
93 changes: 93 additions & 0 deletions python/lib/sift_client/_internal/low_level_wrappers/units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any, cast

from sift.unit.v2.unit_pb2 import (
CreateUnitRequest,
CreateUnitResponse,
ListUnitsRequest,
ListUnitsResponse,
)
from sift.unit.v2.unit_pb2 import Unit as UnitProto
from sift.unit.v2.unit_pb2_grpc import UnitServiceStub

from sift_client._internal.low_level_wrappers.base import DEFAULT_PAGE_SIZE, LowLevelClientBase
from sift_client.transport import WithGrpcClient

if TYPE_CHECKING:
from sift_client.transport.grpc_transport import GrpcClient

# Configure logging
logger = logging.getLogger(__name__)


class UnitsLowLevelClient(LowLevelClientBase, WithGrpcClient):
"""Low-level client for the Units service.

This class provides a thin wrapper around the autogenerated bindings for the Units service.
"""

def __init__(self, grpc_client: GrpcClient):
"""Initialize the UnitsLowLevelClient.

Args:
grpc_client: The gRPC client to use for making API calls.
"""
super().__init__(grpc_client)

async def create_unit(self, name: str) -> UnitProto:
"""Create a new unit.

If a unit with the same name already exists, it is returned instead of creating a duplicate.

Args:
name: The name of the unit.

Returns:
The created unit proto, whose unit_id is used to reference the unit.

Raises:
ValueError: If name is not provided.
"""
if not name:
raise ValueError("name must be provided")

request = CreateUnitRequest(name=name)
response = await self._grpc_client.get_stub(UnitServiceStub).CreateUnit(request)
return cast("CreateUnitResponse", response).unit

async def list_units(
self,
*,
page_size: int | None = DEFAULT_PAGE_SIZE,
page_token: str | None = None,
query_filter: str | None = None,
order_by: str | None = None,
) -> tuple[list[UnitProto], str]:
"""List units with optional filtering and pagination.

Args:
page_size: The maximum number of units to return.
page_token: A page token for pagination.
query_filter: A CEL filter string (e.g. filtering on unit_id or abbreviated_name).
order_by: How to order the retrieved units.

Returns:
A tuple of (unit protos, next_page_token).
"""
request_kwargs: dict[str, Any] = {}
if page_size is not None:
request_kwargs["page_size"] = page_size
if page_token is not None:
request_kwargs["page_token"] = page_token
if query_filter is not None:
request_kwargs["filter"] = query_filter
if order_by is not None:
request_kwargs["order_by"] = order_by

request = ListUnitsRequest(**request_kwargs)
response = await self._grpc_client.get_stub(UnitServiceStub).ListUnits(request)
response = cast("ListUnitsResponse", response)

return list(response.units), response.next_page_token
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""Tests for the Units low-level wrapper."""

from unittest.mock import AsyncMock, MagicMock

import pytest
from sift.unit.v2.unit_pb2 import CreateUnitResponse
from sift.unit.v2.unit_pb2 import Unit as UnitProto

from sift_client._internal.low_level_wrappers.units import UnitsLowLevelClient


@pytest.mark.asyncio
async def test_create_unit_rejects_empty_name():
"""create_unit raises before making a request when name is empty."""
client = UnitsLowLevelClient(grpc_client=MagicMock())

with pytest.raises(ValueError, match="name must be provided"):
await client.create_unit("")


@pytest.mark.asyncio
async def test_create_unit_returns_created_unit_proto():
"""create_unit unwraps the response and returns the unit proto (unit_id + abbreviated_name)."""
stub = MagicMock()
stub.CreateUnit = AsyncMock(
return_value=CreateUnitResponse(unit=UnitProto(unit_id="u1", abbreviated_name="volts"))
)
grpc_client = MagicMock()
grpc_client.get_stub.return_value = stub

unit = await UnitsLowLevelClient(grpc_client).create_unit("volts")

assert unit.unit_id == "u1"
assert unit.abbreviated_name == "volts"
Loading