Skip to content

Commit 4535011

Browse files
authored
Add max_run_number (#47)
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 9814d7c commit 4535011

5 files changed

Lines changed: 89 additions & 29 deletions

File tree

alphatrion/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
from alphatrion.experiment.craft_exp import CraftExperiment
22
from alphatrion.log.log import log_artifact, log_metrics, log_params
33
from alphatrion.runtime.runtime import init
4+
from alphatrion.trial.trial import CheckpointConfig, TrialConfig
45

5-
__all__ = ["log_artifact", "log_params", "log_metrics", "CraftExperiment", "init"]
6+
__all__ = [
7+
"log_artifact",
8+
"log_params",
9+
"log_metrics",
10+
"CraftExperiment",
11+
"init",
12+
"TrialConfig",
13+
"CheckpointConfig",
14+
]

alphatrion/run/run.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import uuid
23

34
from alphatrion.runtime.runtime import global_runtime
@@ -16,3 +17,9 @@ def _start(self):
1617
self._id = self._runtime._metadb.create_run(
1718
project_id=self._runtime._project_id, trial_id=self._trial_id
1819
)
20+
21+
def register_task(self, task: asyncio.Task):
22+
self._task = task
23+
24+
async def wait(self):
25+
await self._task

alphatrion/trial/trial.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class TrialConfig(BaseModel):
5959
after which experiment will be stopped. Default is -1 (no early stopping). \
6060
Count each time when calling log_metrics with the monitored metric.",
6161
)
62+
max_run_number: int = Field(
63+
default=-1,
64+
description="Maximum number of runs for the trial. \
65+
Default is -1 (no limit). Count by the finished runs.",
66+
)
6267
monitor_metric: str | None = Field(
6368
default=None,
6469
description="The metric to monitor for saving the best checkpoint. \
@@ -110,7 +115,10 @@ class Trial:
110115
# key is run_id, value is Run instance
111116
"_runs",
112117
"_running_tasks",
118+
# Only work when early_stopping_runs > 0
113119
"_early_stopping_counter",
120+
# Only work when max_run_number > 0
121+
"_total_runs_counter",
114122
)
115123

116124
def __init__(self, exp_id: int, config: TrialConfig | None = None):
@@ -126,6 +134,7 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
126134
self._runs = dict()
127135
self._running_tasks = dict()
128136
self._early_stopping_counter = 0
137+
self._total_runs_counter = 0
129138

130139
async def __aenter__(self):
131140
return self
@@ -223,9 +232,10 @@ def _timeout(self) -> int | None:
223232
timeout -= int(elapsed)
224233
return timeout
225234

226-
def stopped(self) -> bool:
227-
return self._context.cancelled()
228-
235+
# Make sure you have termination condition, either by timeout or by calling cancel()
236+
# Before we have logic like once all the tasks are done, we'll call the cancel()
237+
# automatically, however, this is unpredictable because some tasks may be waiting
238+
# for external events, so we leave it to the user to decide when to stop the trial.
229239
async def wait(self):
230240
await self._context.wait()
231241

@@ -287,18 +297,22 @@ def start_run(self, call_func: callable) -> Run:
287297
run._start()
288298
self._runs[run.id] = run
289299

290-
# the created task will also inherit the current context,
300+
# The created task will also inherit the current context,
291301
# including the current_trial_id context var.
292302
task = asyncio.create_task(call_func())
293303
self._running_tasks[run.id] = task
304+
run.register_task(task)
305+
294306
task.add_done_callback(lambda t: self._running_tasks.pop(run.id, None))
295307
task.add_done_callback(lambda t: self._runs.pop(run.id, None))
296-
# FIXME: One potential issue here is once the former task finished
297-
# very fast, it could lead to cancelling the trial even if there are
298-
# other pending tasks ready to run. We may need a more robust way to
299-
# handle this.
300-
task.add_done_callback(
301-
lambda t: self.cancel() if len(self._running_tasks) == 0 else None
302-
)
308+
if self._config.max_run_number > 0:
309+
task.add_done_callback(
310+
lambda t: (
311+
setattr(self, "_total_runs_counter", self._total_runs_counter + 1),
312+
self.cancel()
313+
if self._total_runs_counter >= self._config.max_run_number
314+
else None,
315+
)
316+
)
303317

304318
return run

tests/integration/test_log.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import alphatrion as alpha
1111
from alphatrion.metadata.sql_models import TrialStatus
12-
from alphatrion.trial.trial import CheckpointConfig, TrialConfig, current_trial_id
12+
from alphatrion.trial.trial import current_trial_id
1313

1414

1515
@pytest.mark.asyncio
@@ -141,8 +141,8 @@ async def test_log_metrics_with_save_on_max():
141141

142142
_ = exp.start_trial(
143143
name="trial-with-save_on_best",
144-
config=TrialConfig(
145-
checkpoint=CheckpointConfig(
144+
config=alpha.TrialConfig(
145+
checkpoint=alpha.CheckpointConfig(
146146
enabled=True,
147147
path=tmpdir,
148148
save_on_best=True,
@@ -195,8 +195,8 @@ async def test_log_metrics_with_save_on_min():
195195

196196
_ = exp.start_trial(
197197
name="trial-with-save_on_best",
198-
config=TrialConfig(
199-
checkpoint=CheckpointConfig(
198+
config=alpha.TrialConfig(
199+
checkpoint=alpha.CheckpointConfig(
200200
enabled=True,
201201
path=tmpdir,
202202
save_on_best=True,
@@ -251,7 +251,7 @@ async def fake_sleep(value: float):
251251
) as exp:
252252
async with exp.start_trial(
253253
name="trial-with-early-stopping",
254-
config=TrialConfig(
254+
config=alpha.TrialConfig(
255255
monitor_metric="accuracy",
256256
early_stopping_runs=2,
257257
),
@@ -284,20 +284,48 @@ async def fake_sleep(value: float):
284284
await alpha.log_metrics({"accuracy": value})
285285

286286
async with alpha.CraftExperiment.start(
287-
name="log_metrics_with_early_stopping_never_triggered"
287+
name="log_metrics_with_both_early_stopping_and_timeout"
288288
) as exp:
289289
async with exp.start_trial(
290290
name="trial-with-early-stopping",
291-
config=TrialConfig(
291+
config=alpha.TrialConfig(
292292
monitor_metric="accuracy",
293293
early_stopping_runs=3,
294+
max_duration_seconds=3,
294295
),
295296
) as trial:
296297
start_time = datetime.now()
297298
trial.start_run(lambda: fake_work(1))
298299
trial.start_run(lambda: fake_work(2))
299-
trial.start_run(lambda: fake_sleep(3))
300+
trial.start_run(lambda: fake_sleep(2))
301+
# running in parallel.
300302
await trial.wait()
301303

302304
assert len(trial._runtime._metadb.list_metrics(trial_id=trial.id)) == 3
303305
assert datetime.now() - start_time >= timedelta(seconds=3)
306+
307+
308+
@pytest.mark.asyncio
309+
async def test_log_metrics_with_max_run_number():
310+
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
311+
312+
async def fake_work(value: float):
313+
await alpha.log_metrics({"accuracy": value})
314+
print("fake finished.")
315+
316+
async with alpha.CraftExperiment.start(
317+
name="log_metrics_with_max_run_number"
318+
) as exp:
319+
async with exp.start_trial(
320+
name="trial-with-max-run-number",
321+
config=alpha.TrialConfig(
322+
monitor_metric="accuracy",
323+
max_run_number=5,
324+
),
325+
) as trial:
326+
while not trial.cancelled():
327+
run = trial.start_run(lambda: fake_work(1))
328+
# running in serial.
329+
await run.wait()
330+
331+
assert len(trial._runtime._metadb.list_metrics(trial_id=trial.id)) == 5

tests/unit/experiment/test_craft_exp.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from alphatrion.experiment.craft_exp import CraftExperiment
99
from alphatrion.metadata.sql_models import TrialStatus
10-
from alphatrion.runtime.runtime import init
10+
from alphatrion.runtime.runtime import global_runtime, init
1111
from alphatrion.trial.trial import Trial, TrialConfig, current_trial_id
1212

1313

@@ -117,7 +117,7 @@ async def test_craft_experiment_with_context():
117117
name="first-trial", config=TrialConfig(max_duration_seconds=2)
118118
)
119119
await trial.wait()
120-
assert trial.stopped()
120+
assert trial.cancelled()
121121

122122
trial = trial._get_obj()
123123
assert trial.status == TrialStatus.FINISHED
@@ -127,7 +127,9 @@ async def test_craft_experiment_with_context():
127127
async def test_craft_experiment_with_multi_trials_in_parallel():
128128
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
129129

130-
async def fake_work(exp: CraftExperiment):
130+
async def fake_work():
131+
exp = global_runtime().current_exp
132+
131133
duration = random.randint(1, 5)
132134
trial = exp.start_trial(
133135
name="first-trial", config=TrialConfig(max_duration_seconds=duration)
@@ -136,7 +138,7 @@ async def fake_work(exp: CraftExperiment):
136138
assert trial.id == current_trial_id.get()
137139

138140
await trial.wait()
139-
assert trial.stopped()
141+
assert trial.cancelled()
140142
# we don't reset the current trial id.
141143
assert trial.id == current_trial_id.get()
142144

@@ -147,10 +149,10 @@ async def fake_work(exp: CraftExperiment):
147149
name="context_exp",
148150
description="Context manager test",
149151
meta={"key": "value"},
150-
) as exp:
152+
):
151153
await asyncio.gather(
152-
fake_work(exp),
153-
fake_work(exp),
154-
fake_work(exp),
154+
fake_work(),
155+
fake_work(),
156+
fake_work(),
155157
)
156158
print("All trials finished.")

0 commit comments

Comments
 (0)