@@ -128,6 +128,19 @@ def test_create_base_test_case_missing_outcome(self):
128128 framework .create_base_test ({"treatment_variable" : "test_input" , "expected_effect" : {"missing" : "NoEffect" }})
129129 self .assertEqual ("\" Outcome variable 'missing' not found in inputs or outputs\" " , str (e .exception ))
130130
131+ def test_unloaded_tests (self ):
132+ framework = CausalTestingFramework (self .paths )
133+ with self .assertRaises (ValueError ) as e :
134+ framework .run_tests ()
135+ self .assertEqual ("No tests loaded. Call load_tests() first." , str (e .exception ))
136+
137+ def test_unloaded_tests_batches (self ):
138+ framework = CausalTestingFramework (self .paths )
139+ with self .assertRaises (ValueError ) as e :
140+ # Need the next because of the yield statement in run_tests_in_batches
141+ next (framework .run_tests_in_batches ())
142+ self .assertEqual ("No tests loaded. Call load_tests() first." , str (e .exception ))
143+
131144 def test_ctf (self ):
132145 framework = CausalTestingFramework (self .paths )
133146 framework .setup ()
@@ -136,8 +149,6 @@ def test_ctf(self):
136149 framework .load_tests ()
137150 results = framework .run_tests ()
138151
139- print (results )
140-
141152 # Save results
142153 framework .save_results (results )
143154
@@ -205,7 +216,30 @@ def test_ctf_batches_exception_silent(self):
205216 all_results .extend (json .load (f ))
206217
207218 self .assertEqual ([result ["passed" ] for result in all_results ], [False ])
208- self .assertIsNotNone ([result ["result" ].get ("error" ) for result in all_results ])
219+ self .assertIsNotNone ([result .get ("error" ) for result in all_results ])
220+
221+ def test_ctf_exception_silent (self ):
222+ framework = CausalTestingFramework (self .paths , query = "test_input < 0" )
223+ framework .setup ()
224+
225+ # Load and run tests
226+ framework .load_tests ()
227+
228+ results = framework .run_tests (silent = True )
229+
230+ with open (self .test_config_path , "r" , encoding = "utf-8" ) as f :
231+ test_configs = json .load (f )
232+
233+ tests_passed = [
234+ test_case .expected_causal_effect .apply (result ) if result .effect_estimate is not None else False
235+ for test_config , test_case , result in zip (test_configs ["tests" ], framework .test_cases , results )
236+ ]
237+
238+ self .assertEqual (tests_passed , [False ])
239+ self .assertEqual (
240+ [result .error_message for result in results ],
241+ ["zero-size array to reduction operation maximum which has no identity" ],
242+ )
209243
210244 def test_ctf_batches_exception (self ):
211245 framework = CausalTestingFramework (self .paths , query = "test_input < 0" )
@@ -214,7 +248,7 @@ def test_ctf_batches_exception(self):
214248 # Load and run tests
215249 framework .load_tests ()
216250 with self .assertRaises (ValueError ):
217- list (framework .run_tests_in_batches ())
251+ next (framework .run_tests_in_batches ())
218252
219253 def test_ctf_batches_matches_run_tests (self ):
220254 # Run the tests normally
@@ -318,11 +352,11 @@ def test_parse_args(self):
318352 [
319353 "causal_testing" ,
320354 "test" ,
321- "--dag_path " ,
355+ "--dag-path " ,
322356 str (self .dag_path ),
323- "--data_paths " ,
357+ "--data-paths " ,
324358 str (self .data_paths [0 ]),
325- "--test_config " ,
359+ "--test-config " ,
326360 str (self .test_config_path ),
327361 "--output" ,
328362 str (self .output_path .parent / "main.json" ),
@@ -331,17 +365,110 @@ def test_parse_args(self):
331365 main ()
332366 self .assertTrue ((self .output_path .parent / "main.json" ).exists ())
333367
368+ def test_parse_args_adequacy (self ):
369+ with patch (
370+ "sys.argv" ,
371+ [
372+ "causal_testing" ,
373+ "test" ,
374+ "--dag-path" ,
375+ str (self .dag_path ),
376+ "--data-paths" ,
377+ str (self .data_paths [0 ]),
378+ "--test-config" ,
379+ str (self .test_config_path ),
380+ "--output" ,
381+ str (self .output_path .parent / "main.json" ),
382+ "-a" ,
383+ ],
384+ ):
385+ main ()
386+ with open (self .output_path .parent / "main.json" ) as f :
387+ log = json .load (f )
388+ assert all (test ["result" ]["bootstrap_size" ] == 100 for test in log )
389+
390+ def test_parse_args_adequacy_batches (self ):
391+ with patch (
392+ "sys.argv" ,
393+ [
394+ "causal_testing" ,
395+ "test" ,
396+ "--dag-path" ,
397+ str (self .dag_path ),
398+ "--data-paths" ,
399+ str (self .data_paths [0 ]),
400+ "--test-config" ,
401+ str (self .test_config_path ),
402+ "--output" ,
403+ str (self .output_path .parent / "main.json" ),
404+ "-a" ,
405+ "--batch-size" ,
406+ "5" ,
407+ ],
408+ ):
409+ main ()
410+ with open (self .output_path .parent / "main.json" ) as f :
411+ log = json .load (f )
412+ assert all (test ["result" ]["bootstrap_size" ] == 100 for test in log )
413+
414+ def test_parse_args_bootstrap_size (self ):
415+ with patch (
416+ "sys.argv" ,
417+ [
418+ "causal_testing" ,
419+ "test" ,
420+ "--dag-path" ,
421+ str (self .dag_path ),
422+ "--data-paths" ,
423+ str (self .data_paths [0 ]),
424+ "--test-config" ,
425+ str (self .test_config_path ),
426+ "--output" ,
427+ str (self .output_path .parent / "main.json" ),
428+ "-b" ,
429+ "50" ,
430+ ],
431+ ):
432+ main ()
433+ with open (self .output_path .parent / "main.json" ) as f :
434+ log = json .load (f )
435+ assert all (test ["result" ]["bootstrap_size" ] == 50 for test in log )
436+
437+ def test_parse_args_bootstrap_size_explicit_adequacy (self ):
438+ with patch (
439+ "sys.argv" ,
440+ [
441+ "causal_testing" ,
442+ "test" ,
443+ "--dag-path" ,
444+ str (self .dag_path ),
445+ "--data-paths" ,
446+ str (self .data_paths [0 ]),
447+ "--test-config" ,
448+ str (self .test_config_path ),
449+ "--output" ,
450+ str (self .output_path .parent / "main.json" ),
451+ "-a" ,
452+ "-b" ,
453+ "50" ,
454+ ],
455+ ):
456+ main ()
457+ with open (self .output_path .parent / "main.json" ) as f :
458+ log = json .load (f )
459+ assert all (test ["result" ]["bootstrap_size" ] == 50 for test in log )
460+
334461 def test_parse_args_batches (self ):
335462 with patch (
336463 "sys.argv" ,
337464 [
338465 "causal_testing" ,
339466 "test" ,
340- "--dag_path " ,
467+ "--dag-path " ,
341468 str (self .dag_path ),
342- "--data_paths " ,
469+ "--data-paths " ,
343470 str (self .data_paths [0 ]),
344- "--test_config " ,
471+ "--test-config " ,
345472 str (self .test_config_path ),
346473 "--output" ,
347474 str (self .output_path .parent / "main_batch.json" ),
@@ -359,7 +486,7 @@ def test_parse_args_generation(self):
359486 [
360487 "causal_testing" ,
361488 "generate" ,
362- "--dag_path " ,
489+ "--dag-path " ,
363490 str (self .dag_path ),
364491 "--output" ,
365492 os .path .join (tmp , "tests.json" ),
@@ -375,15 +502,15 @@ def test_parse_args_generation_non_default(self):
375502 [
376503 "causal_testing" ,
377504 "generate" ,
378- "--dag_path " ,
505+ "--dag-path " ,
379506 str (self .dag_path ),
380507 "--output" ,
381508 os .path .join (tmp , "tests_non_default.json" ),
382509 "--estimator" ,
383510 "LogisticRegressionEstimator" ,
384- "--estimate_type " ,
511+ "--estimate-type " ,
385512 "unit_odds_ratio" ,
386- "--effect_type " ,
513+ "--effect-type " ,
387514 "total" ,
388515 ],
389516 ):
0 commit comments