3535 UpdateRuleRequest ,
3636)
3737from sift .rules .v1 .rules_pb2_grpc import RuleServiceStub
38+ from sift .tags .v2 .tags_pb2 import Tag , TagType
39+ from sift .tags .v2 .tags_pb2_grpc import TagServiceStub
3840from sift .users .v2 .users_pb2_grpc import UserServiceStub
3941
4042from sift_py ._internal .cel import cel_in
5557 RuleActionKind ,
5658 RuleConfig ,
5759)
60+ from sift_py .tag ._internal .shared import list_tags_impl
5861from sift_py .yaml .rule import load_rule_modules
5962
6063
@@ -72,13 +75,15 @@ class RuleService:
7275 _channel_service_stub : ChannelServiceStub
7376 _rule_service_stub : RuleServiceStub
7477 _user_service_stub : UserServiceStub
78+ _tag_service_stub : TagServiceStub
7579 _enable_caching : bool
7680
7781 def __init__ (self , channel : SiftChannel , enable_caching = False ):
7882 self ._asset_service_stub = AssetServiceStub (channel )
7983 self ._channel_service_stub = ChannelServiceStub (channel )
8084 self ._rule_service_stub = RuleServiceStub (channel )
8185 self ._user_service_stub = UserServiceStub (channel )
86+ self ._tag_service_stub = TagServiceStub (channel )
8287 self ._enable_caching = enable_caching
8388
8489 def load_rules_from_yaml (
@@ -244,6 +249,7 @@ def _parse_rules_from_yaml(
244249 channel_references = rule_channel_references ,
245250 contextual_channels = contextual_channels ,
246251 asset_names = rule_yaml .get ("asset_names" , []),
252+ tag_names = rule_yaml .get ("tag_names" , []),
247253 sub_expressions = subexpr ,
248254 is_external = rule_yaml .get ("is_external" , False ),
249255 is_live = rule_yaml .get ("is_live" , False ),
@@ -402,8 +408,20 @@ def _update_req_from_rule_config(
402408 "See `sift_py.rule.config.RuleAction` for available actions."
403409 )
404410
405- # TODO: once we have TagService_ListTags we can do asset-agnostic rules via tags
406411 assets = self ._get_assets (names = config .asset_names ) if config .asset_names else None
412+ asset_tags = (
413+ self ._get_tags (names = config .tag_names , tag_type = TagType .TAG_TYPE_ASSET )
414+ if config .tag_names
415+ else None
416+ )
417+ annotation_tags = (
418+ self ._get_tags (
419+ names = [tag for tag in config .action .tags ],
420+ tag_type = TagType .TAG_TYPE_ANNOTATION ,
421+ )
422+ if config .action .tags
423+ else None
424+ )
407425
408426 actions = []
409427 if config .action .kind () == RuleActionKind .NOTIFICATION :
@@ -412,6 +430,10 @@ def _update_req_from_rule_config(
412430 "Please contact the Sift team for assistance."
413431 )
414432 elif config .action .kind () == RuleActionKind .ANNOTATION :
433+ annotation_tag_ids = (
434+ [tag .tag_id for tag in annotation_tags ] if annotation_tags else None
435+ )
436+
415437 if isinstance (config .action , RuleActionCreateDataReviewAnnotation ):
416438 assignee = config .action .assignee
417439 user_id = None
@@ -431,7 +453,7 @@ def _update_req_from_rule_config(
431453 annotation = AnnotationActionConfiguration (
432454 assigned_to_user_id = user_id ,
433455 annotation_type = AnnotationType .ANNOTATION_TYPE_DATA_REVIEW ,
434- # tag_ids=config.action.tags, # TODO: Requires TagService
456+ tag_ids = annotation_tag_ids ,
435457 )
436458 ),
437459 )
@@ -442,7 +464,7 @@ def _update_req_from_rule_config(
442464 configuration = RuleActionConfiguration (
443465 annotation = AnnotationActionConfiguration (
444466 annotation_type = AnnotationType .ANNOTATION_TYPE_PHASE ,
445- # tag_ids=config.action.tags, # TODO: Requires TagService
467+ tag_ids = annotation_tag_ids ,
446468 )
447469 ),
448470 )
@@ -523,6 +545,7 @@ def _update_req_from_rule_config(
523545 ],
524546 asset_configuration = RuleAssetConfiguration (
525547 asset_ids = [asset .asset_id for asset in assets ] if assets else None ,
548+ tag_ids = [tag .tag_id for tag in asset_tags ] if asset_tags else None ,
526549 ),
527550 contextual_channels = ContextualChannels (channels = contextual_channel_names ),
528551 is_external = config .is_external ,
@@ -574,6 +597,12 @@ def get_rule(self, rule: str) -> Optional[RuleConfig]:
574597 )
575598 asset_names = [asset .name for asset in assets ]
576599
600+ asset_tags = self ._get_tags (
601+ ids = [tag_id for tag_id in rule_pb .asset_configuration .tag_ids ],
602+ tag_type = TagType .TAG_TYPE_ASSET ,
603+ )
604+ asset_tag_names = [tag .name for tag in asset_tags ]
605+
577606 contextual_channels = []
578607 for channel_ref in rule_pb .contextual_channels .channels :
579608 contextual_channels .append (channel_ref .name )
@@ -585,6 +614,7 @@ def get_rule(self, rule: str) -> Optional[RuleConfig]:
585614 channel_references = channel_references , # type: ignore
586615 contextual_channels = contextual_channels ,
587616 asset_names = asset_names ,
617+ tag_names = asset_tag_names ,
588618 action = action ,
589619 expression = expression ,
590620 )
@@ -616,6 +646,17 @@ def _get_assets(self, names: List[str] = [], ids: List[str] = []) -> List[Asset]
616646 else :
617647 return list_assets_impl (self ._asset_service_stub , names , ids )
618648
649+ def _get_tags (
650+ self ,
651+ names : List [str ] = [],
652+ ids : List [str ] = [],
653+ tag_type : TagType .ValueType = TagType .TAG_TYPE_UNSPECIFIED ,
654+ ) -> List [Tag ]:
655+ if self ._enable_caching :
656+ return self ._get_tags_cached (tuple (sorted (names )), tuple (sorted (ids )), tag_type )
657+ else :
658+ return list_tags_impl (self ._tag_service_stub , names , ids , tag_type )
659+
619660 def _get_channels (self , filter : str ) -> List [ChannelPb ]:
620661 if self ._enable_caching :
621662 return self ._get_channels_cached (filter )
@@ -632,6 +673,15 @@ def _get_active_users(self, filter: str) -> List[UserPb]:
632673 def _get_assets_cached (self , names : Tuple [str ], ids : Tuple [str ]) -> List [Asset ]:
633674 return list_assets_impl (self ._asset_service_stub , names , ids )
634675
676+ @cache
677+ def _get_tags_cached (
678+ self ,
679+ names : Tuple [str ],
680+ ids : Tuple [str ],
681+ tag_type : TagType .ValueType = TagType .TAG_TYPE_UNSPECIFIED ,
682+ ) -> List [Tag ]:
683+ return list_tags_impl (self ._tag_service_stub , names , ids , tag_type )
684+
635685 @cache
636686 def _get_channels_cached (self , filter : str ) -> List [ChannelPb ]:
637687 return get_channels (channel_service = self ._channel_service_stub , filter = filter )
0 commit comments