|
| 1 | +--- |
| 2 | +title: K-Fold Cross-Validation |
| 3 | +sidebar_label: K-Fold Cross-Validation |
| 4 | +description: "Mastering robust model evaluation by rotating training and testing sets to maximize data utility." |
| 5 | +tags: [machine-learning, model-evaluation, cross-validation, k-fold, generalization] |
| 6 | +--- |
| 7 | + |
| 8 | +While a [Train-Test Split](./train-test-split) is a great starting point, it has a major weakness: your results can vary significantly depending on which specific rows end up in the test set. |
| 9 | + |
| 10 | +**K-Fold Cross-Validation** solves this by repeating the split process multiple times and averaging the results, ensuring every single data point gets to be part of the "test set" at least once. |
| 11 | + |
| 12 | +## 1. How the Algorithm Works |
| 13 | + |
| 14 | +The process follows a simple rotation logic: |
| 15 | +1. **Split** the data into **K** equal-sized "folds" (usually $K=5$ or $K=10$). |
| 16 | +2. **Iterate:** For each fold $i$: |
| 17 | + * Treat Fold $i$ as the **Test Set**. |
| 18 | + * Treat the remaining $K-1$ folds as the **Training Set**. |
| 19 | + * Train the model and record the score. |
| 20 | +3. **Aggregate:** Calculate the mean and standard deviation of all $K$ scores. |
| 21 | + |
| 22 | +## 2. Visualizing the Process |
| 23 | + |
| 24 | +```mermaid |
| 25 | +graph TB |
| 26 | + TITLE["$$\text{K-Fold Cross-Validation}$$"] |
| 27 | +
|
| 28 | + %% Dataset |
| 29 | + TITLE --> DATA["$$\text{Full Dataset}$$"] |
| 30 | +
|
| 31 | + %% Folds |
| 32 | + DATA --> F1["$$\text{Fold 1}$$"] |
| 33 | + DATA --> F2["$$\text{Fold 2}$$"] |
| 34 | + DATA --> F3["$$\text{Fold 3}$$"] |
| 35 | + DATA --> Fk["$$\text{Fold } k$$"] |
| 36 | +
|
| 37 | + %% Iterations |
| 38 | + F1 --> I1["$$\text{Iteration 1}$$<br/>$$\text{Validation: Fold 1}$$<br/>$$\text{Training: Others}$$"] |
| 39 | + F2 --> I2["$$\text{Iteration 2}$$<br/>$$\text{Validation: Fold 2}$$<br/>$$\text{Training: Others}$$"] |
| 40 | + F3 --> I3["$$\text{Iteration 3}$$<br/>$$\text{Validation: Fold 3}$$<br/>$$\text{Training: Others}$$"] |
| 41 | + Fk --> Ik["$$\text{Iteration } k$$<br/>$$\text{Validation: Fold } k$$<br/>$$\text{Training: Others}$$"] |
| 42 | +
|
| 43 | + %% Model Training & Evaluation |
| 44 | + I1 --> M1["$$\text{Train Model}$$"] |
| 45 | + I2 --> M2["$$\text{Train Model}$$"] |
| 46 | + I3 --> M3["$$\text{Train Model}$$"] |
| 47 | + Ik --> Mk["$$\text{Train Model}$$"] |
| 48 | +
|
| 49 | + M1 --> S1["$$\text{Score}_1$$"] |
| 50 | + M2 --> S2["$$\text{Score}_2$$"] |
| 51 | + M3 --> S3["$$\text{Score}_3$$"] |
| 52 | + Mk --> Sk["$$\text{Score}_k$$"] |
| 53 | +
|
| 54 | + %% Final Result |
| 55 | + S1 --> AVG["$$\text{Average Score}$$"] |
| 56 | + S2 --> AVG |
| 57 | + S3 --> AVG |
| 58 | + Sk --> AVG |
| 59 | +
|
| 60 | + AVG --> PERF["$$\text{Cross-Validated Performance}$$"] |
| 61 | +
|
| 62 | +``` |
| 63 | + |
| 64 | +## 3. Why Use K-Fold? |
| 65 | + |
| 66 | +### A. Reliability (Reducing Variance) |
| 67 | + |
| 68 | +By averaging 10 different test scores, you get a much more stable estimate of how the model will perform on new data. It eliminates the "luck of the draw." |
| 69 | + |
| 70 | +### B. Maximum Data Utility |
| 71 | + |
| 72 | +In a standard split, 20% of your data is never used for training. In K-Fold, every data point is used for training $K-1$ times and for testing exactly once. This is especially vital for small datasets. |
| 73 | + |
| 74 | +### C. Hyperparameter Tuning |
| 75 | + |
| 76 | +K-Fold is the foundation for **Grid Search**. It helps you find the best settings for your model (like the depth of a tree) without overfitting to one specific validation set. |
| 77 | + |
| 78 | +## 4. Implementation with Scikit-Learn |
| 79 | + |
| 80 | +```python |
| 81 | +from sklearn.model_selection import cross_val_score, KFold |
| 82 | +from sklearn.ensemble import RandomForestClassifier |
| 83 | + |
| 84 | +# 1. Initialize model and data |
| 85 | +model = RandomForestClassifier() |
| 86 | + |
| 87 | +# 2. Define the K-Fold strategy |
| 88 | +kf = KFold(n_splits=5, shuffle=True, random_state=42) |
| 89 | + |
| 90 | +# 3. Perform Cross-Validation |
| 91 | +# This returns an array of 5 scores |
| 92 | +scores = cross_val_score(model, X, y, cv=kf, scoring='accuracy') |
| 93 | + |
| 94 | +print(f"Scores for each fold: {scores}") |
| 95 | +print(f"Mean Accuracy: {scores.mean():.4f}") |
| 96 | +print(f"Standard Deviation: {scores.std():.4f}") |
| 97 | + |
| 98 | +``` |
| 99 | + |
| 100 | +## 5. Variations of Cross-Validation |
| 101 | + |
| 102 | +* **Stratified K-Fold:** Used for imbalanced data. It ensures each fold has the same percentage of samples for each class as the whole dataset. |
| 103 | +* **Leave-One-Out (LOOCV):** A extreme case where $K$ equals the total number of samples ($N$). Extremely computationally expensive but uses the most data possible. |
| 104 | +* **Time-Series Split:** Unlike random K-Fold, this respects the chronological order of data (Training on the past, testing on the future). |
| 105 | + |
| 106 | +## 6. Pros and Cons |
| 107 | + |
| 108 | +| Advantages | Disadvantages | |
| 109 | +| --- | --- | |
| 110 | +| **Robustness:** Provides a more accurate measure of model generalization. | **Computationally Expensive:** Training the model $K$ times takes $K$ times longer. | |
| 111 | +| **Confidence:** The standard deviation tells you how "stable" the model is. | **Not for Big Data:** If your model takes 10 hours to train, doing it 10 times is often impractical. | |
| 112 | + |
| 113 | +## References |
| 114 | + |
| 115 | +* **Scikit-Learn:** [Cross-Validation Guide](https://scikit-learn.org/stable/modules/cross_validation.html) |
| 116 | +* **StatQuest:** [K-Fold Cross-Validation Explained](https://www.youtube.com/watch?v=fSytzGwwBVw) |
| 117 | + |
| 118 | +--- |
| 119 | + |
| 120 | +**Now that you have a robust way to validate your model, how do you handle data where the classes are heavily skewed (e.g., 99% vs 1%)?** |
0 commit comments