-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
95 lines (72 loc) · 3.3 KB
/
models.py
File metadata and controls
95 lines (72 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 27 15:21:15 2022
@author: user
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from efficientnet_pytorch import EfficientNet
class EffNet(nn.Module):
def __init__(self, n_labels, dtype=torch.FloatTensor, device=torch.device("cpu"), type='b0'):
"""
Initializes the EffNet nn model class. This is the class used for the transcoders effnet_b0 and effnet_b7 in exp_train_model/main_doce_training.py.
Args:
- mels_tr: The Mel transform used for converting audio into Mel spectrograms. Here it just serves to retrieve the number of labels
that corresponds to the classifier outputs (527 for PANN, 521 for YamNet)
- effnet_type: effnet_b0 or effnet_b7
- dtype: The data type for the model (default: torch.FloatTensor).
- device: The device to run the model on (default: torch.device("cpu")).
"""
super().__init__()
print(f'Efficient net {type}')
###############
#models loading
if type == 'b7':
self.model = EfficientNet.from_name('efficientnet-b7', num_classes=n_labels)
self.model._conv_stem = nn.Conv2d(1, 64, kernel_size=3, stride=2, bias=False)
if type == 'b0':
self.model = EfficientNet.from_name('efficientnet-b0', num_classes=n_labels)
self.model._conv_stem = nn.Conv2d(1, 32, kernel_size=3, stride=2, bias=False)
if type == 'b1':
self.model = EfficientNet.from_name('efficientnet-b1', num_classes=n_labels)
self.model._conv_stem = nn.Conv2d(1, 32, kernel_size=3, stride=2, bias=False)
if type == 'b5':
self.model = EfficientNet.from_name('efficientnet-b5', num_classes=n_labels)
self.model._conv_stem = nn.Conv2d(1, 32, kernel_size=3, stride=2, bias=False)
self.model.to(device)
def forward(self, x):
x = torch.unsqueeze(x, dim=1)
y_pred = self.model(x)
y_pred = torch.sigmoid(y_pred)
return y_pred
# class MLP(nn.Module):
# def __init__(self, input_shape, output_shape, dtype=torch.FloatTensor,
# hl_1=100, hl_2=50):
# super().__init__()
# self.input_shape = input_shape
# self.output_shape = output_shape
# self.hl_1 = hl_1
# self.hl_2 = hl_2
# self.input_fc = nn.Linear(input_shape, hl_1)
# self.hidden_fc = nn.Linear(hl_1, hl_2)
# self.output_fc = nn.Linear(hl_2, output_shape)
# self.dtype = dtype
# def forward(self, x):
# # x = [batch size, height, width]
# # MT: useless lines (maybe when 2d spectrogramms given ?)
# #batch_size = x.shape[0]
# #x = x.view(batch_size, -1)
# # x = [batch size, height * width]
# x = torch.squeeze(x, dim=-1)
# h_1 = F.relu(self.input_fc(x))
# # h_1 = [batch size, 250]
# h_2 = F.relu(self.hidden_fc(h_1))
# # h_2 = [batch size, 100]
# y_pred = self.output_fc(h_2)
# y_pred = torch.sigmoid(y_pred)
# # y_pred = torch.reshape(y_pred, (y_pred.shape[0], self.output_shape[0], self.output_shape[1]))
# # y_pred = [batch size, output dim]
# return y_pred