-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain_lda.py
More file actions
64 lines (48 loc) · 2.06 KB
/
train_lda.py
File metadata and controls
64 lines (48 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
This script contains the code to train lda
EMSE_DevInt>python ./python/train_lda.py 5 10 D:\Victoria\EMSE\p2\EMSE_DevInt\python\data\data_processed\processed_data.csv
"""
import pandas as pd
from sklearn.decomposition import LatentDirichletAllocation as LDA
from sklearn.feature_extraction.text import CountVectorizer
import sys
# read data from the file and return the whole csv and the series of procesed data
def read_data(filepath):
data = pd.read_csv(filepath)
processed_data = data['processed_title_and_text']
return data, processed_data
# create a count vectorizer
def create_count_vectorizer(processed_data):
# Create a count vectorizer
count_vectorizer = CountVectorizer()
# Fit and transform the processed titles
count_data = count_vectorizer.fit_transform(processed_data)
return count_vectorizer, count_data
# Print the topics found by the LDA model
def print_topics(model, count_vectorizer, n_top_words, _print):
words = count_vectorizer.get_feature_names()
topics = []
for topic_idx, topic in enumerate(model.components_):
if _print:
print("\nTopic #%d:" % topic_idx)
print(" ".join([words[i]
for i in topic.argsort()[:-n_top_words - 1:-1]]))
topics.append([words[i]
for i in topic.argsort()[:-n_top_words - 1:-1]])
return topics
# Create and fit the LDA model
def create_topics(processed_data, number_topics=5, number_words=10, _print=True):
# Create a count vectorizer
count_vectorizer, count_data = create_count_vectorizer(processed_data)
lda = LDA(n_components=number_topics, n_jobs=-1)
lda.fit(count_data)
topics = print_topics(lda, count_vectorizer, number_words, _print)
return lda, topics, count_vectorizer
def main(argv):
number_topics = int(argv[1])
number_words = int(argv[2])
datafilepath = argv[3]
data, processed_data = read_data(datafilepath)
_, _, _ = create_topics(processed_data, number_topics=number_topics, number_words=number_words)
if __name__ == '__main__':
main(sys.argv)