Skip to content

Commit f3d2915

Browse files
authored
Merge pull request #369 from CITCOM-project/jmafoster1/named-confidence-intervals
Jmafoster1/named confidence intervals
2 parents 5449dda + 30a4250 commit f3d2915

4 files changed

Lines changed: 168 additions & 28 deletions

File tree

causal_testing/estimation/logistic_regression_estimator.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,20 @@ def estimate_unit_odds_ratio(self) -> EffectEstimate:
4040
:return: The odds ratio. Confidence intervals are not yet supported.
4141
"""
4242
model = self.fit_model(self.df)
43-
ci_low, ci_high = np.exp(model.conf_int(self.alpha).loc[self.base_test_case.treatment_variable.name])
44-
return EffectEstimate(
43+
44+
treatment_columns = [
45+
param
46+
for param in model.params.index
47+
if param == self.base_test_case.treatment_variable.name
48+
or param.startswith(self.base_test_case.treatment_variable.name + "[")
49+
]
50+
51+
confidence_intervals = np.exp(model.conf_int(self.alpha).loc[treatment_columns])
52+
53+
result = EffectEstimate(
4554
"unit_odds_ratio",
46-
pd.Series(np.exp(model.params[self.base_test_case.treatment_variable.name])),
47-
pd.Series(ci_low),
48-
pd.Series(ci_high),
55+
pd.Series(np.exp(model.params[treatment_columns])),
56+
pd.Series(confidence_intervals[0]),
57+
pd.Series(confidence_intervals[1]),
4958
)
59+
return result

causal_testing/main.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
508508

509509
# Generation
510510
parser_generate = subparsers.add_parser(Command.GENERATE.value, help="Generate causal tests from a DAG")
511-
parser_generate.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
511+
parser_generate.add_argument("-D", "--dag-path", help="Path to the DAG file (.dot)", required=True)
512512
parser_generate.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
513513
parser_generate.add_argument(
514514
"-e",
@@ -518,13 +518,13 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
518518
)
519519
parser_generate.add_argument(
520520
"-T",
521-
"--effect_type",
521+
"--effect-type",
522522
help="The effect type to estimate {direct, total}",
523523
default="direct",
524524
)
525525
parser_generate.add_argument(
526526
"-E",
527-
"--estimate_type",
527+
"--estimate-type",
528528
help="The estimate type to use when evaluating tests (defaults to coefficient)",
529529
default="coefficient",
530530
)
@@ -537,11 +537,11 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
537537

538538
# Testing
539539
parser_test = subparsers.add_parser(Command.TEST.value, help="Run causal tests")
540-
parser_test.add_argument("-D", "--dag_path", help="Path to the DAG file (.dot)", required=True)
540+
parser_test.add_argument("-D", "--dag-path", help="Path to the DAG file (.dot)", required=True)
541541
parser_test.add_argument("-o", "--output", help="Path for output file (.json)", required=True)
542542
parser_test.add_argument("-i", "--ignore-cycles", help="Ignore cycles in DAG", action="store_true", default=False)
543-
parser_test.add_argument("-d", "--data_paths", help="Paths to data files (.csv)", nargs="+", required=True)
544-
parser_test.add_argument("-t", "--test_config", help="Path to test configuration file (.json)", required=True)
543+
parser_test.add_argument("-d", "--data-paths", help="Paths to data files (.csv)", nargs="+", required=True)
544+
parser_test.add_argument("-t", "--test-config", help="Path to test configuration file (.json)", required=True)
545545
parser_test.add_argument("-v", "--verbose", help="Enable verbose logging", action="store_true", default=False)
546546
parser_test.add_argument("-q", "--query", help="Query string to filter data (e.g. 'age > 18')", type=str)
547547
parser_test.add_argument(
@@ -553,7 +553,6 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
553553
dest="bootstrap_size",
554554
help="Number of bootstrap samples for causal test adequacy. Defaults to 100",
555555
type=int,
556-
default=100,
557556
)
558557
parser_test.add_argument(
559558
"-s",
@@ -572,8 +571,12 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
572571
args = main_parser.parse_args(args)
573572

574573
# Assume the user wants test adequacy if they're setting bootstrap_size
575-
if hasattr(args, "bootstrap_size") and args.bootstrap_size:
574+
print(args)
575+
if getattr(args, "bootstrap_size", None) is not None:
576576
args.adequacy = True
577+
if getattr(args, "adequacy", False) and getattr(args, "bootstrap_size", None) is None:
578+
# Need this here rather than a default value because otherwise the above always sets adequacy to True
579+
args.bootstrap_size = 100
577580

578581
args.command = Command(args.command)
579582
return args

tests/estimation_tests/test_logistic_regression_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def test_odds_ratio(self):
2020
BaseTestCase(Input("length_in", float), Output("completed", bool)), 65, 55, set(), df
2121
)
2222
effect_estimate = logistic_regression_estimator.estimate_unit_odds_ratio()
23-
self.assertEqual(round(effect_estimate.value[0], 4), 0.8948)
23+
self.assertEqual(round(effect_estimate.value.iloc[0], 4), 0.8948)

tests/main_tests/test_main.py

Lines changed: 141 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,19 @@ def test_create_base_test_case_missing_outcome(self):
128128
framework.create_base_test({"treatment_variable": "test_input", "expected_effect": {"missing": "NoEffect"}})
129129
self.assertEqual("\"Outcome variable 'missing' not found in inputs or outputs\"", str(e.exception))
130130

131+
def test_unloaded_tests(self):
132+
framework = CausalTestingFramework(self.paths)
133+
with self.assertRaises(ValueError) as e:
134+
framework.run_tests()
135+
self.assertEqual("No tests loaded. Call load_tests() first.", str(e.exception))
136+
137+
def test_unloaded_tests_batches(self):
138+
framework = CausalTestingFramework(self.paths)
139+
with self.assertRaises(ValueError) as e:
140+
# Need the next because of the yield statement in run_tests_in_batches
141+
next(framework.run_tests_in_batches())
142+
self.assertEqual("No tests loaded. Call load_tests() first.", str(e.exception))
143+
131144
def test_ctf(self):
132145
framework = CausalTestingFramework(self.paths)
133146
framework.setup()
@@ -136,8 +149,6 @@ def test_ctf(self):
136149
framework.load_tests()
137150
results = framework.run_tests()
138151

139-
print(results)
140-
141152
# Save results
142153
framework.save_results(results)
143154

@@ -205,7 +216,30 @@ def test_ctf_batches_exception_silent(self):
205216
all_results.extend(json.load(f))
206217

207218
self.assertEqual([result["passed"] for result in all_results], [False])
208-
self.assertIsNotNone([result["result"].get("error") for result in all_results])
219+
self.assertIsNotNone([result.get("error") for result in all_results])
220+
221+
def test_ctf_exception_silent(self):
222+
framework = CausalTestingFramework(self.paths, query="test_input < 0")
223+
framework.setup()
224+
225+
# Load and run tests
226+
framework.load_tests()
227+
228+
results = framework.run_tests(silent=True)
229+
230+
with open(self.test_config_path, "r", encoding="utf-8") as f:
231+
test_configs = json.load(f)
232+
233+
tests_passed = [
234+
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
235+
for test_config, test_case, result in zip(test_configs["tests"], framework.test_cases, results)
236+
]
237+
238+
self.assertEqual(tests_passed, [False])
239+
self.assertEqual(
240+
[result.error_message for result in results],
241+
["zero-size array to reduction operation maximum which has no identity"],
242+
)
209243

210244
def test_ctf_batches_exception(self):
211245
framework = CausalTestingFramework(self.paths, query="test_input < 0")
@@ -214,7 +248,7 @@ def test_ctf_batches_exception(self):
214248
# Load and run tests
215249
framework.load_tests()
216250
with self.assertRaises(ValueError):
217-
list(framework.run_tests_in_batches())
251+
next(framework.run_tests_in_batches())
218252

219253
def test_ctf_batches_matches_run_tests(self):
220254
# Run the tests normally
@@ -318,11 +352,11 @@ def test_parse_args(self):
318352
[
319353
"causal_testing",
320354
"test",
321-
"--dag_path",
355+
"--dag-path",
322356
str(self.dag_path),
323-
"--data_paths",
357+
"--data-paths",
324358
str(self.data_paths[0]),
325-
"--test_config",
359+
"--test-config",
326360
str(self.test_config_path),
327361
"--output",
328362
str(self.output_path.parent / "main.json"),
@@ -331,17 +365,110 @@ def test_parse_args(self):
331365
main()
332366
self.assertTrue((self.output_path.parent / "main.json").exists())
333367

368+
def test_parse_args_adequacy(self):
369+
with patch(
370+
"sys.argv",
371+
[
372+
"causal_testing",
373+
"test",
374+
"--dag-path",
375+
str(self.dag_path),
376+
"--data-paths",
377+
str(self.data_paths[0]),
378+
"--test-config",
379+
str(self.test_config_path),
380+
"--output",
381+
str(self.output_path.parent / "main.json"),
382+
"-a",
383+
],
384+
):
385+
main()
386+
with open(self.output_path.parent / "main.json") as f:
387+
log = json.load(f)
388+
assert all(test["result"]["bootstrap_size"] == 100 for test in log)
389+
390+
def test_parse_args_adequacy_batches(self):
391+
with patch(
392+
"sys.argv",
393+
[
394+
"causal_testing",
395+
"test",
396+
"--dag-path",
397+
str(self.dag_path),
398+
"--data-paths",
399+
str(self.data_paths[0]),
400+
"--test-config",
401+
str(self.test_config_path),
402+
"--output",
403+
str(self.output_path.parent / "main.json"),
404+
"-a",
405+
"--batch-size",
406+
"5",
407+
],
408+
):
409+
main()
410+
with open(self.output_path.parent / "main.json") as f:
411+
log = json.load(f)
412+
assert all(test["result"]["bootstrap_size"] == 100 for test in log)
413+
414+
def test_parse_args_bootstrap_size(self):
415+
with patch(
416+
"sys.argv",
417+
[
418+
"causal_testing",
419+
"test",
420+
"--dag-path",
421+
str(self.dag_path),
422+
"--data-paths",
423+
str(self.data_paths[0]),
424+
"--test-config",
425+
str(self.test_config_path),
426+
"--output",
427+
str(self.output_path.parent / "main.json"),
428+
"-b",
429+
"50",
430+
],
431+
):
432+
main()
433+
with open(self.output_path.parent / "main.json") as f:
434+
log = json.load(f)
435+
assert all(test["result"]["bootstrap_size"] == 50 for test in log)
436+
437+
def test_parse_args_bootstrap_size_explicit_adequacy(self):
438+
with patch(
439+
"sys.argv",
440+
[
441+
"causal_testing",
442+
"test",
443+
"--dag-path",
444+
str(self.dag_path),
445+
"--data-paths",
446+
str(self.data_paths[0]),
447+
"--test-config",
448+
str(self.test_config_path),
449+
"--output",
450+
str(self.output_path.parent / "main.json"),
451+
"-a",
452+
"-b",
453+
"50",
454+
],
455+
):
456+
main()
457+
with open(self.output_path.parent / "main.json") as f:
458+
log = json.load(f)
459+
assert all(test["result"]["bootstrap_size"] == 50 for test in log)
460+
334461
def test_parse_args_batches(self):
335462
with patch(
336463
"sys.argv",
337464
[
338465
"causal_testing",
339466
"test",
340-
"--dag_path",
467+
"--dag-path",
341468
str(self.dag_path),
342-
"--data_paths",
469+
"--data-paths",
343470
str(self.data_paths[0]),
344-
"--test_config",
471+
"--test-config",
345472
str(self.test_config_path),
346473
"--output",
347474
str(self.output_path.parent / "main_batch.json"),
@@ -359,7 +486,7 @@ def test_parse_args_generation(self):
359486
[
360487
"causal_testing",
361488
"generate",
362-
"--dag_path",
489+
"--dag-path",
363490
str(self.dag_path),
364491
"--output",
365492
os.path.join(tmp, "tests.json"),
@@ -375,15 +502,15 @@ def test_parse_args_generation_non_default(self):
375502
[
376503
"causal_testing",
377504
"generate",
378-
"--dag_path",
505+
"--dag-path",
379506
str(self.dag_path),
380507
"--output",
381508
os.path.join(tmp, "tests_non_default.json"),
382509
"--estimator",
383510
"LogisticRegressionEstimator",
384-
"--estimate_type",
511+
"--estimate-type",
385512
"unit_odds_ratio",
386-
"--effect_type",
513+
"--effect-type",
387514
"total",
388515
],
389516
):

0 commit comments

Comments
 (0)