Skip to content

Commit ded0b20

Browse files
committed
Allow access to task arguments while pending
1 parent 27b9460 commit ded0b20

2 files changed

Lines changed: 60 additions & 21 deletions

File tree

src/api/organization/project/branch/__init__.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,8 @@
9595
from .auth import api as auth_api
9696
from .control_tasks import (
9797
_CONTROL_TO_POWER_STATE,
98-
_CONTROL_TRANSITION_INITIAL,
9998
dispatch_control,
100-
perform_control,
99+
get_control_in_progress_status,
101100
)
102101
from .resize_tasks import dispatch_resize
103102
from .tasks import task_api
@@ -233,29 +232,34 @@ def _should_update_branch_status(
233232
def _adjust_derived_status_for_stuck_creation(
234233
branch: Branch, current: BranchServiceStatus, derived: BranchServiceStatus
235234
) -> BranchServiceStatus:
236-
if derived != BranchServiceStatus.STOPPED:
235+
if derived not in {BranchServiceStatus.STOPPED, BranchServiceStatus.ERROR}:
236+
return derived
237+
if current not in {BranchServiceStatus.CREATING, BranchServiceStatus.STARTING}:
237238
return derived
238239

239240
status_timestamp = branch.status_updated_at or branch.created_datetime
240241
elapsed = datetime.now(UTC) - status_timestamp
241242

242-
if current == BranchServiceStatus.CREATING and elapsed >= _CREATING_STATUS_ERROR_GRACE_PERIOD:
243-
logger.warning(
244-
"Branch %s still CREATING after %s with STOPPED services; marking ERROR",
245-
branch.id,
246-
elapsed,
247-
)
248-
return BranchServiceStatus.ERROR
243+
if current == BranchServiceStatus.CREATING:
244+
if derived == BranchServiceStatus.STOPPED and elapsed >= _CREATING_STATUS_ERROR_GRACE_PERIOD:
245+
logger.warning(
246+
"Branch %s still CREATING after %s with STOPPED services; marking ERROR",
247+
branch.id,
248+
elapsed,
249+
)
250+
return BranchServiceStatus.ERROR
251+
return derived
249252

250-
if current == BranchServiceStatus.STARTING and elapsed >= _STARTING_STATUS_ERROR_GRACE_PERIOD:
253+
# current == STARTING
254+
if elapsed >= _STARTING_STATUS_ERROR_GRACE_PERIOD:
251255
logger.warning(
252-
"Branch %s still STARTING after %s with STOPPED services; marking ERROR",
256+
"Branch %s still STARTING after %s; marking ERROR",
253257
branch.id,
254258
elapsed,
255259
)
256260
return BranchServiceStatus.ERROR
257-
258-
return derived
261+
# Within grace period: suppress both STOPPED and ERROR — stay STARTING
262+
return current
259263

260264

261265
async def refresh_branch_status(branch_id: Identifier) -> BranchServiceStatus:
@@ -279,10 +283,9 @@ async def refresh_branch_status(branch_id: Identifier) -> BranchServiceStatus:
279283
def _active_task_status(branch: Branch) -> BranchServiceStatus | None:
280284
"""Return the in-progress status for the running task, or None if idle."""
281285
if branch.control_task_id is not None:
282-
result = perform_control.AsyncResult(str(branch.control_task_id))
283-
action = (result.kwargs or {}).get("action")
284-
if action in _CONTROL_TRANSITION_INITIAL:
285-
return _CONTROL_TRANSITION_INITIAL[action]
286+
status = get_control_in_progress_status(branch.control_task_id)
287+
if status is not None:
288+
return status
286289
if branch.resize_task_id is not None:
287290
return BranchServiceStatus.RESIZING
288291
return None

src/api/organization/project/branch/control_tasks.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import asyncio
88
import logging
9+
from typing import Any
910
from uuid import UUID
1011

1112
from asgiref.sync import async_to_sync
@@ -51,6 +52,24 @@
5152
_TIMEOUT_SEC = 600 # 10 minutes
5253

5354

55+
class _TaskRequest:
56+
"""Minimal Celery request stub used to pre-populate the result backend.
57+
58+
Celery only stores task kwargs in the backend when ``store_result`` is
59+
called with a request object (requires ``result_extended = True``). By
60+
constructing this stub at dispatch time we make ``AsyncResult.kwargs``
61+
readable immediately — before the worker has even picked up the message.
62+
"""
63+
64+
__slots__ = ("id", "task", "args", "kwargs")
65+
66+
def __init__(self, task_id: str, kwargs: dict[str, Any]) -> None:
67+
self.id = task_id
68+
self.task: str = perform_control.name
69+
self.args: tuple[()] = ()
70+
self.kwargs = kwargs
71+
72+
5473
async def _async_perform_control(branch_id: str, action: str) -> dict:
5574
ulid = ULID.from_str(branch_id)
5675
namespace, name = get_autoscaler_vm_identity(ulid)
@@ -103,7 +122,24 @@ def perform_control(branch_id: str, action: str) -> dict:
103122
return async_to_sync(_async_perform_control_with_error_handling)(branch_id, action)
104123

105124

125+
def get_control_in_progress_status(task_id: UUID) -> BranchServiceStatus | None:
126+
"""Return the in-progress status for a running control task, or None."""
127+
result = perform_control.AsyncResult(str(task_id))
128+
action = (result.kwargs or {}).get("action")
129+
if action in _CONTROL_TRANSITION_INITIAL:
130+
return _CONTROL_TRANSITION_INITIAL[action]
131+
return None
132+
133+
106134
def dispatch_control(branch_id: str, action: str) -> UUID:
107-
"""Dispatch perform_control asynchronously; return the Celery task UUID."""
108-
result = perform_control.apply_async(kwargs={"branch_id": branch_id, "action": action})
109-
return UUID(result.id)
135+
"""Dispatch perform_control asynchronously; return the Celery task UUID.
136+
137+
The task kwargs are pre-stored in the result backend immediately so that
138+
``AsyncResult.kwargs`` (and therefore ``get_control_in_progress_status``)
139+
works even before the worker picks up the message.
140+
"""
141+
kwargs = {"branch_id": branch_id, "action": action}
142+
result = perform_control.apply_async(kwargs=kwargs)
143+
task_id = str(result.id)
144+
app.backend.store_result(task_id, None, "PENDING", request=_TaskRequest(task_id, kwargs))
145+
return UUID(task_id)

0 commit comments

Comments
 (0)