Skip to content

Commit e84b9f9

Browse files
committed
big mess of models and visualizations
1 parent 99e603b commit e84b9f9

1 file changed

Lines changed: 160 additions & 33 deletions

File tree

src/examples/colors/utils/ols_model.py

Lines changed: 160 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import numpy as np
55
import matplotlib.pyplot as plt
66

7+
from scipy import stats
8+
79

810
def ols_from_model(model: pd.DataFrame):
911
"""
1012
Prints the output for the OLS model describing `convexity_qmw ~ optimality + complexity * accuracy` and `convexity_quw ~ optimality + complexity * accuracy`
1113
for a given input DataFrame
1214
1315
Args:
14-
model (DataFrame): The input DataFrame, the format should be one that is loaded from one of the `.csv` model files
16+
model (DataFrame): The input DataFrame, the format should be one that is loaded from one of the `.csv` model files
1517
"""
1618
model = model.rename(
1719
columns={"convexity-qmw": "convexity_qmw", "convexity-quw": "convexity_quw"}
@@ -49,19 +51,130 @@ def ols_from_model(model: pd.DataFrame):
4951
)
5052

5153
lm = smf.mixedlm(
52-
"optimality ~ convexity_qmw + type + base_type",
54+
"convexity_qmw ~ optimality + type + base_type",
5355
data=model,
5456
groups=model["base_item_id"], # random intercept for base_item_id
5557
)
5658

57-
result = lm.fit(reml=True)
59+
"""
60+
lm_reduced = smf.mixedlm(
61+
"convexity_qmw ~ optimality + base_type",
62+
data=model,
63+
groups=model["base_item_id"], # random intercept for base_item_id
64+
)
65+
"""
66+
67+
result = lm.fit(reml=False)
5868
print(result.summary())
5969

60-
df = model
61-
# Generate a grid of predictor values for plotting
62-
convexity_range = np.linspace(
63-
df["convexity_qmw"].min(), df["convexity_qmw"].max(), 100
70+
fixed_model_results = smf.ols(
71+
"convexity_qmw ~ optimality + type + base_type", data=model
72+
).fit()
73+
print(fixed_model_results.summary())
74+
75+
# test statistic = 2 * (LL_mixed - LL_OLS)
76+
lr_stat = 2 * (result.llf - fixed_model_results.llf)
77+
p_value = stats.chi2.sf(lr_stat, df=1) # 1 df for one random variance parameter
78+
79+
print(f"Likelihood Ratio Test: χ²(1) = {lr_stat:.3f}, p = {p_value:.4g}")
80+
81+
print(f"Mixed model AIC: {result.aic:.2f}")
82+
print(f"OLS model AIC: {fixed_model_results.aic:.2f}")
83+
print(f"Mixed model BIC: {result.bic:.2f}")
84+
print(f"OLS model BIC: {fixed_model_results.bic:.2f}")
85+
86+
big_model = smf.ols(
87+
"convexity_qmw ~ type + base_type + optimality + complexity + accuracy",
88+
data=model,
89+
)
90+
print(big_model.fit().summary())
91+
92+
big_model = smf.ols(
93+
"convexity_qmw ~ type + base_type + optimality * complexity * accuracy",
94+
data=model,
95+
)
96+
big_model_result = big_model.fit()
97+
print(big_model_result.summary())
98+
99+
# --- Define ranges for your predictors ---
100+
opt_range = np.linspace(model["optimality"].min(), model["optimality"].max(), 100)
101+
comp_levels = np.linspace(
102+
model["complexity"].quantile(0.1), model["complexity"].quantile(0.9), 3
103+
)
104+
acc_levels = np.linspace(
105+
model["accuracy"].quantile(0.1), model["accuracy"].quantile(0.9), 3
106+
)
107+
108+
# --- Helper: predict convexity from fitted model ---
109+
def predict_convexity(opt, comp, acc):
110+
"""Use big_model_result.predict() for consistent term handling."""
111+
df = pd.DataFrame(
112+
{
113+
"optimality": opt,
114+
"complexity": comp,
115+
"accuracy": acc,
116+
# Include categorical predictors at reference levels:
117+
"type": "natural", # replace with your actual reference level if needed
118+
"base_type": "natural", # same here
119+
}
120+
)
121+
return big_model_result.predict(df)
122+
123+
# --- Plot convexity vs. optimality at different levels ---
124+
plt.figure(figsize=(9, 6))
125+
for comp in comp_levels:
126+
for acc in acc_levels:
127+
df_pred = pd.DataFrame(
128+
{
129+
"optimality": opt_range,
130+
"complexity": comp,
131+
"accuracy": acc,
132+
"type": "natural",
133+
"base_type": "natural",
134+
}
135+
)
136+
preds = big_model_result.predict(df_pred)
137+
label = f"Complexity={comp:.2f}, Accuracy={acc:.2f}"
138+
plt.plot(opt_range, preds, label=label)
139+
140+
plt.xlabel("Optimality")
141+
plt.ylabel("Predicted Convexity (qmw)")
142+
plt.title(
143+
"Predicted convexity vs optimality at different complexity and accuracy levels"
144+
)
145+
plt.legend(fontsize=8)
146+
plt.tight_layout()
147+
plt.show()
148+
149+
# --- 3D surface: Convexity as a function of Optimality × Complexity ---
150+
opt_grid, comp_grid = np.meshgrid(
151+
np.linspace(model["optimality"].min(), model["optimality"].max(), 60),
152+
np.linspace(model["complexity"].min(), model["complexity"].max(), 60),
64153
)
154+
acc_fixed = model["accuracy"].median()
155+
156+
df_surface = pd.DataFrame(
157+
{
158+
"optimality": opt_grid.ravel(),
159+
"complexity": comp_grid.ravel(),
160+
"accuracy": acc_fixed,
161+
"type": "natural",
162+
"base_type": "natural",
163+
}
164+
)
165+
conv_grid = big_model_result.predict(df_surface).values.reshape(opt_grid.shape)
166+
167+
fig = plt.figure(figsize=(9, 6))
168+
ax = fig.add_subplot(111, projection="3d")
169+
ax.plot_surface(opt_grid, comp_grid, conv_grid, cmap="viridis", alpha=0.9)
170+
ax.set_xlabel("Optimality")
171+
ax.set_ylabel("Complexity")
172+
ax.set_zlabel("Predicted Convexity")
173+
ax.set_title(f"Interaction surface (Accuracy={acc_fixed:.2f})")
174+
plt.tight_layout()
175+
plt.show()
176+
177+
df = model
65178

66179
# Create all combinations of type × base_type
67180
types = df["type"].unique()
@@ -76,58 +189,72 @@ def ols_from_model(model: pd.DataFrame):
76189
("suboptimal", "optimal"),
77190
]
78191

79-
# Build plot dataframe
192+
# Generate a grid of optimality values for prediction
193+
optimality_range = np.linspace(df["optimality"].min(), df["optimality"].max(), 100)
194+
195+
# Define valid type/base_type combinations
196+
valid_combinations = [
197+
("natural", "natural"),
198+
("optimal", "optimal"),
199+
("suboptimal", "natural"),
200+
("suboptimal", "optimal"),
201+
]
202+
203+
# Build plot dataframe for predictions
80204
plot_df = pd.DataFrame(
81205
[
82-
{"convexity_qmw": c, "type": t, "base_type": b}
83-
for c in convexity_range
206+
{"optimality": o, "type": t, "base_type": b}
207+
for o in optimality_range
84208
for t, b in valid_combinations
85209
]
86210
)
87211

88212
# Predict using fixed effects only
89-
plot_df["predicted_optimality"] = result.predict(exog=plot_df)
213+
plot_df["predicted_convexity"] = result.predict(exog=plot_df)
90214

91215
# Define colors for types
92216
type_colors = {"natural": "blue", "optimal": "green", "suboptimal": "red"}
93217

94218
# Define line styles for base_type
95219
base_styles = {"natural": "-", "optimal": "--"}
220+
base_markers = {"natural": "o", "optimal": "s"}
96221

97222
# Plot
98223
fig, ax = plt.subplots(figsize=(9, 6))
99224

225+
# Plot predicted lines
100226
for t, b in valid_combinations:
101227
subset = plot_df[(plot_df["type"] == t) & (plot_df["base_type"] == b)]
102228
ax.plot(
103-
subset["convexity_qmw"],
104-
subset["predicted_optimality"],
229+
subset["optimality"],
230+
subset["predicted_convexity"],
105231
color=type_colors[t],
106232
linestyle=base_styles[b],
107-
label=f"type={t}, base_type={b}",
233+
label=f"Predicted: type={t}, base_type={b}",
108234
)
109235

110236
# Overlay raw data
111-
# Define marker styles for base_type in raw data
112-
base_markers = {"natural": "o", "optimal": "s"}
113-
for t, b in valid_combinations:
114-
subset = model[(model["type"] == t) & (model["base_type"] == b)]
115-
if len(subset) > 0: # skip impossible combinations
116-
ax.scatter(
117-
subset["convexity_qmw"],
118-
subset["optimality"],
119-
color=type_colors[t],
120-
marker=base_markers[b],
121-
alpha=0.2,
122-
edgecolor="k",
123-
s=10,
124-
label=f"Raw: type={t}, base_type={b}",
125-
)
237+
for t in df["type"].unique():
238+
for b in df["base_type"].unique():
239+
subset = df[(df["type"] == t) & (df["base_type"] == b)]
240+
if len(subset) > 0:
241+
ax.scatter(
242+
subset["optimality"],
243+
subset["convexity_qmw"],
244+
color=type_colors[t],
245+
marker=base_markers[b],
246+
alpha=0.5,
247+
edgecolor="k",
248+
s=30,
249+
label=f"Raw: type={t}, base_type={b}",
250+
)
126251

127-
ax.set_xlabel("Convexity (qmw)")
128-
ax.set_ylabel("Predicted optimality")
129-
ax.set_title("Predicted optimality vs convexity by type and base_type")
130-
ax.legend()
252+
ax.set_xlabel("Optimality")
253+
ax.set_ylabel("Convexity (qmw)")
254+
ax.set_title(
255+
"Predicted convexity vs optimality by type and base_type with raw data"
256+
)
257+
ax.legend(loc="best", fontsize=8, ncol=2)
131258
plt.tight_layout()
132259
plt.show()
133260

0 commit comments

Comments
 (0)