Skip to content

Commit b07a821

Browse files
committed
beginning to clean up analysis code
1 parent 010c7cb commit b07a821

1 file changed

Lines changed: 80 additions & 34 deletions

File tree

src/examples/colors/utils/ols_model.py

Lines changed: 80 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,88 @@
1+
import sys
2+
from collections.abc import Iterable
3+
14
import statsmodels.formula.api as smf
25
import pandas as pd
3-
import sys
46
import numpy as np
57
import matplotlib.pyplot as plt
68

79
from scipy import stats
10+
from statsmodels.regression.linear_model import RegressionResults
11+
from statsmodels.regression.mixed_linear_model import MixedLMResults
812

913

10-
def ols_from_model(model: pd.DataFrame):
14+
def ols_from_model(
15+
model: pd.DataFrame,
16+
dependent_var: str = "convexity_qmw",
17+
independent_vars: Iterable[str] = ("optimality", "complexity", "accuracy"),
18+
interactions: bool = True,
19+
) -> RegressionResults:
1120
"""
1221
Prints the output for the OLS model describing `convexity_qmw ~ optimality + complexity * accuracy` and `convexity_quw ~ optimality + complexity * accuracy`
1322
for a given input DataFrame
1423
1524
Args:
1625
model (DataFrame): The input DataFrame, the format should be one that is loaded from one of the `.csv` model files
1726
"""
18-
model = model.rename(
19-
columns={"convexity-qmw": "convexity_qmw", "convexity-quw": "convexity_quw"}
20-
)
2127

2228
results = smf.ols(
23-
"convexity_qmw ~ optimality + complexity * accuracy", data=model
29+
f"{dependent_var} ~ {(' * ' if interactions else ' + ').join(independent_vars)}",
30+
data=model,
2431
).fit()
2532
print(results.summary())
33+
return results
2634

35+
"""
2736
results = smf.ols(
2837
"convexity_quw ~ optimality + complexity * accuracy", data=model
2938
).fit()
3039
print(results.summary())
40+
"""
41+
42+
43+
def mixed_lm_from_model(
44+
model: pd.DataFrame,
45+
dependent_var: str = "convexity_qmw",
46+
independent_vars: Iterable[str] = ("optimality", "complexity", "accuracy"),
47+
groups_col: str = "base_item_id",
48+
interactions: bool = True,
49+
) -> MixedLMResults:
50+
"""
51+
Prints the output for the mixed linear model describing `convexity_qmw ~ optimality + complexity * accuracy` with random intercepts for `base_item_id`
52+
for a given input DataFrame
53+
54+
Args:
55+
model (DataFrame): The input DataFrame, the format should be one that is loaded from one of the `.csv` model files
56+
"""
57+
58+
lm = smf.mixedlm(
59+
f"{dependent_var} ~ {(' * ' if interactions else ' + ').join(independent_vars)}",
60+
data=model,
61+
groups=model[groups_col], # random intercept
62+
)
63+
64+
result = lm.fit(reml=False)
65+
print(result.summary())
66+
return result
67+
68+
69+
def likelihood_ratio_test(
70+
mixed_model_result: MixedLMResults, ols_model_result: RegressionResults
71+
) -> tuple[float, float, int]:
72+
"""
73+
Performs a likelihood ratio test between a mixed linear model and an OLS model
74+
75+
Args:
76+
mixed_model_result (MixedLMResults): The fitted mixed linear model result
77+
ols_model_result (RegressionResults): The fitted OLS model result
78+
79+
Returns:
80+
tuple[float, float, int]: The likelihood ratio statistic, p-value, and degrees of freedom
81+
"""
82+
83+
lr_stat = 2 * (mixed_model_result.llf - ols_model_result.llf)
84+
p_value = stats.chi2.sf(lr_stat, df=1) # 1 df for one random variance parameter
85+
return lr_stat, p_value, 1
3186

3287

3388
if __name__ == "__main__":
@@ -50,37 +105,31 @@ def ols_from_model(model: pd.DataFrame):
50105
columns={"convexity-qmw": "convexity_qmw", "convexity-quw": "convexity_quw"}
51106
)
52107

53-
lm = smf.mixedlm(
54-
"convexity_qmw ~ optimality + type + base_type",
55-
data=model,
56-
groups=model["base_item_id"], # random intercept for base_item_id
57-
)
108+
# first analysis: mixed vs. fixed effects for type + base_type
58109

59-
"""
60-
lm_reduced = smf.mixedlm(
61-
"convexity_qmw ~ optimality + base_type",
62-
data=model,
63-
groups=model["base_item_id"], # random intercept for base_item_id
64-
)
65-
"""
66-
67-
result = lm.fit(reml=False)
68-
print(result.summary())
110+
mixed_model_results = mixed_lm_from_model(
111+
model,
112+
dependent_var="convexity_qmw",
113+
independent_vars=("type", "base_type"),
114+
interactions=False,
115+
)
69116

70-
fixed_model_results = smf.ols(
71-
"convexity_qmw ~ optimality + type + base_type", data=model
72-
).fit()
73-
print(fixed_model_results.summary())
117+
fixed_model_results = ols_from_model(
118+
model,
119+
dependent_var="convexity_qmw",
120+
independent_vars=("type", "base_type"),
121+
interactions=False,
122+
)
74123

75-
# test statistic = 2 * (LL_mixed - LL_OLS)
76-
lr_stat = 2 * (result.llf - fixed_model_results.llf)
77-
p_value = stats.chi2.sf(lr_stat, df=1) # 1 df for one random variance parameter
124+
lr_stat, p_value, df = likelihood_ratio_test(
125+
mixed_model_results, fixed_model_results
126+
)
78127

79128
print(f"Likelihood Ratio Test: χ²(1) = {lr_stat:.3f}, p = {p_value:.4g}")
80129

81-
print(f"Mixed model AIC: {result.aic:.2f}")
130+
print(f"Mixed model AIC: {mixed_model_results.aic:.2f}")
82131
print(f"OLS model AIC: {fixed_model_results.aic:.2f}")
83-
print(f"Mixed model BIC: {result.bic:.2f}")
132+
print(f"Mixed model BIC: {mixed_model_results.bic:.2f}")
84133
print(f"OLS model BIC: {fixed_model_results.bic:.2f}")
85134

86135
big_model = smf.ols(
@@ -210,7 +259,7 @@ def predict_convexity(opt, comp, acc):
210259
)
211260

212261
# Predict using fixed effects only
213-
plot_df["predicted_convexity"] = result.predict(exog=plot_df)
262+
plot_df["predicted_convexity"] = fixed_model_results.predict(exog=plot_df)
214263

215264
# Define colors for types
216265
type_colors = {"natural": "blue", "optimal": "green", "suboptimal": "red"}
@@ -257,6 +306,3 @@ def predict_convexity(opt, comp, acc):
257306
ax.legend(loc="best", fontsize=8, ncol=2)
258307
plt.tight_layout()
259308
plt.show()
260-
261-
262-
# ols_from_model(model)

0 commit comments

Comments
 (0)