|
35 | 35 | "sns.set_palette(\"husl\")" |
36 | 36 | ] |
37 | 37 | }, |
| 38 | + { |
| 39 | + "cell_type": "markdown", |
| 40 | + "id": "804371b4", |
| 41 | + "metadata": {}, |
| 42 | + "source": [ |
| 43 | + "## Load Data\n", |
| 44 | + "Load the raw patient treatment dataset and check basic structure.\n", |
| 45 | + "Dataset contains individual patient records with demographics, substance use patterns, and treatment info" |
| 46 | + ] |
| 47 | + }, |
38 | 48 | { |
39 | 49 | "cell_type": "code", |
40 | 50 | "execution_count": 6, |
|
45 | 55 | "df = pd.read_csv(\"1_datasets/processed/teds_ml_ready.csv\")" |
46 | 56 | ] |
47 | 57 | }, |
| 58 | + { |
| 59 | + "cell_type": "markdown", |
| 60 | + "id": "e461f05f", |
| 61 | + "metadata": {}, |
| 62 | + "source": [ |
| 63 | + "## Initial Data Quality Checks\n", |
| 64 | + "Understand data structure, missing values, and data types.\n", |
| 65 | + "Missing values inform our imputation strategy (high/medium/low missingness)\n" |
| 66 | + ] |
| 67 | + }, |
48 | 68 | { |
49 | 69 | "cell_type": "code", |
50 | 70 | "execution_count": 7, |
|
74 | 94 | "print(missing_pct[missing_pct > 0].head(10))\n" |
75 | 95 | ] |
76 | 96 | }, |
| 97 | + { |
| 98 | + "cell_type": "markdown", |
| 99 | + "id": "aad309cb", |
| 100 | + "metadata": {}, |
| 101 | + "source": [ |
| 102 | + "## Data Cleaning & Type Conversion\n", |
| 103 | + "Handle missing values and convert problematic columns from object to numeric.\n", |
| 104 | + "Numeric cols to fill with median | Categorical cols to fill with mode.\n", |
| 105 | + "This prepares clean data for feature engineering." |
| 106 | + ] |
| 107 | + }, |
77 | 108 | { |
78 | 109 | "cell_type": "code", |
79 | 110 | "execution_count": 8, |
|
103 | 134 | " df_clean[col] = converted.fillna(0)" |
104 | 135 | ] |
105 | 136 | }, |
| 137 | + { |
| 138 | + "cell_type": "markdown", |
| 139 | + "id": "d08bd2d0", |
| 140 | + "metadata": {}, |
| 141 | + "source": [ |
| 142 | + "## Pre-Aggregation One-Hot Encoding\n", |
| 143 | + "Convert categorical demographics to binary columns BEFORE aggregating.\n", |
| 144 | + "This preserves distributions (e.g., % Female) instead of losing info via mode().\n", |
| 145 | + "When aggregated by mean, these become percentages for each facility." |
| 146 | + ] |
| 147 | + }, |
106 | 148 | { |
107 | 149 | "cell_type": "code", |
108 | 150 | "execution_count": 9, |
|
117 | 159 | ")" |
118 | 160 | ] |
119 | 161 | }, |
| 162 | + { |
| 163 | + "cell_type": "markdown", |
| 164 | + "id": "887728c5", |
| 165 | + "metadata": {}, |
| 166 | + "source": [ |
| 167 | + "## Aggregate by State & Service Type\n", |
| 168 | + "Group individual patient records by (state, service_type).\n", |
| 169 | + "Target: Count of patients = total_admissions.\n", |
| 170 | + "Features: Mean of binary indicators = prevalence rates for that facility type.\n", |
| 171 | + "This is where we CREATE the training dataset." |
| 172 | + ] |
| 173 | + }, |
120 | 174 | { |
121 | 175 | "cell_type": "code", |
122 | 176 | "execution_count": 10, |
|
135 | 189 | "df_grouped.rename(columns={\"patient_id\": \"total_admissions\"}, inplace=True)\n" |
136 | 190 | ] |
137 | 191 | }, |
| 192 | + { |
| 193 | + "cell_type": "markdown", |
| 194 | + "id": "786b8e51", |
| 195 | + "metadata": {}, |
| 196 | + "source": [ |
| 197 | + "## Feature Engineering - Complexity Score\n", |
| 198 | + "Create a composite metric: complexity_score.\n", |
| 199 | + "Weights reflect severity of each condition (chronic=2.0 is most serious).\n", |
| 200 | + "Since we aggregated by mean, values are 0.0-1.0 (prevalence rates).\n", |
| 201 | + "Score reflects average complexity of facility's patient population." |
| 202 | + ] |
| 203 | + }, |
138 | 204 | { |
139 | 205 | "cell_type": "code", |
140 | | - "execution_count": 11, |
| 206 | + "execution_count": null, |
141 | 207 | "id": "168ff8d3", |
142 | 208 | "metadata": {}, |
143 | 209 | "outputs": [], |
144 | 210 | "source": [ |
145 | | - "# Complexity Score (now percentages since we aggregated by mean)\n", |
| 211 | + "# Complexity Score\n", |
146 | 212 | "df_grouped[\"complexity_score\"] = (\n", |
147 | 213 | " df_grouped.get(\"is_polysubstance\", 0) * 1.5\n", |
148 | 214 | " + df_grouped.get(\"is_chronic_treatment\", 0) * 2.0\n", |
149 | 215 | " + df_grouped.get(\"has_mental_health_disorder\", 0) * 1.8\n", |
150 | 216 | " + df_grouped.get(\"is_homeless\", 0) * 1.5\n", |
151 | 217 | " + df_grouped.get(\"is_injection_user\", 0) * 2.0\n", |
152 | 218 | ")\n", |
| 219 | + "# One-hot encode remaining categorical features\n", |
153 | 220 | "df_final = pd.get_dummies(\n", |
154 | 221 | " df_grouped, columns=[\"state\", \"service_type\"], drop_first=True\n", |
155 | 222 | ")" |
156 | 223 | ] |
157 | 224 | }, |
| 225 | + { |
| 226 | + "cell_type": "markdown", |
| 227 | + "id": "08cd2488", |
| 228 | + "metadata": {}, |
| 229 | + "source": [ |
| 230 | + "## Train-Test Split\n", |
| 231 | + "Split before imputation/scaling to prevent leakage.\n", |
| 232 | + "Imputer & scaler will fit only on training data.\n", |
| 233 | + "This ensures test data statistics don't influence preprocessing." |
| 234 | + ] |
| 235 | + }, |
158 | 236 | { |
159 | 237 | "cell_type": "code", |
160 | 238 | "execution_count": 12, |
|
198 | 276 | "print(f\"Target Skew (Log): {y_log.skew():.2f}\")" |
199 | 277 | ] |
200 | 278 | }, |
| 279 | + { |
| 280 | + "cell_type": "markdown", |
| 281 | + "id": "588ec192", |
| 282 | + "metadata": {}, |
| 283 | + "source": [ |
| 284 | + "## Imputation & Scaling (Fit on Train Only)\n", |
| 285 | + "Imputer & scaler learn statistics ONLY from training data.\n", |
| 286 | + "Test data is transformed using train statistics (simulates true production scenario).\n", |
| 287 | + "This gives honest estimate of how model performs on unseen data." |
| 288 | + ] |
| 289 | + }, |
201 | 290 | { |
202 | 291 | "cell_type": "code", |
203 | 292 | "execution_count": 13, |
|
218 | 307 | "X_test_scaled = pd.DataFrame(scaler.transform(X_test_imputed), columns=X_test.columns)" |
219 | 308 | ] |
220 | 309 | }, |
| 310 | + { |
| 311 | + "cell_type": "markdown", |
| 312 | + "id": "0014a72e", |
| 313 | + "metadata": {}, |
| 314 | + "source": [ |
| 315 | + "## Train Multiple Models\n", |
| 316 | + "Compare 3 different algorithms: Ridge, Random Forest, Gradient Boosting\n", |
| 317 | + "Evaluate on TEST data (holds out 20% for final assessment)\n", |
| 318 | + "Use CROSS-VALIDATION on TRAIN data (5-fold, more robust than single split)" |
| 319 | + ] |
| 320 | + }, |
221 | 321 | { |
222 | 322 | "cell_type": "code", |
223 | 323 | "execution_count": 38, |
|
285 | 385 | "print(f\"{name} CV R²: {cv_scores.mean():.4f} (+/- {cv_scores.std():.4f})\")\n" |
286 | 386 | ] |
287 | 387 | }, |
| 388 | + { |
| 389 | + "cell_type": "markdown", |
| 390 | + "id": "7a8476da", |
| 391 | + "metadata": {}, |
| 392 | + "source": [ |
| 393 | + "## Retrain Best Model on Full Data (Production)\n", |
| 394 | + "retrain it on ALL data (no more train/test split).\n", |
| 395 | + "This captures all available signal for production predictions.\n", |
| 396 | + "Use full-data fitted imputer/scaler, not the train-only ones." |
| 397 | + ] |
| 398 | + }, |
288 | 399 | { |
289 | 400 | "cell_type": "code", |
290 | 401 | "execution_count": null, |
|
311 | 422 | "# Retrain Best Model on Full Log Data\n", |
312 | 423 | "y_full_log = np.log1p(y_raw)\n", |
313 | 424 | "best_model_prod = clone(best_model)\n", |
314 | | - "best_model_prod.fit(X_full_scaled, y_full_log)\n", |
315 | | - "\n", |
316 | | - "# Calculate Bias Correction (Fix Log Transformation Under-prediction)\n", |
| 425 | + "best_model_prod.fit(X_full_scaled, y_full_log)" |
| 426 | + ] |
| 427 | + }, |
| 428 | + { |
| 429 | + "cell_type": "markdown", |
| 430 | + "id": "9cc69d17", |
| 431 | + "metadata": {}, |
| 432 | + "source": [ |
| 433 | + "## Generate Predictions\n", |
| 434 | + "Use production model to predict admissions for all facilities.\n", |
| 435 | + "Clip negative predictions to 0 (admissions cannot be negative).\n", |
| 436 | + "Calculate the Bias Correction." |
| 437 | + ] |
| 438 | + }, |
| 439 | + { |
| 440 | + "cell_type": "code", |
| 441 | + "execution_count": null, |
| 442 | + "id": "725331ec", |
| 443 | + "metadata": {}, |
| 444 | + "outputs": [ |
| 445 | + { |
| 446 | + "name": "stdout", |
| 447 | + "output_type": "stream", |
| 448 | + "text": [ |
| 449 | + "Bias Correction Factor: 1.0014\n" |
| 450 | + ] |
| 451 | + } |
| 452 | + ], |
| 453 | + "source": [ |
| 454 | + "# Calculate Bias Correction\n", |
317 | 455 | "train_preds_log = best_model_prod.predict(X_full_scaled)\n", |
318 | 456 | "train_preds_raw = np.maximum(np.expm1(train_preds_log), 0)\n", |
319 | 457 | "\n", |
|
324 | 462 | "print(f\"Bias Correction Factor: {correction_factor:.4f}\")" |
325 | 463 | ] |
326 | 464 | }, |
| 465 | + { |
| 466 | + "cell_type": "markdown", |
| 467 | + "id": "ee72bda7", |
| 468 | + "metadata": {}, |
| 469 | + "source": [ |
| 470 | + "## Generate The Final Values\n", |
| 471 | + "Generate final predictions by converting the model’s log-scaled outputs back to the original scale using `expm1`, then applying a correction factor and clipping negatives to zero. The resulting values are stored in `df_grouped[\"predicted_admissions\"]`.\n" |
| 472 | + ] |
| 473 | + }, |
327 | 474 | { |
328 | 475 | "cell_type": "code", |
329 | 476 | "execution_count": 28, |
|
340 | 487 | "df_grouped[\"predicted_admissions\"] = final_preds" |
341 | 488 | ] |
342 | 489 | }, |
| 490 | + { |
| 491 | + "cell_type": "markdown", |
| 492 | + "id": "53f5b9bd", |
| 493 | + "metadata": {}, |
| 494 | + "source": [ |
| 495 | + "## Calculate Resource Requirements\n", |
| 496 | + "Convert predicted admissions into actionable resource recommendations\n", |
| 497 | + "Beds: Assume 12 patients per bed (standard occupancy rate)\n", |
| 498 | + "Staff: Assume 50 patients per staff member, adjusted by facility complexity\n", |
| 499 | + "High-demand flag: Identifies facilities above median volume for priority" |
| 500 | + ] |
| 501 | + }, |
343 | 502 | { |
344 | 503 | "cell_type": "code", |
345 | 504 | "execution_count": 29, |
|
363 | 522 | ").astype(int)" |
364 | 523 | ] |
365 | 524 | }, |
| 525 | + { |
| 526 | + "cell_type": "markdown", |
| 527 | + "id": "ab50c55a", |
| 528 | + "metadata": {}, |
| 529 | + "source": [ |
| 530 | + "## Feature Importance Analysis\n", |
| 531 | + "Identify which features drive the model's predictions.\n", |
| 532 | + "Top features show what patterns the model learned are most predictive.\n", |
| 533 | + "Use for model interpretability and validation." |
| 534 | + ] |
| 535 | + }, |
366 | 536 | { |
367 | 537 | "cell_type": "code", |
368 | 538 | "execution_count": 32, |
|
405 | 575 | ")" |
406 | 576 | ] |
407 | 577 | }, |
| 578 | + { |
| 579 | + "cell_type": "markdown", |
| 580 | + "id": "d17d2302", |
| 581 | + "metadata": {}, |
| 582 | + "source": [ |
| 583 | + "## Resource Allocation Report\n", |
| 584 | + "Display sample predictions and top high-demand facilities.\n", |
| 585 | + "This is the actionable output for resource planning teams." |
| 586 | + ] |
| 587 | + }, |
408 | 588 | { |
409 | 589 | "cell_type": "code", |
410 | 590 | "execution_count": 36, |
|
460 | 640 | "print(top_demand.to_string(index=False))\n" |
461 | 641 | ] |
462 | 642 | }, |
| 643 | + { |
| 644 | + "cell_type": "markdown", |
| 645 | + "id": "68e413de", |
| 646 | + "metadata": {}, |
| 647 | + "source": [ |
| 648 | + "## Visualization Dashboard\n", |
| 649 | + "Create 4-panel visualization to understand model performance and recommendations\n", |
| 650 | + "Panel 1: Actual vs Predicted (assess accuracy, colored by complexity)\n", |
| 651 | + "Panel 2: Complexity distribution (understand patient case mix)\n", |
| 652 | + "Panel 3: Admissions vs Beds (resource scaling relationship)\n", |
| 653 | + "Panel 4: Model performance comparison (train vs test)" |
| 654 | + ] |
| 655 | + }, |
463 | 656 | { |
464 | 657 | "cell_type": "code", |
465 | 658 | "execution_count": 34, |
|
0 commit comments