Skip to content

Commit 99e765e

Browse files
authored
Merge pull request #367 from CITCOM-project/jmafoster1/filter-inputs
Added filter parameter for testing inputs
2 parents bc0e6ee + defe5da commit 99e765e

2 files changed

Lines changed: 29 additions & 3 deletions

File tree

causal_testing/testing/metamorphic_relation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,12 @@ def generate_metamorphic_relations(
204204

205205

206206
def generate_causal_tests(
207-
dag_path: str, output_path: str, ignore_cycles: bool = False, threads: int = 0, **json_stub_kargs
207+
dag_path: str,
208+
output_path: str,
209+
ignore_cycles: bool = False,
210+
threads: int = 0,
211+
test_inputs: bool = False,
212+
**json_stub_kargs,
208213
):
209214
"""
210215
Generate and output causal tests for a given DAG.
@@ -216,6 +221,8 @@ def generate_causal_tests(
216221
be omitted from the test set.
217222
:param threads: The number of threads to use to generate tests in parallel. If unspecified, tests are generated in
218223
serial. This is tylically fine unless the number of tests to be generated is >10000.
224+
:param test_inputs: Whether to test independences between inputs (i.e. root nodes in the DAG). Defaults to False
225+
as they will typically be independent by construction.
219226
:param json_stub_kargs: Kwargs to pass into `to_json_stub` (see docstring for details.)
220227
"""
221228
causal_dag = CausalDAG(dag_path, ignore_cycles=ignore_cycles)
@@ -241,7 +248,7 @@ def generate_causal_tests(
241248
tests = [
242249
relation.to_json_stub(**json_stub_kargs)
243250
for relation in relations
244-
if len(list(causal_dag.predecessors(relation.base_test_case.outcome_variable))) > 0
251+
if test_inputs or len(list(causal_dag.predecessors(relation.base_test_case.outcome_variable))) > 0
245252
]
246253

247254
logger.warning(

tests/testing_tests/test_metamorphic_relations.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ def test_should_cause_logistic_json_stub(self):
120120
should_cause_mr = ShouldCause(BaseTestCase("X1", "Z"), adj_set)
121121
self.assertEqual(
122122
should_cause_mr.to_json_stub(
123-
effect_type="total", estimate_type="unit_odds_ratio", estimator="LogisticRegressionEstimator", skip=False
123+
effect_type="total",
124+
estimate_type="unit_odds_ratio",
125+
estimator="LogisticRegressionEstimator",
126+
skip=False,
124127
),
125128
{
126129
"effect": "total",
@@ -291,6 +294,22 @@ def test_generate_causal_tests(self):
291294
)
292295
self.assertEqual(tests["tests"], expected)
293296

297+
def test_generate_causal_tests_test_inputs(self):
298+
dag = CausalDAG(self.dag_dot_path)
299+
relations = generate_metamorphic_relations(dag)
300+
with tempfile.TemporaryDirectory() as tmp:
301+
tests_file = os.path.join(tmp, "causal_tests.json")
302+
generate_causal_tests(self.dag_dot_path, tests_file, test_inputs=True)
303+
with open(tests_file, encoding="utf8") as f:
304+
tests = json.load(f)
305+
expected = list(
306+
map(
307+
lambda x: x.to_json_stub(skip=False),
308+
relations,
309+
)
310+
)
311+
self.assertEqual(tests["tests"], expected)
312+
294313
def test_shoud_cause_string(self):
295314
sc_mr = ShouldCause(BaseTestCase("X", "Y"), ["A", "B", "C"])
296315
self.assertEqual(str(sc_mr), "X --> Y | ['A', 'B', 'C']")

0 commit comments

Comments
 (0)