Skip to content

Commit d75cfdf

Browse files
committed
it works and dont ask why
1 parent 3b1ee9b commit d75cfdf

1 file changed

Lines changed: 10 additions & 3 deletions

File tree

python/gputils_api/gputils_api.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ def read_array_from_gputils_binary_file(path, dt=np.dtype('d')):
1414
nc = int.from_bytes(f.read(8), byteorder='little', signed=False) # read number of columns
1515
nm = int.from_bytes(f.read(8), byteorder='little', signed=False) # read number of matrices
1616
dat = np.fromfile(f, dtype=np.dtype(dt)) # read data
17-
dat = dat.reshape((nr, nc, nm)) # reshape
18-
dat = np.dstack(np.split(dat.reshape(6, -1), 2)) # I'll explain this to you when you grow up
17+
18+
if nm >= 2: # if we actually have a 3D tensor (not a matrix or a vector)
19+
dat = dat.reshape((nm, nc, nr)).swapaxes(0, 2) # I'll explain this to you when you grow up
20+
else:
21+
dat = dat.reshape((nr, nc, nm)) # reshape
1922
return dat
2023

2124

@@ -28,6 +31,7 @@ def write_array_to_gputils_binary_file(x, path):
2831
:raises ValueError: if `x` has more than 3 dimensions
2932
:raises ValueError: if the file name specified `path` does not have the .bt extension
3033
"""
34+
3135
if not path.endswith(".bt"):
3236
raise ValueError("The file must have the .bt extension")
3337
x_shape = x.shape
@@ -37,7 +41,10 @@ def write_array_to_gputils_binary_file(x, path):
3741
nr = x_shape[0]
3842
nc = x_shape[1] if x_dims >= 2 else 1
3943
nm = x_shape[2] if x_dims == 3 else 1
40-
x = np.vstack(np.dsplit(x, 2)).reshape(-1)
44+
if x_dims == 3:
45+
x = x.swapaxes(0, 2).reshape(-1) # column-major storage; axis 2 last
46+
else:
47+
x = x.T.reshape(-1) # column-major storage
4148
with open(path, 'wb') as f:
4249
f.write(nr.to_bytes(8, 'little')) # write number of rows
4350
f.write(nc.to_bytes(8, 'little')) # write number of columns

0 commit comments

Comments
 (0)