diff --git a/python/lib/sift_client/_internal/low_level_wrappers/exports.py b/python/lib/sift_client/_internal/low_level_wrappers/exports.py index 63aa200cd..fd438fac0 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/exports.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/exports.py @@ -26,10 +26,25 @@ if TYPE_CHECKING: from datetime import datetime + from sift_client.sift_types.channel import ChannelReference from sift_client.sift_types.export import ExportOutputFormat from sift_client.transport.grpc_transport import GrpcClient +def _abstract_ref_to_proto(ref: ChannelReference) -> CalculatedChannelAbstractChannelReference: + # After ChannelReference validation, calculated_channel is always a version_id string. + if ref.calculated_channel: + assert isinstance(ref.calculated_channel, str) + return CalculatedChannelAbstractChannelReference( + channel_reference=ref.channel_reference, + calculated_channel_version_id=ref.calculated_channel, + ) + return CalculatedChannelAbstractChannelReference( + channel_reference=ref.channel_reference, + channel_identifier=ref.channel_identifier or "", + ) + + def _build_calc_channel_configs( calculated_channels: list[CalculatedChannel | CalculatedChannelCreate] | None, ) -> list[CalculatedChannelConfig]: @@ -46,13 +61,7 @@ def _build_calc_channel_configs( CalculatedChannelConfig( name=cc.name, expression=cc.expression or "", - channel_references=[ - CalculatedChannelAbstractChannelReference( - channel_reference=ref.channel_reference, - channel_identifier=ref.channel_identifier, - ) - for ref in refs - ], + channel_references=[_abstract_ref_to_proto(ref) for ref in refs], units=cc.units, ) ) diff --git a/python/lib/sift_client/_internal/low_level_wrappers/rules.py b/python/lib/sift_client/_internal/low_level_wrappers/rules.py index 3658a0698..364ee37e1 100644 --- a/python/lib/sift_client/_internal/low_level_wrappers/rules.py +++ b/python/lib/sift_client/_internal/low_level_wrappers/rules.py @@ -75,6 +75,16 @@ logger = logging.getLogger(__name__) +def _channel_reference_to_proto(ref: ChannelReference) -> ChannelReferenceProto: + # ChannelReference's validator normalizes calculated_channel to a version_id str + # and guarantees exactly one of calculated_channel / channel_identifier is set. + if ref.calculated_channel: + return ChannelReferenceProto( + calculated_channel_version_id=cast("str", ref.calculated_channel) + ) + return ChannelReferenceProto(name=cast("str", ref.channel_identifier)) + + class RulesLowLevelClient(LowLevelClientBase, WithGrpcClient): """Low-level client for the RulesAPI. @@ -126,7 +136,7 @@ def _update_rule_request_from_create(self, create: RuleCreate) -> UpdateRuleRequ calculated_channel=CalculatedChannelConfig( expression=create.expression, channel_references={ - c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) + c.channel_reference: _channel_reference_to_proto(c) for c in create.channel_references }, ) @@ -259,7 +269,7 @@ def _update_rule_request_from_update( CalculatedChannelConfig( expression=expression, channel_references={ - c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) + c.channel_reference: _channel_reference_to_proto(c) for c in channel_references }, ) diff --git a/python/lib/sift_client/_internal/util/channels.py b/python/lib/sift_client/_internal/util/channels.py index 8c3d39d82..e7c37396b 100644 --- a/python/lib/sift_client/_internal/util/channels.py +++ b/python/lib/sift_client/_internal/util/channels.py @@ -32,6 +32,9 @@ async def resolve_calculated_channels( resolved_refs: list[ChannelReference] = [] for ref in refs: + if ref.calculated_channel: + resolved_refs.append(ref) + continue channel = await channels_api.find( name=ref.channel_identifier, assets=cc.asset_ids, diff --git a/python/lib/sift_client/_tests/_internal/test_channels.py b/python/lib/sift_client/_tests/_internal/test_channels.py index 3b6be637d..792f82ec5 100644 --- a/python/lib/sift_client/_tests/_internal/test_channels.py +++ b/python/lib/sift_client/_tests/_internal/test_channels.py @@ -41,6 +41,25 @@ async def test_resolves_name_to_uuid(self): assert refs is not None assert refs[0].channel_identifier == "resolved-uuid" + @pytest.mark.asyncio + async def test_skips_lookup_for_calculated_channel_version_id(self): + api = MagicMock() + api.find = AsyncMock() + cc = CalculatedChannelCreate( + name="nested", + expression="$1 + 1", + expression_channel_references=[ + ChannelReference(channel_reference="$1", calculated_channel="v-nested") + ], + ) + result = await resolve_calculated_channels([cc], channels_api=api) + api.find.assert_not_awaited() + assert result is not None + refs = result[0].expression_channel_references + assert refs is not None + assert refs[0].calculated_channel == "v-nested" + assert refs[0].channel_identifier is None + @pytest.mark.asyncio async def test_keeps_identifier_when_not_found(self): api = MagicMock() diff --git a/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py index e5b1059c5..d0ac06c93 100644 --- a/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py +++ b/python/lib/sift_client/_tests/sift_types/test_calculated_channel.py @@ -7,6 +7,7 @@ from sift_client.sift_types import CalculatedChannel from sift_client.sift_types.calculated_channel import ( + CalculatedChannelCreate, CalculatedChannelUpdate, ) from sift_client.sift_types.channel import ChannelReference @@ -151,6 +152,68 @@ def test_expression_validator_rejects_references_without_expression(self): ], ) + def test_nested_calculated_channel_reference_serialized_to_version_id_oneof(self): + """A ChannelReference with calculated_channel set routes to the proto oneof.""" + update = CalculatedChannelUpdate( + expression="$1 + $2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", channel_identifier="channel1"), + ChannelReference(channel_reference="$2", calculated_channel="v-nested"), + ], + ) + update.resource_id = "test_calc_channel_id" + + proto, _ = update.to_proto_with_mask() + + refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references + assert len(refs) == 2 + assert refs[0].channel_identifier == "channel1" + assert refs[0].WhichOneof("calculated_channel_reference") is None + assert refs[1].WhichOneof("calculated_channel_reference") == "calculated_channel_version_id" + assert refs[1].calculated_channel_version_id == "v-nested" + assert refs[1].channel_identifier == "" + + def test_create_serializes_nested_calculated_channel_reference(self): + """CalculatedChannelCreate.to_proto routes the version_id into the proto oneof.""" + create = CalculatedChannelCreate( + name="nested-cc", + expression="$1 * 2", + expression_channel_references=[ + ChannelReference(channel_reference="$1", calculated_channel="v-nested"), + ], + all_assets=True, + ) + + proto = create.to_proto() + + refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references + assert len(refs) == 1 + assert refs[0].WhichOneof("calculated_channel_reference") == "calculated_channel_version_id" + assert refs[0].calculated_channel_version_id == "v-nested" + + def test_create_serializes_nested_reference_from_calculated_channel_object( + self, mock_calculated_channel + ): + """Passing a CalculatedChannel object to ChannelReference also serializes correctly.""" + create = CalculatedChannelCreate( + name="nested-cc", + expression="$1 * 2", + expression_channel_references=[ + ChannelReference( + channel_reference="$1", calculated_channel=mock_calculated_channel + ), + ], + all_assets=True, + ) + + proto = create.to_proto() + + refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references + assert len(refs) == 1 + assert refs[0].WhichOneof("calculated_channel_reference") == "calculated_channel_version_id" + # mock_calculated_channel fixture has version_id="v1" + assert refs[0].calculated_channel_version_id == "v1" + def test_expression_validator_accepts_both_set(self): """Test validator accepts expression and channel references together.""" # Should not raise diff --git a/python/lib/sift_client/_tests/sift_types/test_channel.py b/python/lib/sift_client/_tests/sift_types/test_channel.py index ab8fa01c2..76bbafc78 100644 --- a/python/lib/sift_client/_tests/sift_types/test_channel.py +++ b/python/lib/sift_client/_tests/sift_types/test_channel.py @@ -6,7 +6,7 @@ import pytest from sift_client.sift_types import Channel -from sift_client.sift_types.channel import ChannelDataType +from sift_client.sift_types.channel import ChannelDataType, ChannelReference @pytest.fixture @@ -104,6 +104,100 @@ def test_data_method_as_arrow(self, mock_channel, mock_client): mock_client.channels.get_data.assert_not_called() assert result == mock_data + def test_channel_reference_requires_one_target(self): + """ChannelReference must specify exactly one of identifier or calculated_channel.""" + with pytest.raises(ValueError, match="exactly one"): + ChannelReference(channel_reference="$1") + with pytest.raises(ValueError, match="exactly one"): + ChannelReference( + channel_reference="$1", + channel_identifier="ch", + calculated_channel="v-id", + ) + + def test_channel_reference_accepts_version_id_string(self): + """A plain version_id string is stored as-is.""" + ref = ChannelReference(channel_reference="$1", calculated_channel="v-abc") + assert ref.calculated_channel == "v-abc" + assert ref.channel_identifier is None + + def test_channel_reference_accepts_calculated_channel_object(self): + """Passing a CalculatedChannel normalizes to its version_id string.""" + from sift_client.sift_types.calculated_channel import CalculatedChannel + + cc = CalculatedChannel( + proto=MagicMock(), + id_="cc-id", + name="parent", + description="", + expression="$1", + channel_references=[], + is_archived=False, + units=None, + asset_ids=["asset-1"], + tag_ids=None, + all_assets=False, + organization_id=None, + client_key=None, + archived_date=None, + version_id="v-abc", + version=1, + change_message=None, + user_notes=None, + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="u", + modified_by_user_id="u", + ) + + ref = ChannelReference(channel_reference="$1", calculated_channel=cc) + assert ref.calculated_channel == "v-abc" + + def test_channel_reference_rejects_calculated_channel_without_version_id(self): + """A CalculatedChannel missing version_id is unusable as a reference.""" + from sift_client.sift_types.calculated_channel import CalculatedChannel + + cc = CalculatedChannel( + proto=MagicMock(), + id_="cc-id", + name="parent", + description="", + expression="$1", + channel_references=[], + is_archived=False, + units=None, + asset_ids=["asset-1"], + tag_ids=None, + all_assets=False, + organization_id=None, + client_key=None, + archived_date=None, + version_id=None, + version=None, + change_message=None, + user_notes=None, + created_date=datetime.now(timezone.utc), + modified_date=datetime.now(timezone.utc), + created_by_user_id="u", + modified_by_user_id="u", + ) + + with pytest.raises(ValueError, match="no version_id"): + ChannelReference(channel_reference="$1", calculated_channel=cc) + + def test_channel_reference_from_proto_reads_version_id_oneof(self): + """_from_proto picks calculated_channel when the proto oneof selects it.""" + from sift.calculated_channels.v2.calculated_channels_pb2 import ( + CalculatedChannelAbstractChannelReference, + ) + + proto = CalculatedChannelAbstractChannelReference( + channel_reference="$1", calculated_channel_version_id="v-abc" + ) + ref = ChannelReference._from_proto(proto) + assert ref.calculated_channel == "v-abc" + assert ref.channel_identifier is None + def test_data_method_with_minimal_params(self, mock_channel, mock_client): """Test that data() method works with minimal parameters.""" mock_data = {"test_channel": MagicMock()} diff --git a/python/lib/sift_client/_tests/sift_types/test_rule.py b/python/lib/sift_client/_tests/sift_types/test_rule.py index 05a15381f..f91443c6a 100644 --- a/python/lib/sift_client/_tests/sift_types/test_rule.py +++ b/python/lib/sift_client/_tests/sift_types/test_rule.py @@ -147,3 +147,76 @@ def test_unarchive_calls_client_and_updates_self(self, mock_rule, mock_client): mock_update.assert_called_once_with(unarchived_rule) # Verify it returns self assert result is mock_rule + + +class TestRuleChannelReferenceSerialization: + """Nested CC references flow through the low-level wrapper and back.""" + + def test_helper_routes_version_id_into_proto(self): + from sift_client._internal.low_level_wrappers.rules import ( + _channel_reference_to_proto, + ) + + proto = _channel_reference_to_proto( + ChannelReference(channel_reference="$1", calculated_channel="v-abc") + ) + assert proto.HasField("calculated_channel_version_id") + assert proto.calculated_channel_version_id == "v-abc" + assert proto.name == "" + + def test_helper_routes_identifier_into_name(self): + from sift_client._internal.low_level_wrappers.rules import ( + _channel_reference_to_proto, + ) + + proto = _channel_reference_to_proto( + ChannelReference(channel_reference="$1", channel_identifier="my-channel") + ) + assert not proto.HasField("calculated_channel_version_id") + assert proto.name == "my-channel" + + def test_from_proto_reads_version_id_when_present(self): + from google.protobuf.timestamp_pb2 import Timestamp + from sift.rules.v1.rules_pb2 import ( + CalculatedChannelConfig, + RuleCondition, + RuleConditionExpression, + ) + from sift.rules.v1.rules_pb2 import ( + ChannelReference as ChannelReferenceProto, + ) + from sift.rules.v1.rules_pb2 import ( + Rule as RuleProto, + ) + from sift.rules.v1.rules_pb2 import ( + RuleAction as RuleActionProto, + ) + + ts = Timestamp() + ts.GetCurrentTime() + proto = RuleProto( + rule_id="r1", + name="r", + description="", + created_date=ts, + modified_date=ts, + conditions=[ + RuleCondition( + expression=RuleConditionExpression( + calculated_channel=CalculatedChannelConfig( + expression="$1 > 0", + channel_references={ + "$1": ChannelReferenceProto(calculated_channel_version_id="v-xyz"), + }, + ) + ), + actions=[RuleActionProto(created_date=ts, modified_date=ts)], + ) + ], + ) + rule = Rule._from_proto(proto) + assert rule.channel_references is not None + assert len(rule.channel_references) == 1 + ref = rule.channel_references[0] + assert ref.calculated_channel == "v-xyz" + assert ref.channel_identifier is None diff --git a/python/lib/sift_client/sift_types/calculated_channel.py b/python/lib/sift_client/sift_types/calculated_channel.py index d1fca2523..f150431b4 100644 --- a/python/lib/sift_client/sift_types/calculated_channel.py +++ b/python/lib/sift_client/sift_types/calculated_channel.py @@ -26,6 +26,28 @@ from sift_client.client import SiftClient +def _channel_reference_to_proto( + channel_reference: str, + channel_identifier: str | None = None, + calculated_channel: str | None = None, +) -> CalculatedChannelAbstractChannelReference: + """Convert a ChannelReference dict (from model_dump) into its proto form. + + Maps the ``calculated_channel`` field onto the proto's ``calculated_channel_version_id`` + oneof. After ChannelReference validation, ``calculated_channel`` is always a + version_id string (or None). + """ + if calculated_channel: + return CalculatedChannelAbstractChannelReference( + channel_reference=channel_reference, + calculated_channel_version_id=calculated_channel, + ) + return CalculatedChannelAbstractChannelReference( + channel_reference=channel_reference, + channel_identifier=channel_identifier or "", + ) + + class CalculatedChannel(BaseType[CalculatedChannelProto, "CalculatedChannel"]): """Model of the Sift Calculated Channel.""" @@ -108,10 +130,7 @@ def _from_proto( description=proto.description, expression=proto.calculated_channel_configuration.query_configuration.sel.expression, channel_references=[ - ChannelReference( - channel_reference=ref_proto.channel_reference, - channel_identifier=ref_proto.channel_identifier, - ) + ChannelReference._from_proto(ref_proto) for ref_proto in proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references ], organization_id=proto.organization_id, @@ -163,7 +182,7 @@ class CalculatedChannelBase(ModelCreateUpdateBase): "expression_channel_references": MappingHelper( proto_attr_path="calculated_channel_configuration.query_configuration.sel.expression_channel_references", update_field="query_configuration", - converter=CalculatedChannelAbstractChannelReference, + converter=_channel_reference_to_proto, ), "tag_ids": MappingHelper( proto_attr_path="calculated_channel_configuration.asset_configuration.selection.tag_ids", @@ -225,3 +244,8 @@ def _add_resource_id_to_proto(self, proto_msg: CalculatedChannelProto): if self._resource_id is None: raise ValueError("Resource ID must be set before adding to proto") proto_msg.calculated_channel_id = self._resource_id + + +# Resolve the forward reference to CalculatedChannel in ChannelReference now that +# CalculatedChannel is defined in this module's namespace. +ChannelReference.model_rebuild(_types_namespace={"CalculatedChannel": CalculatedChannel}) # type: ignore[arg-type] diff --git a/python/lib/sift_client/sift_types/channel.py b/python/lib/sift_client/sift_types/channel.py index 69ba4b8ed..402f5d25a 100644 --- a/python/lib/sift_client/sift_types/channel.py +++ b/python/lib/sift_client/sift_types/channel.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING import sift.common.type.v1.channel_data_type_pb2 as channel_pb -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from sift.channels.v3.channels_pb2 import Channel as ChannelProto from sift.common.type.v1.channel_bit_field_element_pb2 import ( ChannelBitFieldElement as ChannelBitFieldElementPb, @@ -32,6 +32,7 @@ from sift_client.client import SiftClient from sift_client.sift_types.asset import Asset + from sift_client.sift_types.calculated_channel import CalculatedChannel from sift_client.sift_types.run import Run @@ -344,13 +345,53 @@ def runs(self) -> list[Run]: class ChannelReference(BaseModel): - """Channel reference for calculated channel or rule.""" - - channel_reference: str # The key of the channel in the expression i.e. $1, $2, etc. - channel_identifier: str # The name of the channel + """Channel reference for a calculated channel or rule expression. + + Exactly one of `channel_identifier` or `calculated_channel` must be set. + To reference another calculated channel (nested CC), pass either a fetched + `CalculatedChannel` (its `version_id` is used) or a `version_id` string + directly. Both normalize to the `version_id` string after validation. + + Attributes: + channel_reference: The key of the channel in the expression (e.g. ``$1``, ``$2``). + channel_identifier: The name (or ID) of an existing channel. + calculated_channel: A ``CalculatedChannel`` or its ``version_id``. Normalized + to the ``version_id`` string after validation. + """ + + channel_reference: str + channel_identifier: str | None = None + calculated_channel: CalculatedChannel | str | None = None + + @model_validator(mode="after") + def _normalize_and_validate(self) -> ChannelReference: + # Lazy import avoids a circular dependency at module load time. + from sift_client.sift_types.calculated_channel import CalculatedChannel + + if isinstance(self.calculated_channel, CalculatedChannel): + if not self.calculated_channel.version_id: + raise ValueError( + "ChannelReference: provided CalculatedChannel has no version_id. " + "Fetch it via client.calculated_channels.get(...) first." + ) + self.calculated_channel = self.calculated_channel.version_id + + has_identifier = bool(self.channel_identifier) + has_calc_channel = bool(self.calculated_channel) + if has_identifier == has_calc_channel: + raise ValueError( + "ChannelReference requires exactly one of channel_identifier or " + "calculated_channel to be set" + ) + return self @classmethod def _from_proto(cls, proto) -> ChannelReference: + if proto.WhichOneof("calculated_channel_reference") == "calculated_channel_version_id": + return cls( + channel_reference=proto.channel_reference, + calculated_channel=proto.calculated_channel_version_id, + ) return cls( channel_reference=proto.channel_reference, channel_identifier=proto.channel_identifier, diff --git a/python/lib/sift_client/sift_types/rule.py b/python/lib/sift_client/sift_types/rule.py index 0b241625a..5112b964c 100644 --- a/python/lib/sift_client/sift_types/rule.py +++ b/python/lib/sift_client/sift_types/rule.py @@ -130,7 +130,12 @@ def _from_proto(cls, proto: RuleProto, sift_client: SiftClient | None = None) -> description=proto.description, expression=expression, channel_references=[ - ChannelReference(channel_reference=ref, channel_identifier=c.name) + ChannelReference( + channel_reference=ref, + calculated_channel=c.calculated_channel_version_id, + ) + if c.HasField("calculated_channel_version_id") + else ChannelReference(channel_reference=ref, channel_identifier=c.name) for ref, c in proto.conditions[ 0 ].expression.calculated_channel.channel_references.items()