Skip to content

Commit e262f67

Browse files
committed
fix: commonalities in save outputs
1 parent 0dd5e51 commit e262f67

1 file changed

Lines changed: 13 additions & 13 deletions

File tree

causal_testing/main.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -442,18 +442,23 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
442442
result_index = 0
443443

444444
for test_config in test_configs["tests"]:
445+
446+
# Create a base output first of common entries
447+
base_output = {
448+
"name": test_config["name"],
449+
"estimate_type": test_config["estimate_type"],
450+
"effect": test_config.get("effect", "direct"),
451+
"treatment_variable": test_config["treatment_variable"],
452+
"expected_effect": test_config["expected_effect"],
453+
"alpha": test_config.get("alpha", 0.05),
454+
}
445455
if test_config.get("skip", False):
446456
# Include those skipped test entry without execution results
447457
output = {
448-
"name": test_config["name"],
449-
"estimate_type": test_config["estimate_type"],
450-
"effect": test_config.get("effect", "direct"),
451-
"treatment_variable": test_config["treatment_variable"],
452-
"expected_effect": test_config["expected_effect"],
458+
**base_output,
453459
"formula": test_config.get("formula"),
454-
"alpha": test_config.get("alpha", 0.05),
455460
"skip": True,
456-
"passed": None, # Don't need this for skipped tests
461+
"passed": None,
457462
"result": {
458463
"status": "skipped",
459464
"reason": "Test marked as skip:true in the causal test config file.",
@@ -470,13 +475,8 @@ def save_results(self, results: List[CausalTestResult], output_path: str = None)
470475
)
471476

472477
output = {
473-
"name": test_config["name"],
474-
"estimate_type": test_config["estimate_type"],
475-
"effect": test_config.get("effect", "direct"),
476-
"treatment_variable": test_config["treatment_variable"],
477-
"expected_effect": test_config["expected_effect"],
478+
**base_output,
478479
"formula": result.estimator.formula if hasattr(result.estimator, "formula") else None,
479-
"alpha": test_config.get("alpha", 0.05),
480480
"skip": False,
481481
"passed": test_passed,
482482
"result": (

0 commit comments

Comments
 (0)