-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Expand file tree
/
Copy pathtrain.py
More file actions
80 lines (54 loc) · 1.84 KB
/
train.py
File metadata and controls
80 lines (54 loc) · 1.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import pickle
import pandas as pd
import sklearn
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import make_pipeline
print(f'pandas=={pd.__version__}')
print(f'sklearn=={sklearn.__version__}')
def load_data():
data_url = 'https://raw.githubusercontent.com/alexeygrigorev/mlbookcamp-code/master/chapter-03-churn-prediction/WA_Fn-UseC_-Telco-Customer-Churn.csv'
df = pd.read_csv(data_url)
df.columns = df.columns.str.lower().str.replace(' ', '_')
categorical_columns = list(df.dtypes[df.dtypes == 'object'].index)
for c in categorical_columns:
df[c] = df[c].str.lower().str.replace(' ', '_')
df.totalcharges = pd.to_numeric(df.totalcharges, errors='coerce')
df.totalcharges = df.totalcharges.fillna(0)
df.churn = (df.churn == 'yes').astype(int)
return df
def train_model(df):
numerical = ['tenure', 'monthlycharges', 'totalcharges']
categorical = [
'gender',
'seniorcitizen',
'partner',
'dependents',
'phoneservice',
'multiplelines',
'internetservice',
'onlinesecurity',
'onlinebackup',
'deviceprotection',
'techsupport',
'streamingtv',
'streamingmovies',
'contract',
'paperlessbilling',
'paymentmethod',
]
y_train = df.churn
train_dict = df[categorical + numerical].to_dict(orient='records')
pipeline = make_pipeline(
DictVectorizer(),
LogisticRegression(solver='liblinear')
)
pipeline.fit(train_dict, y_train)
return pipeline
def save_model(pipeline, output_file):
with open(output_file, 'wb') as f_out:
pickle.dump(pipeline, f_out)
df = load_data()
pipeline = train_model(df)
save_model(pipeline, 'model.bin')
print('Model saved to model.bin')