-
Notifications
You must be signed in to change notification settings - Fork 931
Expand file tree
/
Copy pathutils.py
More file actions
20 lines (17 loc) · 814 Bytes
/
utils.py
File metadata and controls
20 lines (17 loc) · 814 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import collections
import logging
import os
import torch
def save_tensors(tensors, prefix, artifact_dir):
if isinstance(tensors, tuple):
for index, output in enumerate(tensors):
save_path = prefix + "_" + str(index) + ".bin"
output.detach().numpy().tofile(os.path.join(artifact_dir, save_path))
elif isinstance(tensors, torch.Tensor):
tensors.detach().numpy().tofile(os.path.join(artifact_dir, prefix + ".bin"))
elif isinstance(tensors, collections.OrderedDict):
for index, output in enumerate(tensors.values()):
save_path = prefix + "_" + str(index) + ".bin"
output.detach().numpy().tofile(os.path.join(artifact_dir, save_path))
else:
logging.warning("Unsupported type (", type(tensors), ") skip saving tensor. ")