Skip to content

Commit 5a4f563

Browse files
authored
feat: add multi hooks support (#251)
* support multi-hooks Signed-off-by: kerthcet <kerthcet@gmail.com> * add more tests Signed-off-by: kerthcet <kerthcet@gmail.com> * fix lint Signed-off-by: kerthcet <kerthcet@gmail.com> * fix test Signed-off-by: kerthcet <kerthcet@gmail.com> * fix test Signed-off-by: kerthcet <kerthcet@gmail.com> --------- Signed-off-by: kerthcet <kerthcet@gmail.com>
1 parent 506f322 commit 5a4f563

3 files changed

Lines changed: 436 additions & 26 deletions

File tree

alphatrion/run/hooks.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Built-in post-run hooks for metadata enrichment."""
22

3+
import logging
34
import uuid
45
from typing import Any
56

67
from alphatrion.runtime.runtime import global_runtime
8+
from alphatrion.storage.sql_models import Status
9+
10+
logger = logging.getLogger(__name__)
711

812

913
class PostRunHookFn:
@@ -14,16 +18,17 @@ def sync_metadata(run_id: uuid.UUID, result: Any) -> None:
1418
"""
1519
Sync function result to run metadata.
1620
17-
If the function returns a dict, it will be merged into the run's metadata.
18-
This is useful for automatically capturing metrics, model info, etc.
21+
Looks for 'metadata' key in result dict and syncs it to run metadata.
1922
2023
Example:
2124
async def train_model():
22-
# ... training code ...
2325
return {
24-
"accuracy": 0.95,
25-
"loss": 0.05,
26-
"num_epochs": 10,
26+
"metadata": {
27+
"accuracy": 0.95,
28+
"loss": 0.05,
29+
"num_epochs": 10,
30+
},
31+
"status": "COMPLETED"
2732
}
2833
2934
run = exp.run(train_model, post_run_hooks=[PostRunHookFn.sync_metadata])
@@ -32,6 +37,70 @@ async def train_model():
3237
:param run_id: UUID of the run
3338
:param result: Return value from the run function
3439
"""
35-
if isinstance(result, dict):
40+
if result is None:
41+
return
42+
43+
if isinstance(result, dict) and "metadata" in result:
44+
metadata = result["metadata"]
45+
if isinstance(metadata, dict):
46+
metadb = global_runtime().metadb
47+
metadb.update_run(run_id=run_id, meta=metadata)
48+
else:
49+
logger.warning(
50+
f"PostRunHookFn.sync_metadata: 'metadata' key in result for run {run_id} is not a dict. Skipping metadata sync."
51+
)
52+
else:
53+
logger.warning(
54+
f"PostRunHookFn.sync_metadata: Result for run {run_id} does not contain 'metadata' key or is not a dict. Skipping metadata sync."
55+
)
56+
57+
@staticmethod
58+
def sync_status(run_id: uuid.UUID, result: Any) -> None:
59+
"""
60+
Sync function result to run status.
61+
62+
Looks for 'status' key in result dict. Status can be a string representation,
63+
or integer value.
64+
65+
Example:
66+
async def train_model():
67+
return {
68+
"status": "COMPLETED" # or 9
69+
}
70+
71+
run = exp.run(train_model, post_run_hooks=[
72+
PostRunHookFn.sync_status
73+
])
74+
75+
:param run_id: UUID of the run
76+
:param result: Return value from the run function
77+
"""
78+
if result is None:
79+
return
80+
81+
status = None
82+
83+
# Extract status from dict
84+
if isinstance(result, dict) and "status" in result:
85+
status_value = result["status"]
86+
87+
if isinstance(status_value, str):
88+
try:
89+
status = Status[status_value.upper()]
90+
except (KeyError, AttributeError):
91+
logger.warning(
92+
f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync."
93+
)
94+
return
95+
elif isinstance(status_value, int):
96+
try:
97+
status = Status(status_value)
98+
except ValueError:
99+
logger.warning(
100+
f"PostRunHookFn.sync_status: Invalid status value '{status_value}' for run {run_id}. Skipping status sync."
101+
)
102+
return
103+
104+
if status is not None:
36105
metadb = global_runtime().metadb
37-
metadb.update_run(run_id=run_id, meta=result)
106+
metadb.update_run(run_id=run_id, status=status)

tests/integration/test_run_hooks.py

Lines changed: 96 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from alphatrion.experiment import CraftExperiment, ExperimentConfig
1010
from alphatrion.run import PostRunHookFn
1111
from alphatrion.runtime.runtime import global_runtime
12+
from alphatrion.storage.sql_models import Status
1213

1314

1415
@pytest.fixture
@@ -32,13 +33,15 @@ async def test_run_hook_sync_metadata(test_org_id, test_user_id, test_team_id):
3233
alpha.init(org_id=test_org_id, team_id=test_team_id, user_id=test_user_id)
3334

3435
async def train_model():
35-
"""Function that returns metrics as dict"""
36+
"""Function that returns metrics in nested metadata structure"""
3637
await asyncio.sleep(0.1)
3738
return {
38-
"accuracy": 0.95,
39-
"loss": 0.05,
40-
"learning_rate": 0.001,
41-
"num_epochs": 10,
39+
"metadata": {
40+
"accuracy": 0.95,
41+
"loss": 0.05,
42+
"learning_rate": 0.001,
43+
"num_epochs": 10,
44+
}
4245
}
4346

4447
async with CraftExperiment.start("test_hook_experiment") as exp:
@@ -84,18 +87,44 @@ async def task_with_string_result():
8487
assert run_obj.meta is None or run_obj.meta == {}
8588

8689

90+
@pytest.mark.asyncio
91+
async def test_run_hook_with_dict_without_metadata_key(
92+
test_org_id, test_user_id, test_team_id
93+
):
94+
"""Test sync_metadata hook with dict result but no 'metadata' key"""
95+
alpha.init(org_id=test_org_id, team_id=test_team_id, user_id=test_user_id)
96+
97+
async def task_without_metadata_key():
98+
"""Function that returns dict without 'metadata' key"""
99+
await asyncio.sleep(0.1)
100+
return {"accuracy": 0.95}
101+
102+
async with CraftExperiment.start("test_hook_no_metadata_key") as exp:
103+
run = exp.run(
104+
task_without_metadata_key, post_run_hooks=[PostRunHookFn.sync_metadata]
105+
)
106+
await exp.wait()
107+
108+
# Verify metadata was not updated
109+
metadb = global_runtime().metadb
110+
run_obj = metadb.get_run(run_id=run.id)
111+
112+
# Metadata should be None or empty (hook didn't update it)
113+
assert run_obj.meta is None or run_obj.meta == {}
114+
115+
87116
@pytest.mark.asyncio
88117
async def test_experiment_level_hooks(test_org_id, test_user_id, test_team_id):
89118
"""Test hooks configured at experiment level apply to all runs"""
90119
alpha.init(org_id=test_org_id, team_id=test_team_id, user_id=test_user_id)
91120

92121
async def task1():
93122
await asyncio.sleep(0.1)
94-
return {"task": "task1", "accuracy": 0.92}
123+
return {"metadata": {"task": "task1", "accuracy": 0.92}}
95124

96125
async def task2():
97126
await asyncio.sleep(0.1)
98-
return {"task": "task2", "accuracy": 0.94}
127+
return {"metadata": {"task": "task2", "accuracy": 0.94}}
99128

100129
# Configure experiment with sync_metadata hook
101130
config = ExperimentConfig(post_run_hooks=[PostRunHookFn.sync_metadata])
@@ -138,7 +167,7 @@ def add_custom_info(run_id, result):
138167

139168
async def train_model():
140169
await asyncio.sleep(0.1)
141-
return {"accuracy": 0.95}
170+
return {"metadata": {"accuracy": 0.95}}
142171

143172
async with CraftExperiment.start("test_custom_hook") as exp:
144173
# Use both built-in and custom hooks
@@ -169,7 +198,7 @@ async def test_hook_merges_with_existing_metadata(
169198

170199
async def train_model():
171200
await asyncio.sleep(0.1)
172-
return {"accuracy": 0.96, "loss": 0.04}
201+
return {"metadata": {"accuracy": 0.96, "loss": 0.04}}
173202

174203
async with CraftExperiment.start("test_merge_metadata") as exp:
175204
run = exp.run(train_model, post_run_hooks=[PostRunHookFn.sync_metadata])
@@ -205,7 +234,7 @@ def buggy_hook(run_id, result):
205234

206235
async def train_model():
207236
await asyncio.sleep(0.1)
208-
return {"accuracy": 0.95}
237+
return {"metadata": {"accuracy": 0.95}}
209238

210239
async with CraftExperiment.start("test_hook_failure") as exp:
211240
run = exp.run(
@@ -220,3 +249,60 @@ async def train_model():
220249
metadb = global_runtime().metadb
221250
run_obj = metadb.get_run(run_id=run.id)
222251
assert run_obj.meta["accuracy"] == 0.95
252+
253+
254+
@pytest.mark.asyncio
255+
async def test_both_hooks_together(test_org_id, test_user_id, test_team_id):
256+
"""Test sync_metadata and sync_status hooks working together"""
257+
alpha.init(org_id=test_org_id, team_id=test_team_id, user_id=test_user_id)
258+
259+
async def train_model():
260+
await asyncio.sleep(0.1)
261+
return {
262+
"metadata": {"accuracy": 0.95, "loss": 0.05, "num_epochs": 10},
263+
"status": "failed",
264+
}
265+
266+
async with CraftExperiment.start("test_both_hooks") as exp:
267+
# Use both hooks
268+
run = exp.run(
269+
train_model,
270+
post_run_hooks=[PostRunHookFn.sync_metadata, PostRunHookFn.sync_status],
271+
)
272+
await exp.wait()
273+
274+
# Verify both hooks ran
275+
metadb = global_runtime().metadb
276+
run_obj = metadb.get_run(run_id=run.id)
277+
278+
# From sync_metadata hook
279+
assert run_obj.meta["accuracy"] == 0.95
280+
assert run_obj.meta["loss"] == 0.05
281+
assert run_obj.meta["num_epochs"] == 10
282+
283+
assert run_obj.status == Status.FAILED
284+
285+
286+
@pytest.mark.asyncio
287+
async def test_sync_metadata_with_none(test_org_id, test_user_id, test_team_id):
288+
"""Test that sync_metadata with None result doesn't update metadata"""
289+
alpha.init(org_id=test_org_id, team_id=test_team_id, user_id=test_user_id)
290+
291+
async def task_with_none_result():
292+
"""Function that returns None"""
293+
await asyncio.sleep(0.1)
294+
295+
async with CraftExperiment.start("test_hook_none_result") as exp:
296+
run = exp.run(
297+
task_with_none_result,
298+
post_run_hooks=[PostRunHookFn.sync_metadata, PostRunHookFn.sync_status],
299+
)
300+
await exp.wait()
301+
302+
# Verify metadata was not updated
303+
metadb = global_runtime().metadb
304+
run_obj = metadb.get_run(run_id=run.id)
305+
306+
# Metadata should be None or empty (hook didn't update it)
307+
assert run_obj.meta is None or run_obj.meta == {}
308+
assert run_obj.status == Status.COMPLETED

0 commit comments

Comments
 (0)