Skip to content

Commit 8d10c36

Browse files
committed
Adding-Classification-On-Decision-Tree
1 parent 841e947 commit 8d10c36

File tree

2 files changed

+122
-39
lines changed

2 files changed

+122
-39
lines changed

.vscode/settings.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
{
22
"githubPullRequests.ignoredPullRequestBranches": [
33
"master"
4-
]
4+
],
5+
"python-envs.defaultEnvManager": "ms-python.python:system",
6+
"python-envs.defaultPackageManager": "ms-python.python:pip"
57
}

machine_learning/decision_tree.py

Lines changed: 119 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,20 @@
55
"""
66

77
import numpy as np
8+
from collections import Counter
89

910

1011
class DecisionTree:
11-
def __init__(self, depth=5, min_leaf_size=5):
12+
def __init__(self, depth=5, min_leaf_size=5, task="regression", criterion="gini"):
1213
self.depth = depth
1314
self.decision_boundary = 0
1415
self.left = None
1516
self.right = None
1617
self.min_leaf_size = min_leaf_size
1718
self.prediction = None
18-
19+
self.task = task
20+
self.criterion = criterion
21+
1922
def mean_squared_error(self, labels, prediction):
2023
"""
2124
mean_squared_error:
@@ -38,10 +41,62 @@ def mean_squared_error(self, labels, prediction):
3841
True
3942
"""
4043
if labels.ndim != 1:
41-
print("Error: Input labels must be one dimensional")
42-
44+
raise ValueError("Input labels must be one dimensional")
4345
return np.mean((labels - prediction) ** 2)
4446

47+
def gini(self, y):
48+
"""
49+
Computes the Gini impurity for a set of labels.
50+
Gini impurity measures how often a randomly chosen element
51+
would be incorrectly classified.
52+
Formula: Gini = 1 - sum(p_i^2)
53+
where p_i is the probability of class i.
54+
55+
Lower Gini value indicates better purity (best split).
56+
"""
57+
classes, counts = np.unique(y, return_counts=True)
58+
prob = counts / counts.sum()
59+
return 1 - np.sum(prob ** 2)
60+
61+
def entropy(self, y):
62+
"""
63+
Computes the entropy (impurity) of a set of labels.
64+
Entropy measures the randomness or disorder in the data.
65+
Formula: Entropy = - sum(p_i * log2(p_i))
66+
where p_i is the probability of class i.
67+
68+
Lower entropy means higher purity.
69+
"""
70+
classes, counts = np.unique(y, return_counts=True)
71+
prob = counts / counts.sum()
72+
return -np.sum(prob * np.log2(prob + 1e-9))
73+
74+
def information_gain(self, parent, left, right):
75+
"""
76+
Computes the information gain from splitting a dataset.
77+
Information gain represents the reduction in impurity
78+
after a dataset is split into left and right subsets.
79+
Formula: IG = Impurity(parent) - [weighted impurity(left) + weighted impurity(right)]
80+
81+
Higher information gain indicates a better split.
82+
"""
83+
if self.criterion == "gini":
84+
func = self.gini
85+
elif self.criterion == "entropy":
86+
func = self.entropy
87+
else:
88+
raise ValueError("Invalid criterion")
89+
90+
weight_l = len(left) / len(parent)
91+
weight_r = len(right) / len(parent)
92+
93+
return func(parent) - (
94+
weight_l * func(left) + weight_r * func(right)
95+
)
96+
97+
def most_common_label(self, y):
98+
return Counter(y).most_common(1)[0][0]
99+
45100
def train(self, x, y):
46101
"""
47102
train:
@@ -87,35 +142,50 @@ def train(self, x, y):
87142
if y.ndim != 1:
88143
raise ValueError("Data set labels must be one-dimensional")
89144

90-
if len(x) < 2 * self.min_leaf_size:
91-
self.prediction = np.mean(y)
92-
return
93-
94-
if self.depth == 1:
95-
self.prediction = np.mean(y)
145+
if len(x) < 2 * self.min_leaf_size or self.depth == 1:
146+
if self.task == "regression":
147+
self.prediction = np.mean(y)
148+
else:
149+
self.prediction = self.most_common_label(y)
96150
return
97151

98152
best_split = 0
99-
min_error = self.mean_squared_error(x, np.mean(y)) * 2
100-
153+
101154
"""
102155
loop over all possible splits for the decision tree. find the best split.
103156
if no split exists that is less than 2 * error for the entire array
104157
then the data set is not split and the average for the entire array is used as
105158
the predictor
106159
"""
160+
if self.task == "regression":
161+
best_score = float("inf")
162+
else:
163+
best_score = -float("inf")
164+
107165
for i in range(len(x)):
108-
if len(x[:i]) < self.min_leaf_size: # noqa: SIM114
166+
if len(x[:i]) < self.min_leaf_size:
109167
continue
110-
elif len(x[i:]) < self.min_leaf_size:
168+
if len(x[i:]) < self.min_leaf_size:
111169
continue
112-
else:
113-
error_left = self.mean_squared_error(x[:i], np.mean(y[:i]))
114-
error_right = self.mean_squared_error(x[i:], np.mean(y[i:]))
115-
error = error_left + error_right
116-
if error < min_error:
170+
171+
left_y = y[:i]
172+
right_y = y[i:]
173+
174+
if self.task == "regression":
175+
error_left = self.mean_squared_error(left_y, np.mean(left_y))
176+
error_right = self.mean_squared_error(right_y, np.mean(right_y))
177+
score = error_left + error_right
178+
179+
if score < best_score:
180+
best_score = score
181+
best_split = i
182+
183+
else:
184+
gain = self.information_gain(y, left_y, right_y)
185+
186+
if gain > best_score:
187+
best_score = gain
117188
best_split = i
118-
min_error = error
119189

120190
if best_split != 0:
121191
left_x = x[:best_split]
@@ -124,18 +194,28 @@ def train(self, x, y):
124194
right_y = y[best_split:]
125195

126196
self.decision_boundary = x[best_split]
197+
127198
self.left = DecisionTree(
128-
depth=self.depth - 1, min_leaf_size=self.min_leaf_size
199+
depth=self.depth - 1,
200+
min_leaf_size=self.min_leaf_size,
201+
task=self.task,
202+
criterion=self.criterion,
129203
)
130204
self.right = DecisionTree(
131-
depth=self.depth - 1, min_leaf_size=self.min_leaf_size
205+
depth=self.depth - 1,
206+
min_leaf_size=self.min_leaf_size,
207+
task=self.task,
208+
criterion=self.criterion,
132209
)
210+
133211
self.left.train(left_x, left_y)
134212
self.right.train(right_x, right_y)
135-
else:
136-
self.prediction = np.mean(y)
137213

138-
return
214+
else:
215+
if self.task == "regression":
216+
self.prediction = np.mean(y)
217+
else:
218+
self.prediction = self.most_common_label(y)
139219

140220
def predict(self, x):
141221
"""
@@ -146,15 +226,15 @@ def predict(self, x):
146226
"""
147227
if self.prediction is not None:
148228
return self.prediction
149-
elif self.left is not None and self.right is not None:
229+
if self.left is not None and self.right is not None:
150230
if x >= self.decision_boundary:
151231
return self.right.predict(x)
152232
else:
153233
return self.left.predict(x)
154-
else:
155-
raise ValueError("Decision tree not yet trained")
156234

235+
raise ValueError("Decision tree not yet trained")
157236

237+
158238
class TestDecisionTree:
159239
"""Decision Tres test class"""
160240

@@ -172,7 +252,7 @@ def helper_mean_squared_error_test(labels, prediction):
172252

173253
return float(squared_error_sum / labels.size)
174254

175-
255+
176256
def main():
177257
"""
178258
In this demonstration we're generating a sample data set from the sin function in
@@ -183,21 +263,22 @@ def main():
183263
x = np.arange(-1.0, 1.0, 0.005)
184264
y = np.sin(x)
185265

186-
tree = DecisionTree(depth=10, min_leaf_size=10)
266+
tree = DecisionTree(depth=10, min_leaf_size=10, task="regression")
187267
tree.train(x, y)
188268

189-
rng = np.random.default_rng()
190-
test_cases = (rng.random(10) * 2) - 1
191-
predictions = np.array([tree.predict(x) for x in test_cases])
192-
avg_error = np.mean((predictions - test_cases) ** 2)
269+
print("Regression prediction:", tree.predict(0.5))
270+
x_cls = np.array([1, 2, 3, 4, 5, 6])
271+
y_cls = np.array([0, 0, 0, 1, 1, 1])
272+
273+
clf = DecisionTree(depth=3, min_leaf_size=1, task="classification", criterion="gini")
274+
clf.train(x_cls, y_cls)
193275

194-
print("Test values: " + str(test_cases))
195-
print("Predictions: " + str(predictions))
196-
print("Average error: " + str(avg_error))
276+
print("Classification prediction (2):", clf.predict(2))
277+
print("Classification prediction (5):", clf.predict(5))
197278

198279

199280
if __name__ == "__main__":
200281
main()
201282
import doctest
202283

203-
doctest.testmod(name="mean_squared_error", verbose=True)
284+
doctest.testmod(name="mean_squared_error", verbose=True)

0 commit comments

Comments
 (0)