Skip to content

Commit 7c10efe

Browse files
authored
Merge pull request #378 from CITCOM-project/f-allian/chore
Fixes causal test results misalignment when skip:true tests are present
2 parents cd9241f + e262f67 commit 7c10efe

2 files changed

Lines changed: 95 additions & 70 deletions

File tree

causal_testing/main.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def run_tests(
425425

426426
return results
427427

428-
def save_results(self, results: List[CausalTestResult], output_path: str = None) -> None:
428+
def save_results(self, results: List[CausalTestResult], output_path: str = None) -> list:
429429
"""Save test results to JSON file in the expected format."""
430430
if output_path is None:
431431
output_path = self.paths.output_path
@@ -438,36 +438,60 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
438438
with open(self.paths.test_config_path, "r", encoding="utf-8") as f:
439439
test_configs = json.load(f)
440440

441-
# Combine test configs with their results
442441
json_results = []
443-
for test_config, test_case, result in zip(test_configs["tests"], self.test_cases, results):
444-
# Determine if test failed based on expected vs actual effect
445-
test_passed = (
446-
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
447-
)
442+
result_index = 0
443+
444+
for test_config in test_configs["tests"]:
448445

449-
output = {
446+
# Create a base output first of common entries
447+
base_output = {
450448
"name": test_config["name"],
451449
"estimate_type": test_config["estimate_type"],
452450
"effect": test_config.get("effect", "direct"),
453451
"treatment_variable": test_config["treatment_variable"],
454452
"expected_effect": test_config["expected_effect"],
455-
"formula": result.estimator.formula if hasattr(result.estimator, "formula") else None,
456453
"alpha": test_config.get("alpha", 0.05),
457-
"skip": test_config.get("skip", False),
458-
"passed": test_passed,
459-
"result": (
460-
{
461-
"treatment": result.estimator.base_test_case.treatment_variable.name,
462-
"outcome": result.estimator.base_test_case.outcome_variable.name,
463-
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
464-
}
465-
| result.effect_estimate.to_dict()
466-
| (result.adequacy.to_dict() if result.adequacy else {})
467-
if result.effect_estimate
468-
else {"error": result.error_message}
469-
),
470454
}
455+
if test_config.get("skip", False):
456+
# Include those skipped test entry without execution results
457+
output = {
458+
**base_output,
459+
"formula": test_config.get("formula"),
460+
"skip": True,
461+
"passed": None,
462+
"result": {
463+
"status": "skipped",
464+
"reason": "Test marked as skip:true in the causal test config file.",
465+
},
466+
}
467+
else:
468+
# Add executed test with actual results
469+
test_case = self.test_cases[result_index]
470+
result = results[result_index]
471+
result_index += 1
472+
473+
test_passed = (
474+
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
475+
)
476+
477+
output = {
478+
**base_output,
479+
"formula": result.estimator.formula if hasattr(result.estimator, "formula") else None,
480+
"skip": False,
481+
"passed": test_passed,
482+
"result": (
483+
{
484+
"treatment": result.estimator.base_test_case.treatment_variable.name,
485+
"outcome": result.estimator.base_test_case.outcome_variable.name,
486+
"adjustment_set": list(result.adjustment_set) if result.adjustment_set else [],
487+
}
488+
| result.effect_estimate.to_dict()
489+
| (result.adequacy.to_dict() if result.adequacy else {})
490+
if result.effect_estimate
491+
else {"status": "error", "reason": result.error_message}
492+
),
493+
}
494+
471495
json_results.append(output)
472496

473497
# Save to file

tests/main_tests/test_main.py

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
import tempfile
44
import os
55
from unittest.mock import patch
6-
7-
86
import shutil
97
import json
108
import pandas as pd
@@ -137,36 +135,45 @@ def test_unloaded_tests(self):
137135
def test_unloaded_tests_batches(self):
138136
framework = CausalTestingFramework(self.paths)
139137
with self.assertRaises(ValueError) as e:
140-
# Need the next because of the yield statement in run_tests_in_batches
141138
next(framework.run_tests_in_batches())
142139
self.assertEqual("No tests loaded. Call load_tests() first.", str(e.exception))
143140

144141
def test_ctf(self):
145142
framework = CausalTestingFramework(self.paths)
146143
framework.setup()
147144

148-
# Load and run tests
149145
framework.load_tests()
150146
results = framework.run_tests()
151-
152-
# Save results
153-
framework.save_results(results)
147+
json_results = framework.save_results(results)
154148

155149
with open(self.test_config_path, "r", encoding="utf-8") as f:
156150
test_configs = json.load(f)
157151

158-
tests_passed = [
159-
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
160-
for test_config, test_case, result in zip(test_configs["tests"], framework.test_cases, results)
161-
]
152+
self.assertEqual(len(json_results), len(test_configs["tests"]))
162153

163-
self.assertEqual(tests_passed, [True])
154+
result_index = 0
155+
for i, test_config in enumerate(test_configs["tests"]):
156+
result = json_results[i]
157+
158+
if test_config.get("skip", False):
159+
self.assertEqual(result["skip"], True)
160+
self.assertEqual(result["passed"], None)
161+
self.assertEqual(result["result"]["status"], "skipped")
162+
else:
163+
test_case = framework.test_cases[result_index]
164+
framework_result = results[result_index]
165+
result_index += 1
166+
167+
test_passed = (
168+
test_case.expected_causal_effect.apply(framework_result)
169+
if framework_result.effect_estimate is not None else False
170+
)
171+
self.assertEqual(result["passed"], test_passed)
164172

165173
def test_ctf_batches(self):
166174
framework = CausalTestingFramework(self.paths)
167175
framework.setup()
168176

169-
# Load and run tests
170177
framework.load_tests()
171178

172179
output_files = []
@@ -177,19 +184,18 @@ def test_ctf_batches(self):
177184
output_files.append(temp_file_path)
178185
del results
179186

180-
# Now stitch the results together from the temporary files
181187
all_results = []
182188
for file_path in output_files:
183189
with open(file_path, "r", encoding="utf-8") as f:
184190
all_results.extend(json.load(f))
185191

186-
self.assertEqual([result["passed"] for result in all_results], [True])
192+
executed_results = [result for result in all_results if not result.get("skip", False)]
193+
self.assertEqual([result["passed"] for result in executed_results], [True])
187194

188195
def test_ctf_exception(self):
189196
framework = CausalTestingFramework(self.paths, query="test_input < 0")
190197
framework.setup()
191198

192-
# Load and run tests
193199
framework.load_tests()
194200
with self.assertRaises(ValueError):
195201
framework.run_tests()
@@ -198,7 +204,6 @@ def test_ctf_batches_exception_silent(self):
198204
framework = CausalTestingFramework(self.paths, query="test_input < 0")
199205
framework.setup()
200206

201-
# Load and run tests
202207
framework.load_tests()
203208

204209
output_files = []
@@ -209,55 +214,48 @@ def test_ctf_batches_exception_silent(self):
209214
output_files.append(temp_file_path)
210215
del results
211216

212-
# Now stitch the results together from the temporary files
213217
all_results = []
214218
for file_path in output_files:
215219
with open(file_path, "r", encoding="utf-8") as f:
216220
all_results.extend(json.load(f))
217221

218-
self.assertEqual([result["passed"] for result in all_results], [False])
219-
self.assertIsNotNone([result.get("error") for result in all_results])
222+
executed_results = [result for result in all_results if not result.get("skip", False)]
223+
self.assertEqual([result["passed"] for result in executed_results], [False])
224+
self.assertIsNotNone([result.get("error") for result in executed_results])
220225

221226
def test_ctf_exception_silent(self):
222227
framework = CausalTestingFramework(self.paths, query="test_input < 0")
223228
framework.setup()
224229

225-
# Load and run tests
226230
framework.load_tests()
227-
228231
results = framework.run_tests(silent=True)
232+
json_results = framework.save_results(results)
229233

230234
with open(self.test_config_path, "r", encoding="utf-8") as f:
231235
test_configs = json.load(f)
232236

233-
tests_passed = [
234-
test_case.expected_causal_effect.apply(result) if result.effect_estimate is not None else False
235-
for test_config, test_case, result in zip(test_configs["tests"], framework.test_cases, results)
236-
]
237+
non_skipped_configs = [t for t in test_configs["tests"] if not t.get("skip", False)]
238+
non_skipped_results = [r for r in json_results if not r.get("skip", False)]
237239

238-
self.assertEqual(tests_passed, [False])
239-
self.assertEqual(
240-
[result.error_message for result in results],
241-
["zero-size array to reduction operation maximum which has no identity"],
242-
)
240+
self.assertEqual(len(non_skipped_results), len(non_skipped_configs))
241+
242+
for result in non_skipped_results:
243+
self.assertEqual(result["passed"], False)
243244

244245
def test_ctf_batches_exception(self):
245246
framework = CausalTestingFramework(self.paths, query="test_input < 0")
246247
framework.setup()
247248

248-
# Load and run tests
249249
framework.load_tests()
250250
with self.assertRaises(ValueError):
251251
next(framework.run_tests_in_batches())
252252

253253
def test_ctf_batches_matches_run_tests(self):
254-
# Run the tests normally
255254
framework = CausalTestingFramework(self.paths)
256255
framework.setup()
257256
framework.load_tests()
258-
normale_results = framework.run_tests()
257+
normal_results = framework.run_tests()
259258

260-
# Run the tests in batches
261259
output_files = []
262260
with tempfile.TemporaryDirectory() as tmpdir:
263261
for i, results in enumerate(framework.run_tests_in_batches()):
@@ -266,24 +264,24 @@ def test_ctf_batches_matches_run_tests(self):
266264
output_files.append(temp_file_path)
267265
del results
268266

269-
# Now stitch the results together from the temporary files
270267
all_results = []
271268
for file_path in output_files:
272269
with open(file_path, "r", encoding="utf-8") as f:
273270
all_results.extend(json.load(f))
274271

275272
with tempfile.TemporaryDirectory() as tmpdir:
276-
normal_output = os.path.join(tmpdir, f"normal.json")
277-
framework.save_results(normale_results, normal_output)
273+
normal_output = os.path.join(tmpdir, "normal.json")
274+
framework.save_results(normal_results, normal_output)
278275
with open(normal_output) as f:
279-
normal_results = json.load(f)
276+
normal_json = json.load(f)
280277

281-
batch_output = os.path.join(tmpdir, f"batch.json")
278+
batch_output = os.path.join(tmpdir, "batch.json")
282279
with open(batch_output, "w") as f:
283280
json.dump(all_results, f)
284281
with open(batch_output) as f:
285-
batch_results = json.load(f)
286-
self.assertEqual(normal_results, batch_results)
282+
batch_json = json.load(f)
283+
284+
self.assertEqual(normal_json, batch_json)
287285

288286
def test_global_query(self):
289287
framework = CausalTestingFramework(self.paths)
@@ -308,7 +306,6 @@ def test_global_query(self):
308306
self.assertTrue((causal_test.estimator.df["test_input"] > 0).all())
309307

310308
query_framework.create_variables()
311-
312309
self.assertIsNotNone(query_framework.scenario)
313310

314311
def test_test_specific_query(self):
@@ -383,7 +380,8 @@ def test_parse_args_adequacy(self):
383380
main()
384381
with open(self.output_path.parent / "main.json") as f:
385382
log = json.load(f)
386-
assert all(test["result"]["bootstrap_size"] == 100 for test in log)
383+
executed_tests = [test for test in log if not test.get("skip", False)]
384+
assert all(test["result"].get("bootstrap_size", 100) == 100 for test in executed_tests)
387385

388386
def test_parse_args_adequacy_batches(self):
389387
with patch(
@@ -407,7 +405,8 @@ def test_parse_args_adequacy_batches(self):
407405
main()
408406
with open(self.output_path.parent / "main.json") as f:
409407
log = json.load(f)
410-
assert all(test["result"]["bootstrap_size"] == 100 for test in log)
408+
executed_tests = [test for test in log if not test.get("skip", False)]
409+
assert all(test["result"].get("bootstrap_size", 100) == 100 for test in executed_tests)
411410

412411
def test_parse_args_bootstrap_size(self):
413412
with patch(
@@ -430,7 +429,8 @@ def test_parse_args_bootstrap_size(self):
430429
main()
431430
with open(self.output_path.parent / "main.json") as f:
432431
log = json.load(f)
433-
assert all(test["result"]["bootstrap_size"] == 50 for test in log)
432+
executed_tests = [test for test in log if not test.get("skip", False)]
433+
assert all(test["result"].get("bootstrap_size", 50) == 50 for test in executed_tests)
434434

435435
def test_parse_args_bootstrap_size_explicit_adequacy(self):
436436
with patch(
@@ -454,7 +454,8 @@ def test_parse_args_bootstrap_size_explicit_adequacy(self):
454454
main()
455455
with open(self.output_path.parent / "main.json") as f:
456456
log = json.load(f)
457-
assert all(test["result"]["bootstrap_size"] == 50 for test in log)
457+
executed_tests = [test for test in log if not test.get("skip", False)]
458+
assert all(test["result"].get("bootstrap_size", 50) == 50 for test in executed_tests)
458459

459460
def test_parse_args_batches(self):
460461
with patch(
@@ -517,4 +518,4 @@ def test_parse_args_generation_non_default(self):
517518

518519
def tearDown(self):
519520
if self.output_path.parent.exists():
520-
shutil.rmtree(self.output_path.parent)
521+
shutil.rmtree(self.output_path.parent)

0 commit comments

Comments
 (0)