@@ -59,6 +59,11 @@ class TrialConfig(BaseModel):
5959 after which experiment will be stopped. Default is -1 (no early stopping). \
6060 Count each time when calling log_metrics with the monitored metric." ,
6161 )
62+ max_run_number : int = Field (
63+ default = - 1 ,
64+ description = "Maximum number of runs for the trial. \
65+ Default is -1 (no limit). Count by the finished runs." ,
66+ )
6267 monitor_metric : str | None = Field (
6368 default = None ,
6469 description = "The metric to monitor for saving the best checkpoint. \
@@ -110,7 +115,10 @@ class Trial:
110115 # key is run_id, value is Run instance
111116 "_runs" ,
112117 "_running_tasks" ,
118+ # Only work when early_stopping_runs > 0
113119 "_early_stopping_counter" ,
120+ # Only work when max_run_number > 0
121+ "_total_runs_counter" ,
114122 )
115123
116124 def __init__ (self , exp_id : int , config : TrialConfig | None = None ):
@@ -126,6 +134,7 @@ def __init__(self, exp_id: int, config: TrialConfig | None = None):
126134 self ._runs = dict ()
127135 self ._running_tasks = dict ()
128136 self ._early_stopping_counter = 0
137+ self ._total_runs_counter = 0
129138
130139 async def __aenter__ (self ):
131140 return self
@@ -223,9 +232,10 @@ def _timeout(self) -> int | None:
223232 timeout -= int (elapsed )
224233 return timeout
225234
226- def stopped (self ) -> bool :
227- return self ._context .cancelled ()
228-
235+ # Make sure you have termination condition, either by timeout or by calling cancel()
236+ # Before we have logic like once all the tasks are done, we'll call the cancel()
237+ # automatically, however, this is unpredictable because some tasks may be waiting
238+ # for external events, so we leave it to the user to decide when to stop the trial.
229239 async def wait (self ):
230240 await self ._context .wait ()
231241
@@ -287,18 +297,22 @@ def start_run(self, call_func: callable) -> Run:
287297 run ._start ()
288298 self ._runs [run .id ] = run
289299
290- # the created task will also inherit the current context,
300+ # The created task will also inherit the current context,
291301 # including the current_trial_id context var.
292302 task = asyncio .create_task (call_func ())
293303 self ._running_tasks [run .id ] = task
304+ run .register_task (task )
305+
294306 task .add_done_callback (lambda t : self ._running_tasks .pop (run .id , None ))
295307 task .add_done_callback (lambda t : self ._runs .pop (run .id , None ))
296- # FIXME: One potential issue here is once the former task finished
297- # very fast, it could lead to cancelling the trial even if there are
298- # other pending tasks ready to run. We may need a more robust way to
299- # handle this.
300- task .add_done_callback (
301- lambda t : self .cancel () if len (self ._running_tasks ) == 0 else None
302- )
308+ if self ._config .max_run_number > 0 :
309+ task .add_done_callback (
310+ lambda t : (
311+ setattr (self , "_total_runs_counter" , self ._total_runs_counter + 1 ),
312+ self .cancel ()
313+ if self ._total_runs_counter >= self ._config .max_run_number
314+ else None ,
315+ )
316+ )
303317
304318 return run
0 commit comments