-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathBertClassifier.py
More file actions
50 lines (42 loc) · 1.88 KB
/
BertClassifier.py
File metadata and controls
50 lines (42 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import tensorflow as tf
from transformers import BertTokenizer, TFBertModel
from tensorflow import keras
class BertClassifier:
MODEL_FILE_PATH = './model/bert_model.pkl'
EPOCHS = 100
def __init__(self):
# Download the pre-trained BERT model
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.model = TFBertModel.from_pretrained('bert-base-uncased')
# Fetch the embeddings of the sentence
def pre_process_text(self, input):
input_ids = tf.constant(self.tokenizer.encode(input, add_special_tokens=True, max_length=50, pad_to_max_length=True))[None, :]
outputs = self.model(input_ids)
last_hidden_states = outputs[0]
reshaped = np.array(last_hidden_states).flatten()
return reshaped
def train(self, inputs, labels):
self.nn_model = keras.Sequential([
keras.layers.Flatten(input_shape=(38400, )),
keras.layers.Dropout(0.7),
keras.layers.Dense(2),
keras.layers.Softmax()
])
self.nn_model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
self.nn_model.fit(inputs, labels, epochs=self.EPOCHS)
keras.models.save_model(self.nn_model, self.MODEL_FILE_PATH)
def evaluate(self, inputs, labels):
prepared_model = keras.models.load_model(self.MODEL_FILE_PATH)
if prepared_model:
self.nn_model = prepared_model
test_loss, test_acc = self.nn_model.evaluate(inputs, labels, verbose=2)
return test_acc
def predict(self, inputs):
prepared_model = keras.models.load_model(self.MODEL_FILE_PATH)
if prepared_model:
self.nn_model = prepared_model
predictions = self.nn_model.predict(inputs)
return predictions