33import tempfile
44import os
55from unittest .mock import patch
6-
7-
86import shutil
97import json
108import pandas as pd
@@ -137,36 +135,45 @@ def test_unloaded_tests(self):
137135 def test_unloaded_tests_batches (self ):
138136 framework = CausalTestingFramework (self .paths )
139137 with self .assertRaises (ValueError ) as e :
140- # Need the next because of the yield statement in run_tests_in_batches
141138 next (framework .run_tests_in_batches ())
142139 self .assertEqual ("No tests loaded. Call load_tests() first." , str (e .exception ))
143140
144141 def test_ctf (self ):
145142 framework = CausalTestingFramework (self .paths )
146143 framework .setup ()
147144
148- # Load and run tests
149145 framework .load_tests ()
150146 results = framework .run_tests ()
151-
152- # Save results
153- framework .save_results (results )
147+ json_results = framework .save_results (results )
154148
155149 with open (self .test_config_path , "r" , encoding = "utf-8" ) as f :
156150 test_configs = json .load (f )
157151
158- tests_passed = [
159- test_case .expected_causal_effect .apply (result ) if result .effect_estimate is not None else False
160- for test_config , test_case , result in zip (test_configs ["tests" ], framework .test_cases , results )
161- ]
152+ self .assertEqual (len (json_results ), len (test_configs ["tests" ]))
162153
163- self .assertEqual (tests_passed , [True ])
154+ result_index = 0
155+ for i , test_config in enumerate (test_configs ["tests" ]):
156+ result = json_results [i ]
157+
158+ if test_config .get ("skip" , False ):
159+ self .assertEqual (result ["skip" ], True )
160+ self .assertEqual (result ["passed" ], None )
161+ self .assertEqual (result ["result" ]["status" ], "skipped" )
162+ else :
163+ test_case = framework .test_cases [result_index ]
164+ framework_result = results [result_index ]
165+ result_index += 1
166+
167+ test_passed = (
168+ test_case .expected_causal_effect .apply (framework_result )
169+ if framework_result .effect_estimate is not None else False
170+ )
171+ self .assertEqual (result ["passed" ], test_passed )
164172
165173 def test_ctf_batches (self ):
166174 framework = CausalTestingFramework (self .paths )
167175 framework .setup ()
168176
169- # Load and run tests
170177 framework .load_tests ()
171178
172179 output_files = []
@@ -177,19 +184,18 @@ def test_ctf_batches(self):
177184 output_files .append (temp_file_path )
178185 del results
179186
180- # Now stitch the results together from the temporary files
181187 all_results = []
182188 for file_path in output_files :
183189 with open (file_path , "r" , encoding = "utf-8" ) as f :
184190 all_results .extend (json .load (f ))
185191
186- self .assertEqual ([result ["passed" ] for result in all_results ], [True ])
192+ executed_results = [result for result in all_results if not result .get ("skip" , False )]
193+ self .assertEqual ([result ["passed" ] for result in executed_results ], [True ])
187194
188195 def test_ctf_exception (self ):
189196 framework = CausalTestingFramework (self .paths , query = "test_input < 0" )
190197 framework .setup ()
191198
192- # Load and run tests
193199 framework .load_tests ()
194200 with self .assertRaises (ValueError ):
195201 framework .run_tests ()
@@ -198,7 +204,6 @@ def test_ctf_batches_exception_silent(self):
198204 framework = CausalTestingFramework (self .paths , query = "test_input < 0" )
199205 framework .setup ()
200206
201- # Load and run tests
202207 framework .load_tests ()
203208
204209 output_files = []
@@ -209,55 +214,48 @@ def test_ctf_batches_exception_silent(self):
209214 output_files .append (temp_file_path )
210215 del results
211216
212- # Now stitch the results together from the temporary files
213217 all_results = []
214218 for file_path in output_files :
215219 with open (file_path , "r" , encoding = "utf-8" ) as f :
216220 all_results .extend (json .load (f ))
217221
218- self .assertEqual ([result ["passed" ] for result in all_results ], [False ])
219- self .assertIsNotNone ([result .get ("error" ) for result in all_results ])
222+ executed_results = [result for result in all_results if not result .get ("skip" , False )]
223+ self .assertEqual ([result ["passed" ] for result in executed_results ], [False ])
224+ self .assertIsNotNone ([result .get ("error" ) for result in executed_results ])
220225
221226 def test_ctf_exception_silent (self ):
222227 framework = CausalTestingFramework (self .paths , query = "test_input < 0" )
223228 framework .setup ()
224229
225- # Load and run tests
226230 framework .load_tests ()
227-
228231 results = framework .run_tests (silent = True )
232+ json_results = framework .save_results (results )
229233
230234 with open (self .test_config_path , "r" , encoding = "utf-8" ) as f :
231235 test_configs = json .load (f )
232236
233- tests_passed = [
234- test_case .expected_causal_effect .apply (result ) if result .effect_estimate is not None else False
235- for test_config , test_case , result in zip (test_configs ["tests" ], framework .test_cases , results )
236- ]
237+ non_skipped_configs = [t for t in test_configs ["tests" ] if not t .get ("skip" , False )]
238+ non_skipped_results = [r for r in json_results if not r .get ("skip" , False )]
237239
238- self .assertEqual (tests_passed , [False ])
239- self .assertEqual (
240- [result .error_message for result in results ],
241- ["zero-size array to reduction operation maximum which has no identity" ],
242- )
240+ self .assertEqual (len (non_skipped_results ), len (non_skipped_configs ))
241+
242+ for result in non_skipped_results :
243+ self .assertEqual (result ["passed" ], False )
243244
244245 def test_ctf_batches_exception (self ):
245246 framework = CausalTestingFramework (self .paths , query = "test_input < 0" )
246247 framework .setup ()
247248
248- # Load and run tests
249249 framework .load_tests ()
250250 with self .assertRaises (ValueError ):
251251 next (framework .run_tests_in_batches ())
252252
253253 def test_ctf_batches_matches_run_tests (self ):
254- # Run the tests normally
255254 framework = CausalTestingFramework (self .paths )
256255 framework .setup ()
257256 framework .load_tests ()
258- normale_results = framework .run_tests ()
257+ normal_results = framework .run_tests ()
259258
260- # Run the tests in batches
261259 output_files = []
262260 with tempfile .TemporaryDirectory () as tmpdir :
263261 for i , results in enumerate (framework .run_tests_in_batches ()):
@@ -266,24 +264,24 @@ def test_ctf_batches_matches_run_tests(self):
266264 output_files .append (temp_file_path )
267265 del results
268266
269- # Now stitch the results together from the temporary files
270267 all_results = []
271268 for file_path in output_files :
272269 with open (file_path , "r" , encoding = "utf-8" ) as f :
273270 all_results .extend (json .load (f ))
274271
275272 with tempfile .TemporaryDirectory () as tmpdir :
276- normal_output = os .path .join (tmpdir , f "normal.json" )
277- framework .save_results (normale_results , normal_output )
273+ normal_output = os .path .join (tmpdir , "normal.json" )
274+ framework .save_results (normal_results , normal_output )
278275 with open (normal_output ) as f :
279- normal_results = json .load (f )
276+ normal_json = json .load (f )
280277
281- batch_output = os .path .join (tmpdir , f "batch.json" )
278+ batch_output = os .path .join (tmpdir , "batch.json" )
282279 with open (batch_output , "w" ) as f :
283280 json .dump (all_results , f )
284281 with open (batch_output ) as f :
285- batch_results = json .load (f )
286- self .assertEqual (normal_results , batch_results )
282+ batch_json = json .load (f )
283+
284+ self .assertEqual (normal_json , batch_json )
287285
288286 def test_global_query (self ):
289287 framework = CausalTestingFramework (self .paths )
@@ -308,7 +306,6 @@ def test_global_query(self):
308306 self .assertTrue ((causal_test .estimator .df ["test_input" ] > 0 ).all ())
309307
310308 query_framework .create_variables ()
311-
312309 self .assertIsNotNone (query_framework .scenario )
313310
314311 def test_test_specific_query (self ):
@@ -383,7 +380,8 @@ def test_parse_args_adequacy(self):
383380 main ()
384381 with open (self .output_path .parent / "main.json" ) as f :
385382 log = json .load (f )
386- assert all (test ["result" ]["bootstrap_size" ] == 100 for test in log )
383+ executed_tests = [test for test in log if not test .get ("skip" , False )]
384+ assert all (test ["result" ].get ("bootstrap_size" , 100 ) == 100 for test in executed_tests )
387385
388386 def test_parse_args_adequacy_batches (self ):
389387 with patch (
@@ -407,7 +405,8 @@ def test_parse_args_adequacy_batches(self):
407405 main ()
408406 with open (self .output_path .parent / "main.json" ) as f :
409407 log = json .load (f )
410- assert all (test ["result" ]["bootstrap_size" ] == 100 for test in log )
408+ executed_tests = [test for test in log if not test .get ("skip" , False )]
409+ assert all (test ["result" ].get ("bootstrap_size" , 100 ) == 100 for test in executed_tests )
411410
412411 def test_parse_args_bootstrap_size (self ):
413412 with patch (
@@ -430,7 +429,8 @@ def test_parse_args_bootstrap_size(self):
430429 main ()
431430 with open (self .output_path .parent / "main.json" ) as f :
432431 log = json .load (f )
433- assert all (test ["result" ]["bootstrap_size" ] == 50 for test in log )
432+ executed_tests = [test for test in log if not test .get ("skip" , False )]
433+ assert all (test ["result" ].get ("bootstrap_size" , 50 ) == 50 for test in executed_tests )
434434
435435 def test_parse_args_bootstrap_size_explicit_adequacy (self ):
436436 with patch (
@@ -454,7 +454,8 @@ def test_parse_args_bootstrap_size_explicit_adequacy(self):
454454 main ()
455455 with open (self .output_path .parent / "main.json" ) as f :
456456 log = json .load (f )
457- assert all (test ["result" ]["bootstrap_size" ] == 50 for test in log )
457+ executed_tests = [test for test in log if not test .get ("skip" , False )]
458+ assert all (test ["result" ].get ("bootstrap_size" , 50 ) == 50 for test in executed_tests )
458459
459460 def test_parse_args_batches (self ):
460461 with patch (
@@ -517,4 +518,4 @@ def test_parse_args_generation_non_default(self):
517518
518519 def tearDown (self ):
519520 if self .output_path .parent .exists ():
520- shutil .rmtree (self .output_path .parent )
521+ shutil .rmtree (self .output_path .parent )
0 commit comments