99from alphatrion .experiment import CraftExperiment , ExperimentConfig
1010from alphatrion .run import PostRunHookFn
1111from alphatrion .runtime .runtime import global_runtime
12+ from alphatrion .storage .sql_models import Status
1213
1314
1415@pytest .fixture
@@ -32,13 +33,15 @@ async def test_run_hook_sync_metadata(test_org_id, test_user_id, test_team_id):
3233 alpha .init (org_id = test_org_id , team_id = test_team_id , user_id = test_user_id )
3334
3435 async def train_model ():
35- """Function that returns metrics as dict """
36+ """Function that returns metrics in nested metadata structure """
3637 await asyncio .sleep (0.1 )
3738 return {
38- "accuracy" : 0.95 ,
39- "loss" : 0.05 ,
40- "learning_rate" : 0.001 ,
41- "num_epochs" : 10 ,
39+ "metadata" : {
40+ "accuracy" : 0.95 ,
41+ "loss" : 0.05 ,
42+ "learning_rate" : 0.001 ,
43+ "num_epochs" : 10 ,
44+ }
4245 }
4346
4447 async with CraftExperiment .start ("test_hook_experiment" ) as exp :
@@ -84,18 +87,44 @@ async def task_with_string_result():
8487 assert run_obj .meta is None or run_obj .meta == {}
8588
8689
90+ @pytest .mark .asyncio
91+ async def test_run_hook_with_dict_without_metadata_key (
92+ test_org_id , test_user_id , test_team_id
93+ ):
94+ """Test sync_metadata hook with dict result but no 'metadata' key"""
95+ alpha .init (org_id = test_org_id , team_id = test_team_id , user_id = test_user_id )
96+
97+ async def task_without_metadata_key ():
98+ """Function that returns dict without 'metadata' key"""
99+ await asyncio .sleep (0.1 )
100+ return {"accuracy" : 0.95 }
101+
102+ async with CraftExperiment .start ("test_hook_no_metadata_key" ) as exp :
103+ run = exp .run (
104+ task_without_metadata_key , post_run_hooks = [PostRunHookFn .sync_metadata ]
105+ )
106+ await exp .wait ()
107+
108+ # Verify metadata was not updated
109+ metadb = global_runtime ().metadb
110+ run_obj = metadb .get_run (run_id = run .id )
111+
112+ # Metadata should be None or empty (hook didn't update it)
113+ assert run_obj .meta is None or run_obj .meta == {}
114+
115+
87116@pytest .mark .asyncio
88117async def test_experiment_level_hooks (test_org_id , test_user_id , test_team_id ):
89118 """Test hooks configured at experiment level apply to all runs"""
90119 alpha .init (org_id = test_org_id , team_id = test_team_id , user_id = test_user_id )
91120
92121 async def task1 ():
93122 await asyncio .sleep (0.1 )
94- return {"task" : "task1" , "accuracy" : 0.92 }
123+ return {"metadata" : { " task" : "task1" , "accuracy" : 0.92 } }
95124
96125 async def task2 ():
97126 await asyncio .sleep (0.1 )
98- return {"task" : "task2" , "accuracy" : 0.94 }
127+ return {"metadata" : { " task" : "task2" , "accuracy" : 0.94 } }
99128
100129 # Configure experiment with sync_metadata hook
101130 config = ExperimentConfig (post_run_hooks = [PostRunHookFn .sync_metadata ])
@@ -138,7 +167,7 @@ def add_custom_info(run_id, result):
138167
139168 async def train_model ():
140169 await asyncio .sleep (0.1 )
141- return {"accuracy" : 0.95 }
170+ return {"metadata" : { " accuracy" : 0.95 } }
142171
143172 async with CraftExperiment .start ("test_custom_hook" ) as exp :
144173 # Use both built-in and custom hooks
@@ -169,7 +198,7 @@ async def test_hook_merges_with_existing_metadata(
169198
170199 async def train_model ():
171200 await asyncio .sleep (0.1 )
172- return {"accuracy" : 0.96 , "loss" : 0.04 }
201+ return {"metadata" : { " accuracy" : 0.96 , "loss" : 0.04 } }
173202
174203 async with CraftExperiment .start ("test_merge_metadata" ) as exp :
175204 run = exp .run (train_model , post_run_hooks = [PostRunHookFn .sync_metadata ])
@@ -205,7 +234,7 @@ def buggy_hook(run_id, result):
205234
206235 async def train_model ():
207236 await asyncio .sleep (0.1 )
208- return {"accuracy" : 0.95 }
237+ return {"metadata" : { " accuracy" : 0.95 } }
209238
210239 async with CraftExperiment .start ("test_hook_failure" ) as exp :
211240 run = exp .run (
@@ -220,3 +249,60 @@ async def train_model():
220249 metadb = global_runtime ().metadb
221250 run_obj = metadb .get_run (run_id = run .id )
222251 assert run_obj .meta ["accuracy" ] == 0.95
252+
253+
254+ @pytest .mark .asyncio
255+ async def test_both_hooks_together (test_org_id , test_user_id , test_team_id ):
256+ """Test sync_metadata and sync_status hooks working together"""
257+ alpha .init (org_id = test_org_id , team_id = test_team_id , user_id = test_user_id )
258+
259+ async def train_model ():
260+ await asyncio .sleep (0.1 )
261+ return {
262+ "metadata" : {"accuracy" : 0.95 , "loss" : 0.05 , "num_epochs" : 10 },
263+ "status" : "failed" ,
264+ }
265+
266+ async with CraftExperiment .start ("test_both_hooks" ) as exp :
267+ # Use both hooks
268+ run = exp .run (
269+ train_model ,
270+ post_run_hooks = [PostRunHookFn .sync_metadata , PostRunHookFn .sync_status ],
271+ )
272+ await exp .wait ()
273+
274+ # Verify both hooks ran
275+ metadb = global_runtime ().metadb
276+ run_obj = metadb .get_run (run_id = run .id )
277+
278+ # From sync_metadata hook
279+ assert run_obj .meta ["accuracy" ] == 0.95
280+ assert run_obj .meta ["loss" ] == 0.05
281+ assert run_obj .meta ["num_epochs" ] == 10
282+
283+ assert run_obj .status == Status .FAILED
284+
285+
286+ @pytest .mark .asyncio
287+ async def test_sync_metadata_with_none (test_org_id , test_user_id , test_team_id ):
288+ """Test that sync_metadata with None result doesn't update metadata"""
289+ alpha .init (org_id = test_org_id , team_id = test_team_id , user_id = test_user_id )
290+
291+ async def task_with_none_result ():
292+ """Function that returns None"""
293+ await asyncio .sleep (0.1 )
294+
295+ async with CraftExperiment .start ("test_hook_none_result" ) as exp :
296+ run = exp .run (
297+ task_with_none_result ,
298+ post_run_hooks = [PostRunHookFn .sync_metadata , PostRunHookFn .sync_status ],
299+ )
300+ await exp .wait ()
301+
302+ # Verify metadata was not updated
303+ metadb = global_runtime ().metadb
304+ run_obj = metadb .get_run (run_id = run .id )
305+
306+ # Metadata should be None or empty (hook didn't update it)
307+ assert run_obj .meta is None or run_obj .meta == {}
308+ assert run_obj .status == Status .COMPLETED
0 commit comments