Skip to content

Commit da30c57

Browse files
committed
Makes sure to load checkpoints onto cpu memory in rs export
1 parent f67f3a7 commit da30c57

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

robosat/tools/export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ def main(args):
2727
num_classes = len(dataset["common"]["classes"])
2828
net = UNet(num_classes)
2929

30-
chkpt = torch.load(args.checkpoint, map_location="cpu")
30+
def map_location(storage, _):
31+
return storage.cpu()
32+
33+
chkpt = torch.load(args.checkpoint, map_location=map_location)
3134
net = torch.nn.DataParallel(net)
3235
net.load_state_dict(chkpt)
3336

0 commit comments

Comments
 (0)