@@ -95,6 +95,9 @@ async def test_log_params():
9595async def test_log_metrics ():
9696 alpha .init (project_id = uuid .uuid4 (), artifact_insecure = True , init_tables = True )
9797
98+ async def log_metric (metrics : dict ):
99+ await alpha .log_metrics (metrics )
100+
98101 async with alpha .CraftExperiment .start (name = "log_metrics_exp" ) as exp :
99102 trial = exp .start_trial (name = "first-trial" , params = {"param1" : 0.1 })
100103
@@ -105,7 +108,8 @@ async def test_log_metrics():
105108 metrics = exp ._runtime ._metadb .list_metrics (trial_id = trial ._id )
106109 assert len (metrics ) == 0
107110
108- await alpha .log_metrics ({"accuracy" : 0.95 , "loss" : 0.1 })
111+ run = trial .start_run (lambda : log_metric ({"accuracy" : 0.95 , "loss" : 0.1 }))
112+ await run .wait ()
109113
110114 metrics = exp ._runtime ._metadb .list_metrics (trial_id = trial ._id )
111115 assert len (metrics ) == 2
@@ -115,14 +119,21 @@ async def test_log_metrics():
115119 assert metrics [1 ].key == "loss"
116120 assert metrics [1 ].value == 0.1
117121 assert metrics [1 ].step == 1
122+ run_id_1 = metrics [0 ].run_id
123+ assert run_id_1 is not None
124+ assert metrics [0 ].run_id == metrics [1 ].run_id
118125
119- await alpha .log_metrics ({"accuracy" : 0.96 })
126+ run = trial .start_run (lambda : log_metric ({"accuracy" : 0.96 }))
127+ await run .wait ()
120128
121129 metrics = exp ._runtime ._metadb .list_metrics (trial_id = trial ._id )
122130 assert len (metrics ) == 3
123131 assert metrics [2 ].key == "accuracy"
124132 assert metrics [2 ].value == 0.96
125133 assert metrics [2 ].step == 2
134+ run_id_2 = metrics [2 ].run_id
135+ assert run_id_2 is not None
136+ assert run_id_2 != run_id_1
126137
127138 trial .cancel ()
128139
@@ -131,6 +142,9 @@ async def test_log_metrics():
131142async def test_log_metrics_with_save_on_max ():
132143 alpha .init (project_id = uuid .uuid4 (), artifact_insecure = True , init_tables = True )
133144
145+ async def log_metric (value : float ):
146+ await alpha .log_metrics ({"accuracy" : value })
147+
134148 async with alpha .CraftExperiment .start (
135149 name = "log_metrics_with_save_on_max" ,
136150 description = "Context manager test" ,
@@ -139,7 +153,7 @@ async def test_log_metrics_with_save_on_max():
139153 with tempfile .TemporaryDirectory () as tmpdir :
140154 os .chdir (tmpdir )
141155
142- _ = exp .start_trial (
156+ trial = exp .start_trial (
143157 name = "trial-with-save_on_best" ,
144158 config = alpha .TrialConfig (
145159 checkpoint = alpha .CheckpointConfig (
@@ -156,35 +170,47 @@ async def test_log_metrics_with_save_on_max():
156170 with open (file1 , "w" ) as f :
157171 f .write ("This is file1." )
158172
159- await alpha .log_metrics ({"accuracy" : 0.90 })
173+ run = trial .start_run (lambda : log_metric (0.90 ))
174+ await run .wait ()
160175
161176 versions = exp ._runtime ._artifact .list_versions (exp .id )
162177 assert len (versions ) == 1
163178
164179 # To avoid the same timestamp hash, we wait for 1 second
165180 time .sleep (1 )
166181
167- await alpha .log_metrics ({"accuracy" : 0.78 })
182+ run = trial .start_run (lambda : log_metric (0.78 ))
183+ await run .wait ()
184+
168185 versions = exp ._runtime ._artifact .list_versions (exp .id )
169186 assert len (versions ) == 1
170187
171188 time .sleep (1 )
172189
173- await alpha .log_metrics ({"accuracy" : 0.91 })
190+ run = trial .start_run (lambda : log_metric (0.91 ))
191+ await run .wait ()
192+
174193 versions = exp ._runtime ._artifact .list_versions (exp .id )
175194 assert len (versions ) == 2
176195
177196 time .sleep (1 )
178197
179- await alpha .log_metrics ({"accuracy2" : 0.98 })
198+ run = trial .start_run (lambda : log_metric (0.98 ))
199+ await run .wait ()
200+
180201 versions = exp ._runtime ._artifact .list_versions (exp .id )
181- assert len (versions ) == 2
202+ assert len (versions ) == 3
203+
204+ trial .cancel ()
182205
183206
184207@pytest .mark .asyncio
185208async def test_log_metrics_with_save_on_min ():
186209 alpha .init (project_id = uuid .uuid4 (), artifact_insecure = True , init_tables = True )
187210
211+ async def log_metric (value : float ):
212+ await alpha .log_metrics ({"accuracy" : value })
213+
188214 async with alpha .CraftExperiment .start (
189215 name = "log_metrics_with_save_on_min" ,
190216 description = "Context manager test" ,
@@ -193,7 +219,7 @@ async def test_log_metrics_with_save_on_min():
193219 with tempfile .TemporaryDirectory () as tmpdir :
194220 os .chdir (tmpdir )
195221
196- _ = exp .start_trial (
222+ trial = exp .start_trial (
197223 name = "trial-with-save_on_best" ,
198224 config = alpha .TrialConfig (
199225 checkpoint = alpha .CheckpointConfig (
@@ -210,29 +236,37 @@ async def test_log_metrics_with_save_on_min():
210236 with open (file1 , "w" ) as f :
211237 f .write ("This is file1." )
212238
213- await alpha .log_metrics ({"accuracy" : 0.30 })
239+ run = trial .start_run (lambda : log_metric (0.30 ))
240+ await run .wait ()
214241
215242 versions = exp ._runtime ._artifact .list_versions (exp .id )
216243 assert len (versions ) == 1
217244
218245 # To avoid the same timestamp hash, we wait for 1 second
219246 time .sleep (1 )
220247
221- await alpha .log_metrics ({"accuracy" : 0.58 })
248+ run = trial .start_run (lambda : log_metric (0.58 ))
249+ await run .wait ()
250+
222251 versions = exp ._runtime ._artifact .list_versions (exp .id )
223252 assert len (versions ) == 1
224253
225254 time .sleep (1 )
226255
227- await alpha .log_metrics ({"accuracy" : 0.21 })
256+ run = trial .start_run (lambda : log_metric (0.21 ))
257+ await run .wait ()
258+
228259 versions = exp ._runtime ._artifact .list_versions (exp .id )
229260 assert len (versions ) == 2
230261
231262 time .sleep (1 )
232263
233- await alpha .log_metrics ({"accuracy2" : 0.18 })
264+ task = trial .start_run (lambda : log_metric (0.18 ))
265+ await task .wait ()
234266 versions = exp ._runtime ._artifact .list_versions (exp .id )
235- assert len (versions ) == 2
267+ assert len (versions ) == 3
268+
269+ trial .cancel ()
236270
237271
238272@pytest .mark .asyncio
0 commit comments