Skip to content

Commit 54531c8

Browse files
Update run_grouped_ablation.py
1 parent 0e0762f commit 54531c8

1 file changed

Lines changed: 8 additions & 0 deletions

File tree

ablation_studies/run_grouped_ablation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,18 @@ def main():
155155
spikes_list = []
156156

157157
import concurrent.futures
158+
import torch
158159

159160
# Adjust range to respect start_seed and num_seeds
160161
seeds_to_run = list(range(args.start_seed, args.start_seed + args.num_seeds))
161162

163+
# Auto-detect single GPU environment (e.g., Colab T4) and force serial execution
164+
# to prevent OOM/Thrashing when trying to run multiple training jobs on one GPU.
165+
if args.max_workers > 1:
166+
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
167+
print(f"{Colors.WARNING}Single GPU detected. Forcing max_workers=1 to prevent crashes.{Colors.ENDC}")
168+
args.max_workers = 1
169+
162170
print(f"Starting {len(seeds_to_run)} runs with {args.max_workers} workers...\n")
163171

164172
with concurrent.futures.ProcessPoolExecutor(max_workers=args.max_workers) as executor:

0 commit comments

Comments
 (0)