Skip to content

Commit 85abd50

Browse files
authored
Support timeout trial (#32)
* Support timeout trial Signed-off-by: kerthcet <kerthcet@gmail.com> * lock poetry Signed-off-by: kerthcet <kerthcet@gmail.com> * update tests Signed-off-by: kerthcet <kerthcet@gmail.com> * fix test Signed-off-by: kerthcet <kerthcet@gmail.com> * poetry lock Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent fd3d96c commit 85abd50

13 files changed

Lines changed: 310 additions & 65 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ format:
2323

2424
.PHONY: test
2525
test: lint
26-
$(POETRY) run pytest tests/unit
26+
$(POETRY) run pytest tests/unit --timeout=15
2727

2828
.PHONY: test-integration
2929
test-integration: lint

alphatrion/experiment/base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import weakref
1+
import uuid
22
from abc import ABC, abstractmethod
33
from dataclasses import dataclass
44

@@ -29,7 +29,7 @@ def get_trial(self, id: int) -> trial.Trial | None:
2929
def _reset(self):
3030
self._trials = dict()
3131

32-
def __enter__(self):
32+
async def __aenter__(self):
3333
if self._id is None:
3434
raise RuntimeError("Experiment is not set. Did you call run()?")
3535

@@ -38,10 +38,10 @@ def __enter__(self):
3838
raise RuntimeError(f"Experiment {self._id} not found in the database.")
3939

4040
# Use weakref to avoid circular reference
41-
self._runtime.current_exp = weakref.ref(self)
41+
self._runtime.current_exp = self
4242
return self
4343

44-
def __exit__(self, exc_type, exc_val, exc_tb):
44+
async def __aexit__(self, exc_type, exc_val, exc_tb):
4545
self._reset()
4646
self._runtime.current_exp = None
4747

@@ -53,9 +53,12 @@ def run(
5353
"""Return a new experiment."""
5454
...
5555

56-
def _register_trial(self, id: int, instance: trial.Trial):
56+
def register_trial(self, id: uuid.UUID, instance: trial.Trial):
5757
self._trials[id] = instance
5858

59+
def unregister_trial(self, id: uuid.UUID):
60+
self._trials.pop(id, None)
61+
5962
def _create(
6063
self,
6164
name: str,

alphatrion/experiment/craft_exp.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import uuid
2+
13
from alphatrion.experiment.base import Experiment
24
from alphatrion.trial.trial import Trial, TrialConfig
35

@@ -14,21 +16,32 @@ def __init__(self):
1416
super().__init__()
1517

1618
@classmethod
17-
def run(cls, name: str, description: str | None = None, meta: dict | None = None):
19+
def run(
20+
cls,
21+
name: str,
22+
id: uuid.UUID | None = None,
23+
description: str | None = None,
24+
meta: dict | None = None,
25+
) -> "CraftExperiment":
1826
"""
1927
Begin the experiment. This method must be used to start multi-trial experiment.
28+
If id is provided, the experiment with the given id will be used.
2029
"""
2130

2231
exp = CraftExperiment()
23-
exp._create(
24-
name=name,
25-
description=description,
26-
meta=meta,
27-
)
32+
33+
if id is not None:
34+
exp._id = id
35+
else:
36+
exp._create(
37+
name=name,
38+
description=description,
39+
meta=meta,
40+
)
2841

2942
return exp
3043

31-
def start_trial(
44+
async def start_trial(
3245
self,
3346
description: str | None = None,
3447
meta: dict | None = None,
@@ -46,11 +59,6 @@ def start_trial(
4659
"""
4760

4861
trial = Trial(exp_id=self._id, config=config)
49-
trial._start(description=description, meta=meta, params=params)
50-
self._register_trial(id=trial._id, instance=trial)
62+
await trial._start(description=description, meta=meta, params=params)
63+
self.register_trial(id=trial.id, instance=trial)
5164
return trial
52-
53-
# @classmethod
54-
# # TODO: support async
55-
# async def async_trial(cls):
56-
# pass

alphatrion/metadata/sql.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import uuid
2+
13
from sqlalchemy import create_engine
24
from sqlalchemy.orm import sessionmaker
35

@@ -166,7 +168,7 @@ def create_trial(
166168
meta: dict | None,
167169
params: dict | None = None,
168170
status: TrialStatus = TrialStatus.PENDING,
169-
) -> int:
171+
) -> uuid.UUID:
170172
session = self._session()
171173
new_trial = Trial(
172174
experiment_id=exp_id,
@@ -183,13 +185,13 @@ def create_trial(
183185

184186
return trial_id
185187

186-
def get_trial(self, trial_id: int) -> Trial | None:
188+
def get_trial(self, trial_id: uuid.UUID) -> Trial | None:
187189
session = self._session()
188190
trial = session.query(Trial).filter(Trial.uuid == trial_id).first()
189191
session.close()
190192
return trial
191193

192-
def update_trial(self, trial_id: int, **kwargs):
194+
def update_trial(self, trial_id: uuid.UUID, **kwargs):
193195
session = self._session()
194196
trial = session.query(Trial).filter(Trial.uuid == trial_id).first()
195197
if trial:
@@ -198,7 +200,7 @@ def update_trial(self, trial_id: int, **kwargs):
198200
session.commit()
199201
session.close()
200202

201-
def create_metric(self, trial_id: int, key: str, value: float, step: int):
203+
def create_metric(self, trial_id: uuid.UUID, key: str, value: float, step: int):
202204
session = self._session()
203205
new_metric = Metrics(
204206
trial_id=trial_id,
@@ -210,7 +212,7 @@ def create_metric(self, trial_id: int, key: str, value: float, step: int):
210212
session.commit()
211213
session.close()
212214

213-
def list_metrics(self, trial_id: int) -> list[Metrics]:
215+
def list_metrics(self, trial_id: uuid.UUID) -> list[Metrics]:
214216
session = self._session()
215217
metrics = session.query(Metrics).filter(Metrics.trial_id == trial_id).all()
216218
session.close()

alphatrion/metadata/sql_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class Metrics(Base):
8585
key = Column(String, nullable=False)
8686
value = Column(Float, nullable=False)
8787
trial_id = Column(UUID(as_uuid=True), nullable=False)
88+
# TODO: do we need?
8889
step = Column(Integer, nullable=False, default=0)
8990
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
9091

alphatrion/runtime/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, project_id: str, artifact_insecure: bool = False):
4141
# current_exp is the current running experiment.
4242
@property
4343
def current_exp(self):
44-
return self.__current_exp()
44+
return self.__current_exp
4545

4646
@current_exp.setter
4747
def current_exp(self, value):

alphatrion/trial/trial.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import contextvars
2+
import uuid
23
from datetime import UTC, datetime
34

45
from pydantic import BaseModel, Field, field_validator
56

67
from alphatrion.metadata.sql_models import COMPLETED_STATUS, TrialStatus
78
from alphatrion.runtime.runtime import global_runtime
9+
from alphatrion.utils.context import Context
810

911
# Used in record/record.py to log params/metrics
1012
current_trial_id = contextvars.ContextVar("current_trial_id", default=None)
@@ -57,15 +59,15 @@ class TrialConfig(BaseModel):
5759
"""Configuration for an experiment."""
5860

5961
max_duration_seconds: int = Field(
60-
default=86400,
62+
default=-1,
6163
description="Maximum duration in seconds for the experiment. \
62-
Default is 86400 seconds (1 day).",
63-
)
64-
max_retries: int = Field(
65-
default=0,
66-
description="Maximum number of retries for the experiment. \
67-
Default is 0 (no retries).",
64+
Default is -1 (no limit).",
6865
)
66+
# max_retries: int = Field(
67+
# default=0,
68+
# description="Maximum number of retries for the experiment. \
69+
# Default is 0 (no retries).",
70+
# )
6971
checkpoint: CheckpointConfig = Field(
7072
default=CheckpointConfig(),
7173
description="Configuration for checkpointing.",
@@ -78,8 +80,9 @@ class Trial:
7880
"_exp_id",
7981
"_config",
8082
"_runtime",
81-
"_token",
8283
"_step",
84+
"_context",
85+
"_token",
8386
)
8487

8588
def __init__(self, exp_id: int, config: TrialConfig | None = None):
@@ -88,13 +91,25 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
8891
self._runtime = global_runtime()
8992
# step is used to track the round, e.g. the step in metric logging.
9093
self._step = 0
94+
self._context = Context(
95+
cancel_func=self._stop,
96+
timeout=self._config.max_duration_seconds
97+
if self._config.max_duration_seconds > 0
98+
else None,
99+
)
100+
101+
def stopped(self) -> bool:
102+
return self._context.cancelled()
91103

92-
def _start(
104+
async def wait_stopped(self):
105+
await self._context.wait_cancelled()
106+
107+
async def _start(
93108
self,
94109
description: str | None = None,
95110
meta: dict | None = None,
96111
params: dict | None = None,
97-
) -> int:
112+
) -> uuid.UUID:
98113
self._id = self._runtime._metadb.create_trial(
99114
exp_id=self._exp_id,
100115
description=description,
@@ -103,26 +118,31 @@ def _start(
103118
status=TrialStatus.RUNNING,
104119
)
105120

121+
# We don't reset the trial id context var here, because
122+
# each trial runs in its own context.
106123
self._token = current_trial_id.set(self._id)
124+
await self._context.start()
107125
return self._id
108126

109127
@property
110-
def id(self):
128+
def id(self) -> uuid.UUID:
111129
return self._id
112130

113-
# finish function should be called manually as a pair of start
114-
def finish(self, status: TrialStatus = TrialStatus.FINISHED):
131+
# stop function should be called manually as a pair of start
132+
def stop(self):
133+
self._context.cancel()
134+
135+
def _stop(self):
115136
trial = self._runtime._metadb.get_trial(trial_id=self._id)
116137
if trial is not None and trial.status not in COMPLETED_STATUS:
117138
duration = (
118139
datetime.now(UTC) - trial.created_at.replace(tzinfo=UTC)
119140
).total_seconds()
120141
self._runtime._metadb.update_trial(
121-
trial_id=self._id, status=status, duration=duration
142+
trial_id=self._id, status=TrialStatus.FINISHED, duration=duration
122143
)
123144

124-
# recover the context var
125-
current_trial_id.reset(self._token)
145+
self._runtime.current_exp.unregister_trial(self._id)
126146

127147
def _get(self):
128148
return self._runtime._metadb.get_trial(trial_id=self._id)

alphatrion/utils/context.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import asyncio
2+
from collections.abc import Callable
3+
4+
5+
# Inspired by golang context package
6+
class Context:
7+
def __init__(self, cancel_func: Callable | None = None, timeout=None):
8+
"""A context for managing cancellation and timeouts.
9+
:param cancel_func: A function to call when the context is cancelled.
10+
:param timeout: Timeout in seconds. If None, no timeout is set.
11+
"""
12+
self._cancel_event = asyncio.Event()
13+
self._cancel_func = cancel_func
14+
self._timeout = timeout
15+
16+
async def start(self):
17+
if self._timeout:
18+
asyncio.create_task(self._auto_cancel(self._timeout))
19+
20+
async def _auto_cancel(self, timeout):
21+
await asyncio.sleep(timeout)
22+
self.cancel()
23+
24+
def cancel(self):
25+
if self.cancelled():
26+
return
27+
if self._cancel_func:
28+
self._cancel_func()
29+
self._cancel_event.set()
30+
31+
def cancelled(self):
32+
return self._cancel_event.is_set()
33+
34+
async def wait_cancelled(self):
35+
await self._cancel_event.wait()

poetry.lock

Lines changed: 35 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ dependencies = [
1717
[tool.poetry.group.dev.dependencies]
1818
pytest = ">=8.4.2,<9.0.0"
1919
ruff = "^0.12.12"
20+
pytest-asyncio = ">=0.22.0,<1.0.0"
21+
pytest-timeout = ">=2.1.0,<3.0.0"
2022

2123
[build-system]
2224
requires = ["poetry-core>=2.0.0,<3.0.0"]

0 commit comments

Comments
 (0)