Skip to content

Commit b3fcd31

Browse files
author
lck6055
committed
updated linear_regression
1 parent 5fb638d commit b3fcd31

1 file changed

Lines changed: 42 additions & 17 deletions

File tree

Python/machine_learning/linear_regression.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,54 @@
22
Algorithm: Linear Regression (Analytical / Closed-form Solution)
33
44
Description:
5-
This script implements Linear Regression using the closed-form solution (Analytical method).
6-
It calculates the best-fit line for a dataset by directly computing the slope (b1) and intercept (b0)
7-
using the formulas derived from minimizing the Mean Squared Error (MSE).
5+
This script implements Linear Regression using the closed-form solution.
6+
It calculates the best-fit line by directly computing the slope (b1) and intercept (b0)
7+
using formulas derived from minimizing the Mean Squared Error (MSE).
88
9-
Mathematical Formulation:
10-
y_pred = b0 + b1 * x
11-
b1 = Σ((x - mean(x)) * (y - mean(y))) / Σ((x - mean(x))^2)
12-
b0 = mean(y) - b1 * mean(x)
13-
14-
Time Complexity: O(n) # Single pass through the data
9+
Time Complexity: O(n) # Single pass through data
1510
Space Complexity: O(1) # Only a few variables stored
1611
"""
1712

1813
import numpy as np
1914
import matplotlib.pyplot as plt
2015

21-
# Function for Analytical Linear Regression
16+
# --- Function for Analytical Linear Regression ---
2217
def linear_regression_analytical(x, y):
18+
"""
19+
Computes linear regression coefficients using the analytical method.
20+
21+
Parameters:
22+
x (np.ndarray): Feature values
23+
y (np.ndarray): Target values
24+
25+
Returns:
26+
tuple: b0 (intercept), b1 (slope), y_pred (predictions), SSE, R²
27+
"""
28+
# Compute mean of x and y
2329
x_mean, y_mean = np.mean(x), np.mean(y)
30+
31+
# Compute slope (b1) using formula
2432
b1 = np.sum((x - x_mean) * (y - y_mean)) / np.sum((x - x_mean)**2)
33+
34+
# Compute intercept (b0) using formula
2535
b0 = y_mean - b1 * x_mean
36+
37+
# Compute predicted y values
2638
y_pred = b0 + b1 * x
39+
40+
# Compute Sum of Squared Errors (SSE)
2741
SSE = np.sum((y - y_pred)**2)
42+
43+
# Compute Total Sum of Squares (SST) for R²
2844
SST = np.sum((y - y_mean)**2)
45+
46+
# Compute R² score (coefficient of determination)
2947
R2 = 1 - (SSE / SST)
48+
3049
return b0, b1, y_pred, SSE, R2
3150

32-
# Test Cases
51+
52+
# --- Test Cases ---
3353
test_cases = {
3454
"Simple Linear": {
3555
"x": np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
@@ -49,19 +69,24 @@ def linear_regression_analytical(x, y):
4969
}
5070
}
5171

72+
# Loop through each test case
5273
for name, data in test_cases.items():
5374
print(f"\n=== {name} ===")
5475
x, y = data["x"], data["y"]
76+
77+
# Compute linear regression using analytical method
5578
b0, b1, y_pred, SSE, R2 = linear_regression_analytical(x, y)
79+
80+
# Print coefficients and metrics
5681
print(f"Intercept (b0): {b0:.4f}, Slope (b1): {b1:.4f}")
5782
print(f"SSE: {SSE:.4f}, R²: {R2:.4f}")
5883

59-
# Plot
84+
# Plot data points and fitted line
6085
plt.figure(figsize=(6, 4))
61-
plt.scatter(x, y, color="blue", label="Data Points")
62-
plt.plot(x, y_pred, "r-", label="Analytical Solution")
63-
plt.xlabel("x")
64-
plt.ylabel("y")
86+
plt.scatter(x, y, color="blue", label="Data Points") # Original data points
87+
plt.plot(x, y_pred, "r-", label="Analytical Solution") # Fitted line
88+
plt.xlabel("x") # x-axis label
89+
plt.ylabel("y") # y-axis label
6590
plt.legend()
66-
plt.title(f"Linear Regression - {name}")
91+
plt.title(f"Linear Regression - {name}") # Plot title
6792
plt.show()

0 commit comments

Comments
 (0)