Skip to content

Commit 3b1ee9b

Browse files
committed
allegedly
1 parent 5257d5d commit 3b1ee9b

4 files changed

Lines changed: 33 additions & 5 deletions

File tree

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
<!-- ---------------------
9+
v1.7.1
10+
--------------------- -->
11+
## v1.7.1 - 4-12-2024
12+
13+
### Fixed
14+
15+
- Compatibility between Python and C++ in how the data is stored in bt files
16+
17+
818
<!-- ---------------------
919
v1.7.0
1020
--------------------- -->

python/gputils_api/gputils_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def read_array_from_gputils_binary_file(path, dt=np.dtype('d')):
1515
nm = int.from_bytes(f.read(8), byteorder='little', signed=False) # read number of matrices
1616
dat = np.fromfile(f, dtype=np.dtype(dt)) # read data
1717
dat = dat.reshape((nr, nc, nm)) # reshape
18+
dat = np.dstack(np.split(dat.reshape(6, -1), 2)) # I'll explain this to you when you grow up
1819
return dat
1920

2021

@@ -36,8 +37,9 @@ def write_array_to_gputils_binary_file(x, path):
3637
nr = x_shape[0]
3738
nc = x_shape[1] if x_dims >= 2 else 1
3839
nm = x_shape[2] if x_dims == 3 else 1
40+
x = np.vstack(np.dsplit(x, 2)).reshape(-1)
3941
with open(path, 'wb') as f:
4042
f.write(nr.to_bytes(8, 'little')) # write number of rows
4143
f.write(nc.to_bytes(8, 'little')) # write number of columns
4244
f.write(nm.to_bytes(8, 'little')) # write number of matrices
43-
x.reshape(nr*nc*nm, 1).tofile(f) # write data
45+
x.tofile(f) # write data

python/test/test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def setUpClass(cls):
3737

3838
a = np.linspace(-100, 100, 4 * 5).reshape((4, 5)).astype('d')
3939
gpuapi.write_array_to_gputils_binary_file(a, os.path.join(base_dir, 'a_d.bt'))
40+
4041
gpuapi.write_array_to_gputils_binary_file(cls._B, os.path.join(base_dir, 'b_d.bt'))
4142

4243
def __test_read_eye(self, dt):

test/testTensor.cu

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,25 @@ TEST_F(TensorTest, parseTensorFromFileBinary) {
182182

183183
TEST_F(TensorTest, parseTensorFromBinaryPython) {
184184
std::string fName = "../../python/b_d.bt";
185-
DTensor<double> b = DTensor<double>::parseFromFile(fName);
186-
std::vector<double> vb(12);
187-
b.download(vb);
188-
for (size_t i = 0; i < 12; i++) EXPECT_NEAR(i + 1., vb[i], PRECISION_HIGH);
185+
DTensor<double> b = DTensor<double>::parseFromFile(fName, rowMajor);
186+
for (size_t i=0; i<3; i++) {
187+
for (size_t j=0; j<3; j++) {
188+
EXPECT_NEAR(1 + 2*j + 6*i, b(i, j, 0), PRECISION_HIGH);
189+
EXPECT_NEAR(2 + 2*j + 6*i, b(i, j, 1), PRECISION_HIGH);
190+
}
191+
}
192+
}
193+
194+
195+
/* ---------------------------------------
196+
* Parse not existing file
197+
* --------------------------------------- */
198+
199+
TEST_F(TensorTest, parseTensorFromNonexistentFile) {
200+
std::string fName = "../../python/whatever.bt";
201+
EXPECT_THROW(DTensor<double> b = DTensor<double>::parseFromFile(fName, rowMajor), std::invalid_argument);
202+
std::string fName2 = "../../python/whatever.txt";
203+
EXPECT_THROW(DTensor<double> b = DTensor<double>::parseFromFile(fName2, rowMajor), std::invalid_argument);
189204
}
190205

191206

0 commit comments

Comments
 (0)