Skip to content

Commit 31aa186

Browse files
authored
Merge pull request #365 from CITCOM-project/jmafoster1/run-test-adequacy
Jmafoster1/run test adequacy
2 parents 2ed5027 + 267191c commit 31aa186

4 files changed

Lines changed: 56 additions & 14 deletions

File tree

causal_testing/__main__.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def main() -> None:
3131
effect_type=args.effect_type,
3232
estimate_type=args.estimate_type,
3333
estimator=args.estimator,
34-
skip=True,
34+
skip=False,
3535
)
3636
logging.info("Causal test generation completed successfully")
3737
return
@@ -58,7 +58,14 @@ def main() -> None:
5858
logging.info(f"Running tests in batches of size {args.batch_size}")
5959
with tempfile.TemporaryDirectory() as tmpdir:
6060
output_files = []
61-
for i, results in enumerate(framework.run_tests_in_batches(batch_size=args.batch_size, silent=args.silent)):
61+
for i, results in enumerate(
62+
framework.run_tests_in_batches(
63+
batch_size=args.batch_size,
64+
silent=args.silent,
65+
adequacy=args.adequacy,
66+
bootstrap_size=args.bootstrap_size,
67+
)
68+
):
6269
temp_file_path = os.path.join(tmpdir, f"output_{i}.json")
6370
framework.save_results(results, temp_file_path)
6471
output_files.append(temp_file_path)
@@ -77,7 +84,7 @@ def main() -> None:
7784
json.dump(all_results, f, indent=4)
7885
else:
7986
logging.info("Running tests in regular mode")
80-
results = framework.run_tests(silent=args.silent)
87+
results = framework.run_tests(silent=args.silent, adequacy=args.adequacy, bootstrap_size=args.bootstrap_size)
8188
framework.save_results(results)
8289

8390
logging.info("Causal testing completed successfully.")

causal_testing/main.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from causal_testing.testing.causal_effect import Negative, NoEffect, Positive, SomeEffect
2323
from causal_testing.testing.causal_test_case import CausalTestCase
2424
from causal_testing.testing.causal_test_result import CausalTestResult
25+
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -335,12 +336,17 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
335336
estimator=estimator,
336337
)
337338

338-
def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> List[CausalTestResult]:
339+
def run_tests_in_batches(
340+
self, batch_size: int = 100, silent: bool = False, adequacy: bool = False, bootstrap_size: int = 100
341+
) -> List[CausalTestResult]:
339342
"""
340343
Run tests in batches to reduce memory usage.
341344
342345
:param batch_size: Number of tests to run in each batch
343346
:param silent: Whether to suppress errors
347+
:param adequacy: Whether to calculate causal test adequacy (defaults to False)
348+
:param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
349+
(defaults to 100)
344350
:return: List of all test results
345351
:raises: ValueError if no tests are loaded
346352
"""
@@ -368,7 +374,12 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
368374
batch_results = []
369375
for test_case in current_batch:
370376
try:
371-
batch_results.append(test_case.execute_test())
377+
result = test_case.execute_test()
378+
if adequacy:
379+
result.adequacy = DataAdequacy(test_case=test_case, bootstrap_size=bootstrap_size)
380+
result.adequacy.measure_adequacy()
381+
382+
batch_results.append(result)
372383
# pylint: disable=broad-exception-caught
373384
except Exception as e:
374385
if not silent:
@@ -383,10 +394,17 @@ def run_tests_in_batches(self, batch_size: int = 100, silent: bool = False) -> L
383394
yield batch_results
384395
logger.info(f"Completed processing in {num_batches} batches")
385396

386-
def run_tests(self, silent=False) -> List[CausalTestResult]:
397+
def run_tests(
398+
self, silent: bool = False, adequacy: bool = False, bootstrap_size: int = 100
399+
) -> List[CausalTestResult]:
387400
"""
388401
Run all test cases and return their results.
389402
403+
:param silent: Whether to suppress errors
404+
:param adequacy: Whether to calculate causal test adequacy (defaults to False)
405+
:param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
406+
(defaults to 100)
407+
390408
:return: List of CausalTestResult objects
391409
:raises: ValueError if no tests are loaded
392410
:raises: Exception if test execution fails
@@ -400,6 +418,9 @@ def run_tests(self, silent=False) -> List[CausalTestResult]:
400418
for test_case in tqdm(self.test_cases):
401419
try:
402420
result = test_case.execute_test()
421+
if adequacy:
422+
result.adequacy = DataAdequacy(test_case=test_case, bootstrap_size=bootstrap_size)
423+
result.adequacy.measure_adequacy()
403424
results.append(result)
404425
# pylint: disable=broad-exception-caught
405426
except Exception as e:
@@ -450,6 +471,7 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
450471
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
451472
}
452473
| result.effect_estimate.to_dict()
474+
| (result.adequacy.to_dict() if result.adequacy else {})
453475
if result.effect_estimate
454476
else {"error": result.error_message}
455477
),
@@ -522,6 +544,17 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
522544
parser_test.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
523545
parser_test.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
524546
parser_test.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
547+
parser_test.add_argument(
548+
"-a", "--adequacy", help="Calculate causal test adequacy for each test case", action="store_true", default=False
549+
)
550+
parser_test.add_argument(
551+
"-b",
552+
"--adequacy-bootstrap-size",
553+
dest="bootstrap_size",
554+
help="Number of bootstrap samples for causal test adequacy. Defaults to 100",
555+
type=int,
556+
default=100,
557+
)
525558
parser_test.add_argument(
526559
"-s",
527560
"--silent",
@@ -537,5 +570,10 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
537570
)
538571

539572
args = main_parser.parse_args(args)
573+
574+
# Assume the user wants test adequacy if they're setting bootstrap_size
575+
if hasattr(args, "bootstrap_size") and args.bootstrap_size:
576+
args.adequacy = True
577+
540578
args.command = Command(args.command)
541579
return args

causal_testing/testing/causal_test_adequacy.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from lifelines.exceptions import ConvergenceError
1111
from numpy.linalg import LinAlgError
1212

13-
from causal_testing.estimation.abstract_estimator import Estimator
1413
from causal_testing.specification.causal_dag import CausalDAG
1514
from causal_testing.testing.causal_test_case import CausalTestCase
1615

@@ -79,12 +78,10 @@ class DataAdequacy:
7978
def __init__(
8079
self,
8180
test_case: CausalTestCase,
82-
estimator: Estimator,
8381
bootstrap_size: int = 100,
8482
group_by=None,
8583
):
8684
self.test_case = test_case
87-
self.estimator = estimator
8885
self.kurtosis = None
8986
self.outcomes = None
9087
self.successful = None
@@ -97,7 +94,7 @@ def measure_adequacy(self):
9794
"""
9895
results = []
9996
for i in range(self.bootstrap_size):
100-
estimator = deepcopy(self.estimator)
97+
estimator = deepcopy(self.test_case.estimator)
10198

10299
if self.group_by is not None:
103100
ids = pd.Series(estimator.df[self.group_by].unique())
@@ -120,7 +117,7 @@ def measure_adequacy(self):
120117
results = pd.concat([c.effect_estimate.to_df() for c in results])
121118
results["var"] = results.index
122119

123-
self.kurtosis = results.groupby("var").apply(lambda x: x.kurtosis())["effect_estimate"]
120+
self.kurtosis = results.groupby("var")["effect_estimate"].apply(lambda x: x.kurtosis())
124121
self.outcomes = sum(filter(lambda x: x is not None, outcomes))
125122
self.successful = sum(x is not None for x in outcomes)
126123

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def test_data_adequacy_numeric(self):
4646
estimate_type="coefficient",
4747
estimator=estimator,
4848
)
49-
adequacy_metric = DataAdequacy(causal_test_case, estimator)
49+
adequacy_metric = DataAdequacy(causal_test_case)
5050
adequacy_metric.measure_adequacy()
5151
self.assertEqual(
5252
adequacy_metric.to_dict(),
@@ -66,7 +66,7 @@ def test_data_adequacy_categorical(self):
6666
estimate_type="coefficient",
6767
estimator=estimator,
6868
)
69-
adequacy_metric = DataAdequacy(causal_test_case, estimator)
69+
adequacy_metric = DataAdequacy(causal_test_case)
7070
adequacy_metric.measure_adequacy()
7171
self.assertEqual(
7272
adequacy_metric.to_dict(),
@@ -100,7 +100,7 @@ def test_data_adequacy_group_by(self):
100100
estimate_type="hazard_ratio",
101101
estimator=estimation_model,
102102
)
103-
adequacy_metric = DataAdequacy(causal_test_case, estimation_model, group_by="id")
103+
adequacy_metric = DataAdequacy(causal_test_case, group_by="id")
104104
adequacy_metric.measure_adequacy()
105105
adequacy_dict = adequacy_metric.to_dict()
106106
self.assertEqual(round(adequacy_dict["kurtosis"]["trtrand"], 3), -0.857)

0 commit comments

Comments
 (0)