-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassifier.py
More file actions
63 lines (55 loc) · 2.31 KB
/
Copy pathclassifier.py
File metadata and controls
63 lines (55 loc) · 2.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import tensorflow as tf
from tensorflow.keras.layers import Dense, Softmax
class Classifier(tf.keras.Model):
"""Class for Classifier. Inherits from Keras' Model class."""
def __init__(self, encoder, num_classes, name="classifier", **kwargs):
"""
Initializes variables and calls constructor of superclass.
:param encoder: The encoder for the classifier, may or may not be pre-trained.
:param num_classes: Number of different classes for the data.
:param name: Name of the model.
:param kwargs: Other arguments which are appropriate for the Model class.
"""
super(Classifier, self).__init__(name=name, **kwargs)
self.encoder = encoder
self.classifier_head = ClassifierHead(num_classes)
def call(self, inputs, training=None, mask=None):
"""
Forward pass for the classifier.
:param inputs: Input data.
:param training: N/A
:param mask: N/A
:return: Output from model.
"""
x = self.encoder(inputs)
return self.classifier_head(x)
def get_config(self):
""" Method inherited from superclass."""
raise NotImplementedError("Not implemented")
class ClassifierHead(tf.keras.Model):
"""Class for ClassifierHead. Inherits from Keras' Model class."""
def __init__(self, num_classes, name="classifierhead", **kwargs):
"""
Initializes variables and calls constructor of superclass.
:param num_classes: Number of different classes for the data.
:param name: Name of the model.
:param kwargs: Other arguments which are appropriate for the Model class.
"""
super(ClassifierHead, self).__init__(name=name, **kwargs)
self.dense1 = Dense(60, activation="tanh")
self.dense2 = Dense(num_classes, activation="sigmoid")
self.softmax = Softmax()
def call(self, inputs, training=None, mask=None):
"""
Forward pass for the classifier head.
:param inputs: Input data.
:param training: N/A
:param mask: N/A
:return: Output from model.
"""
x = self.dense1(inputs)
x = self.dense2(x)
return self.softmax(x)
def get_config(self):
""" Method inherited from superclass."""
raise NotImplementedError("Not implemented")