@@ -14,8 +14,11 @@ def read_array_from_gputils_binary_file(path, dt=np.dtype('d')):
1414 nc = int .from_bytes (f .read (8 ), byteorder = 'little' , signed = False ) # read number of columns
1515 nm = int .from_bytes (f .read (8 ), byteorder = 'little' , signed = False ) # read number of matrices
1616 dat = np .fromfile (f , dtype = np .dtype (dt )) # read data
17- dat = dat .reshape ((nr , nc , nm )) # reshape
18- dat = np .dstack (np .split (dat .reshape (6 , - 1 ), 2 )) # I'll explain this to you when you grow up
17+
18+ if nm >= 2 : # if we actually have a 3D tensor (not a matrix or a vector)
19+ dat = dat .reshape ((nm , nc , nr )).swapaxes (0 , 2 ) # I'll explain this to you when you grow up
20+ else :
21+ dat = dat .reshape ((nr , nc , nm )) # reshape
1922 return dat
2023
2124
@@ -28,6 +31,7 @@ def write_array_to_gputils_binary_file(x, path):
2831 :raises ValueError: if `x` has more than 3 dimensions
2932 :raises ValueError: if the file name specified `path` does not have the .bt extension
3033 """
34+
3135 if not path .endswith (".bt" ):
3236 raise ValueError ("The file must have the .bt extension" )
3337 x_shape = x .shape
@@ -37,7 +41,10 @@ def write_array_to_gputils_binary_file(x, path):
3741 nr = x_shape [0 ]
3842 nc = x_shape [1 ] if x_dims >= 2 else 1
3943 nm = x_shape [2 ] if x_dims == 3 else 1
40- x = np .vstack (np .dsplit (x , 2 )).reshape (- 1 )
44+ if x_dims == 3 :
45+ x = x .swapaxes (0 , 2 ).reshape (- 1 ) # column-major storage; axis 2 last
46+ else :
47+ x = x .T .reshape (- 1 ) # column-major storage
4148 with open (path , 'wb' ) as f :
4249 f .write (nr .to_bytes (8 , 'little' )) # write number of rows
4350 f .write (nc .to_bytes (8 , 'little' )) # write number of columns
0 commit comments