|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | + |
| 5 | + |
| 6 | +class Model(nn.Module): |
| 7 | + def __init__(self, game, args, layers: int = 4): |
| 8 | + # game params |
| 9 | + self.board_x, self.board_y = game.getBoardSize() |
| 10 | + self.action_size = game.getActionSize() |
| 11 | + self.args = args |
| 12 | + |
| 13 | + # nnet params |
| 14 | + self.layers = layers |
| 15 | + assert layers > 2 |
| 16 | + self.shrink = 2 * (self.layers - 2) |
| 17 | + |
| 18 | + self.conv = [] |
| 19 | + self.batchnorm = [] |
| 20 | + self.fc = [] |
| 21 | + self.fcbn = [] |
| 22 | + |
| 23 | + super(Model, self).__init__() |
| 24 | + |
| 25 | + self._setup() |
| 26 | + |
| 27 | + def _setup(self): |
| 28 | + # Create Conv layers |
| 29 | + in_channels = 1 |
| 30 | + kernel_size = int(float(min(self.board_x, self.board_y)) / self.layers) |
| 31 | + if kernel_size < 3: |
| 32 | + kernel_size = 3 |
| 33 | + paddings = [0] * self.layers |
| 34 | + paddings[0] = 1 |
| 35 | + paddings[1] = 1 |
| 36 | + for i in range(self.layers): |
| 37 | + conv = nn.Conv2d(in_channels, self.args.num_channels, kernel_size, stride=1, padding=paddings[i]) |
| 38 | + self.add_module(f'conv{i}', conv) |
| 39 | + self.conv.append(conv) |
| 40 | + in_channels = self.args.num_channels |
| 41 | + |
| 42 | + # Prepare Batch Normalization |
| 43 | + for i in range(self.layers): |
| 44 | + bn = nn.BatchNorm2d(self.args.num_channels) |
| 45 | + self.batchnorm.append(bn) |
| 46 | + self.add_module(f'batchnorm{i}', bn) |
| 47 | + |
| 48 | + # Prepare features |
| 49 | + in_features = self.args.num_channels * (self.board_x - self.shrink) * (self.board_y - self.shrink) |
| 50 | + # self.fc1 = nn.Linear(self.args.num_channels * (self.board_x-4)*(self.board_y-4), 1024) |
| 51 | + |
| 52 | + out_features = 512 * 2 ** (self.layers - 2) |
| 53 | + for i in range(self.layers - 2): |
| 54 | + out_features = int(out_features / 2.0) # needs to be unchanged same outside of the loop |
| 55 | + linear = nn.Linear(in_features, out_features) |
| 56 | + self.fc.append(linear) |
| 57 | + self.add_module(f'fc{i}', linear) |
| 58 | + |
| 59 | + bn = nn.BatchNorm1d(out_features) |
| 60 | + self.fcbn.append(bn) |
| 61 | + self.add_module(f'batchnorm1d{i}', bn) |
| 62 | + |
| 63 | + in_features = out_features |
| 64 | + |
| 65 | + self.fc_pi = nn.Linear(out_features, self.action_size) |
| 66 | + self.fc_v = nn.Linear(out_features, 1) |
| 67 | + |
| 68 | + def forward(self, s: torch.Tensor): |
| 69 | + s = s.view(-1, 1, self.board_x, self.board_y) |
| 70 | + |
| 71 | + for i in range(self.layers): |
| 72 | + s = F.relu(self.batchnorm[i](self.conv[i](s))) |
| 73 | + |
| 74 | + size = self.args.num_channels * (self.board_x - self.shrink) * (self.board_y - self.shrink) |
| 75 | + s = s.view(-1, size) |
| 76 | + |
| 77 | + fs = self.fc[0](s) |
| 78 | + bs = self.fcbn[0](fs) |
| 79 | + tensor = F.relu(bs) |
| 80 | + s = F.dropout(tensor, p=self.args.dropout, training=self.training) |
| 81 | + s = F.dropout(F.relu(self.fcbn[1](self.fc[1](s))), p=self.args.dropout, training=self.training) |
| 82 | + |
| 83 | + pi = self.fc_pi(s) |
| 84 | + v = self.fc_v(s) |
| 85 | + |
| 86 | + return F.log_softmax(pi, dim=1), torch.tanh(v) |
0 commit comments