From b4dea6c1333c929d54ad9faa36197023c8b95aa9 Mon Sep 17 00:00:00 2001 From: Aayush Pratap Singh <141538111+Aayushongit@users.noreply.github.com> Date: Sun, 27 Apr 2025 21:13:10 +0530 Subject: [PATCH] Update test_vae.py --- tests/test_vae.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/test_vae.py b/tests/test_vae.py index beb30b44..31d21a26 100644 --- a/tests/test_vae.py +++ b/tests/test_vae.py @@ -3,9 +3,7 @@ from models import VanillaVAE from torchsummary import summary - class TestVAE(unittest.TestCase): - def setUp(self) -> None: # self.model2 = VAE(3, 10) self.model = VanillaVAE(3, 10) @@ -22,11 +20,15 @@ def test_forward(self): def test_loss(self): x = torch.randn(16, 3, 64, 64) - result = self.model(x) loss = self.model.loss_function(*result, M_N = 0.005) print(loss) - + + def test_reconstruction(self): + x = torch.randn(1, 3, 64, 64) + reconstructed, _, _ = self.model(x) + mse = torch.nn.functional.mse_loss(reconstructed, x) + print(f"Reconstruction MSE: {mse.item()}") if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main()