-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy path3. gain.py
More file actions
28 lines (22 loc) · 772 Bytes
/
3. gain.py
File metadata and controls
28 lines (22 loc) · 772 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from collections import Counter
unsplit_labels = ["unacc", "unacc", "unacc", "unacc", "unacc", "unacc", "good", "good", "good", "good", "vgood", "vgood", "vgood"]
split_labels_1 = [
["unacc", "unacc", "unacc", "unacc", "unacc", "unacc", "good", "good", "vgood"],
[ "good", "good"],
["vgood", "vgood"]
]
split_labels_2 = [
["unacc", "unacc", "unacc", "unacc","unacc", "unacc", "good", "good", "good", "good"],
["vgood", "vgood", "vgood"]
]
def gini(dataset):
impurity = 1
label_counts = Counter(dataset)
for label in label_counts:
prob_of_label = label_counts[label] / len(dataset)
impurity -= prob_of_label ** 2
return impurity
info_gain = gini(unsplit_labels)
for subset in split_labels_1:
info_gain -= gini(subset)
print(info_gain)