Skip to content

Commit 178ffe6

Browse files
authored
Merge pull request #371 from CITCOM-project/jmafoster1/remove-causal-spec
Jmafoster1/remove causal spec
2 parents 8fa3189 + eebbb8b commit 178ffe6

35 files changed

Lines changed: 389 additions & 702 deletions

.github/workflows/ci-tests-drafts.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ jobs:
2727
pip install -e .
2828
pip install -e .[test]
2929
pip install pytest pytest-cov
30+
- name: Register Jupyter Kernel
31+
run: |
32+
python -m ipykernel install --user --name python3
3033
- name: Test with pytest
3134
run: |
3235
pytest --cov=causal_testing --cov-report=xml

.github/workflows/ci-tests.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ jobs:
3232
pip install -e .
3333
pip install -e .[test]
3434
pip install pytest pytest-cov
35+
- name: Register Jupyter Kernel
36+
run: |
37+
python -m ipykernel install --user --name python3
3538
- name: Test with pytest
3639
run: |
3740
pytest --cov=causal_testing --cov-report=xml

.pre-commit-config.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,10 @@ repos:
3737
language: system
3838
types: [python]
3939
args: ['--rcfile=.pylintrc', '--max-line-length=120', '--max-positional-arguments=12', '--disable=W1401']
40-
files: ^causal_testing/
40+
files: ^causal_testing/
41+
42+
- repo: https://github.com/jsh9/pydoclint
43+
rev: 0.8.3
44+
hooks:
45+
- id: pydoclint
46+
args: [--style=google, --check-return-types=False]

causal_testing/estimation/ipcw_estimator.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,20 @@ class IPCWEstimator(Estimator):
2121
"""
2222
Class to perform Inverse Probability of Censoring Weighting (IPCW) estimation
2323
for sequences of treatments over time-varying data.
24+
25+
:param: df: Input DataFrame containing time-varying data.
26+
:param: timesteps_per_observation: Number of timesteps per observation.
27+
:param: control_strategy: The control strategy, with entries of the form (timestep, variable, value).
28+
:param: treatment_strategy: The treatment strategy, with entries of the form (timestep, variable, value).
29+
:param: outcome: Name of the outcome column in the DataFrame.
30+
:param: status_column: Name of the status column in the DataFrame, which should be True for operating normally,
31+
False for a fault.
32+
:param: fit_bl_switch_formula: Formula for fitting the baseline switch model.
33+
:param: fit_bltd_switch_formula: Formula for fitting the baseline time-dependent switch model.
34+
:param: eligibility: Function to determine eligibility for treatment. Defaults to None for "always eligible".
35+
:param: alpha: Significance level for hypothesis testing. Defaults to 0.05.
36+
:param: total_time: Total time for the analysis. Defaults to one plus the length of of the strategy (control or
37+
treatment) with the most elements multiplied by `timesteps_per_observation`.
2438
"""
2539

2640
# pylint: disable=too-many-arguments
@@ -40,23 +54,6 @@ def __init__(
4054
alpha: float = 0.05,
4155
total_time: float = None,
4256
):
43-
"""
44-
Initialise IPCWEstimator.
45-
46-
:param: df: Input DataFrame containing time-varying data.
47-
:param: timesteps_per_observation: Number of timesteps per observation.
48-
:param: control_strategy: The control strategy, with entries of the form (timestep, variable, value).
49-
:param: treatment_strategy: The treatment strategy, with entries of the form (timestep, variable, value).
50-
:param: outcome: Name of the outcome column in the DataFrame.
51-
:param: status_column: Name of the status column in the DataFrame, which should be True for operating normally,
52-
False for a fault.
53-
:param: fit_bl_switch_formula: Formula for fitting the baseline switch model.
54-
:param: fit_bltd_switch_formula: Formula for fitting the baseline time-dependent switch model.
55-
:param: eligibility: Function to determine eligibility for treatment. Defaults to None for "always eligible".
56-
:param: alpha: Significance level for hypothesis testing. Defaults to 0.05.
57-
:param: total_time: Total time for the analysis. Defaults to one plus the length of of the strategy (control or
58-
treatment) with the most elements multiplied by `timesteps_per_observation`.
59-
"""
6057
super().__init__(
6158
base_test_case=BaseTestCase(None, outcome),
6259
treatment_value=[val for _, _, val in treatment_strategy],

causal_testing/main.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
from causal_testing.estimation.linear_regression_estimator import LinearRegressionEstimator
1616
from causal_testing.estimation.logistic_regression_estimator import LogisticRegressionEstimator
1717
from causal_testing.specification.causal_dag import CausalDAG
18-
from causal_testing.specification.causal_specification import CausalSpecification
1918
from causal_testing.specification.scenario import Scenario
2019
from causal_testing.specification.variable import Input, Output
2120
from causal_testing.testing.base_test_case import BaseTestCase
2221
from causal_testing.testing.causal_effect import Negative, NoEffect, Positive, SomeEffect
22+
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2323
from causal_testing.testing.causal_test_case import CausalTestCase
2424
from causal_testing.testing.causal_test_result import CausalTestResult
25-
from causal_testing.testing.causal_test_adequacy import DataAdequacy
2625

2726
logger = logging.getLogger(__name__)
2827

@@ -106,7 +105,6 @@ def __init__(self, paths: CausalTestingPaths, ignore_cycles: bool = False, query
106105
self.data: Optional[pd.DataFrame] = None
107106
self.variables: Dict[str, Any] = {"inputs": {}, "outputs": {}, "metas": {}}
108107
self.scenario: Optional[Scenario] = None
109-
self.causal_specification: Optional[CausalSpecification] = None
110108
self.test_cases: Optional[List[CausalTestCase]] = None
111109

112110
def setup(self) -> None:
@@ -130,8 +128,11 @@ def setup(self) -> None:
130128
# Create variables from DAG
131129
self.create_variables()
132130

133-
# Create scenario and specification
134-
self.create_scenario_and_specification()
131+
# Create scenario
132+
self.scenario = Scenario(
133+
list(self.variables["inputs"].values()) + list(self.variables["outputs"].values()),
134+
{self.query} if self.query else None,
135+
)
135136

136137
logger.info("Setup completed successfully")
137138

@@ -187,18 +188,6 @@ def create_variables(self) -> None:
187188
if self.dag.in_degree(node_name) > 0:
188189
self.variables["outputs"][node_name] = Output(name=node_name, datatype=dtype)
189190

190-
def create_scenario_and_specification(self) -> None:
191-
"""Create scenario and causal specification objects from loaded data."""
192-
# Create scenario
193-
all_variables = list(self.variables["inputs"].values()) + list(self.variables["outputs"].values())
194-
self.scenario = Scenario(variables=all_variables)
195-
196-
# Set up treatment variables
197-
self.scenario.setup_treatment_variables()
198-
199-
# Create causal specification
200-
self.causal_specification = CausalSpecification(scenario=self.scenario, causal_dag=self.dag)
201-
202191
def load_tests(self) -> None:
203192
"""
204193
Load and prepare test configurations from file.
@@ -316,7 +305,10 @@ def create_causal_test(self, test: dict, base_test: BaseTestCase) -> CausalTestC
316305
base_test_case=base_test,
317306
treatment_value=test.get("treatment_value"),
318307
control_value=test.get("control_value"),
319-
adjustment_set=test.get("adjustment_set", self.causal_specification.causal_dag.identification(base_test)),
308+
adjustment_set=test.get(
309+
"adjustment_set",
310+
self.dag.identification(base_test, self.scenario.hidden_variables()),
311+
),
320312
df=filtered_df,
321313
effect_modifiers=None,
322314
formula=test.get("formula"),
@@ -346,7 +338,7 @@ def run_tests_in_batches(
346338
:param silent: Whether to suppress errors
347339
:param adequacy: Whether to calculate causal test adequacy (defaults to False)
348340
:param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
349-
(defaults to 100)
341+
(defaults to 100)
350342
:return: List of all test results
351343
:raises: ValueError if no tests are loaded
352344
"""
@@ -403,7 +395,7 @@ def run_tests(
403395
:param silent: Whether to suppress errors
404396
:param adequacy: Whether to calculate causal test adequacy (defaults to False)
405397
:param bootstrap_size: The number of bootstrap samples to use when calculating causal test adequacy
406-
(defaults to 100)
398+
(defaults to 100)
407399
408400
:return: List of CausalTestResult objects
409401
:raises: ValueError if no tests are loaded

causal_testing/specification/causal_dag.py

Lines changed: 7 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010

1111
from causal_testing.testing.base_test_case import BaseTestCase
1212

13-
from .scenario import Scenario
14-
from .variable import Output
13+
from .variable import Variable
1514

1615
Node = Union[str, int] # Node type hint: A node is a string or an int
1716

@@ -489,38 +488,12 @@ def get_backdoor_graph(self, treatments: list[str]) -> CausalDAG:
489488
backdoor_graph.add_edges_from(filter(lambda x: x not in outgoing_edges, self.edges))
490489
return backdoor_graph
491490

492-
def depends_on_outputs(self, node: Node, scenario: Scenario) -> bool:
493-
"""Check whether a given node in a given scenario is or depends on a
494-
model output in the given scenario. That is, whether or not the model
495-
needs to be run to determine its value.
496-
497-
NOTE: The graph must be acyclic for this to terminate.
498-
499-
:param node: The node in the DAG representing the variable of interest.
500-
:param scenario: The modelling scenario.
501-
502-
:return: Whether the given variable is or depends on an output.
503-
"""
504-
if isinstance(scenario.variables[node], Output):
505-
return True
506-
return any((self.depends_on_outputs(n, scenario) for n in self.predecessors(node)))
507-
508-
@staticmethod
509-
def remove_hidden_adjustment_sets(minimal_adjustment_sets: list[str], scenario: Scenario):
510-
"""Remove variables labelled as hidden from adjustment set(s)
511-
512-
:param minimal_adjustment_sets: list of minimal adjustment set(s) to have hidden variables removed from
513-
:param scenario: The modelling scenario which informs the variables that are hidden
514-
"""
515-
return [adj for adj in minimal_adjustment_sets if all(not scenario.variables.get(x).hidden for x in adj)]
516-
517-
def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None):
491+
def identification(self, base_test_case: BaseTestCase, avoid_variables: set[Variable] = None):
518492
"""Identify and return the minimum adjustment set
519493
520494
:param base_test_case: A base test case instance containing the outcome_variable and the
521495
treatment_variable required for identification.
522-
:param scenario: The modelling scenario relating to the tests
523-
496+
:param avoid_variables: Variables not to be adjusted for (e.g. hidden variables).
524497
:return: The smallest set of variables which can be adjusted for to obtain a causal
525498
estimate as opposed to a purely associational estimate.
526499
"""
@@ -539,8 +512,10 @@ def identification(self, base_test_case: BaseTestCase, scenario: Scenario = None
539512
else:
540513
raise ValueError("Causal effect should be 'total' or 'direct'")
541514

542-
if scenario is not None:
543-
minimal_adjustment_sets = self.remove_hidden_adjustment_sets(minimal_adjustment_sets, scenario)
515+
if avoid_variables is not None:
516+
minimal_adjustment_sets = [
517+
adj for adj in minimal_adjustment_sets if not {x.name for x in avoid_variables}.intersection(adj)
518+
]
544519

545520
minimal_adjustment_set = min(minimal_adjustment_sets, key=len, default=set())
546521
return set(minimal_adjustment_set)

causal_testing/specification/causal_specification.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)