11import asyncio
22import random
3+ from datetime import datetime , timedelta
34
45import pytest
56
67from alphatrion .experiment .craft_exp import CraftExperiment
78from alphatrion .metadata .sql_models import TrialStatus
89from alphatrion .runtime .runtime import init
9- from alphatrion .trial .trial import TrialConfig , current_trial_id
10+ from alphatrion .trial .trial import Trial , TrialConfig , current_trial_id
1011
1112
1213@pytest .mark .asyncio
1314async def test_craft_experiment ():
1415 init (project_id = "test_project" , artifact_insecure = True )
1516
16- async with CraftExperiment .run (
17+ async with CraftExperiment .start (
1718 name = "context_exp" ,
1819 description = "Context manager test" ,
1920 meta = {"key" : "value" },
@@ -28,7 +29,7 @@ async def test_craft_experiment():
2829 assert trial_obj is not None
2930 assert trial_obj .description == "First trial"
3031
31- trial .stop ()
32+ trial .cancel ()
3233
3334 trial2 = trial ._get_obj ()
3435 assert trial2 .status == TrialStatus .FINISHED
@@ -39,7 +40,7 @@ async def test_create_experiment_with_trial():
3940 init (project_id = "test_project" , artifact_insecure = True )
4041
4142 trial_id = None
42- async with CraftExperiment .run (name = "context_exp" ) as exp :
43+ async with CraftExperiment .start (name = "context_exp" ) as exp :
4344 async with exp .start_trial (description = "First trial" ) as trial :
4445 trial_obj = trial ._get_obj ()
4546 assert trial_obj is not None
@@ -50,19 +51,44 @@ async def test_create_experiment_with_trial():
5051 assert trial_obj .status == TrialStatus .FINISHED
5152
5253
54+ @pytest .mark .asyncio
55+ async def test_create_experiment_with_trial_wait ():
56+ init (project_id = "test_project" , artifact_insecure = True )
57+
58+ async def fake_work (trial : Trial ):
59+ await asyncio .sleep (3 )
60+ trial .cancel ()
61+
62+ trial_id = None
63+ async with CraftExperiment .start (name = "context_exp" ) as exp :
64+ async with exp .start_trial (description = "First trial" ) as trial :
65+ trial_id = current_trial_id .get ()
66+
67+ start_time = datetime .now ()
68+
69+ asyncio .create_task (fake_work (trial ))
70+ assert datetime .now () - start_time <= timedelta (seconds = 1 )
71+
72+ await trial .wait ()
73+ assert datetime .now () - start_time >= timedelta (seconds = 3 )
74+
75+ trial_obj = exp ._runtime ._metadb .get_trial (trial_id = trial_id )
76+ assert trial_obj .status == TrialStatus .FINISHED
77+
78+
5379@pytest .mark .asyncio
5480async def test_craft_experiment_with_context ():
5581 init (project_id = "test_project" , artifact_insecure = True )
5682
57- async with CraftExperiment .run (
83+ async with CraftExperiment .start (
5884 name = "context_exp" ,
5985 description = "Context manager test" ,
6086 meta = {"key" : "value" },
6187 ) as exp :
6288 trial = exp .start_trial (
6389 description = "First trial" , config = TrialConfig (max_duration_seconds = 2 )
6490 )
65- await trial .wait_stopped ()
91+ await trial .wait ()
6692 assert trial .stopped ()
6793
6894 trial = trial ._get_obj ()
@@ -81,15 +107,15 @@ async def fake_work(exp: CraftExperiment):
81107 # double check current trial id.
82108 assert trial .id == current_trial_id .get ()
83109
84- await trial .wait_stopped ()
110+ await trial .wait ()
85111 assert trial .stopped ()
86112 # we don't reset the current trial id.
87113 assert trial .id == current_trial_id .get ()
88114
89115 trial = trial ._get_obj ()
90116 assert trial .status == TrialStatus .FINISHED
91117
92- async with CraftExperiment .run (
118+ async with CraftExperiment .start (
93119 name = "context_exp" ,
94120 description = "Context manager test" ,
95121 meta = {"key" : "value" },
0 commit comments