-
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathchannel_concurrency.py
More file actions
209 lines (162 loc) · 7.16 KB
/
channel_concurrency.py
File metadata and controls
209 lines (162 loc) · 7.16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
Channel Concurrency Example
============================
This example demonstrates how to safely share and update channel data
when tasks run in parallel using ``ParallelGroup``.
Problem
-------
A naive read-modify-write (``get`` -> compute -> ``set``) is **not** atomic.
When multiple tasks execute in parallel threads, updates can be lost because
two tasks may read the same value and overwrite each other's writes.
Solutions
---------
1. **``channel.atomic_add(key, amount)``** — Atomic numeric add (inc/dec).
Backed by a per-key lock in MemoryChannel and ``INCRBYFLOAT`` in Redis.
2. **``channel.lock(key)``** — Advisory lock for arbitrary compound
operations. ``MemoryChannel`` uses a per-key ``threading.RLock`` for
in-process coordination; ``RedisChannel`` uses ``redis.lock.Lock``
(SET NX + Lua release) for cross-client distributed coordination.
Expected Output
---------------
=== Channel Concurrency Demo ===
--- Unsafe parallel increment (race condition) ---
Expected counter: 500, Actual: <less than 500>
Updates lost!
--- Safe parallel increment with atomic_add() ---
Expected counter: 500, Actual: 500
--- Safe compound update with lock() ---
Overflow events: 5
Counter after resets: 0
Done!
"""
from graflow.core.context import TaskExecutionContext
from graflow.core.decorators import task
from graflow.core.task import ParallelGroup
from graflow.core.workflow import workflow
def demo_unsafe_increment() -> None:
"""Show that naive get/set loses updates in parallel execution."""
import time
print("--- Unsafe parallel increment (race condition) ---")
num_workers = 5
increments_per_worker = 100
expected = num_workers * increments_per_worker
with workflow("unsafe_demo") as ctx:
@task(inject_context=True)
def init_counter(context: TaskExecutionContext):
context.get_channel().set("counter", 0)
# Create worker tasks that use naive get/set
workers = []
for i in range(num_workers):
@task(inject_context=True, id=f"unsafe_worker_{i}")
def unsafe_worker(context: TaskExecutionContext):
channel = context.get_channel()
for _ in range(increments_per_worker):
val = channel.get("counter")
time.sleep(0) # yield to trigger interleaving
channel.set("counter", val + 1)
workers.append(unsafe_worker)
@task(inject_context=True)
def report(context: TaskExecutionContext):
actual = context.get_channel().get("counter")
print(f" Expected counter: {expected}, Actual: {actual}")
if actual < expected:
print(" Updates lost!\n")
else:
print(" (Got lucky — no interleaving this run)\n")
parallel = ParallelGroup(workers, name="unsafe_group")
_ = init_counter >> parallel >> report
ctx.execute("init_counter")
def demo_atomic_add() -> None:
"""Show that atomic_add() is safe for parallel numeric updates."""
print("--- Safe parallel increment with atomic_add() ---")
num_workers = 5
increments_per_worker = 100
expected = num_workers * increments_per_worker
with workflow("add_demo") as ctx:
@task(inject_context=True)
def init_counter(context: TaskExecutionContext):
context.get_channel().set("counter", 0)
workers = []
for i in range(num_workers):
@task(inject_context=True, id=f"add_worker_{i}")
def add_worker(context: TaskExecutionContext):
channel = context.get_channel()
for _ in range(increments_per_worker):
channel.atomic_add("counter", 1)
workers.append(add_worker)
@task(inject_context=True)
def report(context: TaskExecutionContext):
actual = context.get_channel().get("counter")
print(f" Expected counter: {expected}, Actual: {actual}\n")
parallel = ParallelGroup(workers, name="add_group")
_ = init_counter >> parallel >> report
ctx.execute("init_counter")
def demo_advisory_lock() -> None:
"""Show lock() for compound read-modify-write that atomic_add() can't express."""
print("--- Safe compound update with lock() ---")
threshold = 10
num_workers = 5
increments_per_worker = 10
with workflow("lock_demo") as ctx:
@task(inject_context=True)
def init(context: TaskExecutionContext):
channel = context.get_channel()
channel.set("counter", 0)
channel.set("overflow_count", 0)
workers = []
for i in range(num_workers):
@task(inject_context=True, id=f"lock_worker_{i}")
def lock_worker(context: TaskExecutionContext):
channel = context.get_channel()
for _ in range(increments_per_worker):
# Advisory lock protects the entire read-modify-write block
with channel.lock("counter"):
val = channel.get("counter")
if val >= threshold:
channel.set("counter", 0)
channel.atomic_add("overflow_count", 1)
else:
channel.set("counter", val + 1)
workers.append(lock_worker)
@task(inject_context=True)
def report(context: TaskExecutionContext):
channel = context.get_channel()
overflows = channel.get("overflow_count")
counter = channel.get("counter")
print(f" Overflow events: {overflows}")
print(f" Counter after resets: {counter}\n")
parallel = ParallelGroup(workers, name="lock_group")
_ = init >> parallel >> report
ctx.execute("init")
def main():
print("=== Channel Concurrency Demo ===\n")
demo_unsafe_increment()
demo_atomic_add()
demo_advisory_lock()
print("Done!")
if __name__ == "__main__":
main()
# ============================================================================
# Key Takeaways:
# ============================================================================
#
# 1. **channel.atomic_add(key, amount)**
# - Atomic numeric add/subtract — no lost updates
# - Initialises missing keys to 0 automatically
# - MemoryChannel: per-key RLock; Redis: INCRBYFLOAT (server-side atomic)
# - Use for counters, metrics, scores
#
# 2. **channel.lock(key)**
# - Advisory lock for compound operations that atomic_add() can't express
# - Wrap with ``with channel.lock(key):`` context manager
# - MemoryChannel: per-key RLock; Redis: distributed lock for the same key
# - Use for conditional updates and other compound read-modify-write logic
#
# 3. **When to use which**
# - Simple counter? → channel.atomic_add("counter", 1)
# - Decrement? → channel.atomic_add("counter", -1)
# - Conditional update? → with channel.lock("key"): ...
# - Multi-key update? → with channel.lock("key"): ...
# - No concurrency concern? → channel.get() / channel.set() is fine
#
# ============================================================================