Skip to content

Commit e2a85dd

Browse files
authored
Add run_id to metrics (#56)
Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent fcabc0c commit e2a85dd

8 files changed

Lines changed: 67 additions & 18 deletions

File tree

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# We use PG by default as the metadata database.
2-
METADATA_DB_URL=postgresql+psycopg2://user:pass@localhost:5432/mydb
2+
METADATA_DB_URL=postgresql+psycopg2://alphatrion:alphatr1on@localhost:5432/alphatrion
33
ARTIFACT_REGISTRY_URL=http://localhost:5000/
44
LOG_LEVEL=INFO

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ pip install alphatrion
3131

3232
### Initialize the Environment
3333

34+
Run the following command for setup:
35+
3436
```bash
35-
make up
37+
cp .env.example .env & make up
3638
```
3739

3840
You can login to pgAdmin at `http://localhost:8080` to see the Postgres database. The host name for registering a new server is `postgres`, and the username and password are `alphatrion` and `alphatr1on`, respectively.

alphatrion/log/log.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from alphatrion.run.run import current_run_id
12
from alphatrion.runtime.runtime import global_runtime
23
from alphatrion.trial.trial import current_trial_id
34
from alphatrion.utils import time as utime
@@ -51,7 +52,13 @@ async def log_params(params: dict):
5152
# metric key must be string, value must be float.
5253
# If save_on_best is enabled in the trial config, and the metric is the best metric
5354
# so far, the trial will checkpoint the current data.
55+
#
56+
# Note: log_metrics can only be called inside a Run, because it needs a run_id.
5457
async def log_metrics(metrics: dict[str, float]):
58+
run_id = current_run_id.get()
59+
if run_id is None:
60+
raise RuntimeError("log_metrics must be called inside a Run.")
61+
5562
runtime = global_runtime()
5663
exp = runtime.current_exp
5764

@@ -70,6 +77,7 @@ async def log_metrics(metrics: dict[str, float]):
7077
value=value,
7178
project_id=runtime._project_id,
7279
trial_id=trial_id,
80+
run_id=run_id,
7381
step=step,
7482
)
7583

alphatrion/metadata/sql.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def create_metric(
262262
self,
263263
project_id: uuid.UUID,
264264
trial_id: uuid.UUID,
265+
run_id: uuid.UUID,
265266
key: str,
266267
value: float,
267268
step: int,
@@ -270,6 +271,7 @@ def create_metric(
270271
new_metric = Metric(
271272
project_id=project_id,
272273
trial_id=trial_id,
274+
run_id=run_id,
273275
key=key,
274276
value=value,
275277
step=step,

alphatrion/metadata/sql_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,6 @@ class Metric(Base):
111111
value = Column(Float, nullable=False)
112112
project_id = Column(UUID(as_uuid=True), nullable=False)
113113
trial_id = Column(UUID(as_uuid=True), nullable=False)
114+
run_id = Column(UUID(as_uuid=True), nullable=False)
114115
step = Column(Integer, nullable=False, default=0)
115116
created_at = Column(DateTime(timezone=True), default=datetime.now(UTC))

hack/seed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def generate_metric(runs: list[Run]) -> Metric:
9696
return Metric(
9797
project_id=run.project_id,
9898
trial_id=run.trial_id,
99+
run_id=run.uuid,
99100
key=random.choice(["accuracy", "loss", "precision", "fitness"]),
100101
value=random.uniform(0, 1),
101102
step=random.randint(1, 1000),

tests/integration/test_log.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,9 @@ async def test_log_params():
9595
async def test_log_metrics():
9696
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
9797

98+
async def log_metric(metrics: dict):
99+
await alpha.log_metrics(metrics)
100+
98101
async with alpha.CraftExperiment.start(name="log_metrics_exp") as exp:
99102
trial = exp.start_trial(name="first-trial", params={"param1": 0.1})
100103

@@ -105,7 +108,8 @@ async def test_log_metrics():
105108
metrics = exp._runtime._metadb.list_metrics(trial_id=trial._id)
106109
assert len(metrics) == 0
107110

108-
await alpha.log_metrics({"accuracy": 0.95, "loss": 0.1})
111+
run = trial.start_run(lambda: log_metric({"accuracy": 0.95, "loss": 0.1}))
112+
await run.wait()
109113

110114
metrics = exp._runtime._metadb.list_metrics(trial_id=trial._id)
111115
assert len(metrics) == 2
@@ -115,14 +119,21 @@ async def test_log_metrics():
115119
assert metrics[1].key == "loss"
116120
assert metrics[1].value == 0.1
117121
assert metrics[1].step == 1
122+
run_id_1 = metrics[0].run_id
123+
assert run_id_1 is not None
124+
assert metrics[0].run_id == metrics[1].run_id
118125

119-
await alpha.log_metrics({"accuracy": 0.96})
126+
run = trial.start_run(lambda: log_metric({"accuracy": 0.96}))
127+
await run.wait()
120128

121129
metrics = exp._runtime._metadb.list_metrics(trial_id=trial._id)
122130
assert len(metrics) == 3
123131
assert metrics[2].key == "accuracy"
124132
assert metrics[2].value == 0.96
125133
assert metrics[2].step == 2
134+
run_id_2 = metrics[2].run_id
135+
assert run_id_2 is not None
136+
assert run_id_2 != run_id_1
126137

127138
trial.cancel()
128139

@@ -131,6 +142,9 @@ async def test_log_metrics():
131142
async def test_log_metrics_with_save_on_max():
132143
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
133144

145+
async def log_metric(value: float):
146+
await alpha.log_metrics({"accuracy": value})
147+
134148
async with alpha.CraftExperiment.start(
135149
name="log_metrics_with_save_on_max",
136150
description="Context manager test",
@@ -139,7 +153,7 @@ async def test_log_metrics_with_save_on_max():
139153
with tempfile.TemporaryDirectory() as tmpdir:
140154
os.chdir(tmpdir)
141155

142-
_ = exp.start_trial(
156+
trial = exp.start_trial(
143157
name="trial-with-save_on_best",
144158
config=alpha.TrialConfig(
145159
checkpoint=alpha.CheckpointConfig(
@@ -156,35 +170,47 @@ async def test_log_metrics_with_save_on_max():
156170
with open(file1, "w") as f:
157171
f.write("This is file1.")
158172

159-
await alpha.log_metrics({"accuracy": 0.90})
173+
run = trial.start_run(lambda: log_metric(0.90))
174+
await run.wait()
160175

161176
versions = exp._runtime._artifact.list_versions(exp.id)
162177
assert len(versions) == 1
163178

164179
# To avoid the same timestamp hash, we wait for 1 second
165180
time.sleep(1)
166181

167-
await alpha.log_metrics({"accuracy": 0.78})
182+
run = trial.start_run(lambda: log_metric(0.78))
183+
await run.wait()
184+
168185
versions = exp._runtime._artifact.list_versions(exp.id)
169186
assert len(versions) == 1
170187

171188
time.sleep(1)
172189

173-
await alpha.log_metrics({"accuracy": 0.91})
190+
run = trial.start_run(lambda: log_metric(0.91))
191+
await run.wait()
192+
174193
versions = exp._runtime._artifact.list_versions(exp.id)
175194
assert len(versions) == 2
176195

177196
time.sleep(1)
178197

179-
await alpha.log_metrics({"accuracy2": 0.98})
198+
run = trial.start_run(lambda: log_metric(0.98))
199+
await run.wait()
200+
180201
versions = exp._runtime._artifact.list_versions(exp.id)
181-
assert len(versions) == 2
202+
assert len(versions) == 3
203+
204+
trial.cancel()
182205

183206

184207
@pytest.mark.asyncio
185208
async def test_log_metrics_with_save_on_min():
186209
alpha.init(project_id=uuid.uuid4(), artifact_insecure=True, init_tables=True)
187210

211+
async def log_metric(value: float):
212+
await alpha.log_metrics({"accuracy": value})
213+
188214
async with alpha.CraftExperiment.start(
189215
name="log_metrics_with_save_on_min",
190216
description="Context manager test",
@@ -193,7 +219,7 @@ async def test_log_metrics_with_save_on_min():
193219
with tempfile.TemporaryDirectory() as tmpdir:
194220
os.chdir(tmpdir)
195221

196-
_ = exp.start_trial(
222+
trial = exp.start_trial(
197223
name="trial-with-save_on_best",
198224
config=alpha.TrialConfig(
199225
checkpoint=alpha.CheckpointConfig(
@@ -210,29 +236,37 @@ async def test_log_metrics_with_save_on_min():
210236
with open(file1, "w") as f:
211237
f.write("This is file1.")
212238

213-
await alpha.log_metrics({"accuracy": 0.30})
239+
run = trial.start_run(lambda: log_metric(0.30))
240+
await run.wait()
214241

215242
versions = exp._runtime._artifact.list_versions(exp.id)
216243
assert len(versions) == 1
217244

218245
# To avoid the same timestamp hash, we wait for 1 second
219246
time.sleep(1)
220247

221-
await alpha.log_metrics({"accuracy": 0.58})
248+
run = trial.start_run(lambda: log_metric(0.58))
249+
await run.wait()
250+
222251
versions = exp._runtime._artifact.list_versions(exp.id)
223252
assert len(versions) == 1
224253

225254
time.sleep(1)
226255

227-
await alpha.log_metrics({"accuracy": 0.21})
256+
run = trial.start_run(lambda: log_metric(0.21))
257+
await run.wait()
258+
228259
versions = exp._runtime._artifact.list_versions(exp.id)
229260
assert len(versions) == 2
230261

231262
time.sleep(1)
232263

233-
await alpha.log_metrics({"accuracy2": 0.18})
264+
task = trial.start_run(lambda: log_metric(0.18))
265+
await task.wait()
234266
versions = exp._runtime._artifact.list_versions(exp.id)
235-
assert len(versions) == 2
267+
assert len(versions) == 3
268+
269+
trial.cancel()
236270

237271

238272
@pytest.mark.asyncio

tests/unit/metadata/test_sql.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,9 @@ def test_create_metric(db):
8888
project_id = uuid.uuid4()
8989
exp_id = db.create_exp("test_exp", project_id, "test description")
9090
trial_id = db.create_trial(exp_id=exp_id, project_id=project_id, name="test-trial")
91-
db.create_metric(project_id, trial_id, "accuracy", 0.95, 1)
92-
db.create_metric(project_id, trial_id, "accuracy", 0.85, 2)
91+
run_id = db.create_run(trial_id=trial_id, project_id=project_id)
92+
db.create_metric(project_id, trial_id, run_id, "accuracy", 0.95, 1)
93+
db.create_metric(project_id, trial_id, run_id, "accuracy", 0.85, 2)
9394

9495
metrics = db.list_metrics(trial_id)
9596
assert len(metrics) == 2

0 commit comments

Comments
 (0)