Skip to content

Commit 0c5b8ef

Browse files
authored
Expose ExperimentConfig (#71)
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent b32d9fa commit 0c5b8ef

8 files changed

Lines changed: 85 additions & 19 deletions

File tree

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
POETRY := poetry
22

33
.PHONY: build
4-
build:
4+
build: lint
55
$(POETRY) build
66

77
.PHONY: publish
88
publish: build
9-
$(POETRY) publish
9+
$(POETRY) publish --username=__token__ --password=$(INFTYAI_PYPI_TOKEN)
1010

1111
.PHONY: up
1212
up:

alphatrion/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from alphatrion.experiment.craft_exp import CraftExperiment
1+
from alphatrion.experiment.craft_exp import CraftExperiment, ExperimentConfig
22
from alphatrion.log.log import log_artifact, log_metrics, log_params
33
from alphatrion.runtime.runtime import init
44
from alphatrion.tracing.tracing import task, workflow
@@ -9,6 +9,7 @@
99
"log_params",
1010
"log_metrics",
1111
"CraftExperiment",
12+
"ExperimentConfig",
1213
"init",
1314
"TrialConfig",
1415
"CheckpointConfig",

alphatrion/experiment/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,25 @@
22
from abc import ABC, abstractmethod
33
from dataclasses import dataclass
44

5+
from pydantic import BaseModel, Field
6+
57
from alphatrion.runtime.runtime import global_runtime
68
from alphatrion.trial import trial
79

810

11+
class ExperimentConfig(BaseModel):
12+
"""
13+
Configuration for Experiment.
14+
"""
15+
16+
max_runtime_seconds: int = Field(
17+
default=-1,
18+
description="Maximum runtime seconds for the experiment. \
19+
It will overwrite the trial timeout if both are set. \
20+
Default is -1 (no limit).",
21+
)
22+
23+
924
@dataclass
1025
class Experiment(ABC):
1126
"""
@@ -14,7 +29,8 @@ class Experiment(ABC):
1429

1530
__slots__ = ("_runtime", "_id", "_trials")
1631

17-
def __init__(self):
32+
def __init__(self, config: ExperimentConfig | None = None):
33+
self._config = config or ExperimentConfig()
1834
self._runtime = global_runtime()
1935
# All trials in this experiment, key is trial_id, value is Trial instance.
2036
self._trials = dict()

alphatrion/experiment/craft_exp.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from alphatrion.experiment.base import Experiment
1+
from alphatrion.experiment.base import Experiment, ExperimentConfig
22
from alphatrion.trial.trial import Trial, TrialConfig
33

44

@@ -10,22 +10,23 @@ class CraftExperiment(Experiment):
1010
Opposite to other experiment classes, you need to call all these methods yourself.
1111
"""
1212

13-
def __init__(self):
14-
super().__init__()
13+
def __init__(self, config: ExperimentConfig | None = None):
14+
super().__init__(config=config)
1515

1616
@classmethod
1717
def setup(
1818
cls,
1919
name: str,
2020
description: str | None = None,
2121
meta: dict | None = None,
22+
config: ExperimentConfig | None = None,
2223
) -> "CraftExperiment":
2324
"""
2425
Setup the experiment. If the name already exists in the same project,
2526
it will refer to the existing experiment instead of creating a new one.
2627
"""
2728

28-
exp = CraftExperiment()
29+
exp = CraftExperiment(config=config)
2930
exp_obj = exp._get_by_name(name=name, project_id=exp._runtime._project_id)
3031

3132
# If experiment with the same name exists in the project, use it.
@@ -62,6 +63,15 @@ def start_trial(
6263
:return: the Trial instance
6364
"""
6465

66+
config = config or TrialConfig()
67+
68+
if (
69+
self._config is not None
70+
and self._config.max_runtime_seconds > 0
71+
and config.max_runtime_seconds < 0
72+
):
73+
config.max_runtime_seconds = self._config.max_runtime_seconds
74+
6575
trial = Trial(exp_id=self._id, config=config)
6676
trial._start(name=name, description=description, meta=meta, params=params)
6777
self.register_trial(id=trial.id, instance=trial)

alphatrion/trial/trial.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ class CheckpointConfig(BaseModel):
4747
class TrialConfig(BaseModel):
4848
"""Configuration for an experiment."""
4949

50-
max_duration_seconds: int = Field(
50+
max_runtime_seconds: int = Field(
5151
default=-1,
52-
description="Maximum duration in seconds for the experiment. \
52+
description="Maximum runtime seconds for the trial. \
53+
Trial timeout will override experiment timeout if both are set. \
5354
Default is -1 (no limit).",
5455
)
5556
early_stopping_runs: int = Field(
@@ -216,7 +217,7 @@ def should_early_stop(self, metric_key: str, metric_value: float) -> bool:
216217
return self._early_stopping_counter >= self._config.early_stopping_runs
217218

218219
def _timeout(self) -> int | None:
219-
timeout = self._config.max_duration_seconds
220+
timeout = self._config.max_runtime_seconds
220221
if timeout < 0:
221222
return None
222223

tests/integration/test_log.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ async def fake_sleep(value: float):
325325
config=alpha.TrialConfig(
326326
monitor_metric="accuracy",
327327
early_stopping_runs=3,
328-
max_duration_seconds=3,
328+
max_runtime_seconds=3,
329329
),
330330
) as trial:
331331
start_time = datetime.now()

tests/unit/experiment/test_craft_exp.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77

8-
from alphatrion.experiment.craft_exp import CraftExperiment
8+
from alphatrion.experiment.craft_exp import CraftExperiment, ExperimentConfig
99
from alphatrion.metadata.sql_models import TrialStatus
1010
from alphatrion.runtime.runtime import global_runtime, init
1111
from alphatrion.trial.trial import Trial, TrialConfig, current_trial_id
@@ -35,6 +35,7 @@ async def test_craft_experiment():
3535
trial_obj = trial._get_obj()
3636
assert trial_obj.status == TrialStatus.FINISHED
3737

38+
3839
@pytest.mark.asyncio
3940
async def test_craft_experiment_with_no_context():
4041
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
@@ -51,6 +52,7 @@ async def fake_work(trial: Trial):
5152
trial_obj = trial._get_obj()
5253
assert trial_obj.status == TrialStatus.FINISHED
5354

55+
5456
@pytest.mark.asyncio
5557
async def test_create_experiment_with_trial():
5658
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
@@ -129,7 +131,7 @@ async def test_craft_experiment_with_context():
129131
meta={"key": "value"},
130132
) as exp:
131133
trial = exp.start_trial(
132-
name="first-trial", config=TrialConfig(max_duration_seconds=2)
134+
name="first-trial", config=TrialConfig(max_runtime_seconds=2)
133135
)
134136
await trial.wait()
135137
assert trial.cancelled()
@@ -147,7 +149,7 @@ async def fake_work():
147149

148150
duration = random.randint(1, 5)
149151
trial = exp.start_trial(
150-
name="first-trial", config=TrialConfig(max_duration_seconds=duration)
152+
name="first-trial", config=TrialConfig(max_runtime_seconds=duration)
151153
)
152154
# double check current trial id.
153155
assert trial.id == current_trial_id.get()
@@ -171,3 +173,39 @@ async def fake_work():
171173
fake_work(),
172174
)
173175
print("All trials finished.")
176+
177+
178+
@pytest.mark.asyncio
179+
async def test_craft_experiment_with_timeout():
180+
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
181+
182+
exp = CraftExperiment.setup(
183+
name="timeout_exp",
184+
config=ExperimentConfig(max_runtime_seconds=3),
185+
)
186+
187+
async with exp.start_trial(name="first-trial") as trial:
188+
await trial.wait()
189+
190+
trial_obj = trial._get_obj()
191+
assert trial_obj.status == TrialStatus.FINISHED
192+
193+
194+
@pytest.mark.asyncio
195+
async def test_craft_experiment_with_timeout_overwrite():
196+
init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
197+
198+
exp = CraftExperiment.setup(
199+
name="timeout_exp",
200+
config=ExperimentConfig(max_runtime_seconds=3),
201+
)
202+
203+
start_time = datetime.now()
204+
async with exp.start_trial(
205+
name="first-trial", config=TrialConfig(max_runtime_seconds=1)
206+
) as trial:
207+
await trial.wait()
208+
assert datetime.now() - start_time < timedelta(seconds=3)
209+
210+
trial_obj = trial._get_obj()
211+
assert trial_obj.status == TrialStatus.FINISHED

tests/unit/trial/test_trial.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,25 @@ def test_timeout(self):
1818
},
1919
{
2020
"name": "Positive timeout",
21-
"config": TrialConfig(max_duration_seconds=10),
21+
"config": TrialConfig(max_runtime_seconds=10),
2222
"started_at": None,
2323
"expected": 10,
2424
},
2525
{
2626
"name": "Zero timeout",
27-
"config": TrialConfig(max_duration_seconds=0),
27+
"config": TrialConfig(max_runtime_seconds=0),
2828
"started_at": None,
2929
"expected": 0,
3030
},
3131
{
3232
"name": "Negative timeout",
33-
"config": TrialConfig(max_duration_seconds=-5),
33+
"config": TrialConfig(max_runtime_seconds=-5),
3434
"started_at": None,
3535
"expected": None,
3636
},
3737
{
3838
"name": "With started_at, positive timeout",
39-
"config": TrialConfig(max_duration_seconds=5),
39+
"config": TrialConfig(max_runtime_seconds=5),
4040
"started_at": (datetime.now(UTC) - timedelta(seconds=3)).isoformat(),
4141
"expected": 2,
4242
},

0 commit comments

Comments
 (0)