55"""
66
77import numpy as np
8+ from collections import Counter
89
910
1011class DecisionTree :
11- def __init__ (self , depth = 5 , min_leaf_size = 5 ):
12+ def __init__ (self , depth = 5 , min_leaf_size = 5 , task = "regression" , criterion = "gini" ):
1213 self .depth = depth
1314 self .decision_boundary = 0
1415 self .left = None
1516 self .right = None
1617 self .min_leaf_size = min_leaf_size
1718 self .prediction = None
18-
19+ self .task = task
20+ self .criterion = criterion
21+
1922 def mean_squared_error (self , labels , prediction ):
2023 """
2124 mean_squared_error:
@@ -38,10 +41,62 @@ def mean_squared_error(self, labels, prediction):
3841 True
3942 """
4043 if labels .ndim != 1 :
41- print ("Error: Input labels must be one dimensional" )
42-
44+ raise ValueError ("Input labels must be one dimensional" )
4345 return np .mean ((labels - prediction ) ** 2 )
4446
47+ def gini (self , y ):
48+ """
49+ Computes the Gini impurity for a set of labels.
50+ Gini impurity measures how often a randomly chosen element
51+ would be incorrectly classified.
52+ Formula: Gini = 1 - sum(p_i^2)
53+ where p_i is the probability of class i.
54+
55+ Lower Gini value indicates better purity (best split).
56+ """
57+ classes , counts = np .unique (y , return_counts = True )
58+ prob = counts / counts .sum ()
59+ return 1 - np .sum (prob ** 2 )
60+
61+ def entropy (self , y ):
62+ """
63+ Computes the entropy (impurity) of a set of labels.
64+ Entropy measures the randomness or disorder in the data.
65+ Formula: Entropy = - sum(p_i * log2(p_i))
66+ where p_i is the probability of class i.
67+
68+ Lower entropy means higher purity.
69+ """
70+ classes , counts = np .unique (y , return_counts = True )
71+ prob = counts / counts .sum ()
72+ return - np .sum (prob * np .log2 (prob + 1e-9 ))
73+
74+ def information_gain (self , parent , left , right ):
75+ """
76+ Computes the information gain from splitting a dataset.
77+ Information gain represents the reduction in impurity
78+ after a dataset is split into left and right subsets.
79+ Formula: IG = Impurity(parent) - [weighted impurity(left) + weighted impurity(right)]
80+
81+ Higher information gain indicates a better split.
82+ """
83+ if self .criterion == "gini" :
84+ func = self .gini
85+ elif self .criterion == "entropy" :
86+ func = self .entropy
87+ else :
88+ raise ValueError ("Invalid criterion" )
89+
90+ weight_l = len (left ) / len (parent )
91+ weight_r = len (right ) / len (parent )
92+
93+ return func (parent ) - (
94+ weight_l * func (left ) + weight_r * func (right )
95+ )
96+
97+ def most_common_label (self , y ):
98+ return Counter (y ).most_common (1 )[0 ][0 ]
99+
45100 def train (self , x , y ):
46101 """
47102 train:
@@ -87,35 +142,50 @@ def train(self, x, y):
87142 if y .ndim != 1 :
88143 raise ValueError ("Data set labels must be one-dimensional" )
89144
90- if len (x ) < 2 * self .min_leaf_size :
91- self .prediction = np .mean (y )
92- return
93-
94- if self .depth == 1 :
95- self .prediction = np .mean (y )
145+ if len (x ) < 2 * self .min_leaf_size or self .depth == 1 :
146+ if self .task == "regression" :
147+ self .prediction = np .mean (y )
148+ else :
149+ self .prediction = self .most_common_label (y )
96150 return
97151
98152 best_split = 0
99- min_error = self .mean_squared_error (x , np .mean (y )) * 2
100-
153+
101154 """
102155 loop over all possible splits for the decision tree. find the best split.
103156 if no split exists that is less than 2 * error for the entire array
104157 then the data set is not split and the average for the entire array is used as
105158 the predictor
106159 """
160+ if self .task == "regression" :
161+ best_score = float ("inf" )
162+ else :
163+ best_score = - float ("inf" )
164+
107165 for i in range (len (x )):
108- if len (x [:i ]) < self .min_leaf_size : # noqa: SIM114
166+ if len (x [:i ]) < self .min_leaf_size :
109167 continue
110- elif len (x [i :]) < self .min_leaf_size :
168+ if len (x [i :]) < self .min_leaf_size :
111169 continue
112- else :
113- error_left = self .mean_squared_error (x [:i ], np .mean (y [:i ]))
114- error_right = self .mean_squared_error (x [i :], np .mean (y [i :]))
115- error = error_left + error_right
116- if error < min_error :
170+
171+ left_y = y [:i ]
172+ right_y = y [i :]
173+
174+ if self .task == "regression" :
175+ error_left = self .mean_squared_error (left_y , np .mean (left_y ))
176+ error_right = self .mean_squared_error (right_y , np .mean (right_y ))
177+ score = error_left + error_right
178+
179+ if score < best_score :
180+ best_score = score
181+ best_split = i
182+
183+ else :
184+ gain = self .information_gain (y , left_y , right_y )
185+
186+ if gain > best_score :
187+ best_score = gain
117188 best_split = i
118- min_error = error
119189
120190 if best_split != 0 :
121191 left_x = x [:best_split ]
@@ -124,18 +194,28 @@ def train(self, x, y):
124194 right_y = y [best_split :]
125195
126196 self .decision_boundary = x [best_split ]
197+
127198 self .left = DecisionTree (
128- depth = self .depth - 1 , min_leaf_size = self .min_leaf_size
199+ depth = self .depth - 1 ,
200+ min_leaf_size = self .min_leaf_size ,
201+ task = self .task ,
202+ criterion = self .criterion ,
129203 )
130204 self .right = DecisionTree (
131- depth = self .depth - 1 , min_leaf_size = self .min_leaf_size
205+ depth = self .depth - 1 ,
206+ min_leaf_size = self .min_leaf_size ,
207+ task = self .task ,
208+ criterion = self .criterion ,
132209 )
210+
133211 self .left .train (left_x , left_y )
134212 self .right .train (right_x , right_y )
135- else :
136- self .prediction = np .mean (y )
137213
138- return
214+ else :
215+ if self .task == "regression" :
216+ self .prediction = np .mean (y )
217+ else :
218+ self .prediction = self .most_common_label (y )
139219
140220 def predict (self , x ):
141221 """
@@ -146,15 +226,15 @@ def predict(self, x):
146226 """
147227 if self .prediction is not None :
148228 return self .prediction
149- elif self .left is not None and self .right is not None :
229+ if self .left is not None and self .right is not None :
150230 if x >= self .decision_boundary :
151231 return self .right .predict (x )
152232 else :
153233 return self .left .predict (x )
154- else :
155- raise ValueError ("Decision tree not yet trained" )
156234
235+ raise ValueError ("Decision tree not yet trained" )
157236
237+
158238class TestDecisionTree :
159239 """Decision Tres test class"""
160240
@@ -172,7 +252,7 @@ def helper_mean_squared_error_test(labels, prediction):
172252
173253 return float (squared_error_sum / labels .size )
174254
175-
255+
176256def main ():
177257 """
178258 In this demonstration we're generating a sample data set from the sin function in
@@ -183,21 +263,22 @@ def main():
183263 x = np .arange (- 1.0 , 1.0 , 0.005 )
184264 y = np .sin (x )
185265
186- tree = DecisionTree (depth = 10 , min_leaf_size = 10 )
266+ tree = DecisionTree (depth = 10 , min_leaf_size = 10 , task = "regression" )
187267 tree .train (x , y )
188268
189- rng = np .random .default_rng ()
190- test_cases = (rng .random (10 ) * 2 ) - 1
191- predictions = np .array ([tree .predict (x ) for x in test_cases ])
192- avg_error = np .mean ((predictions - test_cases ) ** 2 )
269+ print ("Regression prediction:" , tree .predict (0.5 ))
270+ x_cls = np .array ([1 , 2 , 3 , 4 , 5 , 6 ])
271+ y_cls = np .array ([0 , 0 , 0 , 1 , 1 , 1 ])
272+
273+ clf = DecisionTree (depth = 3 , min_leaf_size = 1 , task = "classification" , criterion = "gini" )
274+ clf .train (x_cls , y_cls )
193275
194- print ("Test values: " + str (test_cases ))
195- print ("Predictions: " + str (predictions ))
196- print ("Average error: " + str (avg_error ))
276+ print ("Classification prediction (2):" , clf .predict (2 ))
277+ print ("Classification prediction (5):" , clf .predict (5 ))
197278
198279
199280if __name__ == "__main__" :
200281 main ()
201282 import doctest
202283
203- doctest .testmod (name = "mean_squared_error" , verbose = True )
284+ doctest .testmod (name = "mean_squared_error" , verbose = True )
0 commit comments