2222from causal_testing .testing .causal_effect import Negative , NoEffect , Positive , SomeEffect
2323from causal_testing .testing .causal_test_case import CausalTestCase
2424from causal_testing .testing .causal_test_result import CausalTestResult
25+ from causal_testing .testing .causal_test_adequacy import DataAdequacy
2526
2627logger = logging .getLogger (__name__ )
2728
@@ -335,12 +336,17 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
335336 estimator = estimator ,
336337 )
337338
338- def run_tests_in_batches (self , batch_size : int = 100 , silent : bool = False ) -> List [CausalTestResult ]:
339+ def run_tests_in_batches (
340+ self , batch_size : int = 100 , silent : bool = False , adequacy : bool = False , bootstrap_size : int = 100
341+ ) -> List [CausalTestResult ]:
339342 """
340343 Run tests in batches to reduce memory usage.
341344
342345 :param batch_size: Number of tests to run in each batch
343346 :param silent: Whether to suppress errors
347+ :param adequacy: Whether to calculate causal test adequacy (defaults to False)
348+ :param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
349+ (defaults to 100)
344350 :return: List of all test results
345351 :raises: ValueError if no tests are loaded
346352 """
@@ -368,7 +374,12 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
368374 batch_results = []
369375 for test_case in current_batch :
370376 try :
371- batch_results .append (test_case .execute_test ())
377+ result = test_case .execute_test ()
378+ if adequacy :
379+ result .adequacy = DataAdequacy (test_case = test_case , bootstrap_size = bootstrap_size )
380+ result .adequacy .measure_adequacy ()
381+
382+ batch_results .append (result )
372383 # pylint: disable=broad-exception-caught
373384 except Exception as e :
374385 if not silent :
@@ -383,10 +394,17 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
383394 yield batch_results
384395 logger .info (f"Completed processing in { num_batches } batches" )
385396
386- def run_tests (self , silent = False ) -> List [CausalTestResult ]:
397+ def run_tests (
398+ self , silent : bool = False , adequacy : bool = False , bootstrap_size : int = 100
399+ ) -> List [CausalTestResult ]:
387400 """
388401 Run all test cases and return their results.
389402
403+ :param silent: Whether to suppress errors
404+ :param adequacy: Whether to calculate causal test adequacy (defaults to False)
405+ :param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
406+ (defaults to 100)
407+
390408 :return: List of CausalTestResult objects
391409 :raises: ValueError if no tests are loaded
392410 :raises: Exception if test execution fails
@@ -400,6 +418,9 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
400418 for test_case in tqdm (self .test_cases ):
401419 try :
402420 result = test_case .execute_test ()
421+ if adequacy :
422+ result .adequacy = DataAdequacy (test_case = test_case , bootstrap_size = bootstrap_size )
423+ result .adequacy .measure_adequacy ()
403424 results .append (result )
404425 # pylint: disable=broad-exception-caught
405426 except Exception as e :
@@ -450,6 +471,7 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
450471 "adjustment_set" : list (result .adjustment_set ) if result .adjustment_set else [],
451472 }
452473 | result .effect_estimate .to_dict ()
474+ | (result .adequacy .to_dict () if result .adequacy else {})
453475 if result .effect_estimate
454476 else {"error" : result .error_message }
455477 ),
@@ -522,6 +544,17 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
522544 parser_test .add_argument ("-t" , "--test_config" , help = "Path to test configuration file (.json)" , required = True )
523545 parser_test .add_argument ("-v" , "--verbose" , help = "Enable verbose logging" , action = "store_true" , default = False )
524546 parser_test .add_argument ("-q" , "--query" , help = "Query string to filter data (e.g. 'age > 18')" , type = str )
547+ parser_test .add_argument (
548+ "-a" , "--adequacy" , help = "Calculate causal test adequacy for each test case" , action = "store_true" , default = False
549+ )
550+ parser_test .add_argument (
551+ "-b" ,
552+ "--adequacy-bootstrap-size" ,
553+ dest = "bootstrap_size" ,
554+ help = "Number of bootstrap samples for causal test adequacy. Defaults to 100" ,
555+ type = int ,
556+ default = 100 ,
557+ )
525558 parser_test .add_argument (
526559 "-s" ,
527560 "--silent" ,
@@ -537,5 +570,10 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
537570 )
538571
539572 args = main_parser .parse_args (args )
573+
574+ # Assume the user wants test adequacy if they're setting bootstrap_size
575+ if hasattr (args , "bootstrap_size" ) and args .bootstrap_size :
576+ args .adequacy = True
577+
540578 args .command = Command (args .command )
541579 return args
0 commit comments