Skip to content

Commit e73894a

Browse files
authored
python(feat): support nested calculated channel references (#580)
1 parent 83f2cc0 commit e73894a

10 files changed

Lines changed: 362 additions & 21 deletions

File tree

python/lib/sift_client/_internal/low_level_wrappers/exports.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,25 @@
2626
if TYPE_CHECKING:
2727
from datetime import datetime
2828

29+
from sift_client.sift_types.channel import ChannelReference
2930
from sift_client.sift_types.export import ExportOutputFormat
3031
from sift_client.transport.grpc_transport import GrpcClient
3132

3233

34+
def _abstract_ref_to_proto(ref: ChannelReference) -> CalculatedChannelAbstractChannelReference:
35+
# After ChannelReference validation, calculated_channel is always a version_id string.
36+
if ref.calculated_channel:
37+
assert isinstance(ref.calculated_channel, str)
38+
return CalculatedChannelAbstractChannelReference(
39+
channel_reference=ref.channel_reference,
40+
calculated_channel_version_id=ref.calculated_channel,
41+
)
42+
return CalculatedChannelAbstractChannelReference(
43+
channel_reference=ref.channel_reference,
44+
channel_identifier=ref.channel_identifier or "",
45+
)
46+
47+
3348
def _build_calc_channel_configs(
3449
calculated_channels: list[CalculatedChannel | CalculatedChannelCreate] | None,
3550
) -> list[CalculatedChannelConfig]:
@@ -46,13 +61,7 @@ def _build_calc_channel_configs(
4661
CalculatedChannelConfig(
4762
name=cc.name,
4863
expression=cc.expression or "",
49-
channel_references=[
50-
CalculatedChannelAbstractChannelReference(
51-
channel_reference=ref.channel_reference,
52-
channel_identifier=ref.channel_identifier,
53-
)
54-
for ref in refs
55-
],
64+
channel_references=[_abstract_ref_to_proto(ref) for ref in refs],
5665
units=cc.units,
5766
)
5867
)

python/lib/sift_client/_internal/low_level_wrappers/rules.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@
7575
logger = logging.getLogger(__name__)
7676

7777

78+
def _channel_reference_to_proto(ref: ChannelReference) -> ChannelReferenceProto:
79+
# ChannelReference's validator normalizes calculated_channel to a version_id str
80+
# and guarantees exactly one of calculated_channel / channel_identifier is set.
81+
if ref.calculated_channel:
82+
return ChannelReferenceProto(
83+
calculated_channel_version_id=cast("str", ref.calculated_channel)
84+
)
85+
return ChannelReferenceProto(name=cast("str", ref.channel_identifier))
86+
87+
7888
class RulesLowLevelClient(LowLevelClientBase, WithGrpcClient):
7989
"""Low-level client for the RulesAPI.
8090
@@ -126,7 +136,7 @@ def _update_rule_request_from_create(self, create: RuleCreate) -> UpdateRuleRequ
126136
calculated_channel=CalculatedChannelConfig(
127137
expression=create.expression,
128138
channel_references={
129-
c.channel_reference: ChannelReferenceProto(name=c.channel_identifier)
139+
c.channel_reference: _channel_reference_to_proto(c)
130140
for c in create.channel_references
131141
},
132142
)
@@ -259,7 +269,7 @@ def _update_rule_request_from_update(
259269
CalculatedChannelConfig(
260270
expression=expression,
261271
channel_references={
262-
c.channel_reference: ChannelReferenceProto(name=c.channel_identifier)
272+
c.channel_reference: _channel_reference_to_proto(c)
263273
for c in channel_references
264274
},
265275
)

python/lib/sift_client/_internal/util/channels.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ async def resolve_calculated_channels(
3232

3333
resolved_refs: list[ChannelReference] = []
3434
for ref in refs:
35+
if ref.calculated_channel:
36+
resolved_refs.append(ref)
37+
continue
3538
channel = await channels_api.find(
3639
name=ref.channel_identifier,
3740
assets=cc.asset_ids,

python/lib/sift_client/_tests/_internal/test_channels.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,25 @@ async def test_resolves_name_to_uuid(self):
4141
assert refs is not None
4242
assert refs[0].channel_identifier == "resolved-uuid"
4343

44+
@pytest.mark.asyncio
45+
async def test_skips_lookup_for_calculated_channel_version_id(self):
46+
api = MagicMock()
47+
api.find = AsyncMock()
48+
cc = CalculatedChannelCreate(
49+
name="nested",
50+
expression="$1 + 1",
51+
expression_channel_references=[
52+
ChannelReference(channel_reference="$1", calculated_channel="v-nested")
53+
],
54+
)
55+
result = await resolve_calculated_channels([cc], channels_api=api)
56+
api.find.assert_not_awaited()
57+
assert result is not None
58+
refs = result[0].expression_channel_references
59+
assert refs is not None
60+
assert refs[0].calculated_channel == "v-nested"
61+
assert refs[0].channel_identifier is None
62+
4463
@pytest.mark.asyncio
4564
async def test_keeps_identifier_when_not_found(self):
4665
api = MagicMock()

python/lib/sift_client/_tests/sift_types/test_calculated_channel.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from sift_client.sift_types import CalculatedChannel
99
from sift_client.sift_types.calculated_channel import (
10+
CalculatedChannelCreate,
1011
CalculatedChannelUpdate,
1112
)
1213
from sift_client.sift_types.channel import ChannelReference
@@ -151,6 +152,68 @@ def test_expression_validator_rejects_references_without_expression(self):
151152
],
152153
)
153154

155+
def test_nested_calculated_channel_reference_serialized_to_version_id_oneof(self):
156+
"""A ChannelReference with calculated_channel set routes to the proto oneof."""
157+
update = CalculatedChannelUpdate(
158+
expression="$1 + $2",
159+
expression_channel_references=[
160+
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
161+
ChannelReference(channel_reference="$2", calculated_channel="v-nested"),
162+
],
163+
)
164+
update.resource_id = "test_calc_channel_id"
165+
166+
proto, _ = update.to_proto_with_mask()
167+
168+
refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references
169+
assert len(refs) == 2
170+
assert refs[0].channel_identifier == "channel1"
171+
assert refs[0].WhichOneof("calculated_channel_reference") is None
172+
assert refs[1].WhichOneof("calculated_channel_reference") == "calculated_channel_version_id"
173+
assert refs[1].calculated_channel_version_id == "v-nested"
174+
assert refs[1].channel_identifier == ""
175+
176+
def test_create_serializes_nested_calculated_channel_reference(self):
177+
"""CalculatedChannelCreate.to_proto routes the version_id into the proto oneof."""
178+
create = CalculatedChannelCreate(
179+
name="nested-cc",
180+
expression="$1 * 2",
181+
expression_channel_references=[
182+
ChannelReference(channel_reference="$1", calculated_channel="v-nested"),
183+
],
184+
all_assets=True,
185+
)
186+
187+
proto = create.to_proto()
188+
189+
refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references
190+
assert len(refs) == 1
191+
assert refs[0].WhichOneof("calculated_channel_reference") == "calculated_channel_version_id"
192+
assert refs[0].calculated_channel_version_id == "v-nested"
193+
194+
def test_create_serializes_nested_reference_from_calculated_channel_object(
195+
self, mock_calculated_channel
196+
):
197+
"""Passing a CalculatedChannel object to ChannelReference also serializes correctly."""
198+
create = CalculatedChannelCreate(
199+
name="nested-cc",
200+
expression="$1 * 2",
201+
expression_channel_references=[
202+
ChannelReference(
203+
channel_reference="$1", calculated_channel=mock_calculated_channel
204+
),
205+
],
206+
all_assets=True,
207+
)
208+
209+
proto = create.to_proto()
210+
211+
refs = proto.calculated_channel_configuration.query_configuration.sel.expression_channel_references
212+
assert len(refs) == 1
213+
assert refs[0].WhichOneof("calculated_channel_reference") == "calculated_channel_version_id"
214+
# mock_calculated_channel fixture has version_id="v1"
215+
assert refs[0].calculated_channel_version_id == "v1"
216+
154217
def test_expression_validator_accepts_both_set(self):
155218
"""Test validator accepts expression and channel references together."""
156219
# Should not raise

python/lib/sift_client/_tests/sift_types/test_channel.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytest
77

88
from sift_client.sift_types import Channel
9-
from sift_client.sift_types.channel import ChannelDataType
9+
from sift_client.sift_types.channel import ChannelDataType, ChannelReference
1010

1111

1212
@pytest.fixture
@@ -104,6 +104,100 @@ def test_data_method_as_arrow(self, mock_channel, mock_client):
104104
mock_client.channels.get_data.assert_not_called()
105105
assert result == mock_data
106106

107+
def test_channel_reference_requires_one_target(self):
108+
"""ChannelReference must specify exactly one of identifier or calculated_channel."""
109+
with pytest.raises(ValueError, match="exactly one"):
110+
ChannelReference(channel_reference="$1")
111+
with pytest.raises(ValueError, match="exactly one"):
112+
ChannelReference(
113+
channel_reference="$1",
114+
channel_identifier="ch",
115+
calculated_channel="v-id",
116+
)
117+
118+
def test_channel_reference_accepts_version_id_string(self):
119+
"""A plain version_id string is stored as-is."""
120+
ref = ChannelReference(channel_reference="$1", calculated_channel="v-abc")
121+
assert ref.calculated_channel == "v-abc"
122+
assert ref.channel_identifier is None
123+
124+
def test_channel_reference_accepts_calculated_channel_object(self):
125+
"""Passing a CalculatedChannel normalizes to its version_id string."""
126+
from sift_client.sift_types.calculated_channel import CalculatedChannel
127+
128+
cc = CalculatedChannel(
129+
proto=MagicMock(),
130+
id_="cc-id",
131+
name="parent",
132+
description="",
133+
expression="$1",
134+
channel_references=[],
135+
is_archived=False,
136+
units=None,
137+
asset_ids=["asset-1"],
138+
tag_ids=None,
139+
all_assets=False,
140+
organization_id=None,
141+
client_key=None,
142+
archived_date=None,
143+
version_id="v-abc",
144+
version=1,
145+
change_message=None,
146+
user_notes=None,
147+
created_date=datetime.now(timezone.utc),
148+
modified_date=datetime.now(timezone.utc),
149+
created_by_user_id="u",
150+
modified_by_user_id="u",
151+
)
152+
153+
ref = ChannelReference(channel_reference="$1", calculated_channel=cc)
154+
assert ref.calculated_channel == "v-abc"
155+
156+
def test_channel_reference_rejects_calculated_channel_without_version_id(self):
157+
"""A CalculatedChannel missing version_id is unusable as a reference."""
158+
from sift_client.sift_types.calculated_channel import CalculatedChannel
159+
160+
cc = CalculatedChannel(
161+
proto=MagicMock(),
162+
id_="cc-id",
163+
name="parent",
164+
description="",
165+
expression="$1",
166+
channel_references=[],
167+
is_archived=False,
168+
units=None,
169+
asset_ids=["asset-1"],
170+
tag_ids=None,
171+
all_assets=False,
172+
organization_id=None,
173+
client_key=None,
174+
archived_date=None,
175+
version_id=None,
176+
version=None,
177+
change_message=None,
178+
user_notes=None,
179+
created_date=datetime.now(timezone.utc),
180+
modified_date=datetime.now(timezone.utc),
181+
created_by_user_id="u",
182+
modified_by_user_id="u",
183+
)
184+
185+
with pytest.raises(ValueError, match="no version_id"):
186+
ChannelReference(channel_reference="$1", calculated_channel=cc)
187+
188+
def test_channel_reference_from_proto_reads_version_id_oneof(self):
189+
"""_from_proto picks calculated_channel when the proto oneof selects it."""
190+
from sift.calculated_channels.v2.calculated_channels_pb2 import (
191+
CalculatedChannelAbstractChannelReference,
192+
)
193+
194+
proto = CalculatedChannelAbstractChannelReference(
195+
channel_reference="$1", calculated_channel_version_id="v-abc"
196+
)
197+
ref = ChannelReference._from_proto(proto)
198+
assert ref.calculated_channel == "v-abc"
199+
assert ref.channel_identifier is None
200+
107201
def test_data_method_with_minimal_params(self, mock_channel, mock_client):
108202
"""Test that data() method works with minimal parameters."""
109203
mock_data = {"test_channel": MagicMock()}

python/lib/sift_client/_tests/sift_types/test_rule.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,76 @@ def test_unarchive_calls_client_and_updates_self(self, mock_rule, mock_client):
147147
mock_update.assert_called_once_with(unarchived_rule)
148148
# Verify it returns self
149149
assert result is mock_rule
150+
151+
152+
class TestRuleChannelReferenceSerialization:
153+
"""Nested CC references flow through the low-level wrapper and back."""
154+
155+
def test_helper_routes_version_id_into_proto(self):
156+
from sift_client._internal.low_level_wrappers.rules import (
157+
_channel_reference_to_proto,
158+
)
159+
160+
proto = _channel_reference_to_proto(
161+
ChannelReference(channel_reference="$1", calculated_channel="v-abc")
162+
)
163+
assert proto.HasField("calculated_channel_version_id")
164+
assert proto.calculated_channel_version_id == "v-abc"
165+
assert proto.name == ""
166+
167+
def test_helper_routes_identifier_into_name(self):
168+
from sift_client._internal.low_level_wrappers.rules import (
169+
_channel_reference_to_proto,
170+
)
171+
172+
proto = _channel_reference_to_proto(
173+
ChannelReference(channel_reference="$1", channel_identifier="my-channel")
174+
)
175+
assert not proto.HasField("calculated_channel_version_id")
176+
assert proto.name == "my-channel"
177+
178+
def test_from_proto_reads_version_id_when_present(self):
179+
from google.protobuf.timestamp_pb2 import Timestamp
180+
from sift.rules.v1.rules_pb2 import (
181+
CalculatedChannelConfig,
182+
RuleCondition,
183+
RuleConditionExpression,
184+
)
185+
from sift.rules.v1.rules_pb2 import (
186+
ChannelReference as ChannelReferenceProto,
187+
)
188+
from sift.rules.v1.rules_pb2 import (
189+
Rule as RuleProto,
190+
)
191+
from sift.rules.v1.rules_pb2 import (
192+
RuleAction as RuleActionProto,
193+
)
194+
195+
ts = Timestamp()
196+
ts.GetCurrentTime()
197+
proto = RuleProto(
198+
rule_id="r1",
199+
name="r",
200+
description="",
201+
created_date=ts,
202+
modified_date=ts,
203+
conditions=[
204+
RuleCondition(
205+
expression=RuleConditionExpression(
206+
calculated_channel=CalculatedChannelConfig(
207+
expression="$1 > 0",
208+
channel_references={
209+
"$1": ChannelReferenceProto(calculated_channel_version_id="v-xyz"),
210+
},
211+
)
212+
),
213+
actions=[RuleActionProto(created_date=ts, modified_date=ts)],
214+
)
215+
],
216+
)
217+
rule = Rule._from_proto(proto)
218+
assert rule.channel_references is not None
219+
assert len(rule.channel_references) == 1
220+
ref = rule.channel_references[0]
221+
assert ref.calculated_channel == "v-xyz"
222+
assert ref.channel_identifier is None

0 commit comments

Comments
 (0)