Skip to content

Commit 64f8817

Browse files
authored
support loading start time from env (#33)
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 85abd50 commit 64f8817

3 files changed

Lines changed: 72 additions & 4 deletions

File tree

alphatrion/trial/trial.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextvars
2+
import os
23
import uuid
34
from datetime import UTC, datetime
45

@@ -93,11 +94,26 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
9394
self._step = 0
9495
self._context = Context(
9596
cancel_func=self._stop,
96-
timeout=self._config.max_duration_seconds
97-
if self._config.max_duration_seconds > 0
98-
else None,
97+
timeout=self._timeout(),
9998
)
10099

100+
def _timeout(self) -> int | None:
101+
timeout = self._config.max_duration_seconds
102+
if timeout < 0:
103+
return None
104+
105+
# Adjust timeout based on the trial start time from environment variable,
106+
# this is useful when running in cloud env when the trial process may be
107+
# restarted.
108+
start_time = os.environ.get("ALPHATRION_TRIAL_START_TIME", None)
109+
if start_time is not None:
110+
elapsed = (
111+
datetime.now(UTC)
112+
- datetime.fromisoformat(start_time).replace(tzinfo=UTC)
113+
).total_seconds()
114+
timeout -= int(elapsed)
115+
return timeout
116+
101117
def stopped(self) -> bool:
102118
return self._context.cancelled()
103119

alphatrion/utils/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ def __init__(self, cancel_func: Callable | None = None, timeout=None):
1414
self._timeout = timeout
1515

1616
async def start(self):
17-
if self._timeout:
17+
# If timeout is None, it means no timeout is set.
18+
# If timeout is negative, it means already timed out.
19+
if self._timeout is not None:
1820
asyncio.create_task(self._auto_cancel(self._timeout))
1921

2022
async def _auto_cancel(self, timeout):

tests/unit/trial/test_trial.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
import unittest
3+
from datetime import UTC, datetime, timedelta
4+
5+
from alphatrion.trial.trial import Trial, TrialConfig
6+
7+
8+
class TestTrial(unittest.IsolatedAsyncioTestCase):
9+
def test_timeout(self):
10+
test_cases = [
11+
{
12+
"name": "No timeout",
13+
"config": TrialConfig(),
14+
"started_at": None,
15+
"expected": None,
16+
},
17+
{
18+
"name": "Positive timeout",
19+
"config": TrialConfig(max_duration_seconds=10),
20+
"started_at": None,
21+
"expected": 10,
22+
},
23+
{
24+
"name": "Zero timeout",
25+
"config": TrialConfig(max_duration_seconds=0),
26+
"started_at": None,
27+
"expected": 0,
28+
},
29+
{
30+
"name": "Negative timeout",
31+
"config": TrialConfig(max_duration_seconds=-5),
32+
"started_at": None,
33+
"expected": None,
34+
},
35+
{
36+
"name": "With started_at, positive timeout",
37+
"config": TrialConfig(max_duration_seconds=5),
38+
"started_at": (
39+
datetime.now(UTC) - timedelta(seconds=3)
40+
).isoformat(),
41+
"expected": 2,
42+
},
43+
]
44+
45+
for case in test_cases:
46+
if case["started_at"]:
47+
os.environ["ALPHATRION_TRIAL_START_TIME"] = case["started_at"]
48+
with self.subTest(name=case["name"]):
49+
trial = Trial(exp_id=1, config=case["config"])
50+
self.assertEqual(trial._timeout(), case["expected"])

0 commit comments

Comments
 (0)