@@ -120,7 +120,10 @@ def test_should_cause_logistic_json_stub(self):
120120 should_cause_mr = ShouldCause (BaseTestCase ("X1" , "Z" ), adj_set )
121121 self .assertEqual (
122122 should_cause_mr .to_json_stub (
123- effect_type = "total" , estimate_type = "unit_odds_ratio" , estimator = "LogisticRegressionEstimator" , skip = False
123+ effect_type = "total" ,
124+ estimate_type = "unit_odds_ratio" ,
125+ estimator = "LogisticRegressionEstimator" ,
126+ skip = False ,
124127 ),
125128 {
126129 "effect" : "total" ,
@@ -291,6 +294,22 @@ def test_generate_causal_tests(self):
291294 )
292295 self .assertEqual (tests ["tests" ], expected )
293296
297+ def test_generate_causal_tests_test_inputs (self ):
298+ dag = CausalDAG (self .dag_dot_path )
299+ relations = generate_metamorphic_relations (dag )
300+ with tempfile .TemporaryDirectory () as tmp :
301+ tests_file = os .path .join (tmp , "causal_tests.json" )
302+ generate_causal_tests (self .dag_dot_path , tests_file , test_inputs = True )
303+ with open (tests_file , encoding = "utf8" ) as f :
304+ tests = json .load (f )
305+ expected = list (
306+ map (
307+ lambda x : x .to_json_stub (skip = False ),
308+ relations ,
309+ )
310+ )
311+ self .assertEqual (tests ["tests" ], expected )
312+
294313 def test_shoud_cause_string (self ):
295314 sc_mr = ShouldCause (BaseTestCase ("X" , "Y" ), ["A" , "B" , "C" ])
296315 self .assertEqual (str (sc_mr ), "X --> Y | ['A', 'B', 'C']" )
0 commit comments