11import contextvars
2+ import uuid
23from datetime import UTC , datetime
34
45from pydantic import BaseModel , Field , field_validator
56
67from alphatrion .metadata .sql_models import COMPLETED_STATUS , TrialStatus
78from alphatrion .runtime .runtime import global_runtime
9+ from alphatrion .utils .context import Context
810
911# Used in record/record.py to log params/metrics
1012current_trial_id = contextvars .ContextVar ("current_trial_id" , default = None )
@@ -57,15 +59,15 @@ class TrialConfig(BaseModel):
5759 """Configuration for an experiment."""
5860
5961 max_duration_seconds : int = Field (
60- default = 86400 ,
62+ default = - 1 ,
6163 description = "Maximum duration in seconds for the experiment. \
62- Default is 86400 seconds (1 day)." ,
63- )
64- max_retries : int = Field (
65- default = 0 ,
66- description = "Maximum number of retries for the experiment. \
67- Default is 0 (no retries)." ,
64+ Default is -1 (no limit)." ,
6865 )
66+ # max_retries: int = Field(
67+ # default=0,
68+ # description="Maximum number of retries for the experiment. \
69+ # Default is 0 (no retries).",
70+ # )
6971 checkpoint : CheckpointConfig = Field (
7072 default = CheckpointConfig (),
7173 description = "Configuration for checkpointing." ,
@@ -78,8 +80,9 @@ class Trial:
7880 "_exp_id" ,
7981 "_config" ,
8082 "_runtime" ,
81- "_token" ,
8283 "_step" ,
84+ "_context" ,
85+ "_token" ,
8386 )
8487
8588 def __init__ (self , exp_id : int , config : TrialConfig | None = None ):
@@ -88,13 +91,25 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
8891 self ._runtime = global_runtime ()
8992 # step is used to track the round, e.g. the step in metric logging.
9093 self ._step = 0
94+ self ._context = Context (
95+ cancel_func = self ._stop ,
96+ timeout = self ._config .max_duration_seconds
97+ if self ._config .max_duration_seconds > 0
98+ else None ,
99+ )
100+
101+ def stopped (self ) -> bool :
102+ return self ._context .cancelled ()
91103
92- def _start (
104+ async def wait_stopped (self ):
105+ await self ._context .wait_cancelled ()
106+
107+ async def _start (
93108 self ,
94109 description : str | None = None ,
95110 meta : dict | None = None ,
96111 params : dict | None = None ,
97- ) -> int :
112+ ) -> uuid . UUID :
98113 self ._id = self ._runtime ._metadb .create_trial (
99114 exp_id = self ._exp_id ,
100115 description = description ,
@@ -103,26 +118,31 @@ def _start(
103118 status = TrialStatus .RUNNING ,
104119 )
105120
121+ # We don't reset the trial id context var here, because
122+ # each trial runs in its own context.
106123 self ._token = current_trial_id .set (self ._id )
124+ await self ._context .start ()
107125 return self ._id
108126
109127 @property
110- def id (self ):
128+ def id (self ) -> uuid . UUID :
111129 return self ._id
112130
113- # finish function should be called manually as a pair of start
114- def finish (self , status : TrialStatus = TrialStatus .FINISHED ):
131+ # stop function should be called manually as a pair of start
132+ def stop (self ):
133+ self ._context .cancel ()
134+
135+ def _stop (self ):
115136 trial = self ._runtime ._metadb .get_trial (trial_id = self ._id )
116137 if trial is not None and trial .status not in COMPLETED_STATUS :
117138 duration = (
118139 datetime .now (UTC ) - trial .created_at .replace (tzinfo = UTC )
119140 ).total_seconds ()
120141 self ._runtime ._metadb .update_trial (
121- trial_id = self ._id , status = status , duration = duration
142+ trial_id = self ._id , status = TrialStatus . FINISHED , duration = duration
122143 )
123144
124- # recover the context var
125- current_trial_id .reset (self ._token )
145+ self ._runtime .current_exp .unregister_trial (self ._id )
126146
127147 def _get (self ):
128148 return self ._runtime ._metadb .get_trial (trial_id = self ._id )
0 commit comments