Skip to content

Commit 09c5814

Browse files
authored
Merge pull request #642 from lck6055/knn-algo
Add KNN-Algo
2 parents 7aebddc + 8a68809 commit 09c5814

1 file changed

Lines changed: 140 additions & 0 deletions

File tree

Python/machine_learning/knn.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
K-Nearest Neighbors (KNN) - Simple Implementation from Scratch
3+
-------------------------------------------------------------
4+
This script implements a basic version of the KNN algorithm for classification
5+
using only Python and NumPy (no sklearn).
6+
7+
Concept Summary:
8+
----------------
9+
1. KNN is a supervised learning algorithm used for classification & regression.
10+
2. It finds the K nearest data points to a test point using a distance metric
11+
(usually Euclidean distance).
12+
3. For classification → predicts the majority label among neighbors.
13+
4. For regression → predicts the average value among neighbors.
14+
5. It is a **lazy learner** (no explicit training phase, prediction happens at query time).
15+
16+
Steps in this code:
17+
-------------------
18+
1. Compute Euclidean distance between the test point and all training points.
19+
2. Sort training points by their distance to the test point.
20+
3. Select top 'k' nearest points.
21+
4. Use majority voting to determine the predicted class.
22+
5. Return the predicted label.
23+
24+
Time Complexity (per prediction):
25+
1 Distance calculation → O(n * d)
26+
2 Sorting distances → O(n log n)
27+
3 Selecting k nearest points → O(k)
28+
4 Majority voting (Counter) → O(k)
29+
🔹 Overall → O(n * d + n log n)
30+
31+
Space Complexity:
32+
O(n * d) → to store all distances and intermediate computations
33+
34+
Contributor:
35+
---------------------
36+
💻 Contributed by: **Lakhinana Chaturvedi Kashyap**
37+
"""
38+
39+
import numpy as np
40+
import matplotlib.pyplot as plt # ✅ Correct import
41+
from collections import Counter
42+
43+
# Function: Euclidean Distance
44+
45+
def euclidean_distance(p1, p2):
46+
"""
47+
Calculates the Euclidean distance between two points.
48+
49+
Formula:
50+
√( (x2 - x1)² + (y2 - y1)² + ... )
51+
"""
52+
return np.sqrt(np.sum((np.array(p1) - np.array(p2)) ** 2))
53+
54+
55+
# Function: KNN Prediction
56+
57+
def knn_prediction(training_data, training_labels, test_point, k):
58+
"""
59+
Predicts the class of a test point using the K-Nearest Neighbors algorithm.
60+
Returns both the predicted label and the k nearest points (for visualization).
61+
"""
62+
distances = []
63+
for i in range(len(training_data)):
64+
dist = euclidean_distance(test_point, training_data[i])
65+
distances.append((dist, training_labels[i], training_data[i])) # include point itself
66+
67+
# Sort by distance (ascending)
68+
distances.sort(key=lambda x: x[0])
69+
70+
# Select top k neighbors
71+
k_neighbors = [label for _, label, _ in distances[:k]]
72+
nearest_points = [point for _, _, point in distances[:k]]
73+
74+
# Majority voting
75+
prediction = Counter(k_neighbors).most_common(1)[0][0]
76+
77+
return prediction, nearest_points
78+
79+
# Example Usage
80+
81+
# Convert to NumPy arrays for easy slicing
82+
training_data = np.array([
83+
[1.0, 2.0],
84+
[2.0, 3.0],
85+
[3.0, 1.0],
86+
[6.0, 5.0],
87+
[7.0, 7.0],
88+
[8.0, 6.0]
89+
])
90+
training_labels = np.array([0, 0, 0, 1, 1, 1])
91+
92+
# Test data
93+
test_point = np.array([5.0, 5.0])
94+
k = 3
95+
96+
# Predict
97+
prediction, nearest_points = knn_prediction(training_data, training_labels, test_point, k)
98+
99+
print("Predicted label:", prediction)
100+
print("Nearest neighbors:", nearest_points)
101+
102+
# Visualization
103+
104+
plt.figure(figsize=(8, 6))
105+
106+
# Plot class 0 points (blue)
107+
plt.scatter(
108+
training_data[training_labels == 0][:, 0],
109+
training_data[training_labels == 0][:, 1],
110+
color='blue', label='Class 0', s=100
111+
)
112+
113+
# Plot class 1 points (red)
114+
plt.scatter(
115+
training_data[training_labels == 1][:, 0],
116+
training_data[training_labels == 1][:, 1],
117+
color='red', label='Class 1', s=100
118+
)
119+
120+
# Highlight nearest neighbors (yellow)
121+
nearest_points = np.array(nearest_points)
122+
plt.scatter(
123+
nearest_points[:, 0],
124+
nearest_points[:, 1],
125+
edgecolor='black', facecolor='yellow', s=200, label=f'{k} Nearest Neighbors'
126+
)
127+
128+
# Plot test point (green star)
129+
plt.scatter(
130+
test_point[0], test_point[1],
131+
color='green', marker='*', s=250, label='Test Point (Predicted)'
132+
)
133+
134+
# Labels and title
135+
plt.title(f"KNN Visualization (k={k}) — Predicted Label: {prediction}")
136+
plt.xlabel("Feature 1")
137+
plt.ylabel("Feature 2")
138+
plt.legend()
139+
plt.grid(True)
140+
plt.show()

0 commit comments

Comments
 (0)