Skip to content

Commit dd80a51

Browse files
committed
Update unti tests + formatting
1 parent c3a3457 commit dd80a51

4 files changed

Lines changed: 43 additions & 22 deletions

File tree

python/lib/sift_py/rule/_service_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def test_rule_service_load_rules_from_yaml(rule_service):
9393
"assignee": "assignee@abc.com",
9494
"type": "review",
9595
"asset_names": ["asset"],
96+
"tag_names": ["tag1"],
9697
}
9798
with mock.patch.object(RuleService, "create_or_update_rule"):
9899
with mock.patch(
@@ -115,6 +116,7 @@ def test_rule_service_load_rules_from_yaml(rule_service):
115116
assert rule_config.expression == rule_yaml["expression"]
116117
assert rule_config.action.assignee == rule_yaml["assignee"]
117118
assert rule_config.asset_names == rule_yaml["asset_names"]
119+
assert rule_config.tag_names == rule_yaml["tag1"]
118120
assert isinstance(rule_config.action, RuleActionCreateDataReviewAnnotation)
119121

120122

python/lib/sift_py/rule/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def interpolate_sub_expressions(
135135

136136

137137
class RuleAction(ABC):
138+
tags: Optional[List[str]]
139+
138140
@abstractmethod
139141
def kind(self) -> RuleActionKind:
140142
pass

python/lib/sift_py/rule/service.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -414,10 +414,14 @@ def _update_req_from_rule_config(
414414
if config.tag_names
415415
else None
416416
)
417-
annotation_tags = list_tags_impl(
418-
self._tag_service_stub,
419-
names=[tag for tag in config.action.tags],
420-
tag_type=TagType.TAG_TYPE_ANNOTATION,
417+
annotation_tags = (
418+
list_tags_impl(
419+
self._tag_service_stub,
420+
names=[tag for tag in config.action.tags],
421+
tag_type=TagType.TAG_TYPE_ANNOTATION,
422+
)
423+
if config.action.tags
424+
else None
421425
)
422426

423427
actions = []
@@ -427,6 +431,10 @@ def _update_req_from_rule_config(
427431
"Please contact the Sift team for assistance."
428432
)
429433
elif config.action.kind() == RuleActionKind.ANNOTATION:
434+
annotation_tag_ids = (
435+
[tag.tag_id for tag in annotation_tags] if annotation_tags else None
436+
)
437+
430438
if isinstance(config.action, RuleActionCreateDataReviewAnnotation):
431439
assignee = config.action.assignee
432440
user_id = None
@@ -446,7 +454,7 @@ def _update_req_from_rule_config(
446454
annotation=AnnotationActionConfiguration(
447455
assigned_to_user_id=user_id,
448456
annotation_type=AnnotationType.ANNOTATION_TYPE_DATA_REVIEW,
449-
tag_ids=annotation_tags,
457+
tag_ids=annotation_tag_ids,
450458
)
451459
),
452460
)
@@ -457,7 +465,7 @@ def _update_req_from_rule_config(
457465
configuration=RuleActionConfiguration(
458466
annotation=AnnotationActionConfiguration(
459467
annotation_type=AnnotationType.ANNOTATION_TYPE_PHASE,
460-
tag_ids=annotation_tags,
468+
tag_ids=annotation_tag_ids,
461469
)
462470
),
463471
)
@@ -640,7 +648,10 @@ def _get_assets(self, names: List[str] = [], ids: List[str] = []) -> List[Asset]
640648
return list_assets_impl(self._asset_service_stub, names, ids)
641649

642650
def _get_tags(
643-
self, names: List[str] = [], ids: List[str] = [], tag_type: Optional[TagType] = None
651+
self,
652+
names: List[str] = [],
653+
ids: List[str] = [],
654+
tag_type: TagType.ValueType = TagType.TAG_TYPE_UNSPECIFIED,
644655
) -> List[Tag]:
645656
if self._enable_caching:
646657
return self._get_tags_cached(tuple(sorted(names)), tuple(sorted(ids)), tag_type)
@@ -665,8 +676,11 @@ def _get_assets_cached(self, names: Tuple[str], ids: Tuple[str]) -> List[Asset]:
665676

666677
@cache
667678
def _get_tags_cached(
668-
self, names: Tuple[str], ids: Tuple[str], tag_type: Optional[TagType] = None
669-
) -> List[Asset]:
679+
self,
680+
names: Tuple[str],
681+
ids: Tuple[str],
682+
tag_type: TagType.ValueType = TagType.TAG_TYPE_UNSPECIFIED,
683+
) -> List[Tag]:
670684
return list_tags_impl(self._tag_service_stub, names, ids, tag_type)
671685

672686
@cache

python/lib/sift_py/tag/_internal/shared.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def list_tags_impl(
99
tag_service_stub: TagServiceStub,
1010
names: Optional[Union[Tuple[str], List[str]]] = None,
1111
ids: Optional[Union[Tuple[str], List[str]]] = None,
12-
tag_type: Optional[TagType] = None,
12+
tag_type: TagType.ValueType = TagType.TAG_TYPE_UNSPECIFIED,
1313
) -> List[Tag]:
1414
"""
1515
Lists tags in an organization.
@@ -25,26 +25,27 @@ def list_tags_impl(
2525
"""
2626

2727
def get_tags_with_filter(
28-
tag_service_stub: TagServiceStub, cel_filter: str, tag_type: Optional[TagType]
28+
tag_service_stub: TagServiceStub,
29+
cel_filter: str,
30+
tag_type: TagType.ValueType,
2931
) -> List[Tag]:
3032
tags: List[Tag] = []
3133
next_page_token = ""
3234
while True:
33-
req_kwargs = {
34-
"filter": cel_filter,
35-
"page_size": 1_000,
36-
"page_token": next_page_token,
37-
}
38-
if tag_type is not None:
39-
req_kwargs["tag_type"] = tag_type
40-
req = ListTagsRequest(**req_kwargs)
35+
req = ListTagsRequest(
36+
filter=cel_filter,
37+
page_size=1_000,
38+
page_token=next_page_token,
39+
tag_type=tag_type,
40+
)
4141
res = cast(ListTagsResponse, tag_service_stub.ListTags(req))
4242
tags.extend(res.tags)
4343

4444
if not res.next_page_token:
4545
break
4646
next_page_token = res.next_page_token
4747

48+
print(tags)
4849
return tags
4950

5051
if names is None:
@@ -55,13 +56,15 @@ def get_tags_with_filter(
5556
results: List[Tag] = []
5657
if names:
5758
names_cel = cel_in("name", names)
58-
results.append(get_tags_with_filter(tag_service_stub, names_cel, tag_type))
59+
results.extend(get_tags_with_filter(tag_service_stub, names_cel, tag_type))
5960
if ids:
6061
ids_cel = cel_in("tag_id", ids)
61-
results.append(get_tags_with_filter(tag_service_stub, ids_cel, tag_type))
62+
results.extend(get_tags_with_filter(tag_service_stub, ids_cel, tag_type))
6263
if not names and not ids:
6364
# If no filter, but tag_type is specified, fetch all tags of that type
6465
if tag_type is not None:
65-
results.append(get_tags_with_filter(tag_service_stub, "", tag_type))
66+
results.extend(get_tags_with_filter(tag_service_stub, "", tag_type))
67+
68+
print(names, ids, results)
6669

6770
return results

0 commit comments

Comments
 (0)