Skip to content

Commit 6325ced

Browse files
authored
Merge pull request #374 from CITCOM-project/jmafoster1/better-adequacy
More detailed reporting of causal test adequacy for SHaRR collaboration
2 parents 178ffe6 + 4a3e9d7 commit 6325ced

3 files changed

Lines changed: 41 additions & 16 deletions

File tree

causal_testing/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,6 @@ def parse_args(args: Optional[Sequence[str]] = None) -> argparse.Namespace:
563563
args = main_parser.parse_args(args)
564564

565565
# Assume the user wants test adequacy if they're setting bootstrap_size
566-
print(args)
567566
if getattr(args, "bootstrap_size", None) is not None:
568567
args.adequacy = True
569568
if getattr(args, "adequacy", False) and getattr(args, "bootstrap_size", None) is None:

causal_testing/testing/causal_test_adequacy.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def __init__(
8383
):
8484
self.test_case = test_case
8585
self.kurtosis = None
86-
self.outcomes = None
86+
self.passing = None
87+
self.results = None
8788
self.successful = None
8889
self.bootstrap_size = bootstrap_size
8990
self.group_by = group_by
@@ -93,6 +94,7 @@ def measure_adequacy(self):
9394
Calculate the adequacy measurement, and populate the data_adequacy field.
9495
"""
9596
results = []
97+
outcomes = []
9698
for i in range(self.bootstrap_size):
9799
estimator = deepcopy(self.test_case.estimator)
98100

@@ -103,7 +105,9 @@ def measure_adequacy(self):
103105
else:
104106
estimator.df = estimator.df.sample(len(estimator.df), replace=True, random_state=i)
105107
try:
106-
results.append(self.test_case.execute_test(estimator))
108+
result = self.test_case.execute_test(estimator)
109+
outcomes.append(self.test_case.expected_causal_effect.apply(result))
110+
results.append(result.effect_estimate.to_df())
107111
except LinAlgError:
108112
logger.warning("Adequacy LinAlgError")
109113
continue
@@ -113,19 +117,23 @@ def measure_adequacy(self):
113117
except ValueError as e:
114118
logger.warning(f"Adequacy ValueError: {e}")
115119
continue
116-
outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
117-
results = pd.concat([c.effect_estimate.to_df() for c in results])
120+
# outcomes = [self.test_case.expected_causal_effect.apply(c) for c in results]
121+
# results = pd.concat([c.effect_estimate.to_df() for c in results])
122+
results = pd.concat(results)
118123
results["var"] = results.index
124+
results["passed"] = outcomes
119125

126+
self.results = results
120127
self.kurtosis = results.groupby("var")["effect_estimate"].apply(lambda x: x.kurtosis())
121-
self.outcomes = sum(filter(lambda x: x is not None, outcomes))
128+
self.passing = sum(filter(lambda x: x is not None, outcomes))
122129
self.successful = sum(x is not None for x in outcomes)
123130

124131
def to_dict(self):
125132
"""Returns the adequacy object as a dictionary."""
126133
return {
127134
"kurtosis": self.kurtosis.to_dict(),
128135
"bootstrap_size": self.bootstrap_size,
129-
"passing": self.outcomes,
136+
"passing": self.passing,
130137
"successful": self.successful,
138+
"results": self.results.reset_index(drop=True).to_dict(),
131139
}

tests/testing_tests/test_causal_test_adequacy.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,17 @@ def test_data_adequacy_numeric(self):
4848
)
4949
adequacy_metric = DataAdequacy(causal_test_case)
5050
adequacy_metric.measure_adequacy()
51+
52+
self.assertEqual(
53+
adequacy_metric.kurtosis["test_input"],
54+
0,
55+
f"Expected kurtosis not {adequacy_metric.kurtosis['test_input']}",
56+
)
5157
self.assertEqual(
52-
adequacy_metric.to_dict(),
53-
{"kurtosis": {"test_input": 0.0}, "bootstrap_size": 100, "passing": 100, "successful": 100},
58+
adequacy_metric.bootstrap_size, 100, f"Expected bootstrap size 100 not {adequacy_metric.bootstrap_size}"
5459
)
60+
self.assertEqual(adequacy_metric.passing, 100, f"Expected passing 32 not {adequacy_metric.passing}")
61+
self.assertEqual(adequacy_metric.successful, 100, f"Expected successful 100 not {adequacy_metric.successful}")
5562

5663
def test_data_adequacy_categorical(self):
5764
base_test_case = BaseTestCase(
@@ -68,10 +75,17 @@ def test_data_adequacy_categorical(self):
6875
)
6976
adequacy_metric = DataAdequacy(causal_test_case)
7077
adequacy_metric.measure_adequacy()
78+
7179
self.assertEqual(
72-
adequacy_metric.to_dict(),
73-
{"kurtosis": {"test_input_no_dist[T.b]": 0.0}, "bootstrap_size": 100, "passing": 100, "successful": 100},
80+
adequacy_metric.kurtosis["test_input_no_dist[T.b]"],
81+
0,
82+
f"Expected kurtosis not {adequacy_metric.kurtosis['test_input_no_dist[T.b]']}",
7483
)
84+
self.assertEqual(
85+
adequacy_metric.bootstrap_size, 100, f"Expected bootstrap size 100 not {adequacy_metric.bootstrap_size}"
86+
)
87+
self.assertEqual(adequacy_metric.passing, 100, f"Expected passing 100 not {adequacy_metric.passing}")
88+
self.assertEqual(adequacy_metric.successful, 100, f"Expected successful 100 not {adequacy_metric.successful}")
7589

7690
def test_data_adequacy_group_by(self):
7791
timesteps_per_intervention = 1
@@ -102,13 +116,17 @@ def test_data_adequacy_group_by(self):
102116
)
103117
adequacy_metric = DataAdequacy(causal_test_case, group_by="id")
104118
adequacy_metric.measure_adequacy()
105-
adequacy_dict = adequacy_metric.to_dict()
106-
self.assertEqual(round(adequacy_dict["kurtosis"]["trtrand"], 3), -0.857)
107-
adequacy_dict.pop("kurtosis")
119+
120+
self.assertEqual(
121+
round(adequacy_metric.kurtosis["trtrand"], 3),
122+
-0.857,
123+
f"Expected kurtosis not {round(adequacy_metric.kurtosis['trtrand'], 3)}",
124+
)
108125
self.assertEqual(
109-
adequacy_dict,
110-
{"bootstrap_size": 100, "passing": 32, "successful": 100},
126+
adequacy_metric.bootstrap_size, 100, f"Expected bootstrap size 100 not {adequacy_metric.bootstrap_size}"
111127
)
128+
self.assertEqual(adequacy_metric.passing, 32, f"Expected passing 32 not {adequacy_metric.passing}")
129+
self.assertEqual(adequacy_metric.successful, 100, f"Expected successful 100 not {adequacy_metric.successful}")
112130

113131
def test_dag_adequacy_dependent(self):
114132
base_test_case = BaseTestCase(

0 commit comments

Comments
 (0)