|
3 | 3 | import logging |
4 | 4 | from typing import TYPE_CHECKING, Any, cast |
5 | 5 |
|
| 6 | +from sift.common.type.v1.resource_identifier_pb2 import ResourceIdentifier |
| 7 | +from sift.rule_evaluation.v1.rule_evaluation_pb2 import ( |
| 8 | + EvaluateRulesRequest, |
| 9 | + EvaluateRulesResponse, |
| 10 | + RunTimeRange, |
| 11 | +) |
| 12 | +from sift.rule_evaluation.v1.rule_evaluation_pb2_grpc import RuleEvaluationServiceStub |
6 | 13 | from sift.rules.v1.rules_pb2 import ( |
7 | 14 | BatchDeleteRulesRequest, |
8 | 15 | BatchGetRulesRequest, |
|
31 | 38 | from sift.rules.v1.rules_pb2_grpc import RuleServiceStub |
32 | 39 |
|
33 | 40 | from sift_client._internal.low_level_wrappers.base import LowLevelClientBase |
| 41 | +from sift_client._internal.low_level_wrappers.reports import ReportsLowLevelClient |
34 | 42 | from sift_client.sift_types.rule import ( |
35 | 43 | Rule, |
36 | 44 | RuleAction, |
37 | 45 | RuleUpdate, |
38 | 46 | ) |
39 | 47 | from sift_client.transport import GrpcClient, WithGrpcClient |
| 48 | +from sift_client.util.util import count_non_none |
40 | 49 |
|
41 | 50 | if TYPE_CHECKING: |
| 51 | + from datetime import datetime |
| 52 | + |
42 | 53 | from sift_client.sift_types.channel import ChannelReference |
| 54 | + from sift_client.sift_types.report import Report |
43 | 55 |
|
44 | 56 | # Configure logging |
45 | 57 | logger = logging.getLogger(__name__) |
@@ -445,3 +457,76 @@ async def list_all_rules( |
445 | 457 | order_by=order_by, |
446 | 458 | max_results=max_results, |
447 | 459 | ) |
| 460 | + |
| 461 | + async def evaluate_rules( |
| 462 | + self, |
| 463 | + *, |
| 464 | + run_id: str | None = None, |
| 465 | + assets: list[str] | None = None, |
| 466 | + all_applicable_rules: bool | None = None, |
| 467 | + run_start_time: datetime | None = None, |
| 468 | + run_end_time: datetime | None = None, |
| 469 | + rule_ids: list[str] | None = None, |
| 470 | + rule_version_ids: list[str] | None = None, |
| 471 | + report_template_id: str | None = None, |
| 472 | + tags: list[str] | None = None, |
| 473 | + ) -> Report | None: |
| 474 | + """Evaluate a rule. |
| 475 | +
|
| 476 | + Args: |
| 477 | + run_id: The run ID to evaluate. |
| 478 | + assets: The assets to evaluate. |
| 479 | + run_start_time: The start time of the run. |
| 480 | + run_end_time: The end time of the run. |
| 481 | + all_applicable_rules: Whether to evaluate all rules applicable to the selected run, assets, or time range. |
| 482 | + rule_ids: The rule IDs to evaluate. |
| 483 | + rule_version_ids: The rule version IDs to evaluate. |
| 484 | + report_template_id: The report template ID to evaluate. |
| 485 | + tags: Optional tags to add to generated annotations. |
| 486 | +
|
| 487 | + Returns: |
| 488 | + The result of the rule execution. |
| 489 | + """ |
| 490 | + if count_non_none(run_id, assets, run_start_time, run_end_time) > 1: |
| 491 | + raise ValueError( |
| 492 | + "Pick only one run_id, assets, or (run_start_time and run_end_time) to select what to evaluate against." |
| 493 | + ) |
| 494 | + |
| 495 | + all_applicable_rules = ( |
| 496 | + None if not all_applicable_rules else True |
| 497 | + ) # Cast to None if False so we don't count it against other filters if they aren't opting in. |
| 498 | + if count_non_none(rule_ids, rule_version_ids, report_template_id, all_applicable_rules) > 1: |
| 499 | + raise ValueError( |
| 500 | + "Pick only one rule_ids, rule_version_ids, report_template_id, or all_applicable_rules to further filter which rules to evaluate." |
| 501 | + ) |
| 502 | + |
| 503 | + kwargs: dict[str, Any] = {} |
| 504 | + if run_start_time and run_end_time: |
| 505 | + kwargs["run_time_range"] = RunTimeRange( |
| 506 | + run=run_id, start_time=run_start_time, end_time=run_end_time |
| 507 | + ) |
| 508 | + if run_id: |
| 509 | + kwargs["run"] = ResourceIdentifier(id=run_id) |
| 510 | + if assets: |
| 511 | + kwargs["assets"] = assets |
| 512 | + if all_applicable_rules: |
| 513 | + kwargs["all_applicable_rules"] = all_applicable_rules |
| 514 | + if rule_ids: |
| 515 | + kwargs["rules"] = rule_ids |
| 516 | + if rule_version_ids: |
| 517 | + kwargs["rule_versions"] = rule_version_ids |
| 518 | + if report_template_id: |
| 519 | + kwargs["report_template"] = report_template_id |
| 520 | + if tags: |
| 521 | + kwargs["tags"] = tags |
| 522 | + |
| 523 | + request = EvaluateRulesRequest(**kwargs) |
| 524 | + response = await self._grpc_client.get_stub(RuleEvaluationServiceStub).EvaluateRules( |
| 525 | + request |
| 526 | + ) |
| 527 | + response = cast("EvaluateRulesResponse", response) |
| 528 | + report_id = response.report_id |
| 529 | + if report_id: |
| 530 | + report = await ReportsLowLevelClient(self._grpc_client).get_report(report_id=report_id) |
| 531 | + return report |
| 532 | + return None |
0 commit comments