Skip to content

Commit be9be60

Browse files
authored
python(feat): Adds batch rule update/create support to sift_client (#456)
1 parent b146ea4 commit be9be60

4 files changed

Lines changed: 359 additions & 57 deletions

File tree

python/lib/sift_client/_internal/low_level_wrappers/rules.py

Lines changed: 91 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4-
from typing import TYPE_CHECKING, Any, cast
4+
from typing import TYPE_CHECKING, Any, Sequence, cast
55

66
from sift.common.type.v1.resource_identifier_pb2 import ResourceIdentifier, ResourceIdentifiers
77
from sift.rule_evaluation.v1.rule_evaluation_pb2 import (
@@ -97,6 +97,48 @@ async def get_rule(self, rule_id: str | None = None, client_key: str | None = No
9797
grpc_rule = cast("GetRuleResponse", response).rule
9898
return Rule._from_proto(grpc_rule)
9999

100+
def _update_rule_request_from_create(self, create: RuleCreate) -> UpdateRuleRequest:
101+
"""Create an UpdateRuleRequest from a RuleCreate object.
102+
103+
Args:
104+
create: The RuleCreate model with the rule configuration.
105+
106+
Returns:
107+
The UpdateRuleRequest proto message.
108+
"""
109+
expression_proto = RuleConditionExpression(
110+
calculated_channel=CalculatedChannelConfig(
111+
expression=create.expression,
112+
channel_references={
113+
c.channel_reference: ChannelReferenceProto(name=c.channel_identifier)
114+
for c in create.channel_references
115+
},
116+
)
117+
)
118+
conditions_request = [
119+
UpdateConditionRequest(
120+
expression=expression_proto,
121+
actions=[create.action._to_update_request()],
122+
)
123+
]
124+
update_request = UpdateRuleRequest(
125+
name=create.name,
126+
description=create.description,
127+
is_enabled=True,
128+
organization_id=create.organization_id or "",
129+
client_key=create.client_key,
130+
is_external=create.is_external,
131+
conditions=conditions_request,
132+
asset_configuration=RuleAssetConfiguration(
133+
asset_ids=create.asset_ids or [],
134+
tag_ids=create.asset_tag_ids or [],
135+
),
136+
contextual_channels=ContextualChannels(
137+
channels=[ChannelReferenceProto(name=c) for c in create.contextual_channels or []]
138+
), # type: ignore
139+
)
140+
return update_request
141+
100142
async def batch_get_rules(
101143
self, rule_ids: list[str] | None = None, client_keys: list[str] | None = None
102144
) -> list[Rule]:
@@ -138,39 +180,7 @@ async def create_rule(
138180
Returns:
139181
The created Rule.
140182
"""
141-
# Convert rule to UpdateRuleRequest
142-
expression_proto = RuleConditionExpression(
143-
calculated_channel=CalculatedChannelConfig(
144-
expression=create.expression,
145-
channel_references={
146-
c.channel_reference: ChannelReferenceProto(name=c.channel_identifier)
147-
for c in create.channel_references
148-
},
149-
)
150-
)
151-
conditions_request = [
152-
UpdateConditionRequest(
153-
expression=expression_proto,
154-
actions=[create.action._to_update_request()],
155-
)
156-
]
157-
update_request = UpdateRuleRequest(
158-
name=create.name,
159-
description=create.description,
160-
is_enabled=True,
161-
organization_id=create.organization_id or "",
162-
client_key=create.client_key,
163-
is_external=create.is_external,
164-
conditions=conditions_request,
165-
asset_configuration=RuleAssetConfiguration(
166-
asset_ids=create.asset_ids or [],
167-
tag_ids=create.asset_tag_ids or [],
168-
),
169-
contextual_channels=ContextualChannels(
170-
channels=[ChannelReferenceProto(name=c) for c in create.contextual_channels or []]
171-
), # type: ignore
172-
)
173-
183+
update_request = self._update_rule_request_from_create(create)
174184
request = CreateRuleRequest(update=update_request)
175185
created_rule = cast(
176186
"CreateRuleResponse",
@@ -301,22 +311,59 @@ async def update_rule(
301311
# Get the updated rule
302312
return await self.get_rule(rule_id=rule.id_)
303313

304-
async def batch_update_rules(self, rules: list[RuleUpdate]) -> BatchUpdateRulesResponse:
305-
"""Batch update rules.
314+
async def batch_update_rules(
315+
self,
316+
rules: Sequence[RuleCreate | RuleUpdate],
317+
validate_only: bool = False,
318+
override_expression_validation: bool = False,
319+
) -> BatchUpdateRulesResponse:
320+
"""Batch update or create rules.
306321
307322
Args:
308-
rules: List of rule updates to apply.
323+
rules: List of rule creates or updates to apply. RuleUpdate objects must have
324+
resource_id set.
309325
310326
Returns:
311327
The batch update response.
312-
"""
313-
update_requests = []
314-
for rule_update in rules:
315-
rule = await self.get_rule(rule_id=rule_update.resource_id)
316-
request = self._update_rule_request_from_update(rule, rule_update)
317-
update_requests.append(request)
318328
319-
request = BatchUpdateRulesRequest(rules=update_requests) # type: ignore
329+
Raises:
330+
ValueError: If any RuleUpdate objects are missing resource_id or the rule is not found for updating.
331+
"""
332+
# Collect resource_ids from only RuleUpdate objects
333+
rule_ids: list[str] = []
334+
for rule in rules:
335+
if isinstance(rule, RuleUpdate):
336+
if rule.resource_id is None:
337+
raise ValueError("RuleUpdate objects must have resource_id set")
338+
rule_ids.append(rule.resource_id)
339+
340+
# Fetch existing rules for updates
341+
existing_rules = await self.batch_get_rules(rule_ids=rule_ids) if rule_ids else []
342+
existing_rules_by_id = {rule.id_: rule for rule in existing_rules}
343+
344+
# Build update requests maintaining the input order
345+
update_requests: list[UpdateRuleRequest] = []
346+
for rule in rules:
347+
if isinstance(rule, RuleCreate):
348+
# Convert RuleCreate to UpdateRuleRequest
349+
update_request = self._update_rule_request_from_create(rule)
350+
update_requests.append(update_request)
351+
elif isinstance(rule, RuleUpdate):
352+
# Use existing rule + update to create request
353+
existing_rule = existing_rules_by_id.get(rule.resource_id)
354+
if existing_rule is None:
355+
raise ValueError(
356+
f"Rule with resource_id {rule.resource_id} not found for update"
357+
)
358+
update_request = self._update_rule_request_from_update(existing_rule, rule)
359+
update_requests.append(update_request)
360+
361+
# Call the batch update request
362+
request = BatchUpdateRulesRequest(
363+
rules=update_requests,
364+
validate_only=validate_only,
365+
override_expression_validation=override_expression_validation,
366+
) # type: ignore
320367
response = await self._grpc_client.get_stub(RuleServiceStub).BatchUpdateRules(request)
321368
return cast("BatchUpdateRulesResponse", response)
322369

python/lib/sift_client/_tests/resources/test_rules.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,136 @@ async def test_unarchive_rule(self, rules_api_async, new_rule):
616616
finally:
617617
await rules_api_async.archive(new_rule.id_)
618618

619+
class TestBatchUpdate:
620+
"""Tests for the async batch_update_rules method."""
621+
622+
@pytest.mark.asyncio
623+
async def test_batch_update_or_create_rules(self, rules_api_async, nostromo_asset):
624+
"""Test updating multiple rules with different fields."""
625+
from datetime import datetime, timezone
626+
627+
rule1_name = f"test_batch_rule_1_{datetime.now(timezone.utc).isoformat()}"
628+
rule2_name = f"test_batch_rule_2_{datetime.now(timezone.utc).isoformat()}"
629+
630+
rule1 = await rules_api_async.create(
631+
RuleCreate(
632+
name=rule1_name,
633+
client_key=f"test_batch_1_{str(uuid.uuid4())[-8:]}",
634+
description="Test rule 1 for batch update",
635+
expression="$1 > $2",
636+
channel_references=[
637+
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
638+
ChannelReference(channel_reference="$2", channel_identifier="channel2"),
639+
],
640+
action=RuleAction.annotation(
641+
annotation_type=RuleAnnotationType.DATA_REVIEW,
642+
tags=[],
643+
),
644+
asset_ids=[nostromo_asset.id_],
645+
)
646+
)
647+
648+
rule2 = await rules_api_async.create(
649+
RuleCreate(
650+
name=rule2_name,
651+
client_key=f"test_batch_2_{str(uuid.uuid4())[-8:]}",
652+
description="Test rule 2 for batch update",
653+
expression="$1 > 0.5",
654+
channel_references=[
655+
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
656+
],
657+
action=RuleAction.annotation(
658+
annotation_type=RuleAnnotationType.DATA_REVIEW,
659+
tags=[],
660+
),
661+
asset_ids=[nostromo_asset.id_],
662+
)
663+
)
664+
665+
try:
666+
# Batch update both rules
667+
rule1_update = RuleUpdate(description="Updated description 1")
668+
rule1_update.resource_id = rule1.id_
669+
670+
rule2_update = RuleUpdate(description="Updated description 2")
671+
rule2_update.resource_id = rule2.id_
672+
673+
updates = [rule1_update, rule2_update]
674+
675+
updated_rules = await rules_api_async.batch_update_or_create_rules(updates)
676+
677+
assert isinstance(updated_rules, list)
678+
assert len(updated_rules) == 2
679+
680+
# Verify updates were applied
681+
assert updated_rules[0].description == "Updated description 1"
682+
assert updated_rules[1].description == "Updated description 2"
683+
finally:
684+
await rules_api_async.archive(rule1.id_)
685+
await rules_api_async.archive(rule2.id_)
686+
687+
@pytest.mark.asyncio
688+
async def test_batch_update_rules_creates_rules(self, rules_api_async, nostromo_asset):
689+
"""Test batch updating rules that don't already exist."""
690+
from datetime import datetime, timezone
691+
692+
rule1_name = f"test_batch_rule_1_{datetime.now(timezone.utc).isoformat()}"
693+
rule2_name = f"test_batch_rule_2_{datetime.now(timezone.utc).isoformat()}"
694+
695+
rule1 = RuleCreate(
696+
name=rule1_name,
697+
client_key=f"test_batch_1_{str(uuid.uuid4())[-8:]}",
698+
description="Test rule 1 for batch update",
699+
expression="$1 > $2",
700+
channel_references=[
701+
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
702+
ChannelReference(channel_reference="$2", channel_identifier="channel2"),
703+
],
704+
action=RuleAction.annotation(
705+
annotation_type=RuleAnnotationType.DATA_REVIEW,
706+
tags=[],
707+
),
708+
asset_ids=[nostromo_asset.id_],
709+
)
710+
711+
rule2 = RuleCreate(
712+
name=rule2_name,
713+
client_key=f"test_batch_2_{str(uuid.uuid4())[-8:]}",
714+
description="Test rule 2 for batch update",
715+
expression="$1 > 0.5",
716+
channel_references=[
717+
ChannelReference(channel_reference="$1", channel_identifier="channel1"),
718+
],
719+
action=RuleAction.annotation(
720+
annotation_type=RuleAnnotationType.DATA_REVIEW,
721+
tags=[],
722+
),
723+
asset_ids=[nostromo_asset.id_],
724+
)
725+
726+
updated_rules: list[Rule] = []
727+
try:
728+
# Batch update (actually create) both rules
729+
updates = [rule1, rule2]
730+
updated_rules = await rules_api_async.batch_update_or_create_rules(updates)
731+
732+
assert isinstance(updated_rules, list)
733+
assert len(updated_rules) == 2
734+
735+
assert updated_rules[0].client_key == rule1.client_key
736+
assert updated_rules[1].client_key == rule2.client_key
737+
finally:
738+
for rule in updated_rules:
739+
await rules_api_async.archive(rule.id_)
740+
741+
@pytest.mark.asyncio
742+
async def test_batch_update_rules_empty_list(self, rules_api_async):
743+
"""Test handling empty list."""
744+
updated_rules = await rules_api_async.batch_update_or_create_rules([])
745+
746+
assert isinstance(updated_rules, list)
747+
assert len(updated_rules) == 0
748+
619749

620750
class TestRulesAPISync:
621751
"""Test suite for the synchronous Rules API functionality."""

0 commit comments

Comments
 (0)