Skip to content

Commit fae4b30

Browse files
authored
Add tags support to rules (#438)
1 parent d54c304 commit fae4b30

4 files changed

Lines changed: 123 additions & 4 deletions

File tree

python/lib/sift_py/rule/_service_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def test_rule_service_load_rules_from_yaml(rule_service):
9393
"assignee": "assignee@abc.com",
9494
"type": "review",
9595
"asset_names": ["asset"],
96+
"tag_names": ["tag1"],
9697
}
9798
with mock.patch.object(RuleService, "create_or_update_rule"):
9899
with mock.patch(
@@ -115,6 +116,7 @@ def test_rule_service_load_rules_from_yaml(rule_service):
115116
assert rule_config.expression == rule_yaml["expression"]
116117
assert rule_config.action.assignee == rule_yaml["assignee"]
117118
assert rule_config.asset_names == rule_yaml["asset_names"]
119+
assert rule_config.tag_names == rule_yaml["tag_names"]
118120
assert isinstance(rule_config.action, RuleActionCreateDataReviewAnnotation)
119121

120122

python/lib/sift_py/rule/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class RuleConfig(AsJson):
2424
- `channel_references`: Reference to channel. If an expression is "$1 < 10", then "$1" is the reference and thus should the key in the dict.
2525
- `rule_client_key`: User defined unique string that uniquely identifies this rule.
2626
- `asset_names`: A list of asset names that this rule should be applied to. ONLY VALID if defining rules outside of a telemetry config.
27-
- `tag_names`: A list of asset names that this rule should be applied to. ONLY VALID if defining rules outside of a telemetry config.
27+
- `tag_names`: A list of asset tags that this rule should be applied to. ONLY VALID if defining rules outside of a telemetry config.
2828
- `contextual_channels`: A list of channel names that provide context but aren't directly used in the expression.
2929
- `is_external`: If this is an external rule.
3030
- `is_live`: If set to True then this rule will be evaluated on live data, otherwise live rule evaluation will be disabled.
@@ -38,6 +38,7 @@ class RuleConfig(AsJson):
3838
channel_references: List[ExpressionChannelReference]
3939
rule_client_key: Optional[str]
4040
asset_names: List[str]
41+
tag_names: List[str]
4142
contextual_channels: List[str]
4243
is_external: bool
4344
is_live: bool
@@ -65,6 +66,7 @@ def __init__(
6566

6667
self.name = name
6768
self.asset_names = asset_names or []
69+
self.tag_names = tag_names or []
6870
self.action = action
6971
self.rule_client_key = rule_client_key
7072
self.description = description
@@ -133,6 +135,8 @@ def interpolate_sub_expressions(
133135

134136

135137
class RuleAction(ABC):
138+
tags: Optional[List[str]]
139+
136140
@abstractmethod
137141
def kind(self) -> RuleActionKind:
138142
pass

python/lib/sift_py/rule/service.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
UpdateRuleRequest,
3636
)
3737
from sift.rules.v1.rules_pb2_grpc import RuleServiceStub
38+
from sift.tags.v2.tags_pb2 import Tag, TagType
39+
from sift.tags.v2.tags_pb2_grpc import TagServiceStub
3840
from sift.users.v2.users_pb2_grpc import UserServiceStub
3941

4042
from sift_py._internal.cel import cel_in
@@ -55,6 +57,7 @@
5557
RuleActionKind,
5658
RuleConfig,
5759
)
60+
from sift_py.tag._internal.shared import list_tags_impl
5861
from sift_py.yaml.rule import load_rule_modules
5962

6063

@@ -72,13 +75,15 @@ class RuleService:
7275
_channel_service_stub: ChannelServiceStub
7376
_rule_service_stub: RuleServiceStub
7477
_user_service_stub: UserServiceStub
78+
_tag_service_stub: TagServiceStub
7579
_enable_caching: bool
7680

7781
def __init__(self, channel: SiftChannel, enable_caching=False):
7882
self._asset_service_stub = AssetServiceStub(channel)
7983
self._channel_service_stub = ChannelServiceStub(channel)
8084
self._rule_service_stub = RuleServiceStub(channel)
8185
self._user_service_stub = UserServiceStub(channel)
86+
self._tag_service_stub = TagServiceStub(channel)
8287
self._enable_caching = enable_caching
8388

8489
def load_rules_from_yaml(
@@ -244,6 +249,7 @@ def _parse_rules_from_yaml(
244249
channel_references=rule_channel_references,
245250
contextual_channels=contextual_channels,
246251
asset_names=rule_yaml.get("asset_names", []),
252+
tag_names=rule_yaml.get("tag_names", []),
247253
sub_expressions=subexpr,
248254
is_external=rule_yaml.get("is_external", False),
249255
is_live=rule_yaml.get("is_live", False),
@@ -402,8 +408,20 @@ def _update_req_from_rule_config(
402408
"See `sift_py.rule.config.RuleAction` for available actions."
403409
)
404410

405-
# TODO: once we have TagService_ListTags we can do asset-agnostic rules via tags
406411
assets = self._get_assets(names=config.asset_names) if config.asset_names else None
412+
asset_tags = (
413+
self._get_tags(names=config.tag_names, tag_type=TagType.TAG_TYPE_ASSET)
414+
if config.tag_names
415+
else None
416+
)
417+
annotation_tags = (
418+
self._get_tags(
419+
names=[tag for tag in config.action.tags],
420+
tag_type=TagType.TAG_TYPE_ANNOTATION,
421+
)
422+
if config.action.tags
423+
else None
424+
)
407425

408426
actions = []
409427
if config.action.kind() == RuleActionKind.NOTIFICATION:
@@ -412,6 +430,10 @@ def _update_req_from_rule_config(
412430
"Please contact the Sift team for assistance."
413431
)
414432
elif config.action.kind() == RuleActionKind.ANNOTATION:
433+
annotation_tag_ids = (
434+
[tag.tag_id for tag in annotation_tags] if annotation_tags else None
435+
)
436+
415437
if isinstance(config.action, RuleActionCreateDataReviewAnnotation):
416438
assignee = config.action.assignee
417439
user_id = None
@@ -431,7 +453,7 @@ def _update_req_from_rule_config(
431453
annotation=AnnotationActionConfiguration(
432454
assigned_to_user_id=user_id,
433455
annotation_type=AnnotationType.ANNOTATION_TYPE_DATA_REVIEW,
434-
# tag_ids=config.action.tags, # TODO: Requires TagService
456+
tag_ids=annotation_tag_ids,
435457
)
436458
),
437459
)
@@ -442,7 +464,7 @@ def _update_req_from_rule_config(
442464
configuration=RuleActionConfiguration(
443465
annotation=AnnotationActionConfiguration(
444466
annotation_type=AnnotationType.ANNOTATION_TYPE_PHASE,
445-
# tag_ids=config.action.tags, # TODO: Requires TagService
467+
tag_ids=annotation_tag_ids,
446468
)
447469
),
448470
)
@@ -523,6 +545,7 @@ def _update_req_from_rule_config(
523545
],
524546
asset_configuration=RuleAssetConfiguration(
525547
asset_ids=[asset.asset_id for asset in assets] if assets else None,
548+
tag_ids=[tag.tag_id for tag in asset_tags] if asset_tags else None,
526549
),
527550
contextual_channels=ContextualChannels(channels=contextual_channel_names),
528551
is_external=config.is_external,
@@ -574,6 +597,12 @@ def get_rule(self, rule: str) -> Optional[RuleConfig]:
574597
)
575598
asset_names = [asset.name for asset in assets]
576599

600+
asset_tags = self._get_tags(
601+
ids=[tag_id for tag_id in rule_pb.asset_configuration.tag_ids],
602+
tag_type=TagType.TAG_TYPE_ASSET,
603+
)
604+
asset_tag_names = [tag.name for tag in asset_tags]
605+
577606
contextual_channels = []
578607
for channel_ref in rule_pb.contextual_channels.channels:
579608
contextual_channels.append(channel_ref.name)
@@ -585,6 +614,7 @@ def get_rule(self, rule: str) -> Optional[RuleConfig]:
585614
channel_references=channel_references, # type: ignore
586615
contextual_channels=contextual_channels,
587616
asset_names=asset_names,
617+
tag_names=asset_tag_names,
588618
action=action,
589619
expression=expression,
590620
)
@@ -616,6 +646,17 @@ def _get_assets(self, names: List[str] = [], ids: List[str] = []) -> List[Asset]
616646
else:
617647
return list_assets_impl(self._asset_service_stub, names, ids)
618648

649+
def _get_tags(
650+
self,
651+
names: List[str] = [],
652+
ids: List[str] = [],
653+
tag_type: TagType.ValueType = TagType.TAG_TYPE_UNSPECIFIED,
654+
) -> List[Tag]:
655+
if self._enable_caching:
656+
return self._get_tags_cached(tuple(sorted(names)), tuple(sorted(ids)), tag_type)
657+
else:
658+
return list_tags_impl(self._tag_service_stub, names, ids, tag_type)
659+
619660
def _get_channels(self, filter: str) -> List[ChannelPb]:
620661
if self._enable_caching:
621662
return self._get_channels_cached(filter)
@@ -632,6 +673,15 @@ def _get_active_users(self, filter: str) -> List[UserPb]:
632673
def _get_assets_cached(self, names: Tuple[str], ids: Tuple[str]) -> List[Asset]:
633674
return list_assets_impl(self._asset_service_stub, names, ids)
634675

676+
@cache
677+
def _get_tags_cached(
678+
self,
679+
names: Tuple[str],
680+
ids: Tuple[str],
681+
tag_type: TagType.ValueType = TagType.TAG_TYPE_UNSPECIFIED,
682+
) -> List[Tag]:
683+
return list_tags_impl(self._tag_service_stub, names, ids, tag_type)
684+
635685
@cache
636686
def _get_channels_cached(self, filter: str) -> List[ChannelPb]:
637687
return get_channels(channel_service=self._channel_service_stub, filter=filter)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from typing import List, Optional, Tuple, Union, cast
2+
3+
from sift.tags.v2.tags_pb2 import ListTagsRequest, ListTagsResponse, Tag, TagType
4+
from sift.tags.v2.tags_pb2_grpc import TagServiceStub
5+
from sift_py._internal.cel import cel_in
6+
7+
8+
def list_tags_impl(
9+
tag_service_stub: TagServiceStub,
10+
names: Optional[Union[Tuple[str], List[str]]] = None,
11+
ids: Optional[Union[Tuple[str], List[str]]] = None,
12+
tag_type: TagType.ValueType = TagType.TAG_TYPE_UNSPECIFIED,
13+
) -> List[Tag]:
14+
"""
15+
Lists tags in an organization.
16+
17+
Args:
18+
tag_service_stub: The tag service stub to use.
19+
names: Optional collection of names to filter by.
20+
ids: Optional collection of IDs to filter by.
21+
tag_type: Optional tag type to filter by.
22+
23+
Returns:
24+
A list of tags matching the criteria.
25+
"""
26+
27+
def get_tags_with_filter(
28+
tag_service_stub: TagServiceStub,
29+
cel_filter: str,
30+
tag_type: TagType.ValueType,
31+
) -> List[Tag]:
32+
tags: List[Tag] = []
33+
next_page_token = ""
34+
while True:
35+
req = ListTagsRequest(
36+
filter=cel_filter,
37+
page_size=1_000,
38+
page_token=next_page_token,
39+
tag_type=tag_type,
40+
)
41+
res = cast(ListTagsResponse, tag_service_stub.ListTags(req))
42+
tags.extend(res.tags)
43+
44+
if not res.next_page_token:
45+
break
46+
next_page_token = res.next_page_token
47+
48+
return tags
49+
50+
if names is None:
51+
names = []
52+
if ids is None:
53+
ids = []
54+
55+
results: List[Tag] = []
56+
if names:
57+
names_cel = cel_in("name", names)
58+
results.extend(get_tags_with_filter(tag_service_stub, names_cel, tag_type))
59+
if ids:
60+
ids_cel = cel_in("tag_id", ids)
61+
results.extend(get_tags_with_filter(tag_service_stub, ids_cel, tag_type))
62+
63+
return results

0 commit comments

Comments
 (0)