Skip to content

Commit 94e2c1e

Browse files
committed
Add generic file for NNet
1 parent 5156c7f commit 94e2c1e

2 files changed

Lines changed: 86 additions & 0 deletions

File tree

pytorch/__init__.py

Whitespace-only changes.

pytorch/models.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
6+
class Model(nn.Module):
7+
def __init__(self, game, args, layers: int = 4):
8+
# game params
9+
self.board_x, self.board_y = game.getBoardSize()
10+
self.action_size = game.getActionSize()
11+
self.args = args
12+
13+
# nnet params
14+
self.layers = layers
15+
assert layers > 2
16+
self.shrink = 2 * (self.layers - 2)
17+
18+
self.conv = []
19+
self.batchnorm = []
20+
self.fc = []
21+
self.fcbn = []
22+
23+
super(Model, self).__init__()
24+
25+
self._setup()
26+
27+
def _setup(self):
28+
# Create Conv layers
29+
in_channels = 1
30+
kernel_size = int(float(min(self.board_x, self.board_y)) / self.layers)
31+
if kernel_size < 3:
32+
kernel_size = 3
33+
paddings = [0] * self.layers
34+
paddings[0] = 1
35+
paddings[1] = 1
36+
for i in range(self.layers):
37+
conv = nn.Conv2d(in_channels, self.args.num_channels, kernel_size, stride=1, padding=paddings[i])
38+
self.add_module(f'conv{i}', conv)
39+
self.conv.append(conv)
40+
in_channels = self.args.num_channels
41+
42+
# Prepare Batch Normalization
43+
for i in range(self.layers):
44+
bn = nn.BatchNorm2d(self.args.num_channels)
45+
self.batchnorm.append(bn)
46+
self.add_module(f'batchnorm{i}', bn)
47+
48+
# Prepare features
49+
in_features = self.args.num_channels * (self.board_x - self.shrink) * (self.board_y - self.shrink)
50+
# self.fc1 = nn.Linear(self.args.num_channels * (self.board_x-4)*(self.board_y-4), 1024)
51+
52+
out_features = 512 * 2 ** (self.layers - 2)
53+
for i in range(self.layers - 2):
54+
out_features = int(out_features / 2.0) # needs to be unchanged same outside of the loop
55+
linear = nn.Linear(in_features, out_features)
56+
self.fc.append(linear)
57+
self.add_module(f'fc{i}', linear)
58+
59+
bn = nn.BatchNorm1d(out_features)
60+
self.fcbn.append(bn)
61+
self.add_module(f'batchnorm1d{i}', bn)
62+
63+
in_features = out_features
64+
65+
self.fc_pi = nn.Linear(out_features, self.action_size)
66+
self.fc_v = nn.Linear(out_features, 1)
67+
68+
def forward(self, s: torch.Tensor):
69+
s = s.view(-1, 1, self.board_x, self.board_y)
70+
71+
for i in range(self.layers):
72+
s = F.relu(self.batchnorm[i](self.conv[i](s)))
73+
74+
size = self.args.num_channels * (self.board_x - self.shrink) * (self.board_y - self.shrink)
75+
s = s.view(-1, size)
76+
77+
fs = self.fc[0](s)
78+
bs = self.fcbn[0](fs)
79+
tensor = F.relu(bs)
80+
s = F.dropout(tensor, p=self.args.dropout, training=self.training)
81+
s = F.dropout(F.relu(self.fcbn[1](self.fc[1](s))), p=self.args.dropout, training=self.training)
82+
83+
pi = self.fc_pi(s)
84+
v = self.fc_v(s)
85+
86+
return F.log_softmax(pi, dim=1), torch.tanh(v)

0 commit comments

Comments
 (0)