|
16 | 16 | parser.add_argument('--testBatchSize', type=int, default=10, help='testing batch size') |
17 | 17 | parser.add_argument('--nEpochs', type=int, default=2, help='number of epochs to train for') |
18 | 18 | parser.add_argument('--lr', type=float, default=0.01, help='Learning Rate. Default=0.01') |
19 | | -parser.add_argument('--cuda', action='store_true', help='use cuda?') |
20 | | -parser.add_argument('--mps', action='store_true', default=False, help='enables macOS GPU training') |
| 19 | +parser.add_argument('--accel', action='store_true', help='Enables acceleration for training, if available') |
21 | 20 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use') |
22 | 21 | parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123') |
23 | 22 | opt = parser.parse_args() |
24 | 23 |
|
25 | 24 | print(opt) |
26 | 25 |
|
27 | | -if opt.cuda and not torch.cuda.is_available(): |
28 | | - raise Exception("No GPU found, please run without --cuda") |
29 | | -if not opt.mps and torch.backends.mps.is_available(): |
30 | | - raise Exception("Found mps device, please run with --mps to enable macOS GPU") |
31 | | - |
32 | 26 | torch.manual_seed(opt.seed) |
33 | | -use_mps = opt.mps and torch.backends.mps.is_available() |
34 | 27 |
|
35 | | -if opt.cuda: |
36 | | - device = torch.device("cuda") |
37 | | -elif use_mps: |
38 | | - device = torch.device("mps") |
| 28 | + |
| 29 | +if opt.accel and torch.accelerator.is_available(): |
| 30 | + device = torch.accelerator.current_accelerator() |
39 | 31 | else: |
40 | 32 | device = torch.device("cpu") |
41 | 33 |
|
|
0 commit comments