@@ -52,43 +52,51 @@ def __init__(
5252 super ().__init__ ()
5353 self .conv1 = nn .Conv2d (1 , 20 , 5 , 1 )
5454 self .conv2 = nn .Conv2d (20 , 50 , 5 , 1 )
55- self .dropout1 = nn .Dropout2d (0.25 )
56- self .dropout2 = nn .Dropout2d (0.5 )
57- self .wavelet = wavelet
55+ self .max_pool_2s_k2 = torch .nn .MaxPool2d (2 )
56+
57+ self .log_softmax = torch .nn .LogSoftmax (dim = 1 )
58+ self .flatten = torch .nn .Flatten (start_dim = 1 )
59+ self .relu = torch .nn .ReLU ()
60+
5861 if compression == "None" :
59- self .fc1 = torch .nn .Linear (4 * 4 * 50 , 500 )
60- self .fc2 = torch .nn .Linear (500 , 10 )
61- self .do_dropout = True
62+ fc1 = torch .nn .Linear (4 * 4 * 50 , 500 )
63+ fc2 = torch .nn .Linear (500 , 10 )
64+ self .sequence = torch .nn .Sequential (
65+ self .conv1 ,
66+ self .max_pool_2s_k2 ,
67+ self .conv2 ,
68+ self .max_pool_2s_k2 ,
69+ nn .Dropout2d (0.25 ),
70+ self .flatten ,
71+ fc1 ,
72+ self .relu ,
73+ nn .Dropout2d (0.5 ),
74+ fc2 ,
75+ self .log_softmax ,
76+ )
6277 elif compression == "Wavelet" :
6378 assert wavelet is not None , "initial wavelet must be set."
64- self .fc1 = WaveletLayer (
79+ self .wavelet = wavelet
80+ fc1 = WaveletLayer (
6581 init_wavelet = wavelet , scales = 6 , depth = 800 , p_drop = wave_dropout
6682 )
67- self .fc2 = torch .nn .Linear (800 , 10 )
68- self .do_dropout = False
83+ fc2 = torch .nn .Linear (800 , 10 )
84+ self .sequence = torch .nn .Sequential (
85+ self .conv1 ,
86+ self .max_pool_2s_k2 ,
87+ self .conv2 ,
88+ self .max_pool_2s_k2 ,
89+ self .flatten ,
90+ fc1 ,
91+ self .relu ,
92+ fc2 ,
93+ self .log_softmax ,
94+ )
6995 else :
7096 raise ValueError (f"invalid compression: { compression } " )
7197
72- self .log_softmax = torch .nn .LogSoftmax (dim = 1 )
73- self .max_pool_2s_k2 = torch .nn .MaxPool2d (2 )
74- self .flatten = torch .nn .Flatten (start_dim = 1 )
75- self .relu = torch .nn .ReLU ()
76-
77- def forward (self , x ):
78- x = self .conv1 (x )
79- x = self .max_pool_2s_k2 (x )
80- x = self .conv2 (x )
81- x = self .max_pool_2s_k2 (x )
82- if self .do_dropout :
83- x = self .dropout1 (x )
84- x = self .flatten (x )
85- x = self .fc1 (x )
86- x = self .relu (x )
87- if self .do_dropout :
88- x = self .dropout2 (x )
89- x = self .fc2 (x )
90- x = self .log_softmax (x )
91- return x
98+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
99+ return self .sequence (x )
92100
93101 def wavelet_loss (self ):
94102 if self .wavelet is None :
0 commit comments