-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathLSTM_CNN_tile_by_tile_obsolete.py
More file actions
92 lines (76 loc) · 3.87 KB
/
Copy pathLSTM_CNN_tile_by_tile_obsolete.py
File metadata and controls
92 lines (76 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
---------------------------------------OBSOLETE BY CAE LSTM---------------------------------------
---------------------------------------OBSOLETE BY CAE LSTM---------------------------------------
---------------------------------------OBSOLETE BY CAE LSTM---------------------------------------
---------------------------------------OBSOLETE BY CAE LSTM---------------------------------------
from typing import Dict, Any, Tuple, List
import torch
from torch import nn
from algorithms.basic_testing import BasicTesting
from algorithms.learning.online_lstm import BasicLSTMModule, OnlineLSTM
from simulator.services.services import Services
class LSTMCNN(BasicLSTMModule):
def __init__(self, services: Services, config: Dict[str, any]):
super().__init__(services, config)
self._hidden_state = None
self._normalisation_layer1 = nn.BatchNorm1d(num_features=9*9+60*60)
self.__local_pipeline = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(in_channels=3, out_channels=1, kernel_size=3, padding=1),
nn.ReLU(),
)
self.__global_pipeline = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=3, kernel_size=3, padding=1), # 60 * 60
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, padding=1), # 30 * 30
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(in_channels=6, out_channels=3, kernel_size=3, padding=1), # 15 * 15
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(in_channels=3, out_channels=1, kernel_size=3, padding=1), # 7 * 7
nn.ReLU(),
)
self._lstm_layer = nn.LSTM(self.config["lstm_input_size"], self.config["lstm_output_size"],
num_layers=self.config["num_layers"], batch_first=True)
self._normalisation_layer2 = nn.BatchNorm1d(num_features=self.config["lstm_output_size"])
self._fc = nn.Linear(in_features=self.config["lstm_output_size"], out_features=self.config["lstm_output_size"])
def forward(self, x: torch.Tensor, x_len: torch.Tensor, perm: List[int] = None) -> torch.Tensor:
normalized_data = self._normalisation_layer1(x.view(-1, x.shape[-1])).view(x.shape)
local_map: torch.Tensor = normalized_data[:, :, :9*9].view((-1, 1, 9, 9))
global_map: torch.Tensor = normalized_data[:, :, 9*9:].view(-1, 1, 60, 60)
local_out = self.__local_pipeline(local_map).view((x.shape[0], x.shape[1], -1))
global_out = self.__global_pipeline(global_map).view((x.shape[0], x.shape[1], -1))
all_out = torch.cat((local_out, global_out), len(x.shape) - 1)
packed_data, _ = BasicLSTMModule.pack_data(all_out, x_len, perm)
packed_out, self._hidden_state = self._lstm_layer(packed_data, self._hidden_state)
normalized_lstm_out = self._normalisation_layer2(packed_out.data)
out = self._fc(normalized_lstm_out)
return out
@staticmethod
def get_config() -> Dict[str, Any]:
return {
"data_features": [
"local_map",
"global_map",
],
"data_labels": [
"next_position_index",
],
"save_name": "cnn_tile_by_tile",
"training_data": "training_1000",
"epochs": 5,
"num_layers": 2,
"lstm_input_size": 65,
"lstm_output_size": 8,
"loss": nn.CrossEntropyLoss(),
"optimizer": lambda model: torch.optim.Adam(model.parameters(), lr=0.01),
}
class LSTMCNNTileByTile(OnlineLSTM):
def __init__(self, services: Services, testing: BasicTesting = None):
super().__init__(services, testing)
self._load_name = "cnn_tile_by_tile_training_1000_model"
"""