diff --git a/alphatrion/experiment/base.py b/alphatrion/experiment/base.py index 0246a98b..75fb69db 100644 --- a/alphatrion/experiment/base.py +++ b/alphatrion/experiment/base.py @@ -18,24 +18,30 @@ def __init__(self): # All trials in this experiment, key is trial_id, value is Trial instance. self._trials = dict() + @property + def id(self): + return self._id + + def get_trial(self, id: int) -> trial.Trial | None: + return self._trials.get(id) + def _reset(self): self._trials = dict() - self._exp = None def __enter__(self): if self._id is None: - raise RuntimeError("Experiment is not set. Did you call begin()?") + raise RuntimeError("Experiment is not set. Did you call run()?") exp = self._get() if exp is None: raise RuntimeError(f"Experiment {self._id} not found in the database.") - self._runtime.current_exp_uuid = exp.uuid + self._runtime.current_exp = self return self def __exit__(self, exc_type, exc_val, exc_tb): self._reset() - self._runtime.current_exp_uuid = None + self._runtime.current_exp = None @classmethod @abstractmethod diff --git a/alphatrion/metadata/sql.py b/alphatrion/metadata/sql.py index 191c8f27..63880611 100644 --- a/alphatrion/metadata/sql.py +++ b/alphatrion/metadata/sql.py @@ -189,12 +189,13 @@ def update_trial(self, trial_id: int, **kwargs): session.commit() session.close() - def create_metric(self, trial_id: int, key: str, value: float): + def create_metric(self, trial_id: int, key: str, value: float, step: int): session = self._session() new_metric = Metrics( trial_id=trial_id, key=key, value=value, + step=step, ) session.add(new_metric) session.commit() diff --git a/alphatrion/metadata/sql_models.py b/alphatrion/metadata/sql_models.py index b1934829..f3df8ae4 100644 --- a/alphatrion/metadata/sql_models.py +++ b/alphatrion/metadata/sql_models.py @@ -88,4 +88,5 @@ class Metrics(Base): key = Column(String, nullable=False) value = Column(Float, nullable=False) trial_id = Column(Integer, nullable=False) + step = Column(Integer, nullable=False, default=0) created_at = Column(DateTime(timezone=True), default=datetime.now(UTC)) diff --git a/alphatrion/record/record.py b/alphatrion/record/record.py index f15811f0..7b09038c 100644 --- a/alphatrion/record/record.py +++ b/alphatrion/record/record.py @@ -27,9 +27,16 @@ def log_artifact( # We use experiment ID as the repo name rather than the experiment name, # because experiment name is not unique - runtime._artifact.push( - repo_name=str(runtime.current_exp_uuid), paths=paths, version=version - ) + + exp = runtime.current_exp + if exp is None: + raise RuntimeError("No running experiment found in the current context.") + + exp_obj = runtime._metadb.get_exp(exp.id) + if exp_obj is None: + raise RuntimeError(f"Experiment {exp.id} not found in the database.") + + runtime._artifact.push(repo_name=str(exp_obj.uuid), paths=paths, version=version) # log_params is used to save a set of parameters, which is a dict of key-value pairs. @@ -48,9 +55,18 @@ def log_params(params: dict): # metric key must be string, value must be float def log_metrics(metrics: dict[str, float]): runtime = global_runtime() + exp = runtime.current_exp + + trial_id = current_trial_id.get() + trial = exp.get_trial(id=trial_id) + if trial is None: + raise RuntimeError(f"Trial {trial_id} not found in the database.") + + step = trial.increment_step() for key, value in metrics.items(): runtime._metadb.create_metric( key=key, value=value, - trial_id=current_trial_id.get(), + trial_id=trial.id, + step=step, ) diff --git a/alphatrion/runtime/runtime.py b/alphatrion/runtime/runtime.py index 9ccf315a..19fd501d 100644 --- a/alphatrion/runtime/runtime.py +++ b/alphatrion/runtime/runtime.py @@ -29,6 +29,8 @@ def global_runtime(): # Runtime contains all kinds of clients, e.g., metadb client, artifact client, etc. # Stateful information will also be stored here, e.g., current running experiment ID. class Runtime: + __slots__ = ("_project_id", "_metadb", "_artifact", "__current_exp") + def __init__(self, project_id: str, artifact_insecure: bool = False): self._project_id = project_id self._metadb = SQLStore(os.getenv(consts.METADATA_DB_URL), init_tables=True) @@ -36,15 +38,11 @@ def __init__(self, project_id: str, artifact_insecure: bool = False): project_id=self._project_id, insecure=artifact_insecure ) - # Current running Experiment UUID. One experiment at a time. - # Set in Experiment.__enter__ and cleared in __exit__. - # This is for global access, e.g., log_metrics/artifacts/params - self.__current_exp_uuid = None - + # current_exp is the current running experiment. @property - def current_exp_uuid(self): - return self.__current_exp_uuid + def current_exp(self): + return self.__current_exp - @current_exp_uuid.setter - def current_exp_uuid(self, value): - self.__current_exp_uuid = value + @current_exp.setter + def current_exp(self, value): + self.__current_exp = value diff --git a/alphatrion/trial/trial.py b/alphatrion/trial/trial.py index 2738e758..988048f7 100644 --- a/alphatrion/trial/trial.py +++ b/alphatrion/trial/trial.py @@ -79,12 +79,15 @@ class Trial: "_config", "_runtime", "_token", + "_step", ) def __init__(self, exp_id: int, config: TrialConfig | None = None): self._exp_id = exp_id self._config = config or TrialConfig() self._runtime = global_runtime() + # step is used to track the round, e.g. the step in metric logging. + self._step = 0 def _start( self, @@ -103,6 +106,10 @@ def _start( self._token = current_trial_id.set(self._id) return self._id + @property + def id(self): + return self._id + # finish function should be called manually as a pair of start def finish(self, status: TrialStatus = TrialStatus.FINISHED): trial = self._runtime._metadb.get_trial(trial_id=self._id) @@ -119,3 +126,7 @@ def finish(self, status: TrialStatus = TrialStatus.FINISHED): def _get(self): return self._runtime._metadb.get_trial(trial_id=self._id) + + def increment_step(self) -> int: + self._step += 1 + return self._step diff --git a/poetry.lock b/poetry.lock index 75e12abd..c1f9000d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -409,14 +409,14 @@ files = [ [[package]] name = "pydantic" -version = "2.11.8" +version = "2.11.9" description = "Data validation using Python type hints" optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "pydantic-2.11.8-py3-none-any.whl", hash = "sha256:830ec4cccc3cf21be1ef5aec1d3348a179c92a61a7dab0e59837f9cc9fa93351"}, - {file = "pydantic-2.11.8.tar.gz", hash = "sha256:3d080f4a3ac6bde98e959ba552124d46be9f565b7be67769e49fcb286bae1bfb"}, + {file = "pydantic-2.11.9-py3-none-any.whl", hash = "sha256:c42dd626f5cfc1c6950ce6205ea58c93efa406da65f479dcb4029d5934857da2"}, + {file = "pydantic-2.11.9.tar.gz", hash = "sha256:6b8ffda597a14812a7975c90b82a8a2e777d9257aba3453f973acd3c032a18e2"}, ] [package.dependencies] diff --git a/tests/integration/test_log_functions.py b/tests/integration/test_log_functions.py index 34367dd6..2abc53af 100644 --- a/tests/integration/test_log_functions.py +++ b/tests/integration/test_log_functions.py @@ -16,6 +16,9 @@ def test_log_artifact(): ) as exp: trial = exp.start_trial(description="First trial") + exp_obj = exp._runtime._metadb.get_exp(exp_id=exp._id) + assert exp_obj is not None + with tempfile.TemporaryDirectory() as tmpdir: os.chdir(tmpdir) @@ -25,7 +28,7 @@ def test_log_artifact(): alpha.log_artifact(paths="file1.txt", version="v1") versions = exp._runtime._artifact.list_versions( - exp._runtime.current_exp_uuid + exp_obj.uuid ) assert "v1" in versions @@ -35,15 +38,15 @@ def test_log_artifact(): # push folder instead alpha.log_artifact(paths=["file1.txt"], version="v2") versions = exp._runtime._artifact.list_versions( - exp._runtime.current_exp_uuid + exp_obj.uuid ) assert "v2" in versions exp._runtime._artifact.delete( - repo_name=exp._runtime.current_exp_uuid, + repo_name=exp_obj.uuid, versions=["v1", "v2"], ) - versions = exp._runtime._artifact.list_versions(exp._runtime.current_exp_uuid) + versions = exp._runtime._artifact.list_versions(exp_obj.uuid) assert len(versions) == 0 trial.finish() @@ -94,13 +97,26 @@ def test_log_metrics(): assert new_trial is not None assert new_trial.params == {"param1": 0.1} + metrics = exp._runtime._metadb.list_metrics(trial_id=trial._id) + assert len(metrics) == 0 + alpha.log_metrics({"accuracy": 0.95, "loss": 0.1}) metrics = exp._runtime._metadb.list_metrics(trial_id=trial._id) assert len(metrics) == 2 assert metrics[0].key == "accuracy" assert metrics[0].value == 0.95 + assert metrics[0].step == 1 assert metrics[1].key == "loss" assert metrics[1].value == 0.1 + assert metrics[1].step == 1 + + alpha.log_metrics({"accuracy": 0.96}) + + metrics = exp._runtime._metadb.list_metrics(trial_id=trial._id) + assert len(metrics) == 3 + assert metrics[2].key == "accuracy" + assert metrics[2].value == 0.96 + assert metrics[2].step == 2 trial.finish() diff --git a/tests/unit/metadata/test_sql.py b/tests/unit/metadata/test_sql.py index 296b912e..9977c76a 100644 --- a/tests/unit/metadata/test_sql.py +++ b/tests/unit/metadata/test_sql.py @@ -77,8 +77,8 @@ def test_update_trial(db): def test_create_metric(db): trial_id = db.create_trial(1, "test description", None) - db.create_metric(trial_id, "accuracy", 0.95) - db.create_metric(trial_id, "accuracy", 0.85) + db.create_metric(trial_id, "accuracy", 0.95, 1) + db.create_metric(trial_id, "accuracy", 0.85, 2) metrics = db.list_metrics(trial_id) assert len(metrics) == 2