-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwinnow.py
More file actions
31 lines (23 loc) · 764 Bytes
/
winnow.py
File metadata and controls
31 lines (23 loc) · 764 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import numpy as np
from helper_functions import error
class Winnow:
def __init__(self):
self.w = None
def train(self, train_x, train_y):
n = train_x.shape[1]
w = np.ones(n)
for index, x in enumerate(train_x):
y_hat = 1 * (x.dot(w) >= n)
if y_hat != train_y[index]:
if train_y[index] - y_hat == 1:
w *= (2 ** x)
elif train_y[index] - y_hat == -1:
w = w / (2 ** x)
self.w = w
pred = 1 * (train_x.dot(self.w) >= n)
train_err = error(pred, train_y)
return train_err
def predict(self, test_x):
n = test_x.shape[1]
y_hat = 1 * (test_x.dot(self.w) >= n)
return y_hat