Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions alphatrion/run/run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import uuid
from datetime import UTC, datetime

from alphatrion.runtime.contextvars import current_run_id
from alphatrion.runtime.runtime import global_runtime
Expand Down Expand Up @@ -50,19 +51,28 @@ def done(self):
if self.cancelled():
return

run = self._runtime._metadb.get_run(run_id=self.id)
duration = (
datetime.now(UTC) - run.created_at.replace(tzinfo=UTC)
).total_seconds()

self._runtime.metadb.update_run(
run_id=self._id,
status=Status.COMPLETED,
run_id=self._id, status=Status.COMPLETED, duration=duration
)
self._result = self._task.result()

def cancel(self):
# TODO: we should wait for the task to be actually cancelled
# and catch the CancelledError exception in the task function.
self._task.cancel()

run = self._runtime._metadb.get_run(run_id=self.id)
duration = (
datetime.now(UTC) - run.created_at.replace(tzinfo=UTC)
).total_seconds()

self._runtime.metadb.update_run(
run_id=self._id,
status=Status.CANCELLED,
run_id=self._id, status=Status.CANCELLED, duration=duration
)

def cancelled(self) -> bool:
Expand Down
1 change: 1 addition & 0 deletions alphatrion/storage/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ class Run(Base):
nullable=True,
comment="Additional metadata for the run",
)
duration = Column(Float, default=0.0, comment="Duration of the run in seconds")
status = Column(
Integer,
default=Status.PENDING,
Expand Down
32 changes: 32 additions & 0 deletions migrations/versions/0f417c7cf4d3_add_duration_for_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""add duration for run

Revision ID: 0f417c7cf4d3
Revises: 766c8b7fe6c5
Create Date: 2026-03-02 20:20:56.486059

"""
from typing import Sequence, Union

from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision: str = '0f417c7cf4d3'
down_revision: Union[str, Sequence[str], None] = '766c8b7fe6c5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('runs', sa.Column('duration', sa.Float(), nullable=True, comment='Duration of the run in seconds'))
# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('runs', 'duration')
# ### end Alembic commands ###
16 changes: 14 additions & 2 deletions tests/unit/experiment/test_experimant.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,24 @@ async def fake_work(exp_id: uuid.UUID):
async with CraftExperiment.start(name="first-experiment") as exp:
start_time = datetime.now()

exp.run(lambda: fake_work(exp.id))
run1 = exp.run(lambda: fake_work(exp.id))
assert len(exp._runs) == 1

exp.run(lambda: fake_work(exp.id))
run2 = exp.run(lambda: fake_work(exp.id))
assert len(exp._runs) == 2

await exp.wait()
assert datetime.now() - start_time >= timedelta(seconds=3)
assert len(exp._runs) == 0

run1_obj = run1._get_obj()
assert run1_obj.status == Status.COMPLETED
assert run1_obj.duration >= 3.0

run2_obj = run2._get_obj()
assert run2_obj.status == Status.COMPLETED
assert run2_obj.duration >= 3.0


@pytest.mark.asyncio
async def test_create_experiment_with_run_cancelled():
Expand All @@ -256,12 +264,16 @@ async def fake_work(timeout: int):

run_0_obj = run_0._get_obj()
assert run_0_obj.status == Status.COMPLETED
assert run_0_obj.duration >= 1.0
run_1_obj = run_1._get_obj()
assert run_1_obj.status == Status.CANCELLED
assert run_1_obj.duration >= 2.0
run_2_obj = run_2._get_obj()
assert run_2_obj.status == Status.CANCELLED
assert run_2_obj.duration >= 2.0
run_3_obj = run_3._get_obj()
assert run_3_obj.status == Status.CANCELLED
assert run_3_obj.duration >= 2.0


@pytest.mark.asyncio
Expand Down
Loading