|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import logging |
4 | | -from typing import TYPE_CHECKING, Any, cast |
| 4 | +from typing import TYPE_CHECKING, Any, Sequence, cast |
5 | 5 |
|
6 | 6 | from sift.common.type.v1.resource_identifier_pb2 import ResourceIdentifier, ResourceIdentifiers |
7 | 7 | from sift.rule_evaluation.v1.rule_evaluation_pb2 import ( |
@@ -97,6 +97,48 @@ async def get_rule(self, rule_id: str | None = None, client_key: str | None = No |
97 | 97 | grpc_rule = cast("GetRuleResponse", response).rule |
98 | 98 | return Rule._from_proto(grpc_rule) |
99 | 99 |
|
| 100 | + def _update_rule_request_from_create(self, create: RuleCreate) -> UpdateRuleRequest: |
| 101 | + """Create an UpdateRuleRequest from a RuleCreate object. |
| 102 | +
|
| 103 | + Args: |
| 104 | + create: The RuleCreate model with the rule configuration. |
| 105 | +
|
| 106 | + Returns: |
| 107 | + The UpdateRuleRequest proto message. |
| 108 | + """ |
| 109 | + expression_proto = RuleConditionExpression( |
| 110 | + calculated_channel=CalculatedChannelConfig( |
| 111 | + expression=create.expression, |
| 112 | + channel_references={ |
| 113 | + c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) |
| 114 | + for c in create.channel_references |
| 115 | + }, |
| 116 | + ) |
| 117 | + ) |
| 118 | + conditions_request = [ |
| 119 | + UpdateConditionRequest( |
| 120 | + expression=expression_proto, |
| 121 | + actions=[create.action._to_update_request()], |
| 122 | + ) |
| 123 | + ] |
| 124 | + update_request = UpdateRuleRequest( |
| 125 | + name=create.name, |
| 126 | + description=create.description, |
| 127 | + is_enabled=True, |
| 128 | + organization_id=create.organization_id or "", |
| 129 | + client_key=create.client_key, |
| 130 | + is_external=create.is_external, |
| 131 | + conditions=conditions_request, |
| 132 | + asset_configuration=RuleAssetConfiguration( |
| 133 | + asset_ids=create.asset_ids or [], |
| 134 | + tag_ids=create.asset_tag_ids or [], |
| 135 | + ), |
| 136 | + contextual_channels=ContextualChannels( |
| 137 | + channels=[ChannelReferenceProto(name=c) for c in create.contextual_channels or []] |
| 138 | + ), # type: ignore |
| 139 | + ) |
| 140 | + return update_request |
| 141 | + |
100 | 142 | async def batch_get_rules( |
101 | 143 | self, rule_ids: list[str] | None = None, client_keys: list[str] | None = None |
102 | 144 | ) -> list[Rule]: |
@@ -138,39 +180,7 @@ async def create_rule( |
138 | 180 | Returns: |
139 | 181 | The created Rule. |
140 | 182 | """ |
141 | | - # Convert rule to UpdateRuleRequest |
142 | | - expression_proto = RuleConditionExpression( |
143 | | - calculated_channel=CalculatedChannelConfig( |
144 | | - expression=create.expression, |
145 | | - channel_references={ |
146 | | - c.channel_reference: ChannelReferenceProto(name=c.channel_identifier) |
147 | | - for c in create.channel_references |
148 | | - }, |
149 | | - ) |
150 | | - ) |
151 | | - conditions_request = [ |
152 | | - UpdateConditionRequest( |
153 | | - expression=expression_proto, |
154 | | - actions=[create.action._to_update_request()], |
155 | | - ) |
156 | | - ] |
157 | | - update_request = UpdateRuleRequest( |
158 | | - name=create.name, |
159 | | - description=create.description, |
160 | | - is_enabled=True, |
161 | | - organization_id=create.organization_id or "", |
162 | | - client_key=create.client_key, |
163 | | - is_external=create.is_external, |
164 | | - conditions=conditions_request, |
165 | | - asset_configuration=RuleAssetConfiguration( |
166 | | - asset_ids=create.asset_ids or [], |
167 | | - tag_ids=create.asset_tag_ids or [], |
168 | | - ), |
169 | | - contextual_channels=ContextualChannels( |
170 | | - channels=[ChannelReferenceProto(name=c) for c in create.contextual_channels or []] |
171 | | - ), # type: ignore |
172 | | - ) |
173 | | - |
| 183 | + update_request = self._update_rule_request_from_create(create) |
174 | 184 | request = CreateRuleRequest(update=update_request) |
175 | 185 | created_rule = cast( |
176 | 186 | "CreateRuleResponse", |
@@ -301,22 +311,59 @@ async def update_rule( |
301 | 311 | # Get the updated rule |
302 | 312 | return await self.get_rule(rule_id=rule.id_) |
303 | 313 |
|
304 | | - async def batch_update_rules(self, rules: list[RuleUpdate]) -> BatchUpdateRulesResponse: |
305 | | - """Batch update rules. |
| 314 | + async def batch_update_rules( |
| 315 | + self, |
| 316 | + rules: Sequence[RuleCreate | RuleUpdate], |
| 317 | + validate_only: bool = False, |
| 318 | + override_expression_validation: bool = False, |
| 319 | + ) -> BatchUpdateRulesResponse: |
| 320 | + """Batch update or create rules. |
306 | 321 |
|
307 | 322 | Args: |
308 | | - rules: List of rule updates to apply. |
| 323 | + rules: List of rule creates or updates to apply. RuleUpdate objects must have |
| 324 | + resource_id set. |
309 | 325 |
|
310 | 326 | Returns: |
311 | 327 | The batch update response. |
312 | | - """ |
313 | | - update_requests = [] |
314 | | - for rule_update in rules: |
315 | | - rule = await self.get_rule(rule_id=rule_update.resource_id) |
316 | | - request = self._update_rule_request_from_update(rule, rule_update) |
317 | | - update_requests.append(request) |
318 | 328 |
|
319 | | - request = BatchUpdateRulesRequest(rules=update_requests) # type: ignore |
| 329 | + Raises: |
| 330 | + ValueError: If any RuleUpdate objects are missing resource_id or the rule is not found for updating. |
| 331 | + """ |
| 332 | + # Collect resource_ids from only RuleUpdate objects |
| 333 | + rule_ids: list[str] = [] |
| 334 | + for rule in rules: |
| 335 | + if isinstance(rule, RuleUpdate): |
| 336 | + if rule.resource_id is None: |
| 337 | + raise ValueError("RuleUpdate objects must have resource_id set") |
| 338 | + rule_ids.append(rule.resource_id) |
| 339 | + |
| 340 | + # Fetch existing rules for updates |
| 341 | + existing_rules = await self.batch_get_rules(rule_ids=rule_ids) if rule_ids else [] |
| 342 | + existing_rules_by_id = {rule.id_: rule for rule in existing_rules} |
| 343 | + |
| 344 | + # Build update requests maintaining the input order |
| 345 | + update_requests: list[UpdateRuleRequest] = [] |
| 346 | + for rule in rules: |
| 347 | + if isinstance(rule, RuleCreate): |
| 348 | + # Convert RuleCreate to UpdateRuleRequest |
| 349 | + update_request = self._update_rule_request_from_create(rule) |
| 350 | + update_requests.append(update_request) |
| 351 | + elif isinstance(rule, RuleUpdate): |
| 352 | + # Use existing rule + update to create request |
| 353 | + existing_rule = existing_rules_by_id.get(rule.resource_id) |
| 354 | + if existing_rule is None: |
| 355 | + raise ValueError( |
| 356 | + f"Rule with resource_id {rule.resource_id} not found for update" |
| 357 | + ) |
| 358 | + update_request = self._update_rule_request_from_update(existing_rule, rule) |
| 359 | + update_requests.append(update_request) |
| 360 | + |
| 361 | + # Call the batch update request |
| 362 | + request = BatchUpdateRulesRequest( |
| 363 | + rules=update_requests, |
| 364 | + validate_only=validate_only, |
| 365 | + override_expression_validation=override_expression_validation, |
| 366 | + ) # type: ignore |
320 | 367 | response = await self._grpc_client.get_stub(RuleServiceStub).BatchUpdateRules(request) |
321 | 368 | return cast("BatchUpdateRulesResponse", response) |
322 | 369 |
|
|
0 commit comments