Skip to content

Commit 95e25bf

Browse files
committed
Use rule_evaluation protos for rule evaluation.
1 parent 402331a commit 95e25bf

12 files changed

Lines changed: 238 additions & 56 deletions

File tree

python/lib/sift_client/_internal/low_level_wrappers/rules.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
import logging
44
from typing import TYPE_CHECKING, Any, cast
55

6+
from sift.common.type.v1.resource_identifier_pb2 import ResourceIdentifier
7+
from sift.rule_evaluation.v1.rule_evaluation_pb2 import (
8+
EvaluateRulesRequest,
9+
EvaluateRulesResponse,
10+
RunTimeRange,
11+
)
12+
from sift.rule_evaluation.v1.rule_evaluation_pb2_grpc import RuleEvaluationServiceStub
613
from sift.rules.v1.rules_pb2 import (
714
BatchDeleteRulesRequest,
815
BatchGetRulesRequest,
@@ -31,15 +38,20 @@
3138
from sift.rules.v1.rules_pb2_grpc import RuleServiceStub
3239

3340
from sift_client._internal.low_level_wrappers.base import LowLevelClientBase
41+
from sift_client._internal.low_level_wrappers.reports import ReportsLowLevelClient
3442
from sift_client.sift_types.rule import (
3543
Rule,
3644
RuleAction,
3745
RuleUpdate,
3846
)
3947
from sift_client.transport import GrpcClient, WithGrpcClient
48+
from sift_client.util.util import count_non_none
4049

4150
if TYPE_CHECKING:
51+
from datetime import datetime
52+
4253
from sift_client.sift_types.channel import ChannelReference
54+
from sift_client.sift_types.report import Report
4355

4456
# Configure logging
4557
logger = logging.getLogger(__name__)
@@ -445,3 +457,76 @@ async def list_all_rules(
445457
order_by=order_by,
446458
max_results=max_results,
447459
)
460+
461+
async def evaluate_rules(
462+
self,
463+
*,
464+
run_id: str | None = None,
465+
assets: list[str] | None = None,
466+
all_applicable_rules: bool | None = None,
467+
run_start_time: datetime | None = None,
468+
run_end_time: datetime | None = None,
469+
rule_ids: list[str] | None = None,
470+
rule_version_ids: list[str] | None = None,
471+
report_template_id: str | None = None,
472+
tags: list[str] | None = None,
473+
) -> Report | None:
474+
"""Evaluate a rule.
475+
476+
Args:
477+
run_id: The run ID to evaluate.
478+
assets: The assets to evaluate.
479+
run_start_time: The start time of the run.
480+
run_end_time: The end time of the run.
481+
all_applicable_rules: Whether to evaluate all rules applicable to the selected run, assets, or time range.
482+
rule_ids: The rule IDs to evaluate.
483+
rule_version_ids: The rule version IDs to evaluate.
484+
report_template_id: The report template ID to evaluate.
485+
tags: Optional tags to add to generated annotations.
486+
487+
Returns:
488+
The result of the rule execution.
489+
"""
490+
if count_non_none(run_id, assets, run_start_time, run_end_time) > 1:
491+
raise ValueError(
492+
"Pick only one run_id, assets, or (run_start_time and run_end_time) to select what to evaluate against."
493+
)
494+
495+
all_applicable_rules = (
496+
None if not all_applicable_rules else True
497+
) # Cast to None if False so we don't count it against other filters if they aren't opting in.
498+
if count_non_none(rule_ids, rule_version_ids, report_template_id, all_applicable_rules) > 1:
499+
raise ValueError(
500+
"Pick only one rule_ids, rule_version_ids, report_template_id, or all_applicable_rules to further filter which rules to evaluate."
501+
)
502+
503+
kwargs: dict[str, Any] = {}
504+
if run_start_time and run_end_time:
505+
kwargs["run_time_range"] = RunTimeRange(
506+
run=run_id, start_time=run_start_time, end_time=run_end_time
507+
)
508+
if run_id:
509+
kwargs["run"] = ResourceIdentifier(id=run_id)
510+
if assets:
511+
kwargs["assets"] = assets
512+
if all_applicable_rules:
513+
kwargs["all_applicable_rules"] = all_applicable_rules
514+
if rule_ids:
515+
kwargs["rules"] = rule_ids
516+
if rule_version_ids:
517+
kwargs["rule_versions"] = rule_version_ids
518+
if report_template_id:
519+
kwargs["report_template"] = report_template_id
520+
if tags:
521+
kwargs["tags"] = tags
522+
523+
request = EvaluateRulesRequest(**kwargs)
524+
response = await self._grpc_client.get_stub(RuleEvaluationServiceStub).EvaluateRules(
525+
request
526+
)
527+
response = cast("EvaluateRulesResponse", response)
528+
report_id = response.report_id
529+
if report_id:
530+
report = await ReportsLowLevelClient(self._grpc_client).get_report(report_id=report_id)
531+
return report
532+
return None

python/lib/sift_client/_internal/low_level_wrappers/tags.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sift.tags.v2.tags_pb2_grpc import TagServiceStub
1313

1414
from sift_client._internal.low_level_wrappers.base import LowLevelClientBase
15-
from sift_client.sift_types.tag import Tag, TagUpdate
15+
from sift_client.sift_types.tag import Tag
1616
from sift_client.transport import WithGrpcClient
1717

1818
if TYPE_CHECKING:
@@ -114,4 +114,4 @@ async def list_all_tags(
114114
kwargs={"query_filter": query_filter},
115115
order_by=order_by,
116116
max_results=max_results,
117-
)
117+
)

python/lib/sift_client/_tests/integrated/reports.py

Lines changed: 15 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
import asyncio
1111
import os
12-
from datetime import datetime, timedelta, timezone
12+
from datetime import datetime
13+
1314
from zoneinfo import ZoneInfo
1415

1516
from sift_client import SiftClient
16-
from sift_client.sift_types import report
1717

1818

1919
async def main():
@@ -23,6 +23,7 @@ async def main():
2323
grpc_url = os.getenv("SIFT_GRPC_URI", "localhost:50051")
2424
rest_url = os.getenv("SIFT_REST_URI", "localhost:8080")
2525
api_key = os.getenv("SIFT_API_KEY", "")
26+
2627
client = SiftClient(
2728
api_key=api_key,
2829
grpc_url=grpc_url,
@@ -35,49 +36,24 @@ async def main():
3536
limit=100,
3637
)
3738

38-
asset_ids = []
39-
asset_tags_names = []
4039
rules = []
41-
reports = []
40+
failed_runs = []
4241
for run in runs:
4342
print("run.name: ", run.name)
4443
print(" client_key: ", run.client_key)
45-
if run.client_key:
46-
# rules = client.rules.list_(
47-
# client_key=run.client_key,
48-
# limit=100,
49-
# )
50-
raise Exception("client_key is not None! Let's add these rules")
51-
run_assets = run.assets
52-
print(" assets: ", [asset.name for asset in run_assets])
53-
asset_ids.extend([asset.id_ for asset in run_assets])
54-
asset_tags_names.extend([tag for asset in run_assets for tag in asset.tags])
55-
per_run_reports = client.reports.list_(
56-
run_id=run.id_,
57-
)
58-
print(" reports: ", [report.name for report in per_run_reports])
59-
reports.extend(per_run_reports)
60-
61-
asset_ids = list(set(asset_ids))
62-
asset_tags_names = list(set(asset_tags_names))
63-
asset_tags = client.tags.list_(
64-
names=asset_tags_names,
65-
)
66-
print(" asset_tags: ", [(tag.name, tag.id_) for tag in asset_tags])
67-
print("Number of runs: ", len(runs))
68-
print("Number of assets: ", len(asset_ids))
44+
try:
45+
report = client.rules.evaluate(
46+
run_id=run.id_,
47+
all_applicable_rules=True,
48+
)
49+
except Exception as e:
50+
failed_runs.append(run.id_)
51+
print(f"Failed to evaluate rules for run {run.id_}: {e}")
52+
53+
print("Number of successful runs: ", len(runs) - len(failed_runs))
54+
print("Number of failed runs: ", len(failed_runs))
6955

70-
rules = client.rules.list_(
71-
asset_ids=asset_ids,
72-
asset_tags_ids=[tag.id_ for tag in asset_tags],
73-
)
74-
print("reports: ", [report.name for report in reports])
75-
if len(rules) < 10:
76-
print("rules: ", [rule.name for rule in rules])
77-
else:
78-
print("number of rules: ", len(rules))
7956

80-
8157

8258
if __name__ == "__main__":
8359
asyncio.run(main())

python/lib/sift_client/resources/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from sift_client.resources.rules import RulesAPIAsync
88
from sift_client.resources.runs import RunsAPIAsync
99
from sift_client.resources.tags import TagsAPIAsync
10+
11+
# ruff: noqa TagsAPIAsync needs to be imported before sync_stubs to avoid circular import
1012
from sift_client.resources.sync_stubs import (
1113
AssetsAPI,
1214
CalculatedChannelsAPI,

python/lib/sift_client/resources/reports.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66
from sift_client._internal.low_level_wrappers.reports import ReportsLowLevelClient
77
from sift_client.resources._base import ResourceBase
88
from sift_client.sift_types.report import Report
9-
from sift_client.util.cel_utils import contains, equals, equals_null, match, not_
9+
from sift_client.util.cel_utils import contains, equals, match
1010

1111
if TYPE_CHECKING:
1212
from sift_client.client import SiftClient
1313

1414

1515
class ReportsAPIAsync(ResourceBase):
16-
"""High-level API for interacting with reports.
17-
"""
16+
"""High-level API for interacting with reports."""
1817

1918
def __init__(self, sift_client: SiftClient):
2019
"""Initialize the ReportsAPI.
@@ -201,7 +200,6 @@ async def create_from_rules(
201200
)
202201
return self._apply_client_to_instance(created_report)
203202

204-
205203
async def rerun(
206204
self,
207205
*,
@@ -233,4 +231,4 @@ async def cancel(
233231
report_id = report.id_ if isinstance(report, Report) else report
234232
if not isinstance(report_id, str):
235233
raise TypeError(f"report_id must be a string not {type(report_id)}")
236-
await self._low_level_client.cancel_report(report_id=report_id)
234+
await self._low_level_client.cancel_report(report_id=report_id)

python/lib/sift_client/resources/rules.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
if TYPE_CHECKING:
1111
import re
12+
from datetime import datetime
1213

1314
from sift_client.client import SiftClient
1415
from sift_client.sift_types.channel import ChannelReference
16+
from sift_client.sift_types.report import Report
1517

1618

1719
class RulesAPIAsync(ResourceBase):
@@ -273,3 +275,56 @@ async def batch_get(
273275
rule_ids=rule_ids, client_keys=client_keys
274276
)
275277
return self._apply_client_to_instances(rules)
278+
279+
async def evaluate(
280+
self,
281+
*,
282+
run_id: str | None = None,
283+
assets: list[str] | None = None,
284+
all_applicable_rules: bool | None = None,
285+
run_start_time: datetime | None = None,
286+
run_end_time: datetime | None = None,
287+
rule_ids: list[str] | None = None,
288+
rule_version_ids: list[str] | None = None,
289+
report_template_id: str | None = None,
290+
tags: list[str] | None = None,
291+
) -> Report | None:
292+
"""Evaluate a rule.
293+
294+
Pick one of the following grouping of rules to evaluate against:
295+
- run_id
296+
- assets
297+
- run_start_time and run_end_time
298+
And one of the following filters to select which rules to evaluate:
299+
- rule_ids
300+
- rule_version_ids
301+
- report_template_id
302+
- all_applicable_rules
303+
304+
Args:
305+
run_id: The run ID to evaluate.
306+
assets: The assets to evaluate.
307+
all_applicable_rules: Whether to evaluate all rules applicable to the selected run, assets, or time range.
308+
run_start_time: The start time of the run.
309+
run_end_time: The end time of the run.
310+
rule_ids: The rule IDs to evaluate.
311+
rule_version_ids: The rule version IDs to evaluate.
312+
report_template_id: The report template ID to evaluate.
313+
tags: Optional tags to add to generated annotations.
314+
315+
Returns:
316+
The result of the rule evaluation.
317+
"""
318+
report = await self._low_level_client.evaluate_rules(
319+
run_id=run_id,
320+
assets=assets,
321+
all_applicable_rules=all_applicable_rules,
322+
run_start_time=run_start_time,
323+
run_end_time=run_end_time,
324+
rule_ids=rule_ids,
325+
rule_version_ids=rule_version_ids,
326+
report_template_id=report_template_id,
327+
tags=tags,
328+
)
329+
if report:
330+
return self._apply_client_to_instance(report)

python/lib/sift_client/resources/runs.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,16 @@
66
from sift_client._internal.low_level_wrappers.runs import RunsLowLevelClient
77
from sift_client.resources._base import ResourceBase
88
from sift_client.sift_types.run import Run, RunUpdate
9-
from sift_client.util.cel_utils import contains, equals, equals_null, greater_than, less_than, match, not_
9+
from sift_client.util.cel_utils import (
10+
contains,
11+
equals,
12+
equals_null,
13+
greater_than,
14+
in_,
15+
less_than,
16+
match,
17+
not_,
18+
)
1019

1120
if TYPE_CHECKING:
1221
from datetime import datetime
@@ -55,6 +64,7 @@ async def list_(
5564
name: str | None = None,
5665
name_contains: str | None = None,
5766
name_regex: str | re.Pattern | None = None,
67+
run_ids: list[str] | None = None,
5868
description: str | None = None,
5969
description_contains: str | None = None,
6070
duration_seconds: int | None = None,
@@ -82,6 +92,7 @@ async def list_(
8292
name: Exact name of the run.
8393
name_contains: Partial name of the run.
8494
name_regex: Regular expression string to filter runs by name.
95+
run_ids: List of run IDs to filter by.
8596
description: Exact description of the run.
8697
description_contains: Partial description of the run.
8798
duration_seconds: Duration of the run in seconds.
@@ -118,6 +129,9 @@ async def list_(
118129
name_regex = name_regex.pattern
119130
filter_parts.append(match("name", name_regex)) # type: ignore
120131

132+
if run_ids:
133+
filter_parts.append(in_("run_id", run_ids))
134+
121135
if description:
122136
filter_parts.append(equals("description", description))
123137
elif description_contains:

0 commit comments

Comments
 (0)