Skip to content

Commit 04748fe

Browse files
committed
Convert branch state management to jobs
1 parent eb40828 commit 04748fe

7 files changed

Lines changed: 205 additions & 84 deletions

File tree

src/api/organization/project/branch/__init__.py

Lines changed: 32 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
deployment_status,
4646
)
4747
from .....deployment.kubernetes._util import core_v1_client
48-
from .....deployment.kubernetes.neonvm import PowerState as NeonVMPowerState
4948
from .....deployment.kubernetes.neonvm import set_virtualmachine_power_state
5049
from .....deployment.kubernetes.volume_clone import (
5150
clone_branch_database_volume,
@@ -95,6 +94,12 @@
9594
from ....settings import get_settings as get_api_settings
9695
from .api_keys import api as api_key_api
9796
from .auth import api as auth_api
97+
from .control_tasks import (
98+
_CONTROL_TO_POWER_STATE,
99+
_CONTROL_TRANSITION_INITIAL,
100+
dispatch_control,
101+
perform_control,
102+
)
98103
from .resize_tasks import dispatch_resize
99104
from .tasks import task_api
100105

@@ -284,6 +289,12 @@ async def refresh_branch_status(branch_id: Identifier) -> BranchServiceStatus:
284289

285290

286291
async def _refresh_branch_status(branch: Branch) -> BranchServiceStatus:
292+
if branch.control_task_id is not None:
293+
result = perform_control.AsyncResult(str(branch.control_task_id))
294+
action = (result.kwargs or {}).get("action")
295+
if action in _CONTROL_TRANSITION_INITIAL:
296+
return _CONTROL_TRANSITION_INITIAL[action]
297+
287298
current_status = _parse_branch_status(branch.status)
288299
status = deployment_status(branch.id)
289300

@@ -1908,109 +1919,58 @@ async def resize(
19081919
404: NotFound,
19091920
}
19101921

1911-
_CONTROL_TO_AUTOSCALER_POWERSTATE: dict[str, NeonVMPowerState] = {
1912-
"pause": "Stopped",
1913-
"resume": "Running",
1914-
"start": "Running",
1915-
"stop": "Stopped",
1916-
}
1917-
1918-
_CONTROL_TRANSITION_INITIAL: dict[str, BranchServiceStatus] = {
1919-
"pause": BranchServiceStatus.PAUSING,
1920-
"resume": BranchServiceStatus.RESUMING,
1921-
"start": BranchServiceStatus.STARTING,
1922-
"stop": BranchServiceStatus.STOPPING,
1923-
}
1924-
1925-
_CONTROL_TRANSITION_FINAL: dict[str, BranchServiceStatus | None] = {
1926-
"pause": BranchServiceStatus.PAUSED,
1927-
"resume": BranchServiceStatus.STARTING,
1928-
"start": None,
1929-
"stop": BranchServiceStatus.STOPPED,
1930-
}
1931-
19321922

19331923
async def _set_branch_status(session: SessionDep, branch: Branch, status: BranchServiceStatus):
19341924
branch.set_status(status)
19351925
await session.commit()
19361926

19371927

1938-
async def _set_final_branch_status(session: SessionDep, branch: Branch, action: str) -> None:
1939-
final_status = _CONTROL_TRANSITION_FINAL[action]
1940-
if final_status is None:
1941-
return
1942-
await _set_branch_status(session, branch, final_status)
1943-
1944-
1945-
async def _set_autoscaler_power_state(action: str, namespace: str, name: str) -> None:
1946-
power_state = _CONTROL_TO_AUTOSCALER_POWERSTATE.get(action)
1947-
if power_state is None:
1948-
return
1949-
await set_virtualmachine_power_state(namespace, name, power_state)
1950-
1951-
1952-
async def _apply_branch_action(
1953-
*,
1954-
action: str,
1955-
autoscaler_namespace: str,
1956-
autoscaler_vm_name: str,
1957-
) -> None:
1958-
await _set_autoscaler_power_state(action, autoscaler_namespace, autoscaler_vm_name)
1959-
1960-
19611928
@instance_api.post(
19621929
"/pause",
19631930
name="organizations:projects:branch:pause",
1964-
status_code=204,
1931+
status_code=202,
19651932
responses=_control_responses,
19661933
)
19671934
@instance_api.post(
19681935
"/resume",
19691936
name="organizations:projects:branch:resume",
1970-
status_code=204,
1937+
status_code=202,
19711938
responses=_control_responses,
19721939
)
19731940
@instance_api.post(
19741941
"/start",
19751942
name="organizations:projects:branch:start",
1976-
status_code=204,
1943+
status_code=202,
19771944
responses=_control_responses,
19781945
)
19791946
@instance_api.post(
19801947
"/stop",
19811948
name="organizations:projects:branch:stop",
1982-
status_code=204,
1949+
status_code=202,
19831950
responses=_control_responses,
19841951
)
19851952
async def control_branch(
19861953
session: SessionDep,
19871954
request: Request,
1988-
_organization: OrganizationDep,
1989-
_project: ProjectDep,
1955+
organization: OrganizationDep,
1956+
project: ProjectDep,
19901957
branch: BranchDep,
19911958
):
19921959
action = request.scope["route"].name.split(":")[-1]
1993-
assert action in _CONTROL_TO_AUTOSCALER_POWERSTATE
1994-
branch_in_session = await session.merge(branch)
1995-
branch_id = branch_in_session.id
1996-
autoscaler_namespace, autoscaler_vm_name = get_autoscaler_vm_identity(branch_id)
1997-
await _set_branch_status(session, branch_in_session, _CONTROL_TRANSITION_INITIAL[action])
1998-
try:
1999-
await _apply_branch_action(
2000-
action=action,
2001-
autoscaler_namespace=autoscaler_namespace,
2002-
autoscaler_vm_name=autoscaler_vm_name,
2003-
)
2004-
except ApiException as e:
2005-
await _set_branch_status(session, branch_in_session, BranchServiceStatus.ERROR)
2006-
status = 404 if e.status == 404 else 400
2007-
raise HTTPException(status_code=status, detail=e.body or str(e)) from e
2008-
except VelaKubernetesError as e:
2009-
await _set_branch_status(session, branch_in_session, BranchServiceStatus.ERROR)
2010-
raise HTTPException(status_code=500, detail=str(e)) from e
2011-
else:
2012-
await _set_final_branch_status(session, branch_in_session, action)
2013-
return Response(status_code=204)
1960+
assert action in _CONTROL_TO_POWER_STATE
1961+
task_id = dispatch_control(str(branch.id), action)
1962+
branch.control_task_id = task_id
1963+
await session.commit()
1964+
1965+
task_url = url_path_for(
1966+
request,
1967+
"organizations:projects:branch:tasks:detail",
1968+
organization_id=await organization.awaitable_attrs.id,
1969+
project_id=await project.awaitable_attrs.id,
1970+
branch_id=await branch.awaitable_attrs.id,
1971+
task_id=task_id,
1972+
)
1973+
return Response(status_code=202, headers={"Location": task_url})
20141974

20151975

20161976
instance_api.include_router(auth_api, prefix="/auth")
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Celery task for branch start/stop/pause/resume lifecycle management.
2+
3+
A single task patches the Kubernetes VM power state, polls until the VM
4+
reaches the desired phase, then updates the branch status in the database.
5+
"""
6+
7+
import asyncio
8+
import logging
9+
from uuid import UUID
10+
11+
from asgiref.sync import async_to_sync
12+
from ulid import ULID
13+
14+
from .....database import AsyncSessionLocal
15+
from .....deployment import get_autoscaler_vm_identity
16+
from .....deployment.kubernetes.neonvm import Phase, PowerState, get_neon_vm, set_virtualmachine_power_state
17+
from .....models.branch import Branch, BranchServiceStatus
18+
from .....worker import app
19+
20+
logger = logging.getLogger(__name__)
21+
22+
_CONTROL_TO_POWER_STATE: dict[str, PowerState] = {
23+
"pause": "Stopped",
24+
"resume": "Running",
25+
"start": "Running",
26+
"stop": "Stopped",
27+
}
28+
29+
_CONTROL_TRANSITION_INITIAL: dict[str, BranchServiceStatus] = {
30+
"pause": BranchServiceStatus.PAUSING,
31+
"resume": BranchServiceStatus.RESUMING,
32+
"start": BranchServiceStatus.STARTING,
33+
"stop": BranchServiceStatus.STOPPING,
34+
}
35+
36+
_CONTROL_TRANSITION_FINAL: dict[str, BranchServiceStatus | None] = {
37+
"pause": BranchServiceStatus.PAUSED,
38+
"resume": BranchServiceStatus.STARTING,
39+
"start": None,
40+
"stop": BranchServiceStatus.STOPPED,
41+
}
42+
43+
_DESIRED_PHASES: dict[str, set[Phase]] = {
44+
"start": {Phase.running},
45+
"resume": {Phase.running},
46+
"stop": {Phase.stopped, Phase.succeeded},
47+
"pause": {Phase.stopped, Phase.succeeded},
48+
}
49+
50+
_POLL_INTERVAL_SEC = 5
51+
_TIMEOUT_SEC = 600 # 10 minutes
52+
53+
54+
async def _async_perform_control(branch_id: str, action: str) -> dict:
55+
ulid = ULID.from_str(branch_id)
56+
namespace, name = get_autoscaler_vm_identity(ulid)
57+
58+
power_state = _CONTROL_TO_POWER_STATE[action]
59+
await set_virtualmachine_power_state(namespace, name, power_state)
60+
61+
desired_phases = _DESIRED_PHASES[action]
62+
elapsed = 0.0
63+
while True:
64+
vm = await get_neon_vm(namespace, name)
65+
if vm.status is not None and vm.status.phase in desired_phases:
66+
break
67+
if elapsed >= _TIMEOUT_SEC:
68+
raise TimeoutError(f"VM {name!r} did not reach {desired_phases} within {_TIMEOUT_SEC}s")
69+
await asyncio.sleep(_POLL_INTERVAL_SEC)
70+
elapsed += _POLL_INTERVAL_SEC
71+
72+
final_status = _CONTROL_TRANSITION_FINAL[action]
73+
async with AsyncSessionLocal() as session:
74+
branch = await session.get(Branch, ulid)
75+
if branch is None:
76+
logger.error("Branch %s not found after control action %s", branch_id, action)
77+
return {"action": action}
78+
if final_status is not None:
79+
branch.set_status(final_status)
80+
branch.control_task_id = None
81+
await session.commit()
82+
83+
return {"action": action}
84+
85+
86+
async def _async_perform_control_with_error_handling(branch_id: str, action: str) -> dict:
87+
try:
88+
return await _async_perform_control(branch_id, action)
89+
except Exception:
90+
ulid = ULID.from_str(branch_id)
91+
async with AsyncSessionLocal() as session:
92+
branch = await session.get(Branch, ulid)
93+
if branch is not None:
94+
branch.set_status(BranchServiceStatus.ERROR)
95+
branch.control_task_id = None
96+
await session.commit()
97+
raise
98+
99+
100+
@app.task(name="simplyblock.vela.branch.control")
101+
def perform_control(branch_id: str, action: str) -> dict:
102+
"""Patch VM power state, wait for desired phase, update branch status."""
103+
return async_to_sync(_async_perform_control_with_error_handling)(branch_id, action)
104+
105+
106+
def dispatch_control(branch_id: str, action: str) -> UUID:
107+
"""Dispatch perform_control asynchronously; return the Celery task UUID."""
108+
result = perform_control.apply_async(kwargs={"branch_id": branch_id, "action": action})
109+
return UUID(result.id)

src/api/organization/project/branch/tasks.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Branch task list/detail endpoints.
22
3-
Exposes Celery task state (currently resize only) under:
3+
Exposes Celery task state (resize and control) under:
44
GET .../branches/{branch_id}/tasks
55
GET .../branches/{branch_id}/tasks/{task_id}
66
"""
@@ -14,6 +14,7 @@
1414

1515
from ...._util import Forbidden, NotFound, Unauthenticated
1616
from ....dependencies import BranchDep, OrganizationDep, ProjectDep
17+
from .control_tasks import perform_control
1718
from .resize_tasks import finalize_resize
1819

1920
task_api = APIRouter(tags=["branch"])
@@ -37,7 +38,7 @@ class BranchTaskPublic(BaseModel):
3738
date_done: datetime | None
3839

3940

40-
def _build_task_public(task_id: UUID) -> BranchTaskPublic:
41+
def _build_resize_task_public(task_id: UUID) -> BranchTaskPublic:
4142
result = finalize_resize.AsyncResult(str(task_id))
4243
state = result.state
4344
status = _CELERY_STATE_TO_STATUS.get(state, state)
@@ -53,6 +54,23 @@ def _build_task_public(task_id: UUID) -> BranchTaskPublic:
5354
)
5455

5556

57+
def _build_control_task_public(task_id: UUID) -> BranchTaskPublic:
58+
result = perform_control.AsyncResult(str(task_id))
59+
state = result.state
60+
status = _CELERY_STATE_TO_STATUS.get(state, state)
61+
kwargs: dict = result.kwargs or {}
62+
action = kwargs.get("action", "control")
63+
return BranchTaskPublic(
64+
id=task_id,
65+
task_type=action,
66+
status=status,
67+
parameters={"action": action},
68+
result=result.result if state == "SUCCESS" else None,
69+
error=str(result.traceback) if state == "FAILURE" and result.traceback else None,
70+
date_done=result.date_done,
71+
)
72+
73+
5674
@task_api.get(
5775
"/",
5876
name="organizations:projects:branch:tasks:list",
@@ -64,9 +82,12 @@ async def list_tasks(
6482
_project: ProjectDep,
6583
branch: BranchDep,
6684
) -> list[BranchTaskPublic]:
67-
if branch.resize_task_id is None:
68-
return []
69-
return [_build_task_public(branch.resize_task_id)]
85+
tasks = []
86+
if branch.resize_task_id is not None:
87+
tasks.append(_build_resize_task_public(branch.resize_task_id))
88+
if branch.control_task_id is not None:
89+
tasks.append(_build_control_task_public(branch.control_task_id))
90+
return tasks
7091

7192

7293
@task_api.get(
@@ -81,6 +102,8 @@ async def get_task(
81102
branch: BranchDep,
82103
task_id: UUID,
83104
) -> BranchTaskPublic:
84-
if branch.resize_task_id != task_id:
85-
raise HTTPException(status_code=404, detail="Task not found")
86-
return _build_task_public(task_id)
105+
if branch.resize_task_id == task_id:
106+
return _build_resize_task_public(task_id)
107+
if branch.control_task_id == task_id:
108+
return _build_control_task_public(task_id)
109+
raise HTTPException(status_code=404, detail="Task not found")

src/models/branch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from enum import StrEnum
44
from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Literal, Optional
55

6-
import sqlalchemy as sa
76
from pydantic import BaseModel, model_validator
87
from pydantic import Field as PydanticField
98
from sqlalchemy import BigInteger, Boolean, Column, String, Text, UniqueConstraint, text
@@ -109,7 +108,8 @@ class Branch(AsyncAttrs, Model, table=True):
109108
)
110109
pitr_enabled: bool = Field(default=False, sa_column=Column(Boolean, nullable=False, server_default=text("false")))
111110
resize_task_id: uuid.UUID | None = Field(default=None, nullable=True)
112-
db_port: int | None = Field(default=None, sa_column=Column(sa.Integer, nullable=True))
111+
db_port: int | None = None
112+
control_task_id: uuid.UUID | None = None
113113

114114
__table_args__ = (UniqueConstraint("project_id", "name", name="unique_branch_name_per_project"),)
115115

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""Add control_task_id to branch table
2+
3+
Revision ID: b1c2d3e4f5a6
4+
Revises: 07822c477427
5+
Create Date: 2026-04-01 00:00:00.000000
6+
7+
"""
8+
from typing import Sequence, Union
9+
10+
import sqlalchemy as sa
11+
from alembic import op
12+
13+
14+
# revision identifiers, used by Alembic.
15+
revision: str = "b1c2d3e4f5a6"
16+
down_revision: Union[str, Sequence[str], None] = "07822c477427"
17+
branch_labels: Union[str, Sequence[str], None] = None
18+
depends_on: Union[str, Sequence[str], None] = None
19+
20+
21+
def upgrade() -> None:
22+
op.add_column("branch", sa.Column("control_task_id", sa.Uuid(), nullable=True))
23+
24+
25+
def downgrade() -> None:
26+
op.drop_column("branch", "control_task_id")

0 commit comments

Comments
 (0)