Skip to content

Commit d5dbb5d

Browse files
committed
Update mnist_compression.py
1 parent 3477a8f commit d5dbb5d

1 file changed

Lines changed: 37 additions & 29 deletions

File tree

examples/network_compression/mnist_compression.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -52,43 +52,51 @@ def __init__(
5252
super().__init__()
5353
self.conv1 = nn.Conv2d(1, 20, 5, 1)
5454
self.conv2 = nn.Conv2d(20, 50, 5, 1)
55-
self.dropout1 = nn.Dropout2d(0.25)
56-
self.dropout2 = nn.Dropout2d(0.5)
57-
self.wavelet = wavelet
55+
self.max_pool_2s_k2 = torch.nn.MaxPool2d(2)
56+
57+
self.log_softmax = torch.nn.LogSoftmax(dim=1)
58+
self.flatten = torch.nn.Flatten(start_dim=1)
59+
self.relu = torch.nn.ReLU()
60+
5861
if compression == "None":
59-
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
60-
self.fc2 = torch.nn.Linear(500, 10)
61-
self.do_dropout = True
62+
fc1 = torch.nn.Linear(4 * 4 * 50, 500)
63+
fc2 = torch.nn.Linear(500, 10)
64+
self.sequence = torch.nn.Sequential(
65+
self.conv1,
66+
self.max_pool_2s_k2,
67+
self.conv2,
68+
self.max_pool_2s_k2,
69+
nn.Dropout2d(0.25),
70+
self.flatten,
71+
fc1,
72+
self.relu,
73+
nn.Dropout2d(0.5),
74+
fc2,
75+
self.log_softmax,
76+
)
6277
elif compression == "Wavelet":
6378
assert wavelet is not None, "initial wavelet must be set."
64-
self.fc1 = WaveletLayer(
79+
self.wavelet = wavelet
80+
fc1 = WaveletLayer(
6581
init_wavelet=wavelet, scales=6, depth=800, p_drop=wave_dropout
6682
)
67-
self.fc2 = torch.nn.Linear(800, 10)
68-
self.do_dropout = False
83+
fc2 = torch.nn.Linear(800, 10)
84+
self.sequence = torch.nn.Sequential(
85+
self.conv1,
86+
self.max_pool_2s_k2,
87+
self.conv2,
88+
self.max_pool_2s_k2,
89+
self.flatten,
90+
fc1,
91+
self.relu,
92+
fc2,
93+
self.log_softmax,
94+
)
6995
else:
7096
raise ValueError(f"invalid compression: {compression}")
7197

72-
self.log_softmax = torch.nn.LogSoftmax(dim=1)
73-
self.max_pool_2s_k2 = torch.nn.MaxPool2d(2)
74-
self.flatten = torch.nn.Flatten(start_dim=1)
75-
self.relu = torch.nn.ReLU()
76-
77-
def forward(self, x):
78-
x = self.conv1(x)
79-
x = self.max_pool_2s_k2(x)
80-
x = self.conv2(x)
81-
x = self.max_pool_2s_k2(x)
82-
if self.do_dropout:
83-
x = self.dropout1(x)
84-
x = self.flatten(x)
85-
x = self.fc1(x)
86-
x = self.relu(x)
87-
if self.do_dropout:
88-
x = self.dropout2(x)
89-
x = self.fc2(x)
90-
x = self.log_softmax(x)
91-
return x
98+
def forward(self, x: torch.Tensor) -> torch.Tensor:
99+
return self.sequence(x)
92100

93101
def wavelet_loss(self):
94102
if self.wavelet is None:

0 commit comments

Comments
 (0)