Skip to content

Commit dd3aae2

Browse files
committed
Convert branch state management to jobs
1 parent 3a6e4c2 commit dd3aae2

7 files changed

Lines changed: 199 additions & 81 deletions

File tree

src/api/organization/project/branch/__init__.py

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
deployment_status,
4646
)
4747
from .....deployment.kubernetes._util import core_v1_client
48-
from .....deployment.kubernetes.neonvm import PowerState as NeonVMPowerState
4948
from .....deployment.kubernetes.neonvm import set_virtualmachine_power_state
5049
from .....deployment.kubernetes.volume_clone import (
5150
clone_branch_database_volume,
@@ -94,6 +93,11 @@
9493
from ....settings import get_settings as get_api_settings
9594
from .api_keys import api as api_key_api
9695
from .auth import api as auth_api
96+
from .control_tasks import (
97+
_CONTROL_TO_POWER_STATE,
98+
_CONTROL_TRANSITION_INITIAL,
99+
dispatch_control,
100+
)
97101
from .resize_tasks import dispatch_resize
98102
from .tasks import task_api
99103

@@ -1889,109 +1893,61 @@ async def resize(
18891893
404: NotFound,
18901894
}
18911895

1892-
_CONTROL_TO_AUTOSCALER_POWERSTATE: dict[str, NeonVMPowerState] = {
1893-
"pause": "Stopped",
1894-
"resume": "Running",
1895-
"start": "Running",
1896-
"stop": "Stopped",
1897-
}
1898-
1899-
_CONTROL_TRANSITION_INITIAL: dict[str, BranchServiceStatus] = {
1900-
"pause": BranchServiceStatus.PAUSING,
1901-
"resume": BranchServiceStatus.RESUMING,
1902-
"start": BranchServiceStatus.STARTING,
1903-
"stop": BranchServiceStatus.STOPPING,
1904-
}
1905-
1906-
_CONTROL_TRANSITION_FINAL: dict[str, BranchServiceStatus | None] = {
1907-
"pause": BranchServiceStatus.PAUSED,
1908-
"resume": BranchServiceStatus.STARTING,
1909-
"start": None,
1910-
"stop": BranchServiceStatus.STOPPED,
1911-
}
1912-
19131896

19141897
async def _set_branch_status(session: SessionDep, branch: Branch, status: BranchServiceStatus):
19151898
branch.set_status(status)
19161899
await session.commit()
19171900

19181901

1919-
async def _set_final_branch_status(session: SessionDep, branch: Branch, action: str) -> None:
1920-
final_status = _CONTROL_TRANSITION_FINAL[action]
1921-
if final_status is None:
1922-
return
1923-
await _set_branch_status(session, branch, final_status)
1924-
1925-
1926-
async def _set_autoscaler_power_state(action: str, namespace: str, name: str) -> None:
1927-
power_state = _CONTROL_TO_AUTOSCALER_POWERSTATE.get(action)
1928-
if power_state is None:
1929-
return
1930-
await set_virtualmachine_power_state(namespace, name, power_state)
1931-
1932-
1933-
async def _apply_branch_action(
1934-
*,
1935-
action: str,
1936-
autoscaler_namespace: str,
1937-
autoscaler_vm_name: str,
1938-
) -> None:
1939-
await _set_autoscaler_power_state(action, autoscaler_namespace, autoscaler_vm_name)
1940-
1941-
19421902
@instance_api.post(
19431903
"/pause",
19441904
name="organizations:projects:branch:pause",
1945-
status_code=204,
1905+
status_code=202,
19461906
responses=_control_responses,
19471907
)
19481908
@instance_api.post(
19491909
"/resume",
19501910
name="organizations:projects:branch:resume",
1951-
status_code=204,
1911+
status_code=202,
19521912
responses=_control_responses,
19531913
)
19541914
@instance_api.post(
19551915
"/start",
19561916
name="organizations:projects:branch:start",
1957-
status_code=204,
1917+
status_code=202,
19581918
responses=_control_responses,
19591919
)
19601920
@instance_api.post(
19611921
"/stop",
19621922
name="organizations:projects:branch:stop",
1963-
status_code=204,
1923+
status_code=202,
19641924
responses=_control_responses,
19651925
)
19661926
async def control_branch(
19671927
session: SessionDep,
19681928
request: Request,
1969-
_organization: OrganizationDep,
1970-
_project: ProjectDep,
1929+
organization: OrganizationDep,
1930+
project: ProjectDep,
19711931
branch: BranchDep,
19721932
):
19731933
action = request.scope["route"].name.split(":")[-1]
1974-
assert action in _CONTROL_TO_AUTOSCALER_POWERSTATE
1934+
assert action in _CONTROL_TO_POWER_STATE
19751935
branch_in_session = await session.merge(branch)
1976-
branch_id = branch_in_session.id
1977-
autoscaler_namespace, autoscaler_vm_name = get_autoscaler_vm_identity(branch_id)
19781936
await _set_branch_status(session, branch_in_session, _CONTROL_TRANSITION_INITIAL[action])
1979-
try:
1980-
await _apply_branch_action(
1981-
action=action,
1982-
autoscaler_namespace=autoscaler_namespace,
1983-
autoscaler_vm_name=autoscaler_vm_name,
1984-
)
1985-
except ApiException as e:
1986-
await _set_branch_status(session, branch_in_session, BranchServiceStatus.ERROR)
1987-
status = 404 if e.status == 404 else 400
1988-
raise HTTPException(status_code=status, detail=e.body or str(e)) from e
1989-
except VelaKubernetesError as e:
1990-
await _set_branch_status(session, branch_in_session, BranchServiceStatus.ERROR)
1991-
raise HTTPException(status_code=500, detail=str(e)) from e
1992-
else:
1993-
await _set_final_branch_status(session, branch_in_session, action)
1994-
return Response(status_code=204)
1937+
1938+
task_id = dispatch_control(str(branch_in_session.id), action)
1939+
branch_in_session.control_task_id = task_id
1940+
await session.commit()
1941+
1942+
task_url = url_path_for(
1943+
request,
1944+
"organizations:projects:branch:tasks:detail",
1945+
organization_id=await organization.awaitable_attrs.id,
1946+
project_id=await project.awaitable_attrs.id,
1947+
branch_id=await branch_in_session.awaitable_attrs.id,
1948+
task_id=task_id,
1949+
)
1950+
return Response(status_code=202, headers={"Location": task_url})
19951951

19961952

19971953
instance_api.include_router(auth_api, prefix="/auth")
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Celery task for branch start/stop/pause/resume lifecycle management.
2+
3+
A single task patches the Kubernetes VM power state, polls until the VM
4+
reaches the desired phase, then updates the branch status in the database.
5+
"""
6+
7+
import asyncio
8+
import logging
9+
from uuid import UUID
10+
11+
from asgiref.sync import async_to_sync
12+
from ulid import ULID
13+
14+
from .....database import AsyncSessionLocal
15+
from .....deployment import get_autoscaler_vm_identity
16+
from .....deployment.kubernetes.neonvm import Phase, PowerState, get_neon_vm, set_virtualmachine_power_state
17+
from .....models.branch import Branch, BranchServiceStatus
18+
from .....worker import app
19+
20+
logger = logging.getLogger(__name__)
21+
22+
_CONTROL_TO_POWER_STATE: dict[str, PowerState] = {
23+
"pause": "Stopped",
24+
"resume": "Running",
25+
"start": "Running",
26+
"stop": "Stopped",
27+
}
28+
29+
_CONTROL_TRANSITION_INITIAL: dict[str, BranchServiceStatus] = {
30+
"pause": BranchServiceStatus.PAUSING,
31+
"resume": BranchServiceStatus.RESUMING,
32+
"start": BranchServiceStatus.STARTING,
33+
"stop": BranchServiceStatus.STOPPING,
34+
}
35+
36+
_CONTROL_TRANSITION_FINAL: dict[str, BranchServiceStatus | None] = {
37+
"pause": BranchServiceStatus.PAUSED,
38+
"resume": BranchServiceStatus.STARTING,
39+
"start": None,
40+
"stop": BranchServiceStatus.STOPPED,
41+
}
42+
43+
_DESIRED_PHASES: dict[str, set[Phase]] = {
44+
"start": {Phase.running},
45+
"resume": {Phase.running},
46+
"stop": {Phase.stopped, Phase.succeeded},
47+
"pause": {Phase.stopped, Phase.succeeded},
48+
}
49+
50+
_POLL_INTERVAL_SEC = 5
51+
_TIMEOUT_SEC = 600 # 10 minutes
52+
53+
54+
async def _async_perform_control(branch_id: str, action: str) -> dict:
55+
ulid = ULID.from_str(branch_id)
56+
namespace, name = get_autoscaler_vm_identity(ulid)
57+
58+
power_state = _CONTROL_TO_POWER_STATE[action]
59+
await set_virtualmachine_power_state(namespace, name, power_state)
60+
61+
desired_phases = _DESIRED_PHASES[action]
62+
elapsed = 0.0
63+
while True:
64+
vm = await get_neon_vm(namespace, name)
65+
if vm.status is not None and vm.status.phase in desired_phases:
66+
break
67+
if elapsed >= _TIMEOUT_SEC:
68+
raise TimeoutError(f"VM {name!r} did not reach {desired_phases} within {_TIMEOUT_SEC}s")
69+
await asyncio.sleep(_POLL_INTERVAL_SEC)
70+
elapsed += _POLL_INTERVAL_SEC
71+
72+
final_status = _CONTROL_TRANSITION_FINAL[action]
73+
async with AsyncSessionLocal() as session:
74+
branch = await session.get(Branch, ulid)
75+
if branch is None:
76+
logger.error("Branch %s not found after control action %s", branch_id, action)
77+
return {"action": action}
78+
if final_status is not None:
79+
branch.set_status(final_status)
80+
branch.control_task_id = None
81+
await session.commit()
82+
83+
return {"action": action}
84+
85+
86+
async def _async_perform_control_with_error_handling(branch_id: str, action: str) -> dict:
87+
try:
88+
return await _async_perform_control(branch_id, action)
89+
except Exception:
90+
ulid = ULID.from_str(branch_id)
91+
async with AsyncSessionLocal() as session:
92+
branch = await session.get(Branch, ulid)
93+
if branch is not None:
94+
branch.set_status(BranchServiceStatus.ERROR)
95+
branch.control_task_id = None
96+
await session.commit()
97+
raise
98+
99+
100+
@app.task(name="simplyblock.vela.branch.control")
101+
def perform_control(branch_id: str, action: str) -> dict:
102+
"""Patch VM power state, wait for desired phase, update branch status."""
103+
return async_to_sync(_async_perform_control_with_error_handling)(branch_id, action)
104+
105+
106+
def dispatch_control(branch_id: str, action: str) -> UUID:
107+
"""Dispatch perform_control asynchronously; return the Celery task UUID."""
108+
result = perform_control.apply_async(kwargs={"branch_id": branch_id, "action": action})
109+
return UUID(result.id)

src/api/organization/project/branch/tasks.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Branch task list/detail endpoints.
22
3-
Exposes Celery task state (currently resize only) under:
3+
Exposes Celery task state (resize and control) under:
44
GET .../branches/{branch_id}/tasks
55
GET .../branches/{branch_id}/tasks/{task_id}
66
"""
@@ -14,6 +14,7 @@
1414

1515
from ...._util import Forbidden, NotFound, Unauthenticated
1616
from ....dependencies import BranchDep, OrganizationDep, ProjectDep
17+
from .control_tasks import perform_control
1718
from .resize_tasks import finalize_resize
1819

1920
task_api = APIRouter(tags=["branch"])
@@ -37,7 +38,7 @@ class BranchTaskPublic(BaseModel):
3738
date_done: datetime | None
3839

3940

40-
def _build_task_public(task_id: UUID) -> BranchTaskPublic:
41+
def _build_resize_task_public(task_id: UUID) -> BranchTaskPublic:
4142
result = finalize_resize.AsyncResult(str(task_id))
4243
state = result.state
4344
status = _CELERY_STATE_TO_STATUS.get(state, state)
@@ -53,6 +54,23 @@ def _build_task_public(task_id: UUID) -> BranchTaskPublic:
5354
)
5455

5556

57+
def _build_control_task_public(task_id: UUID) -> BranchTaskPublic:
58+
result = perform_control.AsyncResult(str(task_id))
59+
state = result.state
60+
status = _CELERY_STATE_TO_STATUS.get(state, state)
61+
kwargs: dict = result.kwargs or {}
62+
action = kwargs.get("action", "control")
63+
return BranchTaskPublic(
64+
id=task_id,
65+
task_type=action,
66+
status=status,
67+
parameters={"action": action},
68+
result=result.result if state == "SUCCESS" else None,
69+
error=str(result.traceback) if state == "FAILURE" and result.traceback else None,
70+
date_done=result.date_done,
71+
)
72+
73+
5674
@task_api.get(
5775
"/",
5876
name="organizations:projects:branch:tasks:list",
@@ -64,9 +82,12 @@ async def list_tasks(
6482
_project: ProjectDep,
6583
branch: BranchDep,
6684
) -> list[BranchTaskPublic]:
67-
if branch.resize_task_id is None:
68-
return []
69-
return [_build_task_public(branch.resize_task_id)]
85+
tasks = []
86+
if branch.resize_task_id is not None:
87+
tasks.append(_build_resize_task_public(branch.resize_task_id))
88+
if branch.control_task_id is not None:
89+
tasks.append(_build_control_task_public(branch.control_task_id))
90+
return tasks
7091

7192

7293
@task_api.get(
@@ -81,6 +102,8 @@ async def get_task(
81102
branch: BranchDep,
82103
task_id: UUID,
83104
) -> BranchTaskPublic:
84-
if branch.resize_task_id != task_id:
85-
raise HTTPException(status_code=404, detail="Task not found")
86-
return _build_task_public(task_id)
105+
if branch.resize_task_id == task_id:
106+
return _build_resize_task_public(task_id)
107+
if branch.control_task_id == task_id:
108+
return _build_control_task_public(task_id)
109+
raise HTTPException(status_code=404, detail="Task not found")

src/models/branch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ class Branch(AsyncAttrs, Model, table=True):
109109
)
110110
pitr_enabled: bool = Field(default=False, sa_column=Column(Boolean, nullable=False, server_default=text("false")))
111111
resize_task_id: uuid.UUID | None = Field(default=None, nullable=True)
112-
db_port: int | None = Field(default=None, sa_column=Column(sa.Integer, nullable=True))
112+
db_port: int | None = None
113+
control_task_id: uuid.UUID | None = None
113114

114115
__table_args__ = (UniqueConstraint("project_id", "name", name="unique_branch_name_per_project"),)
115116

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Add control_task_id to branch table
2+
3+
Revision ID: b1c2d3e4f5a6
4+
Revises: ad471311850e
5+
Create Date: 2026-04-01 00:00:00.000000
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
import sqlalchemy as sa
11+
from alembic import op
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = "b1c2d3e4f5a6"
16+
down_revision: Union[str, Sequence[str], None] = "a1b2c3d4e5f6"
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
op.add_column("branch", sa.Column("control_task_id", sa.Uuid(), nullable=True))
23+
24+
25+
def downgrade() -> None:
26+
op.drop_column("branch", "control_task_id")

src/worker/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@ class Settings(BaseSettings):
2020
app.conf.task_chord_propagates = False
2121

2222
# Register tasks — must be imported after `app` is defined.
23+
from ..api.organization.project.branch import control_tasks as _api_control_tasks # noqa: E402, F401
2324
from ..api.organization.project.branch import resize_tasks as _api_resize_tasks # noqa: E402, F401
2425
from ..deployment import resize as _deployment_resize # noqa: E402, F401

0 commit comments

Comments
 (0)