Skip to content

Latest commit

 

History

History
60 lines (46 loc) · 3.25 KB

File metadata and controls

60 lines (46 loc) · 3.25 KB

2.7 序列化张量

动态创建张量是很不错的,但是如果其中的数据对你来说具有价值,那么你可能希望将其保存到文件中并在某个时候加载回去。毕竟你可不想每次开始运行程序时都从头开始重新训练模型!PyTorch内部使用pickle来序列化张量对象和实现用于存储的专用序列化代码。下面展示怎样将points张量保存到ourpoints.t文件中:

torch.save(points, '../../data/chapter2/ourpoints.t')

或者,你也可以传递文件描述符代替文件名:

with open('../../data/chapter2/ourpoints.t','wb') as f:
    torch.save(points, f)

points加载回来也是一行类似代码:

points = torch.load('../../data/chapter2/ourpoints.t')

等价于

with open('../../data/chapter2/ourpoints.t','rb') as f:
    points = torch.load(f)

如果只想通过PyTorch加载张量,则上述例子可让你快速保存张量,但这个文件格式本身是不互通(interoperable)的,你无法使用除PyTorch外其他软件读取它。根据实际使用情况,上述情况可能问题不大,但应该学习一下如何在有的时候(即想用其他软件读取的时候)互通地保存张量。尽管实际情况都是独一无二的,但当你想将PyTorch引入已经依赖于不同库的现有系统中时,上述情况会很常见;而全新的项目可能不需要经常互通地保存张量。

对于需要(互通)的情况,你可以使用HDF5格式和库。HDF5是一种可移植的、广泛支持的格式,用于表示以嵌套键值字典形式组织的序列化多维数组。Python通过h5py库支持HDF5,该库以NumPy数组的形式接收和返回数据。

你可以使用以下命令安装h5py

conda install h5py

此时,你可以通过将points张量转换为NumPy数组(如前所述,此操作几乎零花费)并将其传递给create_dataset函数来保存points张量:

import h5py

f = h5py.File('../../data/chapter2/ourpoints.hdf5', 'w')
dset = f.create_dataset('coords', data=points.numpy())
f.close()

这里,coords是传入HDF5文件的键值。你还可以有其他键值,甚至是嵌套键值。HDF5中的一件有趣的事情是,你可以索引在磁盘的数据并且仅访问你感兴趣的元素。例如你只想加载数据集中的最后两个点数据:

f = h5py.File('../../data/chapter2/ourpoints.hdf5', 'r')
dset = f['coords']
last_points = dset[1:]

上例中,当你打开文件或需要数据集时并未加载数据。相反,数据一直保留在磁盘上,直到你请求数据集中的第二行和最后一行。此时,h5py才访问这两行并返回了一个包含你想要数据的类似NumPy数组的对象,该对象的行为类似于NumPy数组,并且具有相同的API。

基于这个事实,你可以将返回的对象传递给torch.from_numpy函数以直接获取张量。需要注意的是,在这种情况下,数据将复制到张量存储中:

last_points = torch.from_numpy(dset[1:])
f.close()
# last_points = torch.from_numpy(dset[1:]) # 会报错, 因为f已经关了

完成数据加载后,必须关闭文件。