Skip to content

Commit fcb1ec1

Browse files
committed
Add guards to .run_every()
1 parent 2f28c2a commit fcb1ec1

10 files changed

Lines changed: 340 additions & 15 deletions

File tree

examples/stress_task.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""CPU and memory stress task for cloud_poc.
2+
3+
Submits a task that burns CPU for a given duration while repeatedly
4+
allocating and releasing chunks of memory.
5+
6+
Usage:
7+
BROKER_URL="http://localhost:8000/api/v1/broker?api_key=<key>" \
8+
python stress_task.py [cpu_seconds] [mem_mib_per_iter] [mem_iters]
9+
10+
Arguments (positional, all optional):
11+
cpu_seconds Seconds to keep the CPU busy (default: 10)
12+
mem_mib_per_iter MiB to allocate per iteration (default: 64)
13+
mem_iters Number of allocation iterations (default: 8)
14+
"""
15+
16+
import asyncio
17+
import os
18+
import sys
19+
20+
import offwork
21+
from offwork import progress
22+
23+
broker_url = os.environ.get("BROKER_URL")
24+
if not broker_url:
25+
print("error: BROKER_URL is not set", file=sys.stderr)
26+
sys.exit(1)
27+
28+
offwork.connect(broker_url)
29+
30+
31+
def _burn_cpu(seconds: float) -> int:
32+
"""Busy-loop for *seconds* wall time. Returns iteration count."""
33+
import time
34+
import hashlib
35+
36+
deadline = time.monotonic() + seconds
37+
data = b"offwork-stress" * 256 # ~3.5 KB
38+
iters = 0
39+
while time.monotonic() < deadline:
40+
hashlib.sha256(data).digest()
41+
iters += 1
42+
return iters
43+
44+
45+
def _alloc_mib(mib: float) -> int:
46+
"""Allocate *mib* MiB, touch every page, return checksum byte."""
47+
size = int(mib * 1024 * 1024)
48+
chunk = bytearray(size)
49+
# Touch every 4096-byte page so the OS actually faults the memory in.
50+
for i in range(0, size, 4096):
51+
chunk[i] = (i // 4096) & 0xFF
52+
return chunk[-1]
53+
54+
55+
@offwork.task
56+
def stress(cpu_seconds: float, mem_mib_per_iter: float, mem_iters: int) -> dict:
57+
"""Stress the CPU and memory on the worker.
58+
59+
Args:
60+
cpu_seconds: Wall-clock seconds to spend hashing.
61+
mem_mib_per_iter: MiB to allocate and touch per memory iteration.
62+
mem_iters: How many times to allocate / release that block.
63+
64+
Returns a summary dict with hash iteration count and total memory touched.
65+
"""
66+
import time
67+
68+
start = time.monotonic()
69+
total_steps = mem_iters # N CPU phase + N mem phases
70+
71+
checksums = []
72+
for i in range(mem_iters):
73+
# --- CPU phase ---
74+
progress(i, total_steps, message=f"Iter {i + 1}/{mem_iters}")
75+
hash_iters = _burn_cpu(cpu_seconds)
76+
checksums.append(_alloc_mib(mem_mib_per_iter))
77+
78+
elapsed = time.monotonic() - start
79+
total_mib = mem_mib_per_iter * mem_iters
80+
return {
81+
"elapsed_seconds": round(elapsed, 2),
82+
"cpu_seconds_requested": cpu_seconds,
83+
"hash_iterations": hash_iters,
84+
"memory_iterations": mem_iters,
85+
"mem_mib_per_iter": mem_mib_per_iter,
86+
"total_mem_touched_mib": total_mib,
87+
"checksum": sum(checksums) & 0xFF,
88+
}
89+
90+
91+
async def main() -> None:
92+
cpu_seconds = float(sys.argv[1]) if len(sys.argv) > 1 else 10.0
93+
mem_mib_per_iter = float(sys.argv[2]) if len(sys.argv) > 2 else 64.0
94+
mem_iters = int(sys.argv[3]) if len(sys.argv) > 3 else 8
95+
96+
print(f"submitting stress task: cpu={cpu_seconds}s mem={mem_mib_per_iter}MiB x {mem_iters} iters")
97+
future = await stress.start(cpu_seconds, mem_mib_per_iter, mem_iters)
98+
print(f"task id: {future.task_id}")
99+
100+
while not await future.done():
101+
p = await future.progress()
102+
if p is not None:
103+
print(f" [{p.current}/{p.total}] {p.message}")
104+
await asyncio.sleep(2.0)
105+
106+
result = await future
107+
print("\nresult:")
108+
for k, v in result.items():
109+
print(f" {k}: {v}")
110+
111+
112+
asyncio.run(main())

offwork/core/envelope.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def verify_task_envelope(
171171
retry_delay=data.get("retry_delay", 1.0),
172172
scheduled_at=data.get("scheduled_at"),
173173
recur_interval=data.get("recur_interval"),
174+
recur_deadline=data.get("recur_deadline"),
175+
recur_remaining=data.get("recur_remaining"),
174176
schedule_id=data.get("schedule_id"),
175177
throttle=data.get("throttle"),
176178
)

offwork/core/task.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,8 @@ class Task:
418418
retry_delay: float = 1.0
419419
scheduled_at: float | None = None
420420
recur_interval: float | None = None
421+
recur_deadline: float | None = None
422+
recur_remaining: int | None = None
421423
schedule_id: str | None = None
422424
throttle: float | None = None
423425

@@ -442,6 +444,10 @@ def _to_dict(self) -> dict[str, Any]:
442444
d["scheduled_at"] = self.scheduled_at
443445
if self.recur_interval is not None:
444446
d["recur_interval"] = self.recur_interval
447+
if self.recur_deadline is not None:
448+
d["recur_deadline"] = self.recur_deadline
449+
if self.recur_remaining is not None:
450+
d["recur_remaining"] = self.recur_remaining
445451
if self.schedule_id is not None:
446452
d["schedule_id"] = self.schedule_id
447453
if self.throttle is not None:
@@ -467,6 +473,8 @@ def from_json(cls, json_str: str | bytes) -> Self:
467473
retry_delay=data.get("retry_delay", 1.0),
468474
scheduled_at=data.get("scheduled_at"),
469475
recur_interval=data.get("recur_interval"),
476+
recur_deadline=data.get("recur_deadline"),
477+
recur_remaining=data.get("recur_remaining"),
470478
schedule_id=data.get("schedule_id"),
471479
throttle=data.get("throttle"),
472480
)

offwork/graph/tracing.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ async def run_every(
137137
frequency: Any,
138138
*args: Any,
139139
_start_at: Any = None,
140+
run_for: Any = None,
141+
max_runs: int | None = None,
140142
backend: str | Backend | None = None,
141143
**kwargs: Any,
142144
) -> object:
@@ -146,9 +148,19 @@ async def run_every(
146148
start_ts: float | None = None
147149
if _start_at is not None:
148150
start_ts = _start_at.timestamp() if isinstance(_start_at, datetime) else float(_start_at)
151+
if run_for is None and max_runs is None:
152+
run_for = timedelta(hours=1)
153+
run_for_seconds: float | None = None
154+
if run_for is not None:
155+
run_for_seconds = run_for.total_seconds() if isinstance(run_for, timedelta) else float(run_for)
156+
if run_for_seconds <= 0:
157+
raise ValueError(f"run_for must be positive, got {run_for}")
158+
if max_runs is not None and max_runs <= 0:
159+
raise ValueError(f"max_runs must be positive, got {max_runs}")
149160
return await submit_recurring(
150161
func, wrapper, *args,
151162
_backend=backend, _interval=interval, _start_at=start_ts,
163+
_run_for=run_for_seconds, _max_runs=max_runs,
152164
**kwargs,
153165
)
154166

offwork/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ async def run_every(
3838
frequency: timedelta | float,
3939
*args: Any,
4040
_start_at: datetime | None = ...,
41+
run_for: timedelta | float | None = ...,
42+
max_runs: int | None = ...,
4143
**kwargs: Any,
4244
) -> ScheduleHandle: ...
4345

offwork/worker/remote.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,8 @@ async def submit_recurring(
280280
_root_token: bytes | None = None,
281281
_interval: float = 0,
282282
_start_at: float | None = None,
283+
_run_for: float | None = None,
284+
_max_runs: int | None = None,
283285
**kwargs: Any,
284286
) -> ScheduleHandle:
285287
"""Submit a recurring task and return a :class:`ScheduleHandle`."""
@@ -302,6 +304,7 @@ async def submit_recurring(
302304

303305
schedule_id = uuid.uuid4().hex[:12]
304306
scheduled_at = _start_at or time.time()
307+
recur_deadline = scheduled_at + _run_for if _run_for is not None else None
305308

306309
opts = getattr(wrapper, "__offwork_options__", {})
307310
task = Task(
@@ -315,6 +318,8 @@ async def submit_recurring(
315318
throttle=opts.get("throttle"),
316319
scheduled_at=scheduled_at,
317320
recur_interval=_interval,
321+
recur_deadline=recur_deadline,
322+
recur_remaining=_max_runs,
318323
schedule_id=schedule_id,
319324
)
320325

@@ -634,7 +639,18 @@ async def _handle_task(
634639

635640
# Re-enqueue recurring task — worker re-signs with its own identity.
636641
if task.recur_interval is not None and task.schedule_id is not None:
637-
if not await backend.is_schedule_cancelled(task.schedule_id):
642+
next_at = time.time() + task.recur_interval
643+
remaining = task.recur_remaining
644+
deadline_exceeded = task.recur_deadline is not None and next_at > task.recur_deadline
645+
runs_exhausted = remaining is not None and remaining <= 1
646+
if deadline_exceeded or runs_exhausted:
647+
await backend.cancel_schedule(task.schedule_id)
648+
logger.info(
649+
"Recurring schedule %s exhausted (%s)",
650+
task.schedule_id,
651+
"deadline" if deadline_exceeded else "max_runs",
652+
)
653+
elif not await backend.is_schedule_cancelled(task.schedule_id):
638654
next_task = Task(
639655
graph_json=task.graph_json,
640656
function_name=task.function_name,
@@ -644,8 +660,10 @@ async def _handle_task(
644660
retries=task.retries,
645661
retry_delay=task.retry_delay,
646662
throttle=task.throttle,
647-
scheduled_at=time.time() + task.recur_interval,
663+
scheduled_at=next_at,
648664
recur_interval=task.recur_interval,
665+
recur_deadline=task.recur_deadline,
666+
recur_remaining=remaining - 1 if remaining is not None else None,
649667
schedule_id=task.schedule_id,
650668
)
651669
await backend.submit(_encode_task(next_task, root_token))

offwork/worker/result.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -214,22 +214,52 @@ def __await__(self) -> Generator[Any, None, Any]:
214214

215215
# -- cancellation ----------------------------------------------------------
216216

217-
async def cancel(self) -> None:
217+
async def cancel(self, wait: bool | float = False) -> bool:
218218
"""Cancel the task.
219219
220-
Marks the task as cancelled in the backend. If the worker
221-
hasn't started execution yet, it will skip the task. If
222-
execution is already in progress, it will continue but the
223-
client will receive a :class:`TaskCancelled` error.
220+
Marks the task as cancelled in the backend. The worker
221+
observes the flag via its heartbeat loop and aborts execution
222+
cooperatively.
224223
225-
Awaiting the result after cancellation raises
226-
:class:`TaskCancelled`.
224+
Parameters
225+
----------
226+
wait
227+
If ``False`` (default), return immediately after signalling
228+
cancellation. If ``True``, block until the worker confirms
229+
(default 30s timeout). If a number, wait that many seconds
230+
for confirmation.
231+
232+
Returns
233+
-------
234+
bool
235+
``True`` if cancellation was confirmed by the worker (or
236+
``wait=False``). ``False`` if the wait timed out.
227237
"""
228238
await self._backend.cancel_task(self._task_id)
229-
await self._backend.send_result(
230-
self._task_id,
231-
ResultEnvelope.cancelled(self._task_id).to_json(),
232-
)
239+
if wait is False:
240+
# Pre-seed a cancelled envelope so a client that never awaits
241+
# confirmation still gets TaskCancelled when it reads.
242+
await self._backend.send_result(
243+
self._task_id,
244+
ResultEnvelope.cancelled(self._task_id).to_json(),
245+
)
246+
return True
247+
timeout = 30.0 if wait is True else float(wait)
248+
deadline = time.monotonic() + timeout
249+
while True:
250+
raw = await self._backend.try_get_result(self._task_id)
251+
if raw is not None:
252+
self._envelope = ResultEnvelope.from_json(raw)
253+
return True
254+
if time.monotonic() >= deadline:
255+
# Fall back: seed a cancelled envelope so subsequent reads
256+
# don't hang forever, and tell the caller we timed out.
257+
await self._backend.send_result(
258+
self._task_id,
259+
ResultEnvelope.cancelled(self._task_id).to_json(),
260+
)
261+
return False
262+
await asyncio.sleep(0.5)
233263

234264
# -- progress --------------------------------------------------------------
235265

offwork/worker/schedule.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
"""Schedule handle for recurring task management."""
22

3+
import asyncio
4+
import time
5+
36
from offwork.worker.backends.base import Backend
47

58

@@ -14,13 +17,37 @@ def __init__(self, schedule_id: str, backend: Backend) -> None:
1417
def schedule_id(self) -> str:
1518
return self._schedule_id
1619

17-
async def cancel(self) -> None:
20+
async def cancel(self, wait: bool | float = False) -> bool:
1821
"""Cancel this recurring schedule.
1922
20-
The worker will stop re-enqueuing new occurrences after the
23+
The worker stops re-enqueuing new occurrences after the
2124
current one completes.
25+
26+
Parameters
27+
----------
28+
wait
29+
If ``False`` (default), return immediately. If ``True``,
30+
block until the backend confirms the schedule is marked
31+
cancelled (default 30s timeout). If a number, wait that
32+
many seconds.
33+
34+
Returns
35+
-------
36+
bool
37+
``True`` if cancellation was acknowledged (or
38+
``wait=False``). ``False`` on timeout.
2239
"""
2340
await self._backend.cancel_schedule(self._schedule_id)
41+
if wait is False:
42+
return True
43+
timeout = 30.0 if wait is True else float(wait)
44+
deadline = time.monotonic() + timeout
45+
while True:
46+
if await self._backend.is_schedule_cancelled(self._schedule_id):
47+
return True
48+
if time.monotonic() >= deadline:
49+
return False
50+
await asyncio.sleep(0.5)
2451

2552
def __repr__(self) -> str:
2653
return f"ScheduleHandle(schedule_id={self._schedule_id!r})"

0 commit comments

Comments
 (0)