|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | from dataclasses import dataclass |
| 4 | +from functools import cache |
4 | 5 | from pathlib import Path |
5 | | -from typing import Any, Dict, List, Optional, Union, cast |
| 6 | +from typing import Any, Dict, List, Optional, Tuple, Union, cast |
6 | 7 |
|
7 | 8 | from sift.annotations.v1.annotations_pb2 import AnnotationType |
8 | 9 | from sift.assets.v1.assets_pb2 import Asset |
9 | 10 | from sift.assets.v1.assets_pb2_grpc import AssetServiceStub |
| 11 | +from sift.channels.v3.channels_pb2 import Channel as ChannelPb |
10 | 12 | from sift.channels.v3.channels_pb2_grpc import ChannelServiceStub |
| 13 | +from sift.common.type.v1.user_pb2 import User as UserPb |
11 | 14 | from sift.rules.v1.rules_pb2 import ( |
12 | 15 | ANNOTATION, |
13 | 16 | AnnotationActionConfiguration, |
|
53 | 56 | class RuleService: |
54 | 57 | """ |
55 | 58 | A service for managing rules. Allows for loading rules from YAML and creating or updating them in the Sift API. |
| 59 | +
|
| 60 | + Args: |
| 61 | + channel: The configured Sift channel. |
| 62 | + enable_caching: Enable caching on various API calls to speed up rule creation. Use this for short lived |
| 63 | + instantiations of the RuleService where assets, channels, users are unlikely to change. |
56 | 64 | """ |
57 | 65 |
|
58 | 66 | _asset_service_stub: AssetServiceStub |
59 | 67 | _channel_service_stub: ChannelServiceStub |
60 | 68 | _rule_service_stub: RuleServiceStub |
61 | 69 | _user_service_stub: UserServiceStub |
| 70 | + _enable_caching: bool |
62 | 71 |
|
63 | | - def __init__(self, channel: SiftChannel): |
| 72 | + def __init__(self, channel: SiftChannel, enable_caching=False): |
64 | 73 | self._asset_service_stub = AssetServiceStub(channel) |
65 | 74 | self._channel_service_stub = ChannelServiceStub(channel) |
66 | 75 | self._rule_service_stub = RuleServiceStub(channel) |
67 | 76 | self._user_service_stub = UserServiceStub(channel) |
| 77 | + self._enable_caching = enable_caching |
68 | 78 |
|
69 | 79 | def load_rules_from_yaml( |
70 | 80 | self, |
@@ -401,8 +411,7 @@ def _update_req_from_rule_config( |
401 | 411 | assignee = config.action.assignee |
402 | 412 | user_id = None |
403 | 413 | if assignee: |
404 | | - users = get_active_users( |
405 | | - user_service=self._user_service_stub, |
| 414 | + users = self._get_active_users( |
406 | 415 | filter=f"name=='{assignee}'", |
407 | 416 | ) |
408 | 417 | if not users: |
@@ -453,8 +462,7 @@ def _update_req_from_rule_config( |
453 | 462 |
|
454 | 463 | # Validate channels are present within each asset |
455 | 464 | for asset in assets: |
456 | | - found_channels = get_channels( |
457 | | - channel_service=self._channel_service_stub, |
| 465 | + found_channels = self._get_channels( |
458 | 466 | filter=f"asset_id == '{asset.asset_id}' && {name_in}", |
459 | 467 | ) |
460 | 468 | found_channels_names = [channel.name for channel in found_channels] |
@@ -598,8 +606,35 @@ def _get_rule_from_rule_id(self, rule_id: str) -> Optional[Rule]: |
598 | 606 | return None |
599 | 607 |
|
600 | 608 | def _get_assets(self, names: List[str] = [], ids: List[str] = []) -> List[Asset]: |
| 609 | + if self._enable_caching: |
| 610 | + return self._get_assets_cached(tuple(sorted(names)), tuple(sorted(ids))) |
| 611 | + else: |
| 612 | + return list_assets_impl(self._asset_service_stub, names, ids) |
| 613 | + |
| 614 | + def _get_channels(self, filter: str) -> List[ChannelPb]: |
| 615 | + if self._enable_caching: |
| 616 | + return self._get_channels_cached(filter) |
| 617 | + else: |
| 618 | + return get_channels(channel_service=self._channel_service_stub, filter=filter) |
| 619 | + |
| 620 | + def _get_active_users(self, filter: str) -> List[UserPb]: |
| 621 | + if self._enable_caching: |
| 622 | + return self._get_active_users_cached(filter) |
| 623 | + else: |
| 624 | + return get_active_users(user_service=self._user_service_stub, filter=filter) |
| 625 | + |
| 626 | + @cache |
| 627 | + def _get_assets_cached(self, names: Tuple[str], ids: Tuple[str]) -> List[Asset]: |
601 | 628 | return list_assets_impl(self._asset_service_stub, names, ids) |
602 | 629 |
|
| 630 | + @cache |
| 631 | + def _get_channels_cached(self, filter: str) -> List[ChannelPb]: |
| 632 | + return get_channels(channel_service=self._channel_service_stub, filter=filter) |
| 633 | + |
| 634 | + @cache |
| 635 | + def _get_active_users_cached(self, filter: str) -> List[UserPb]: |
| 636 | + return get_active_users(user_service=self._user_service_stub, filter=filter) |
| 637 | + |
603 | 638 |
|
604 | 639 | @dataclass |
605 | 640 | class RuleChannelReference: |
|
0 commit comments