forked from microsoft/durabletask-python
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathorchestration_entity_context.py
More file actions
123 lines (100 loc) · 5.5 KB
/
Copy pathorchestration_entity_context.py
File metadata and controls
123 lines (100 loc) · 5.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from collections.abc import Generator
from datetime import datetime
from typing import Any
from durabletask.internal.helpers import get_string_value
import durabletask.internal.orchestrator_service_pb2 as pb
from durabletask.entities import EntityInstanceId
class OrchestrationEntityContext:
def __init__(self, instance_id: str):
self.instance_id = instance_id
self.lock_acquisition_pending = False
self.critical_section_id: str | None = None
self.critical_section_locks: list[EntityInstanceId] = []
self.available_locks: list[EntityInstanceId] = []
@property
def is_inside_critical_section(self) -> bool:
return self.critical_section_id is not None
def get_available_entities(self) -> Generator[EntityInstanceId, None, None]:
if self.is_inside_critical_section:
for available_lock in self.available_locks:
yield available_lock
def validate_suborchestration_transition(self) -> tuple[bool, str]:
if self.is_inside_critical_section:
return False, "While holding locks, cannot call suborchestrators."
return True, ""
def validate_operation_transition(self, target_instance_id: EntityInstanceId, one_way: bool) -> tuple[bool, str]:
if self.is_inside_critical_section:
lock_to_use = target_instance_id
if one_way:
if target_instance_id in self.critical_section_locks:
return False, "Must not signal a locked entity from a critical section."
else:
try:
self.available_locks.remove(lock_to_use)
except ValueError:
if self.lock_acquisition_pending:
return False, "Must await the completion of the lock request prior to calling any entity."
if lock_to_use in self.critical_section_locks:
return False, "Must not call an entity from a critical section while a prior call to the same entity is still pending."
else:
return False, "Must not call an entity from a critical section if it is not one of the locked entities."
return True, ""
def validate_acquire_transition(self) -> tuple[bool, str]:
if self.is_inside_critical_section:
return False, "Must not enter another critical section from within a critical section."
return True, ""
def recover_lock_after_call(self, target_instance_id: EntityInstanceId):
if self.is_inside_critical_section:
self.available_locks.append(target_instance_id)
def emit_lock_release_messages(self) -> Generator[pb.SendEntityMessageAction, None, None]:
if self.is_inside_critical_section:
for entity_id in self.critical_section_locks:
unlock_event = pb.SendEntityMessageAction(entityUnlockSent=pb.EntityUnlockSentEvent(
criticalSectionId=self.critical_section_id,
targetInstanceId=get_string_value(str(entity_id)),
parentInstanceId=get_string_value(self.instance_id)
))
yield unlock_event
self.critical_section_locks = []
self.available_locks = []
self.critical_section_id = None
def emit_request_message(self, target: Any, operation_name: str, one_way: bool, operation_id: str,
scheduled_time_utc: datetime, input: str | None,
request_time: datetime | None = None, create_trace: bool = False) -> Any:
raise NotImplementedError()
def emit_acquire_message(
self,
critical_section_id: str,
entities: list[EntityInstanceId],
) -> tuple[None, None] | tuple[pb.SendEntityMessageAction, pb.OrchestrationInstance]:
if not entities:
return None, None
# Acquire the locks in a globally fixed order to avoid deadlocks
# Also remove duplicates - this can be optimized for perf if necessary
entity_ids = sorted(entities)
entity_ids_dedup: list[EntityInstanceId] = []
for i, entity_id in enumerate(entity_ids):
if entity_id != entity_ids[i - 1] if i > 0 else True:
entity_ids_dedup.append(entity_id)
target = pb.OrchestrationInstance(instanceId=str(entity_ids_dedup[0]))
request = pb.SendEntityMessageAction(entityLockRequested=pb.EntityLockRequestedEvent(
criticalSectionId=critical_section_id,
parentInstanceId=get_string_value(self.instance_id),
lockSet=[str(eid) for eid in entity_ids_dedup],
position=0,
))
self.critical_section_id = critical_section_id
self.critical_section_locks = entity_ids_dedup
self.lock_acquisition_pending = True
return request, target
def complete_acquire(self, critical_section_id: str) -> None:
if self.critical_section_id != critical_section_id:
raise RuntimeError(f"Unexpected lock acquire for critical section ID '{critical_section_id}' (expected '{self.critical_section_id}')")
self.available_locks = self.critical_section_locks
self.lock_acquisition_pending = False
def adjust_outgoing_message(self, instance_id: str, request_message: Any, capped_time: datetime) -> str:
raise NotImplementedError()
def deserialize_entity_response_event(self, event_content: str) -> Any:
raise NotImplementedError()