@@ -14,7 +14,11 @@ def read_array_from_gputils_binary_file(path, dt=np.dtype('d')):
1414 nc = int .from_bytes (f .read (8 ), byteorder = 'little' , signed = False ) # read number of columns
1515 nm = int .from_bytes (f .read (8 ), byteorder = 'little' , signed = False ) # read number of matrices
1616 dat = np .fromfile (f , dtype = np .dtype (dt )) # read data
17- dat = dat .reshape ((nr , nc , nm )) # reshape
17+
18+ if nm >= 2 : # if we actually have a 3D tensor (not a matrix or a vector)
19+ dat = dat .reshape ((nm , nc , nr )).swapaxes (0 , 2 ) # I'll explain this to you when you grow up
20+ else :
21+ dat = dat .reshape ((nr , nc , nm )) # reshape
1822 return dat
1923
2024
@@ -27,6 +31,7 @@ def write_array_to_gputils_binary_file(x, path):
2731 :raises ValueError: if `x` has more than 3 dimensions
2832 :raises ValueError: if the file name specified `path` does not have the .bt extension
2933 """
34+
3035 if not path .endswith (".bt" ):
3136 raise ValueError ("The file must have the .bt extension" )
3237 x_shape = x .shape
@@ -36,8 +41,12 @@ def write_array_to_gputils_binary_file(x, path):
3641 nr = x_shape [0 ]
3742 nc = x_shape [1 ] if x_dims >= 2 else 1
3843 nm = x_shape [2 ] if x_dims == 3 else 1
44+ if x_dims == 3 :
45+ x = x .swapaxes (0 , 2 ).reshape (- 1 ) # column-major storage; axis 2 last
46+ else :
47+ x = x .T .reshape (- 1 ) # column-major storage
3948 with open (path , 'wb' ) as f :
4049 f .write (nr .to_bytes (8 , 'little' )) # write number of rows
4150 f .write (nc .to_bytes (8 , 'little' )) # write number of columns
4251 f .write (nm .to_bytes (8 , 'little' )) # write number of matrices
43- x .reshape ( nr * nc * nm , 1 ). tofile (f ) # write data
52+ x .tofile (f ) # write data
0 commit comments