Skip to content

Commit a39f532

Browse files
CPU fallback if possible
1 parent bb4284e commit a39f532

7 files changed

Lines changed: 96 additions & 22 deletions

File tree

src/config.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pydantic import BaseModel, model_validator, validator
1313
from pathlib import Path
1414
from typing import List
15+
import torch
1516

1617

1718
def type_validator(lut):
@@ -191,6 +192,12 @@ def _load_from_file(filepath):
191192

192193
@staticmethod
193194
def _create_from_dict(config_data, filepath=None):
195+
# Apply CPU fallback for device settings if CUDA is not available
196+
if not torch.cuda.is_available():
197+
print("CUDA not available, falling back to CPU for all devices")
198+
# Recursively replace any cuda device settings with cpu
199+
Config._apply_cpu_fallback(config_data)
200+
194201
# Instantiate RootConfig with the loaded data
195202
root_config = RootConfig(**config_data)
196203

@@ -203,3 +210,22 @@ def _create_from_dict(config_data, filepath=None):
203210
)
204211

205212
return root_config
213+
214+
@staticmethod
215+
def _apply_cpu_fallback(config_data):
216+
"""Recursively replace CUDA device settings with CPU when CUDA is not available"""
217+
if isinstance(config_data, dict):
218+
for key, value in config_data.items():
219+
if (
220+
key == "device"
221+
and isinstance(value, str)
222+
and "cuda" in value.lower()
223+
):
224+
config_data[key] = "cpu"
225+
print(f" {key}: {value} -> cpu")
226+
elif isinstance(value, (dict, list)):
227+
Config._apply_cpu_fallback(value)
228+
elif isinstance(config_data, list):
229+
for item in config_data:
230+
if isinstance(item, (dict, list)):
231+
Config._apply_cpu_fallback(item)

src/losses/quaild_facility_location_loss_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,22 +91,22 @@ def test_submodularity(self):
9191

9292
best_diversity, best_candidate = pick_most_diverse([b], [a, c, d, e])
9393

94-
assert best_diversity == 1.0, best_diversity
94+
self.assertAlmostEqual(best_diversity, 1.0, places=4)
9595
assert best_candidate == e, best_candidate
9696

9797
best_diversity, best_candidate = pick_most_diverse([b, e], [a, d, c])
9898

99-
assert best_diversity == 0.5, best_diversity
99+
self.assertAlmostEqual(best_diversity, 0.5, places=4)
100100
assert best_candidate == c, best_candidate
101101

102102
best_diversity, best_candidate = pick_most_diverse([b, c, e], [a, d])
103103

104-
assert best_diversity == 0.2642977237701416, best_diversity
104+
self.assertAlmostEqual(best_diversity, 0.2642977237701416, places=4)
105105
assert best_candidate == a, best_candidate
106106

107107
best_diversity, best_candidate = pick_most_diverse([a, b, c, e], [d])
108108

109-
assert best_diversity == 0.13720381259918213, best_diversity
109+
self.assertAlmostEqual(best_diversity, 0.13720381259918213, places=4)
110110
assert best_candidate == d, best_candidate
111111

112112
# python -m unittest losses.quaild_facility_location_loss_test.TestQuaildFacilityLocation.test_submodularity_with_arbitary_order -v
@@ -124,22 +124,22 @@ def test_submodularity_with_arbitary_order(self):
124124

125125
best_diversity, best_candidate = pick_most_diverse([b], [a, c, d, e])
126126

127-
assert best_diversity == 1.0, best_diversity
127+
self.assertAlmostEqual(best_diversity, 1.0, places=4)
128128
assert best_candidate == e, best_candidate
129129

130130
best_diversity, best_candidate = pick_most_diverse([a, b], [c, d, e])
131131

132-
assert best_diversity == 0.8779611587524414, best_diversity
132+
self.assertAlmostEqual(best_diversity, 0.8779611587524414, places=4)
133133
assert best_candidate == e, best_candidate
134134

135135
best_diversity, best_candidate = pick_most_diverse([a, b, c], [d, e])
136136

137-
assert best_diversity == 0.5948392152786255, best_diversity
137+
self.assertAlmostEqual(best_diversity, 0.5948392152786255, places=4)
138138
assert best_candidate == e, best_candidate
139139

140140
best_diversity, best_candidate = pick_most_diverse([a, b, c, d], [e])
141141

142-
assert best_diversity == 0.6127961277961731, best_diversity
142+
self.assertAlmostEqual(best_diversity, 0.6127961277961731, places=4)
143143
assert best_candidate == e, best_candidate
144144

145145
# python -m unittest losses.quaild_facility_location_loss_test.TestQuaildFacilityLocation.test_overfit -v

src/losses/quaild_log_det_mi_loss_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from train_utils import set_seed
88
import torch.nn.functional as F
99
from torch.cuda.amp import GradScaler, autocast
10+
from test_utils import skip_if_no_gpu
1011

1112

1213
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss -v
@@ -17,6 +18,7 @@ def setUp(self):
1718
self.loss_fn = QuaidLogDetMILoss(config)
1819

1920
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_log_det_happy -v
21+
@skip_if_no_gpu
2022
def test_log_det_happy(self):
2123
# Create a tensor representing positive infinity
2224
matrix = torch.tensor(
@@ -41,6 +43,7 @@ def test_log_det_happy(self):
4143
], matrix.grad.tolist()
4244

4345
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_log_det_singular -v
46+
@skip_if_no_gpu
4447
def test_log_det_singular(self):
4548
# Create a tensor representing positive infinity
4649
matrix = torch.tensor(
@@ -65,6 +68,7 @@ def test_log_det_singular(self):
6568
], matrix.grad.tolist()
6669

6770
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_log_det_weird -v
71+
@skip_if_no_gpu
6872
def test_log_det_weird(self):
6973
# Create a tensor representing positive infinity
7074
matrix = torch.tensor(
@@ -95,6 +99,7 @@ def test_log_det_weird(self):
9599
], matrix.grad.tolist()
96100

97101
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_safe_pinverse_happy -v
102+
@skip_if_no_gpu
98103
def test_safe_pinverse_happy(self):
99104
# Create a tensor representing positive infinity
100105
matrix = torch.tensor(
@@ -123,6 +128,7 @@ def test_safe_pinverse_happy(self):
123128
], matrix.grad.tolist()
124129

125130
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_safe_pinverse_singular -v
131+
@skip_if_no_gpu
126132
def test_safe_pinverse_singular(self):
127133
# Create a tensor representing positive infinity
128134
matrix = torch.tensor(
@@ -151,6 +157,7 @@ def test_safe_pinverse_singular(self):
151157
], matrix.grad.tolist()
152158

153159
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_safe_pinverse_weird -v
160+
@skip_if_no_gpu
154161
def test_safe_pinverse_weird(self):
155162
# Create a tensor representing positive infinity
156163
matrix = torch.tensor(
@@ -194,6 +201,7 @@ def test_safe_pinverse_weird(self):
194201
], matrix.grad.tolist()
195202

196203
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_theoretical_lower_bound -v
204+
@skip_if_no_gpu
197205
def test_theoretical_lower_bound(self):
198206
# Construct vectors that should ideally minimize mutual information
199207
original_a = torch.tensor(
@@ -220,6 +228,7 @@ def test_theoretical_lower_bound(self):
220228
loss.backward()
221229

222230
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_theoretical_upper_bound -v
231+
@skip_if_no_gpu
223232
def test_theoretical_upper_bound(self):
224233
original_a = torch.tensor(
225234
[[[1.0, 0.0], [-1.0, 0.0]]], requires_grad=True, device="cuda:0"
@@ -246,6 +255,7 @@ def test_theoretical_upper_bound(self):
246255
loss.backward()
247256

248257
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_dimension_mismatch -v
258+
@skip_if_no_gpu
249259
def test_dimension_mismatch(self):
250260
a = torch.tensor(
251261
[[[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]], [[0.0, 0.0, 1.0], [0.0, 0.0, 1.0]]],
@@ -266,6 +276,7 @@ def test_dimension_mismatch(self):
266276
loss.backward()
267277

268278
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_submodularity -v
279+
@skip_if_no_gpu
269280
def test_submodularity(self):
270281
# q = [0.7071, 0.7071, 0.0000] # query
271282
a = [1.0000, 0.0000, 0.0000] # 0 # partial match
@@ -297,6 +308,7 @@ def test_submodularity(self):
297308
assert best_candidate == e, best_candidate
298309

299310
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_submodularity_with_arbitary_order -v
311+
@skip_if_no_gpu
300312
def test_submodularity_with_arbitary_order(self):
301313
# q = [0.7071, 0.7071, 0.0000] # query
302314
a = [1.0000, 0.0000, 0.0000] # 0 # partial match
@@ -330,6 +342,7 @@ def test_submodularity_with_arbitary_order(self):
330342
assert best_candidate == e, best_candidate
331343

332344
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_overfit -v
345+
@skip_if_no_gpu
333346
def test_overfit(self):
334347
set_seed(42)
335348

@@ -416,6 +429,7 @@ def test_overfit(self):
416429
# assert mse_loss.item() <= 1.5, mse_loss.item()
417430

418431
# python -m unittest losses.quaild_log_det_mi_loss_test.TestQuaidLogDetMILoss.test_overfit_amp -v
432+
@skip_if_no_gpu
419433
def test_overfit_amp(self):
420434
set_seed(42)
421435

src/offline_eval_pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,10 @@ def cleanup(self):
4545
gc.collect()
4646

4747
# Hopefully fix oom
48-
torch.cuda.synchronize()
49-
torch.cuda.empty_cache()
50-
torch.cuda.synchronize()
48+
if torch.cuda.is_available():
49+
torch.cuda.synchronize()
50+
torch.cuda.empty_cache()
51+
torch.cuda.synchronize()
5152

5253
def is_done(self):
5354
if self.current_dataset_name is None:

src/subset_selection_strategies/quaild_submodular_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from subset_selection_strategies.quaild_submodular import QuaildSubmodularStrategy
55
import unittest
66
import torch.nn.functional as F
7+
from test_utils import skip_if_no_gpu
78

89

910
# python -m unittest subset_selection_strategies.quaild_submodular_test.TestQuaildSubmodularStrategy -v
@@ -368,6 +369,7 @@ def test_subset_select_with_similarity_many_fl(self):
368369
assert scores.tolist() == expected_scores, scores.tolist()
369370

370371
# python -m unittest subset_selection_strategies.quaild_submodular_test.TestQuaildSubmodularStrategy.test_subset_select_many_ld -v
372+
@skip_if_no_gpu
371373
def test_subset_select_many_ld(self):
372374
config = Config.from_file("experiments/tests/quaild_test_experiment.yaml")
373375
config.architecture.semantic_search_model.type = "noop"
@@ -426,6 +428,7 @@ def test_subset_select_many_ld(self):
426428
assert scores.tolist() == expected_scores, scores.tolist()
427429

428430
# python -m unittest subset_selection_strategies.quaild_submodular_test.TestQuaildSubmodularStrategy.test_subset_select_with_similarity_many_ld -v
431+
@skip_if_no_gpu
429432
def test_subset_select_with_similarity_many_ld(self):
430433
config = Config.from_file("experiments/tests/quaild_test_experiment.yaml")
431434
config.architecture.semantic_search_model.type = "noop"

src/test_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
3+
4+
def skip_if_no_gpu(test_func):
5+
"""Decorator to skip test if GPU is not available"""
6+
7+
def wrapper(self):
8+
if not torch.cuda.is_available():
9+
self.skipTest("GPU not available")
10+
return test_func(self)
11+
12+
return wrapper

src/training_pipeline.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,14 @@
1919
)
2020
from training_strategies import TRAINING_STRATEGIES_LUT
2121
from config import RootConfig
22-
from torch.cuda.amp import GradScaler
22+
23+
try:
24+
from torch.cuda.amp import GradScaler
25+
26+
CUDA_AVAILABLE = torch.cuda.is_available()
27+
except ImportError:
28+
GradScaler = None
29+
CUDA_AVAILABLE = False
2330
import torch.optim as optim
2431
from tqdm import tqdm
2532

@@ -107,7 +114,7 @@ def _load_parts(self, config: RootConfig):
107114

108115
# Optimizer
109116
print("Preparing optimizer")
110-
self.scaler = GradScaler()
117+
self.scaler = GradScaler() if CUDA_AVAILABLE and GradScaler else None
111118
self.optimizer = optim.AdamW(
112119
self.semantic_search_model.get_all_trainable_parameters(),
113120
lr=config.training.learning_rate,
@@ -171,17 +178,23 @@ def train_one_epoch(self):
171178
self.optimizer.zero_grad()
172179

173180
# Automatic Mixed Precision
174-
with torch.cuda.amp.autocast():
181+
if CUDA_AVAILABLE:
182+
with torch.cuda.amp.autocast():
183+
loss = self.training_strategy.train_step(batch)
184+
else:
175185
loss = self.training_strategy.train_step(batch)
176186

177-
# Bad batch
178-
if check_for_nan_then_dump(loss, batch):
179-
continue
187+
# Bad batch
188+
if check_for_nan_then_dump(loss, batch):
189+
continue
180190

181-
pbar.set_description(f"Loss: {round(loss.item()*10000)/10000}")
191+
pbar.set_description(f"Loss: {round(loss.item()*10000)/10000}")
182192

183193
# Scales loss. Calls backward() on scaled loss to create scaled gradients.
184-
self.scaler.scale(loss).backward()
194+
if self.scaler:
195+
self.scaler.scale(loss).backward()
196+
else:
197+
loss.backward()
185198

186199
extra_metrics = {}
187200
if self.current_step % 100 == 0:
@@ -206,15 +219,20 @@ def train_one_epoch(self):
206219
)
207220

208221
# Unscales gradients and calls or skips optimizer.step()
209-
self.scaler.step(self.optimizer)
222+
if self.scaler:
223+
self.scaler.step(self.optimizer)
224+
else:
225+
self.optimizer.step()
210226
self.lr_scheduler.step()
211227

212228
# Updates the scale for next iteration
213-
self.scaler.update()
229+
if self.scaler:
230+
self.scaler.update()
214231
except Exception as e:
215232
traceback.print_exc()
216233
print("[train_one_epoch]", e)
217-
torch.cuda.empty_cache()
234+
if torch.cuda.is_available():
235+
torch.cuda.empty_cache()
218236

219237
self.current_step += 1
220238

0 commit comments

Comments
 (0)