-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathconvolution_layer_from_scratch.py
More file actions
108 lines (78 loc) · 3.87 KB
/
Copy pathconvolution_layer_from_scratch.py
File metadata and controls
108 lines (78 loc) · 3.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import numpy as np
# credits to https://wiseodd.github.io/techblog/2016/07/16/convnet-conv-layer/
'''
Our conv layer will accept an input in X: DxCxHxW dimension, input filter W: NFxCxHFxHW, and bias b: Fx1, where:
D is the number of input
C is the number of image channel
H is the height of image
W is the width of the image
NF is the number of filter in the filter map W
HF is the height of the filter, and finally
HW is the width of the filter.'''
def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
# First figure out what the size of the output should be
N, C, H, W = x_shape
assert (H + 2 * padding - field_height) % stride == 0
assert (W + 2 * padding - field_height) % stride == 0
out_height = int((H + 2 * padding - field_height) / stride + 1)
out_width = int((W + 2 * padding - field_width) / stride + 1)
i0 = np.repeat(np.arange(field_height), field_width)
i0 = np.tile(i0, C)
i1 = stride * np.repeat(np.arange(out_height), out_width)
j0 = np.tile(np.arange(field_width), field_height * C)
j1 = stride * np.tile(np.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = np.repeat(np.arange(C), field_height * field_width).reshape(-1, 1)
return (k.astype(int), i.astype(int), j.astype(int))
def im2col_indices(x, field_height, field_width, padding=1, stride=1):
""" An implementation of im2col based on some fancy indexing """
# Zero-pad the input
p = padding
x_padded = np.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, stride)
cols = x_padded[:, k, i, j]
C = x.shape[1]
cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
return cols
def col2im_indices(cols, x_shape, field_height=3, field_width=3, padding=1,
stride=1):
""" An implementation of col2im based on fancy indexing and np.add.at """
N, C, H, W = x_shape
H_padded, W_padded = H + 2 * padding, W + 2 * padding
x_padded = np.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding, stride)
cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
cols_reshaped = cols_reshaped.transpose(2, 0, 1)
np.add.at(x_padded, (slice(None), k, i, j), cols_reshaped)
if padding == 0:
return x_padded
return x_padded[:, :, padding:-padding, padding:-padding]
def conv_forward(X, W, b, stride=1, padding=1):
cache = W, b, stride, padding
n_filters, d_filter, h_filter, w_filter = W.shape
n_x, d_x, h_x, w_x = X.shape
h_out = (h_x - h_filter + 2 * padding) / stride + 1
w_out = (w_x - w_filter + 2 * padding) / stride + 1
if not h_out.is_integer() or not w_out.is_integer():
raise Exception('Invalid output dimension!')
h_out, w_out = int(h_out), int(w_out)
X_col = im2col_indices(X, h_filter, w_filter, padding=padding, stride=stride)
W_col = W.reshape(n_filters, -1)
out = W_col @ X_col + b
out = out.reshape(n_filters, h_out, w_out, n_x)
out = out.transpose(3, 0, 1, 2)
cache = (X, W, b, stride, padding, X_col)
return out, cache
def conv_backward(dout, cache):
X, W, b, stride, padding, X_col = cache
n_filter, d_filter, h_filter, w_filter = W.shape
db = np.sum(dout, axis=(0, 2, 3))
db = db.reshape(n_filter, -1)
dout_reshaped = dout.transpose(1, 2, 3, 0).reshape(n_filter, -1)
dW = dout_reshaped @ X_col.T
dW = dW.reshape(W.shape)
W_reshape = W.reshape(n_filter, -1)
dX_col = W_reshape.T @ dout_reshaped
dX = col2im_indices(dX_col, X.shape, h_filter, w_filter, padding=padding, stride=stride)
return dX, dW, db