1515from causal_testing .estimation .linear_regression_estimator import LinearRegressionEstimator
1616from causal_testing .estimation .logistic_regression_estimator import LogisticRegressionEstimator
1717from causal_testing .specification .causal_dag import CausalDAG
18- from causal_testing .specification .causal_specification import CausalSpecification
1918from causal_testing .specification .scenario import Scenario
2019from causal_testing .specification .variable import Input , Output
2120from causal_testing .testing .base_test_case import BaseTestCase
2221from causal_testing .testing .causal_effect import Negative , NoEffect , Positive , SomeEffect
22+ from causal_testing .testing .causal_test_adequacy import DataAdequacy
2323from causal_testing .testing .causal_test_case import CausalTestCase
2424from causal_testing .testing .causal_test_result import CausalTestResult
25- from causal_testing .testing .causal_test_adequacy import DataAdequacy
2625
2726logger = logging .getLogger (__name__ )
2827
@@ -106,7 +105,6 @@ def __init__(self, paths: CausalTestingPaths, ignore_cycles: bool = False, query
106105 self .data : Optional [pd .DataFrame ] = None
107106 self .variables : Dict [str , Any ] = {"inputs" : {}, "outputs" : {}, "metas" : {}}
108107 self .scenario : Optional [Scenario ] = None
109- self .causal_specification : Optional [CausalSpecification ] = None
110108 self .test_cases : Optional [List [CausalTestCase ]] = None
111109
112110 def setup (self ) -> None :
@@ -130,8 +128,11 @@ def setup(self) -> None:
130128 # Create variables from DAG
131129 self .create_variables ()
132130
133- # Create scenario and specification
134- self .create_scenario_and_specification ()
131+ # Create scenario
132+ self .scenario = Scenario (
133+ list (self .variables ["inputs" ].values ()) + list (self .variables ["outputs" ].values ()),
134+ {self .query } if self .query else None ,
135+ )
135136
136137 logger .info ("Setup completed successfully" )
137138
@@ -187,18 +188,6 @@ def create_variables(self) -> None:
187188 if self .dag .in_degree (node_name ) > 0 :
188189 self .variables ["outputs" ][node_name ] = Output (name = node_name , datatype = dtype )
189190
190- def create_scenario_and_specification (self ) -> None :
191- """Create scenario and causal specification objects from loaded data."""
192- # Create scenario
193- all_variables = list (self .variables ["inputs" ].values ()) + list (self .variables ["outputs" ].values ())
194- self .scenario = Scenario (variables = all_variables )
195-
196- # Set up treatment variables
197- self .scenario .setup_treatment_variables ()
198-
199- # Create causal specification
200- self .causal_specification = CausalSpecification (scenario = self .scenario , causal_dag = self .dag )
201-
202191 def load_tests (self ) -> None :
203192 """
204193 Load and prepare test configurations from file.
@@ -316,7 +305,10 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
316305 base_test_case = base_test ,
317306 treatment_value = test .get ("treatment_value" ),
318307 control_value = test .get ("control_value" ),
319- adjustment_set = test .get ("adjustment_set" , self .causal_specification .causal_dag .identification (base_test )),
308+ adjustment_set = test .get (
309+ "adjustment_set" ,
310+ self .dag .identification (base_test , self .scenario .hidden_variables ()),
311+ ),
320312 df = filtered_df ,
321313 effect_modifiers = None ,
322314 formula = test .get ("formula" ),
@@ -346,7 +338,7 @@ def run_tests_in_batches(
346338 :param silent: Whether to suppress errors
347339 :param adequacy: Whether to calculate causal test adequacy (defaults to False)
348340 :param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
349- (defaults to 100)
341+ (defaults to 100)
350342 :return: List of all test results
351343 :raises: ValueError if no tests are loaded
352344 """
@@ -403,7 +395,7 @@ def run_tests(
403395 :param silent: Whether to suppress errors
404396 :param adequacy: Whether to calculate causal test adequacy (defaults to False)
405397 :param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
406- (defaults to 100)
398+ (defaults to 100)
407399
408400 :return: List of CausalTestResult objects
409401 :raises: ValueError if no tests are loaded
0 commit comments