-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathrecurrent_BatchNorm.py
More file actions
56 lines (50 loc) · 2.33 KB
/
recurrent_BatchNorm.py
File metadata and controls
56 lines (50 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import torch.nn as nn
import torch.nn.functional as F
class recurrent_BatchNorm(nn.Module):
def __init__(self, num_features, max_len, eps=1e-5, momentum=0.1, affine=True):
super(recurrent_BatchNorm, self).__init__()
self.num_features = num_features
self.affine = affine
self.max_len = max_len
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = nn.Parameter(torch.Tensor(num_features))
self.register_parameter('weight', self.weight)
self.bias = nn.Parameter(torch.Tensor(num_features))
self.register_parameter('bias', self.bias)
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
for i in xrange(max_len):
self.register_buffer('running_mean_{}'.format(i), torch.zeros(num_features))
self.register_buffer('running_var_{}'.format(i), torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
for i in xrange(self.max_len):
running_mean = getattr(self, 'running_mean_{}'.format(i))
running_mean.zero_()
running_var = getattr(self, 'running_var_{}'.format(i))
running_var.fill_(1)
if self.affine:
self.weight.data.uniform_()
self.bias.data.zero_()
def _check_input_dim(self, input_, index):
running_mean = getattr(self, 'running_mean_{}'.format(index))
if input_.size(1) != running_mean.nelement():
raise ValueError('got {}-feature tensor, expected {}'
.format(input_.size(1), self.num_features))
def forward(self, input_, index):
if index >= self.max_len:
index = self.max_len - 1
self._check_input_dim(input_, index)
running_mean = getattr(self, 'running_mean_{}'.format(index))
running_var = getattr(self, 'running_var_{}'.format(index))
return F.batch_norm(
input_, running_mean, running_var, self.weight, self.bias,
self.training, self.momentum, self.eps)
def __repr__(self):
return ('{name}({num_features}, eps={eps}, momentum={momentum},'
' max_length={max_length}, affine={affine})'
.format(name=self.__class__.__name__, **self.__dict__))