55
66import pytest
77
8- from alphatrion .experiment .craft_exp import CraftExperiment
8+ from alphatrion .experiment .craft_exp import CraftExperiment , ExperimentConfig
99from alphatrion .metadata .sql_models import TrialStatus
1010from alphatrion .runtime .runtime import global_runtime , init
1111from alphatrion .trial .trial import Trial , TrialConfig , current_trial_id
@@ -35,6 +35,7 @@ async def test_craft_experiment():
3535 trial_obj = trial ._get_obj ()
3636 assert trial_obj .status == TrialStatus .FINISHED
3737
38+
3839@pytest .mark .asyncio
3940async def test_craft_experiment_with_no_context ():
4041 init (project_id = uuid .uuid4 (), artifact_insecure = True , init_tables = True )
@@ -51,6 +52,7 @@ async def fake_work(trial: Trial):
5152 trial_obj = trial ._get_obj ()
5253 assert trial_obj .status == TrialStatus .FINISHED
5354
55+
5456@pytest .mark .asyncio
5557async def test_create_experiment_with_trial ():
5658 init (project_id = uuid .uuid4 (), artifact_insecure = True , init_tables = True )
@@ -129,7 +131,7 @@ async def test_craft_experiment_with_context():
129131 meta = {"key" : "value" },
130132 ) as exp :
131133 trial = exp .start_trial (
132- name = "first-trial" , config = TrialConfig (max_duration_seconds = 2 )
134+ name = "first-trial" , config = TrialConfig (max_runtime_seconds = 2 )
133135 )
134136 await trial .wait ()
135137 assert trial .cancelled ()
@@ -147,7 +149,7 @@ async def fake_work():
147149
148150 duration = random .randint (1 , 5 )
149151 trial = exp .start_trial (
150- name = "first-trial" , config = TrialConfig (max_duration_seconds = duration )
152+ name = "first-trial" , config = TrialConfig (max_runtime_seconds = duration )
151153 )
152154 # double check current trial id.
153155 assert trial .id == current_trial_id .get ()
@@ -171,3 +173,39 @@ async def fake_work():
171173 fake_work (),
172174 )
173175 print ("All trials finished." )
176+
177+
178+ @pytest .mark .asyncio
179+ async def test_craft_experiment_with_timeout ():
180+ init (project_id = uuid .uuid4 (), artifact_insecure = True , init_tables = True )
181+
182+ exp = CraftExperiment .setup (
183+ name = "timeout_exp" ,
184+ config = ExperimentConfig (max_runtime_seconds = 3 ),
185+ )
186+
187+ async with exp .start_trial (name = "first-trial" ) as trial :
188+ await trial .wait ()
189+
190+ trial_obj = trial ._get_obj ()
191+ assert trial_obj .status == TrialStatus .FINISHED
192+
193+
194+ @pytest .mark .asyncio
195+ async def test_craft_experiment_with_timeout_overwrite ():
196+ init (project_id = uuid .uuid4 (), artifact_insecure = True , init_tables = True )
197+
198+ exp = CraftExperiment .setup (
199+ name = "timeout_exp" ,
200+ config = ExperimentConfig (max_runtime_seconds = 3 ),
201+ )
202+
203+ start_time = datetime .now ()
204+ async with exp .start_trial (
205+ name = "first-trial" , config = TrialConfig (max_runtime_seconds = 1 )
206+ ) as trial :
207+ await trial .wait ()
208+ assert datetime .now () - start_time < timedelta (seconds = 3 )
209+
210+ trial_obj = trial ._get_obj ()
211+ assert trial_obj .status == TrialStatus .FINISHED
0 commit comments