Skip to content

Commit e55586b

Browse files
committed
Update mnist_compression.py
1 parent d5dbb5d commit e55586b

1 file changed

Lines changed: 2 additions & 4 deletions

File tree

examples/network_compression/mnist_compression.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
100100

101101
def wavelet_loss(self):
102102
if self.wavelet is None:
103-
raise ValueError
104-
acl, _, _ = self.wavelet.alias_cancellation_loss()
105-
prl, _, _ = self.wavelet.perfect_reconstruction_loss()
106-
return acl + prl
103+
return torch.tensor(0.0)
104+
return self.wavelet.wavelet_loss()
107105

108106

109107
def train(args, model, device, train_loader, optimizer, epoch):

0 commit comments

Comments
 (0)