Skip to content

Commit 569cb0e

Browse files
authored
Change func name (#38)
* change start_trial to run_trial Signed-off-by: kerthcet <kerthcet@gmail.com> * Change func name Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent e8722e6 commit 569cb0e

5 files changed

Lines changed: 53 additions & 26 deletions

File tree

alphatrion/experiment/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
4747

4848
@classmethod
4949
@abstractmethod
50-
def run(
50+
def start(
5151
cls, name: str, description: str | None = None, meta: dict | None = None
5252
) -> "Experiment":
5353
"""Return a new experiment."""

alphatrion/experiment/craft_exp.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self):
1616
super().__init__()
1717

1818
@classmethod
19-
def run(
19+
def start(
2020
cls,
2121
name: str,
2222
id: uuid.UUID | None = None,
@@ -50,9 +50,10 @@ def start_trial(
5050
) -> Trial:
5151
"""
5252
start_trial starts a new trial in this experiment.
53-
You need to call trial.stop() to stop the trial for proper cleanup,
54-
unless it's a timeout trial. Or you can use 'async with exp.run_trial(...)'
55-
as trial, which will automatically stop the trial at the end of the context.
53+
You need to call trial.cancel() to stop the trial for proper cleanup,
54+
unless it's a timeout trial.
55+
Or you can use 'async with exp.start_trial(...) as trial', which will
56+
automatically stop the trial at the end of the context.
5657
5758
:params description: the description of the trial
5859
:params meta: the metadata of the trial

alphatrion/trial/trial.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ async def __aenter__(self):
103103
return self
104104

105105
async def __aexit__(self, exc_type, exc_val, exc_tb):
106-
self.stop()
106+
self.cancel()
107107

108108
def _construct_meta(self):
109109
self._meta = dict()
@@ -167,7 +167,7 @@ def _timeout(self) -> int | None:
167167
def stopped(self) -> bool:
168168
return self._context.cancelled()
169169

170-
async def wait_stopped(self):
170+
async def wait(self):
171171
await self._context.wait_cancelled()
172172

173173
def _start(
@@ -194,8 +194,8 @@ def _start(
194194
def id(self) -> uuid.UUID:
195195
return self._id
196196

197-
# stop function should be called manually as a pair of start
198-
def stop(self):
197+
# cancel function should be called manually as a pair of start
198+
def cancel(self):
199199
self._context.cancel()
200200

201201
def _stop(self):

tests/integration/test_log.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
async def test_log_artifact():
1414
alpha.init(project_id="test_project", artifact_insecure=True)
1515

16-
async with alpha.CraftExperiment.run(
16+
async with alpha.CraftExperiment.start(
1717
name="context_exp",
1818
description="Context manager test",
1919
meta={"key": "value"},
@@ -49,7 +49,7 @@ async def test_log_artifact():
4949
versions = exp._runtime._artifact.list_versions(exp_obj.uuid)
5050
assert len(versions) == 0
5151

52-
trial.stop()
52+
trial.cancel()
5353

5454
got_exp = exp._runtime._metadb.get_exp(exp_id=exp._id)
5555
assert got_exp is not None
@@ -65,7 +65,7 @@ async def test_log_artifact():
6565
async def test_log_params():
6666
alpha.init(project_id="test_project", artifact_insecure=True)
6767

68-
async with alpha.CraftExperiment.run(name="test_experiment") as exp:
68+
async with alpha.CraftExperiment.start(name="test_experiment") as exp:
6969
trial = exp.start_trial(description="First trial", params={"param1": 0.1})
7070

7171
new_trial = exp._runtime._metadb.get_trial(trial_id=trial.id)
@@ -81,18 +81,18 @@ async def test_log_params():
8181
assert new_trial.status == TrialStatus.RUNNING
8282
assert current_trial_id.get() == trial.id
8383

84-
trial.stop()
84+
trial.cancel()
8585

8686
trial = exp.start_trial(description="Second trial", params={"param1": 0.1})
8787
assert current_trial_id.get() == trial.id
88-
trial.stop()
88+
trial.cancel()
8989

9090

9191
@pytest.mark.asyncio
9292
async def test_log_metrics():
9393
alpha.init(project_id="test_project", artifact_insecure=True)
9494

95-
async with alpha.CraftExperiment.run(name="test_experiment") as exp:
95+
async with alpha.CraftExperiment.start(name="test_experiment") as exp:
9696
trial = exp.start_trial(description="First trial", params={"param1": 0.1})
9797

9898
new_trial = exp._runtime._metadb.get_trial(trial_id=trial._id)
@@ -121,14 +121,14 @@ async def test_log_metrics():
121121
assert metrics[2].value == 0.96
122122
assert metrics[2].step == 2
123123

124-
trial.stop()
124+
trial.cancel()
125125

126126

127127
@pytest.mark.asyncio
128128
async def test_log_metrics_with_save_on_max():
129129
alpha.init(project_id="test_project", artifact_insecure=True)
130130

131-
async with alpha.CraftExperiment.run(
131+
async with alpha.CraftExperiment.start(
132132
name="context_exp",
133133
description="Context manager test",
134134
meta={"key": "value"},
@@ -182,7 +182,7 @@ async def test_log_metrics_with_save_on_max():
182182
async def test_log_metrics_with_save_on_min():
183183
alpha.init(project_id="test_project", artifact_insecure=True)
184184

185-
async with alpha.CraftExperiment.run(
185+
async with alpha.CraftExperiment.start(
186186
name="context_exp",
187187
description="Context manager test",
188188
meta={"key": "value"},

tests/unit/experiment/test_craft_exp.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import asyncio
22
import random
3+
from datetime import datetime, timedelta
34

45
import pytest
56

67
from alphatrion.experiment.craft_exp import CraftExperiment
78
from alphatrion.metadata.sql_models import TrialStatus
89
from alphatrion.runtime.runtime import init
9-
from alphatrion.trial.trial import TrialConfig, current_trial_id
10+
from alphatrion.trial.trial import Trial, TrialConfig, current_trial_id
1011

1112

1213
@pytest.mark.asyncio
1314
async def test_craft_experiment():
1415
init(project_id="test_project", artifact_insecure=True)
1516

16-
async with CraftExperiment.run(
17+
async with CraftExperiment.start(
1718
name="context_exp",
1819
description="Context manager test",
1920
meta={"key": "value"},
@@ -28,7 +29,7 @@ async def test_craft_experiment():
2829
assert trial_obj is not None
2930
assert trial_obj.description == "First trial"
3031

31-
trial.stop()
32+
trial.cancel()
3233

3334
trial2 = trial._get_obj()
3435
assert trial2.status == TrialStatus.FINISHED
@@ -39,7 +40,7 @@ async def test_create_experiment_with_trial():
3940
init(project_id="test_project", artifact_insecure=True)
4041

4142
trial_id = None
42-
async with CraftExperiment.run(name="context_exp") as exp:
43+
async with CraftExperiment.start(name="context_exp") as exp:
4344
async with exp.start_trial(description="First trial") as trial:
4445
trial_obj = trial._get_obj()
4546
assert trial_obj is not None
@@ -50,19 +51,44 @@ async def test_create_experiment_with_trial():
5051
assert trial_obj.status == TrialStatus.FINISHED
5152

5253

54+
@pytest.mark.asyncio
55+
async def test_create_experiment_with_trial_wait():
56+
init(project_id="test_project", artifact_insecure=True)
57+
58+
async def fake_work(trial: Trial):
59+
await asyncio.sleep(3)
60+
trial.cancel()
61+
62+
trial_id = None
63+
async with CraftExperiment.start(name="context_exp") as exp:
64+
async with exp.start_trial(description="First trial") as trial:
65+
trial_id = current_trial_id.get()
66+
67+
start_time = datetime.now()
68+
69+
asyncio.create_task(fake_work(trial))
70+
assert datetime.now() - start_time <= timedelta(seconds=1)
71+
72+
await trial.wait()
73+
assert datetime.now() - start_time >= timedelta(seconds=3)
74+
75+
trial_obj = exp._runtime._metadb.get_trial(trial_id=trial_id)
76+
assert trial_obj.status == TrialStatus.FINISHED
77+
78+
5379
@pytest.mark.asyncio
5480
async def test_craft_experiment_with_context():
5581
init(project_id="test_project", artifact_insecure=True)
5682

57-
async with CraftExperiment.run(
83+
async with CraftExperiment.start(
5884
name="context_exp",
5985
description="Context manager test",
6086
meta={"key": "value"},
6187
) as exp:
6288
trial = exp.start_trial(
6389
description="First trial", config=TrialConfig(max_duration_seconds=2)
6490
)
65-
await trial.wait_stopped()
91+
await trial.wait()
6692
assert trial.stopped()
6793

6894
trial = trial._get_obj()
@@ -81,15 +107,15 @@ async def fake_work(exp: CraftExperiment):
81107
# double check current trial id.
82108
assert trial.id == current_trial_id.get()
83109

84-
await trial.wait_stopped()
110+
await trial.wait()
85111
assert trial.stopped()
86112
# we don't reset the current trial id.
87113
assert trial.id == current_trial_id.get()
88114

89115
trial = trial._get_obj()
90116
assert trial.status == TrialStatus.FINISHED
91117

92-
async with CraftExperiment.run(
118+
async with CraftExperiment.start(
93119
name="context_exp",
94120
description="Context manager test",
95121
meta={"key": "value"},

0 commit comments

Comments
 (0)