diff --git a/alphatrion/log/log.py b/alphatrion/log/log.py index c88ac791..dc16cb69 100644 --- a/alphatrion/log/log.py +++ b/alphatrion/log/log.py @@ -39,11 +39,14 @@ async def log_artifact( # log_params is used to save a set of parameters, which is a dict of key-value pairs. # should be called after starting a trial. async def log_params(params: dict): + trial_id = current_trial_id.get() + if trial_id is None: + raise RuntimeError("log_params must be called inside a Trial.") runtime = global_runtime() # TODO: should we upload to the artifact as well? # current_trial_id is protect by contextvar, so it's safe to use in async runtime._metadb.update_trial( - trial_id=current_trial_id.get(), + trial_id=trial_id, params=params, ) @@ -63,6 +66,9 @@ async def log_metrics(metrics: dict[str, float]): exp = runtime.current_exp trial_id = current_trial_id.get() + if trial_id is None: + raise RuntimeError("log_metrics must be called inside a Trial.") + trial = exp.get_trial(id=trial_id) if trial is None: raise RuntimeError(f"Trial {trial_id} not found in the database.") diff --git a/alphatrion/run/run.py b/alphatrion/run/run.py index 672827bf..60c75554 100644 --- a/alphatrion/run/run.py +++ b/alphatrion/run/run.py @@ -16,13 +16,21 @@ def __init__(self, trial_id: uuid.UUID): def id(self) -> uuid.UUID: return self._id - def _start(self): + def _start(self, call_func: callable) -> asyncio.Task | None: self._id = self._runtime._metadb.create_run( project_id=self._runtime._project_id, trial_id=self._trial_id ) - def register_task(self, task: asyncio.Task): - self._task = task + # current_run_id context var is used in tracing workflow/task decorators. + token = current_run_id.set(self.id) + try: + # The created task will also inherit the current context, + # including the current_trial_id, current_run_id context var. + self._task = asyncio.create_task(call_func()) + finally: + current_run_id.reset(token) + + return self._task async def wait(self): await self._task diff --git a/alphatrion/trial/trial.py b/alphatrion/trial/trial.py index 9e126bd3..1410d532 100644 --- a/alphatrion/trial/trial.py +++ b/alphatrion/trial/trial.py @@ -1,4 +1,3 @@ -import asyncio import contextvars import os import uuid @@ -7,7 +6,7 @@ from pydantic import BaseModel, Field, model_validator from alphatrion.metadata.sql_models import COMPLETED_STATUS, TrialStatus -from alphatrion.run.run import Run, current_run_id +from alphatrion.run.run import Run from alphatrion.runtime.runtime import global_runtime from alphatrion.utils.context import Context @@ -308,20 +307,11 @@ def start_run(self, call_func: callable) -> Run: :return: the Run instance.""" run = Run(trial_id=self._id) - run._start() + task = run._start(call_func) + if task is None: + raise RuntimeError("Failed to start the run task.") self._runs[run.id] = run - - # current_run_id context var is used in tracing workflow/task decorators. - token = current_run_id.set(run.id) - try: - # The created task will also inherit the current context, - # including the current_trial_id, current_run_id context var. - task = asyncio.create_task(call_func()) - finally: - current_run_id.reset(token) - self._running_tasks[run.id] = task - run.register_task(task) task.add_done_callback(lambda t: self._running_tasks.pop(run.id, None)) task.add_done_callback(lambda t: self._runs.pop(run.id, None))