@@ -20,12 +20,18 @@ def setUp(self) -> None:
2020 with open (self .dag_dot_path , "w" ) as f :
2121 f .write (dag_dot )
2222
23+ def test_graphml (self ):
24+ dot_dag = CausalDAG (self .dag_dot_path )
25+ xml_dag = CausalDAG (os .path .join ("tests" , "resources" , "data" , "dag.xml" ))
26+ self .assertEqual (dot_dag .nodes , xml_dag .nodes )
27+ self .assertEqual (dot_dag .edges , xml_dag .edges )
28+
2329 def test_enumerate_minimal_adjustment_sets (self ):
2430 """Test whether enumerate_minimal_adjustment_sets lists all possible minimum sized adjustment sets."""
2531 causal_dag = CausalDAG (self .dag_dot_path )
2632 xs , ys = ["X" ], ["Y" ]
2733 adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
28- self .assertEqual ([{"Z" }], adjustment_sets )
34+ self .assertEqual ([{"Z" }], list ( adjustment_sets ) )
2935
3036 def tearDown (self ) -> None :
3137 shutil .rmtree (self .temp_dir_path )
@@ -46,19 +52,19 @@ def test_valid_iv(self):
4652
4753 def test_unrelated_instrument (self ):
4854 causal_dag = CausalDAG (self .dag_dot_path )
49- causal_dag .graph . remove_edge ("I" , "X" )
55+ causal_dag .remove_edge ("I" , "X" )
5056 with self .assertRaises (ValueError ):
5157 causal_dag .check_iv_assumptions ("X" , "Y" , "I" )
5258
5359 def test_direct_cause (self ):
5460 causal_dag = CausalDAG (self .dag_dot_path )
55- causal_dag .graph . add_edge ("I" , "Y" )
61+ causal_dag .add_edge ("I" , "Y" )
5662 with self .assertRaises (ValueError ):
5763 causal_dag .check_iv_assumptions ("X" , "Y" , "I" )
5864
5965 def test_common_cause (self ):
6066 causal_dag = CausalDAG (self .dag_dot_path )
61- causal_dag .graph . add_edge ("U" , "I" )
67+ causal_dag .add_edge ("U" , "I" )
6268 with self .assertRaises (ValueError ):
6369 causal_dag .check_iv_assumptions ("X" , "Y" , "I" )
6470
@@ -279,12 +285,12 @@ def test_enumerate_minimal_adjustment_sets(self):
279285 causal_dag = CausalDAG (self .dag_dot_path )
280286 xs , ys = ["X1" , "X2" ], ["Y" ]
281287 adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
282- self .assertEqual ([{"Z" }], adjustment_sets )
288+ self .assertEqual ([{"Z" }], list ( adjustment_sets ) )
283289
284290 def test_enumerate_minimal_adjustment_sets_multiple (self ):
285291 """Test whether enumerate_minimal_adjustment_sets lists all minimum adjustment sets if multiple are possible."""
286292 causal_dag = CausalDAG ()
287- causal_dag .graph . add_edges_from (
293+ causal_dag .add_edges_from (
288294 [
289295 ("X1" , "X2" ),
290296 ("X2" , "V" ),
@@ -308,7 +314,7 @@ def test_enumerate_minimal_adjustment_sets_multiple(self):
308314 def test_enumerate_minimal_adjustment_sets_two_adjustments (self ):
309315 """Test whether enumerate_minimal_adjustment_sets lists all possible minimum adjustment sets of arity two."""
310316 causal_dag = CausalDAG ()
311- causal_dag .graph . add_edges_from (
317+ causal_dag .add_edges_from (
312318 [
313319 ("X1" , "X2" ),
314320 ("X2" , "V" ),
@@ -335,7 +341,7 @@ def test_enumerate_minimal_adjustment_sets_two_adjustments(self):
335341 def test_dag_with_non_character_nodes (self ):
336342 """Test identification for a DAG whose nodes are not just characters (strings of length greater than 1)."""
337343 causal_dag = CausalDAG ()
338- causal_dag .graph . add_edges_from (
344+ causal_dag .add_edges_from (
339345 [
340346 ("va" , "ba" ),
341347 ("ba" , "ia" ),
@@ -350,7 +356,7 @@ def test_dag_with_non_character_nodes(self):
350356 )
351357 xs , ys = ["ba" ], ["da" ]
352358 adjustment_sets = causal_dag .enumerate_minimal_adjustment_sets (xs , ys )
353- self .assertEqual (adjustment_sets , [{"aa" }, {"la" }, {"va" }])
359+ self .assertEqual (list ( adjustment_sets ) , [{"aa" }, {"la" }, {"va" }])
354360
355361 def tearDown (self ) -> None :
356362 shutil .rmtree (self .temp_dir_path )
@@ -475,3 +481,12 @@ def test_hidden_varaible_adjustment_sets(self):
475481
476482 def tearDown (self ) -> None :
477483 shutil .rmtree (self .temp_dir_path )
484+
485+
486+ def time_it (label , func , * args , ** kwargs ):
487+ import time
488+
489+ start = time .time ()
490+ result = func (* args , ** kwargs )
491+ print (f"{ label } took { time .time () - start :.6f} seconds" )
492+ return result
0 commit comments