Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions alphatrion/experiment/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,30 @@ def __init__(self):
# All trials in this experiment, key is trial_id, value is Trial instance.
self._trials = dict()

@property
def id(self):
return self._id

def get_trial(self, id: int) -> trial.Trial | None:
return self._trials.get(id)

def _reset(self):
self._trials = dict()
self._exp = None

def __enter__(self):
if self._id is None:
raise RuntimeError("Experiment is not set. Did you call begin()?")
raise RuntimeError("Experiment is not set. Did you call run()?")

exp = self._get()
if exp is None:
raise RuntimeError(f"Experiment {self._id} not found in the database.")

self._runtime.current_exp_uuid = exp.uuid
self._runtime.current_exp = self
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self._reset()
self._runtime.current_exp_uuid = None
self._runtime.current_exp = None

@classmethod
@abstractmethod
Expand Down
3 changes: 2 additions & 1 deletion alphatrion/metadata/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,13 @@ def update_trial(self, trial_id: int, **kwargs):
session.commit()
session.close()

def create_metric(self, trial_id: int, key: str, value: float):
def create_metric(self, trial_id: int, key: str, value: float, step: int):
session = self._session()
new_metric = Metrics(
trial_id=trial_id,
key=key,
value=value,
step=step,
)
session.add(new_metric)
session.commit()
Expand Down
1 change: 1 addition & 0 deletions alphatrion/metadata/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,5 @@ class Metrics(Base):
key = Column(String, nullable=False)
value = Column(Float, nullable=False)
trial_id = Column(Integer, nullable=False)
step = Column(Integer, nullable=False, default=0)
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))
24 changes: 20 additions & 4 deletions alphatrion/record/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,16 @@ def log_artifact(

# We use experiment ID as the repo name rather than the experiment name,
# because experiment name is not unique
runtime._artifact.push(
repo_name=str(runtime.current_exp_uuid), paths=paths, version=version
)

exp = runtime.current_exp
if exp is None:
raise RuntimeError("No running experiment found in the current context.")

exp_obj = runtime._metadb.get_exp(exp.id)
if exp_obj is None:
raise RuntimeError(f"Experiment {exp.id} not found in the database.")

runtime._artifact.push(repo_name=str(exp_obj.uuid), paths=paths, version=version)


# log_params is used to save a set of parameters, which is a dict of key-value pairs.
Expand All @@ -48,9 +55,18 @@ def log_params(params: dict):
# metric key must be string, value must be float
def log_metrics(metrics: dict[str, float]):
runtime = global_runtime()
exp = runtime.current_exp

trial_id = current_trial_id.get()
trial = exp.get_trial(id=trial_id)
if trial is None:
raise RuntimeError(f"Trial {trial_id} not found in the database.")

step = trial.increment_step()
for key, value in metrics.items():
runtime._metadb.create_metric(
key=key,
value=value,
trial_id=current_trial_id.get(),
trial_id=trial.id,
step=step,
)
18 changes: 8 additions & 10 deletions alphatrion/runtime/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,20 @@ def global_runtime():
# Runtime contains all kinds of clients, e.g., metadb client, artifact client, etc.
# Stateful information will also be stored here, e.g., current running experiment ID.
class Runtime:
__slots__ = ("_project_id", "_metadb", "_artifact", "__current_exp")

def __init__(self, project_id: str, artifact_insecure: bool = False):
self._project_id = project_id
self._metadb = SQLStore(os.getenv(consts.METADATA_DB_URL), init_tables=True)
self._artifact = Artifact(
project_id=self._project_id, insecure=artifact_insecure
)

# Current running Experiment UUID. One experiment at a time.
# Set in Experiment.__enter__ and cleared in __exit__.
# This is for global access, e.g., log_metrics/artifacts/params
self.__current_exp_uuid = None

# current_exp is the current running experiment.
@property
def current_exp_uuid(self):
return self.__current_exp_uuid
def current_exp(self):
return self.__current_exp

@current_exp_uuid.setter
def current_exp_uuid(self, value):
self.__current_exp_uuid = value
@current_exp.setter
def current_exp(self, value):
self.__current_exp = value
11 changes: 11 additions & 0 deletions alphatrion/trial/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,15 @@ class Trial:
"_config",
"_runtime",
"_token",
"_step",
)

def __init__(self, exp_id: int, config: TrialConfig | None = None):
self._exp_id = exp_id
self._config = config or TrialConfig()
self._runtime = global_runtime()
# step is used to track the round, e.g. the step in metric logging.
self._step = 0

def _start(
self,
Expand All @@ -103,6 +106,10 @@ def _start(
self._token = current_trial_id.set(self._id)
return self._id

@property
def id(self):
return self._id

# finish function should be called manually as a pair of start
def finish(self, status: TrialStatus = TrialStatus.FINISHED):
trial = self._runtime._metadb.get_trial(trial_id=self._id)
Expand All @@ -119,3 +126,7 @@ def finish(self, status: TrialStatus = TrialStatus.FINISHED):

def _get(self):
return self._runtime._metadb.get_trial(trial_id=self._id)

def increment_step(self) -> int:
self._step += 1
return self._step
6 changes: 3 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 20 additions & 4 deletions tests/integration/test_log_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def test_log_artifact():
) as exp:
trial = exp.start_trial(description="First trial")

exp_obj = exp._runtime._metadb.get_exp(exp_id=exp._id)
assert exp_obj is not None

with tempfile.TemporaryDirectory() as tmpdir:
os.chdir(tmpdir)

Expand All @@ -25,7 +28,7 @@ def test_log_artifact():

alpha.log_artifact(paths="file1.txt", version="v1")
versions = exp._runtime._artifact.list_versions(
exp._runtime.current_exp_uuid
exp_obj.uuid
)
assert "v1" in versions

Expand All @@ -35,15 +38,15 @@ def test_log_artifact():
# push folder instead
alpha.log_artifact(paths=["file1.txt"], version="v2")
versions = exp._runtime._artifact.list_versions(
exp._runtime.current_exp_uuid
exp_obj.uuid
)
assert "v2" in versions

exp._runtime._artifact.delete(
repo_name=exp._runtime.current_exp_uuid,
repo_name=exp_obj.uuid,
versions=["v1", "v2"],
)
versions = exp._runtime._artifact.list_versions(exp._runtime.current_exp_uuid)
versions = exp._runtime._artifact.list_versions(exp_obj.uuid)
assert len(versions) == 0

trial.finish()
Expand Down Expand Up @@ -94,13 +97,26 @@ def test_log_metrics():
assert new_trial is not None
assert new_trial.params == {"param1": 0.1}

metrics = exp._runtime._metadb.list_metrics(trial_id=trial._id)
assert len(metrics) == 0

alpha.log_metrics({"accuracy": 0.95, "loss": 0.1})

metrics = exp._runtime._metadb.list_metrics(trial_id=trial._id)
assert len(metrics) == 2
assert metrics[0].key == "accuracy"
assert metrics[0].value == 0.95
assert metrics[0].step == 1
assert metrics[1].key == "loss"
assert metrics[1].value == 0.1
assert metrics[1].step == 1

alpha.log_metrics({"accuracy": 0.96})

metrics = exp._runtime._metadb.list_metrics(trial_id=trial._id)
assert len(metrics) == 3
assert metrics[2].key == "accuracy"
assert metrics[2].value == 0.96
assert metrics[2].step == 2

trial.finish()
4 changes: 2 additions & 2 deletions tests/unit/metadata/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def test_update_trial(db):

def test_create_metric(db):
trial_id = db.create_trial(1, "test description", None)
db.create_metric(trial_id, "accuracy", 0.95)
db.create_metric(trial_id, "accuracy", 0.85)
db.create_metric(trial_id, "accuracy", 0.95, 1)
db.create_metric(trial_id, "accuracy", 0.85, 2)

metrics = db.list_metrics(trial_id)
assert len(metrics) == 2
Expand Down
Loading