-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_classifier.py
More file actions
46 lines (36 loc) · 1.45 KB
/
train_classifier.py
File metadata and controls
46 lines (36 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import pickle
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from data.preprocess_tweets import get_twitter_dataset
def train_tweet_classifier(X_train, Y_train, X_test, Y_test):
"""
Train different tweets classifier and save the best performing one
:param X_train: train input features
:param Y_train: train labels
:param X_test: test input features
:param Y_test: test labels
"""
# Train different tweets classifiers and keep the best one
# Logistic Regression classifier
lr_classifier = LogisticRegression(random_state=42)
lr_classifier.fit(X_train, Y_train)
# SVM classifier
svm_classifier = SVC(random_state=42)
svm_classifier.fit(X_train, Y_train)
# Naïve Bayes classifier
nb_classifier = GaussianNB()
nb_classifier.fit(X_train, Y_train)
classifiers = [lr_classifier, svm_classifier, nb_classifier]
accuracies = []
for classifier in classifiers:
Y_pred = classifier.predict(X_test)
accuracies.append(accuracy_score(Y_test, Y_pred))
best_classifier = classifiers[np.argmax(accuracies)]
filename = "tweet_classifier.sav"
pickle.dump(best_classifier, open(filename, "wb"))
if __name__ == "__main__":
X_train, Y_train, X_test, Y_test = get_twitter_dataset()
train_tweet_classifier(X_train, Y_train, X_test, Y_test)