Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions python/lib/sift_client/_internal/low_level_wrappers/exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,25 @@
if TYPE_CHECKING:
from datetime import datetime

from sift_client.sift_types.channel import ChannelReference
from sift_client.sift_types.export import ExportOutputFormat
from sift_client.transport.grpc_transport import GrpcClient


def _abstract_ref_to_proto(ref: ChannelReference) -> CalculatedChannelAbstractChannelReference:
# After ChannelReference validation, calculated_channel is always a version_id string.
if ref.calculated_channel:
assert isinstance(ref.calculated_channel, str)
return CalculatedChannelAbstractChannelReference(
channel_reference=ref.channel_reference,
calculated_channel_version_id=ref.calculated_channel,
)
return CalculatedChannelAbstractChannelReference(
channel_reference=ref.channel_reference,
channel_identifier=ref.channel_identifier or "",
)


def _build_calc_channel_configs(
calculated_channels: list[CalculatedChannel | CalculatedChannelCreate] | None,
) -> list[CalculatedChannelConfig]:
Expand All @@ -46,13 +61,7 @@ def _build_calc_channel_configs(
CalculatedChannelConfig(
name=cc.name,
expression=cc.expression or "",
channel_references=[
CalculatedChannelAbstractChannelReference(
channel_reference=ref.channel_reference,
channel_identifier=ref.channel_identifier,
)
for ref in refs
],
channel_references=[_abstract_ref_to_proto(ref) for ref in refs],
units=cc.units,
)
)
Expand Down
14 changes: 12 additions & 2 deletions python/lib/sift_client/_internal/low_level_wrappers/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@
logger = logging.getLogger(__name__)


def _channel_reference_to_proto(ref: ChannelReference) -> ChannelReferenceProto:
# ChannelReference's validator normalizes calculated_channel to a version_id str
# and guarantees exactly one of calculated_channel / channel_identifier is set.
if ref.calculated_channel:
return ChannelReferenceProto(
calculated_channel_version_id=cast("str", ref.calculated_channel)
)
return ChannelReferenceProto(name=cast("str", ref.channel_identifier))


class RulesLowLevelClient(LowLevelClientBase, WithGrpcClient):
"""Low-level client for the RulesAPI.

Expand Down Expand Up @@ -126,7 +136,7 @@ def _update_rule_request_from_create(self, create: RuleCreate) -> UpdateRuleRequ
calculated_channel=CalculatedChannelConfig(
expression=create.expression,
channel_references={
c.channel_reference: ChannelReferenceProto(name=c.channel_identifier)
c.channel_reference: _channel_reference_to_proto(c)
for c in create.channel_references
},
)
Expand Down Expand Up @@ -259,7 +269,7 @@ def _update_rule_request_from_update(
CalculatedChannelConfig(
expression=expression,
channel_references={
c.channel_reference: ChannelReferenceProto(name=c.channel_identifier)
c.channel_reference: _channel_reference_to_proto(c)
for c in channel_references
},
)
Expand Down
3 changes: 3 additions & 0 deletions python/lib/sift_client/_internal/util/channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ async def resolve_calculated_channels(

resolved_refs: list[ChannelReference] = []
for ref in refs:
if ref.calculated_channel:
resolved_refs.append(ref)
continue
channel = await channels_api.find(
name=ref.channel_identifier,
assets=cc.asset_ids,
Expand Down
19 changes: 19 additions & 0 deletions python/lib/sift_client/_tests/_internal/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@ async def test_resolves_name_to_uuid(self):
assert refs is not None
assert refs[0].channel_identifier == "resolved-uuid"

@pytest.mark.asyncio
async def test_skips_lookup_for_calculated_channel_version_id(self):
api = MagicMock()
api.find = AsyncMock()
cc = CalculatedChannelCreate(
name="nested",
expression="$1 + 1",
expression_channel_references=[
ChannelReference(channel_reference="$1", calculated_channel="v-nested")
],
)
result = await resolve_calculated_channels([cc], channels_api=api)
api.find.assert_not_awaited()
assert result is not None
refs = result[0].expression_channel_references
assert refs is not None
assert refs[0].calculated_channel == "v-nested"
assert refs[0].channel_identifier is None

@pytest.mark.asyncio
async def test_keeps_identifier_when_not_found(self):
api = MagicMock()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from sift_client.sift_types import CalculatedChannel
from sift_client.sift_types.calculated_channel import (
CalculatedChannelCreate,
CalculatedChannelUpdate,
)
from sift_client.sift_types.channel import ChannelReference
Expand Down Expand Up @@ -151,6 +152,68 @@ def test_expression_validator_rejects_references_without_expression(self):
],
)

def test_nested_calculated_channel_reference_serialized_to_version_id_oneof(self):
"""A ChannelReference with calculated_channel set routes to the proto oneof."""
update = CalculatedChannelUpdate(
expression="$1 + $2",
expression_channel_references=[
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
ChannelReference(channel_reference="$2", calculated_channel="v-nested"),
],
)
update.resource_id = "test_calc_channel_id"

proto, _ = update.to_proto_with_mask()

refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references
assert len(refs) == 2
assert refs[0].channel_identifier == "channel1"
assert refs[0].WhichOneof("calculated_channel_reference") is None
assert refs[1].WhichOneof("calculated_channel_reference") == "calculated_channel_version_id"
assert refs[1].calculated_channel_version_id == "v-nested"
assert refs[1].channel_identifier == ""

def test_create_serializes_nested_calculated_channel_reference(self):
"""CalculatedChannelCreate.to_proto routes the version_id into the proto oneof."""
create = CalculatedChannelCreate(
name="nested-cc",
expression="$1 * 2",
expression_channel_references=[
ChannelReference(channel_reference="$1", calculated_channel="v-nested"),
],
all_assets=True,
)

proto = create.to_proto()

refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references
assert len(refs) == 1
assert refs[0].WhichOneof("calculated_channel_reference") == "calculated_channel_version_id"
assert refs[0].calculated_channel_version_id == "v-nested"

def test_create_serializes_nested_reference_from_calculated_channel_object(
self, mock_calculated_channel
):
"""Passing a CalculatedChannel object to ChannelReference also serializes correctly."""
create = CalculatedChannelCreate(
name="nested-cc",
expression="$1 * 2",
expression_channel_references=[
ChannelReference(
channel_reference="$1", calculated_channel=mock_calculated_channel
),
],
all_assets=True,
)

proto = create.to_proto()

refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references
assert len(refs) == 1
assert refs[0].WhichOneof("calculated_channel_reference") == "calculated_channel_version_id"
# mock_calculated_channel fixture has version_id="v1"
assert refs[0].calculated_channel_version_id == "v1"

def test_expression_validator_accepts_both_set(self):
"""Test validator accepts expression and channel references together."""
# Should not raise
Expand Down
96 changes: 95 additions & 1 deletion python/lib/sift_client/_tests/sift_types/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from sift_client.sift_types import Channel
from sift_client.sift_types.channel import ChannelDataType
from sift_client.sift_types.channel import ChannelDataType, ChannelReference


@pytest.fixture
Expand Down Expand Up @@ -104,6 +104,100 @@ def test_data_method_as_arrow(self, mock_channel, mock_client):
mock_client.channels.get_data.assert_not_called()
assert result == mock_data

def test_channel_reference_requires_one_target(self):
"""ChannelReference must specify exactly one of identifier or calculated_channel."""
with pytest.raises(ValueError, match="exactly one"):
ChannelReference(channel_reference="$1")
with pytest.raises(ValueError, match="exactly one"):
ChannelReference(
channel_reference="$1",
channel_identifier="ch",
calculated_channel="v-id",
)

def test_channel_reference_accepts_version_id_string(self):
"""A plain version_id string is stored as-is."""
ref = ChannelReference(channel_reference="$1", calculated_channel="v-abc")
assert ref.calculated_channel == "v-abc"
assert ref.channel_identifier is None

def test_channel_reference_accepts_calculated_channel_object(self):
"""Passing a CalculatedChannel normalizes to its version_id string."""
from sift_client.sift_types.calculated_channel import CalculatedChannel

cc = CalculatedChannel(
proto=MagicMock(),
id_="cc-id",
name="parent",
description="",
expression="$1",
channel_references=[],
is_archived=False,
units=None,
asset_ids=["asset-1"],
tag_ids=None,
all_assets=False,
organization_id=None,
client_key=None,
archived_date=None,
version_id="v-abc",
version=1,
change_message=None,
user_notes=None,
created_date=datetime.now(timezone.utc),
modified_date=datetime.now(timezone.utc),
created_by_user_id="u",
modified_by_user_id="u",
)

ref = ChannelReference(channel_reference="$1", calculated_channel=cc)
assert ref.calculated_channel == "v-abc"

def test_channel_reference_rejects_calculated_channel_without_version_id(self):
"""A CalculatedChannel missing version_id is unusable as a reference."""
from sift_client.sift_types.calculated_channel import CalculatedChannel

cc = CalculatedChannel(
proto=MagicMock(),
id_="cc-id",
name="parent",
description="",
expression="$1",
channel_references=[],
is_archived=False,
units=None,
asset_ids=["asset-1"],
tag_ids=None,
all_assets=False,
organization_id=None,
client_key=None,
archived_date=None,
version_id=None,
version=None,
change_message=None,
user_notes=None,
created_date=datetime.now(timezone.utc),
modified_date=datetime.now(timezone.utc),
created_by_user_id="u",
modified_by_user_id="u",
)

with pytest.raises(ValueError, match="no version_id"):
ChannelReference(channel_reference="$1", calculated_channel=cc)

def test_channel_reference_from_proto_reads_version_id_oneof(self):
"""_from_proto picks calculated_channel when the proto oneof selects it."""
from sift.calculated_channels.v2.calculated_channels_pb2 import (
CalculatedChannelAbstractChannelReference,
)

proto = CalculatedChannelAbstractChannelReference(
channel_reference="$1", calculated_channel_version_id="v-abc"
)
ref = ChannelReference._from_proto(proto)
assert ref.calculated_channel == "v-abc"
assert ref.channel_identifier is None

def test_data_method_with_minimal_params(self, mock_channel, mock_client):
"""Test that data() method works with minimal parameters."""
mock_data = {"test_channel": MagicMock()}
Expand Down
73 changes: 73 additions & 0 deletions python/lib/sift_client/_tests/sift_types/test_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,76 @@ def test_unarchive_calls_client_and_updates_self(self, mock_rule, mock_client):
mock_update.assert_called_once_with(unarchived_rule)
# Verify it returns self
assert result is mock_rule


class TestRuleChannelReferenceSerialization:
"""Nested CC references flow through the low-level wrapper and back."""

def test_helper_routes_version_id_into_proto(self):
from sift_client._internal.low_level_wrappers.rules import (
_channel_reference_to_proto,
)

proto = _channel_reference_to_proto(
ChannelReference(channel_reference="$1", calculated_channel="v-abc")
)
assert proto.HasField("calculated_channel_version_id")
assert proto.calculated_channel_version_id == "v-abc"
assert proto.name == ""

def test_helper_routes_identifier_into_name(self):
from sift_client._internal.low_level_wrappers.rules import (
_channel_reference_to_proto,
)

proto = _channel_reference_to_proto(
ChannelReference(channel_reference="$1", channel_identifier="my-channel")
)
assert not proto.HasField("calculated_channel_version_id")
assert proto.name == "my-channel"

def test_from_proto_reads_version_id_when_present(self):
from google.protobuf.timestamp_pb2 import Timestamp
from sift.rules.v1.rules_pb2 import (
CalculatedChannelConfig,
RuleCondition,
RuleConditionExpression,
)
from sift.rules.v1.rules_pb2 import (
ChannelReference as ChannelReferenceProto,
)
from sift.rules.v1.rules_pb2 import (
Rule as RuleProto,
)
from sift.rules.v1.rules_pb2 import (
RuleAction as RuleActionProto,
)

ts = Timestamp()
ts.GetCurrentTime()
proto = RuleProto(
rule_id="r1",
name="r",
description="",
created_date=ts,
modified_date=ts,
conditions=[
RuleCondition(
expression=RuleConditionExpression(
calculated_channel=CalculatedChannelConfig(
expression="$1 > 0",
channel_references={
"$1": ChannelReferenceProto(calculated_channel_version_id="v-xyz"),
},
)
),
actions=[RuleActionProto(created_date=ts, modified_date=ts)],
)
],
)
rule = Rule._from_proto(proto)
assert rule.channel_references is not None
assert len(rule.channel_references) == 1
ref = rule.channel_references[0]
assert ref.calculated_channel == "v-xyz"
assert ref.channel_identifier is None
Loading
Loading