77from train_utils import set_seed
88import torch .nn .functional as F
99from torch .cuda .amp import GradScaler , autocast
10+ from test_utils import skip_if_no_gpu
1011
1112
1213# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss -v
@@ -17,6 +18,7 @@ def setUp(self):
1718 self .loss_fn = QuaidLogDetMILoss (config )
1819
1920 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_log_det_happy -v
21+ @skip_if_no_gpu
2022 def test_log_det_happy (self ):
2123 # Create a tensor representing positive infinity
2224 matrix = torch .tensor (
@@ -41,6 +43,7 @@ def test_log_det_happy(self):
4143 ], matrix .grad .tolist ()
4244
4345 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_log_det_singular -v
46+ @skip_if_no_gpu
4447 def test_log_det_singular (self ):
4548 # Create a tensor representing positive infinity
4649 matrix = torch .tensor (
@@ -65,6 +68,7 @@ def test_log_det_singular(self):
6568 ], matrix .grad .tolist ()
6669
6770 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_log_det_weird -v
71+ @skip_if_no_gpu
6872 def test_log_det_weird (self ):
6973 # Create a tensor representing positive infinity
7074 matrix = torch .tensor (
@@ -95,6 +99,7 @@ def test_log_det_weird(self):
9599 ], matrix .grad .tolist ()
96100
97101 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_safe_pinverse_happy -v
102+ @skip_if_no_gpu
98103 def test_safe_pinverse_happy (self ):
99104 # Create a tensor representing positive infinity
100105 matrix = torch .tensor (
@@ -123,6 +128,7 @@ def test_safe_pinverse_happy(self):
123128 ], matrix .grad .tolist ()
124129
125130 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_safe_pinverse_singular -v
131+ @skip_if_no_gpu
126132 def test_safe_pinverse_singular (self ):
127133 # Create a tensor representing positive infinity
128134 matrix = torch .tensor (
@@ -151,6 +157,7 @@ def test_safe_pinverse_singular(self):
151157 ], matrix .grad .tolist ()
152158
153159 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_safe_pinverse_weird -v
160+ @skip_if_no_gpu
154161 def test_safe_pinverse_weird (self ):
155162 # Create a tensor representing positive infinity
156163 matrix = torch .tensor (
@@ -194,6 +201,7 @@ def test_safe_pinverse_weird(self):
194201 ], matrix .grad .tolist ()
195202
196203 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_theoretical_lower_bound -v
204+ @skip_if_no_gpu
197205 def test_theoretical_lower_bound (self ):
198206 # Construct vectors that should ideally minimize mutual information
199207 original_a = torch .tensor (
@@ -220,6 +228,7 @@ def test_theoretical_lower_bound(self):
220228 loss .backward ()
221229
222230 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_theoretical_upper_bound -v
231+ @skip_if_no_gpu
223232 def test_theoretical_upper_bound (self ):
224233 original_a = torch .tensor (
225234 [[[1.0 , 0.0 ], [- 1.0 , 0.0 ]]], requires_grad = True , device = "cuda:0"
@@ -246,6 +255,7 @@ def test_theoretical_upper_bound(self):
246255 loss .backward ()
247256
248257 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_dimension_mismatch -v
258+ @skip_if_no_gpu
249259 def test_dimension_mismatch (self ):
250260 a = torch .tensor (
251261 [[[1.0 , 0.0 , 0.0 ], [1.0 , 0.0 , 0.0 ]], [[0.0 , 0.0 , 1.0 ], [0.0 , 0.0 , 1.0 ]]],
@@ -266,6 +276,7 @@ def test_dimension_mismatch(self):
266276 loss .backward ()
267277
268278 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_submodularity -v
279+ @skip_if_no_gpu
269280 def test_submodularity (self ):
270281 # q = [0.7071, 0.7071, 0.0000] # query
271282 a = [1.0000 , 0.0000 , 0.0000 ] # 0 # partial match
@@ -297,6 +308,7 @@ def test_submodularity(self):
297308 assert best_candidate == e , best_candidate
298309
299310 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_submodularity_with_arbitary_order -v
311+ @skip_if_no_gpu
300312 def test_submodularity_with_arbitary_order (self ):
301313 # q = [0.7071, 0.7071, 0.0000] # query
302314 a = [1.0000 , 0.0000 , 0.0000 ] # 0 # partial match
@@ -330,6 +342,7 @@ def test_submodularity_with_arbitary_order(self):
330342 assert best_candidate == e , best_candidate
331343
332344 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_overfit -v
345+ @skip_if_no_gpu
333346 def test_overfit (self ):
334347 set_seed (42 )
335348
@@ -416,6 +429,7 @@ def test_overfit(self):
416429 # assert mse_loss.item() <= 1.5, mse_loss.item()
417430
418431 # python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_overfit_amp -v
432+ @skip_if_no_gpu
419433 def test_overfit_amp (self ):
420434 set_seed (42 )
421435
0 commit comments