Skip to content

Commit d3b54ce

Browse files
fix ASAN problems with unit tests (#1554)
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
1 parent c21d72a commit d3b54ce

2 files changed

Lines changed: 40 additions & 14 deletions

File tree

tests/ut/test_get_vector.cc

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// or implied. See the License for the specific language governing permissions and limitations under the License.
1111

1212
#include <future>
13+
#include <stdexcept>
1314

1415
#include "catch2/catch_approx.hpp"
1516
#include "catch2/catch_test_macros.hpp"
@@ -84,18 +85,26 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
8485

8586
auto retrieve_task = [&]() {
8687
auto results = idx_new.GetVectorByIds(ids_ds);
87-
REQUIRE(results.has_value());
88+
if (!results.has_value()) {
89+
throw std::runtime_error("GetVectorByIds returned no value");
90+
}
8891
auto xb = (uint8_t*)train_ds->GetTensor();
8992
auto res_rows = results.value()->GetRows();
9093
auto res_dim = results.value()->GetDim();
9194
auto res_data = (uint8_t*)results.value()->GetTensor();
92-
REQUIRE(res_rows == nq);
93-
REQUIRE(res_dim == dim);
95+
if (res_rows != nq) {
96+
throw std::runtime_error("res_rows mismatch: " + std::to_string(res_rows));
97+
}
98+
if (res_dim != dim) {
99+
throw std::runtime_error("res_dim mismatch: " + std::to_string(res_dim));
100+
}
94101
const auto data_bytes = dim / 8;
95102
for (int i = 0; i < nq; ++i) {
96103
auto id = ids_ds->GetIds()[i];
97104
for (int j = 0; j < data_bytes; ++j) {
98-
REQUIRE(res_data[i * data_bytes + j] == xb[id * data_bytes + j]);
105+
if (res_data[i * data_bytes + j] != xb[id * data_bytes + j]) {
106+
throw std::runtime_error("data mismatch at i=" + std::to_string(i) + " j=" + std::to_string(j));
107+
}
99108
}
100109
}
101110
};
@@ -105,7 +114,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
105114
retrieve_task_list.push_back(std::async(std::launch::async, [&] { return retrieve_task(); }));
106115
}
107116
for (auto& task : retrieve_task_list) {
108-
task.wait();
117+
REQUIRE_NOTHROW(task.get());
109118
}
110119
};
111120
}
@@ -218,17 +227,25 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") {
218227

219228
auto retrieve_task = [&]() {
220229
auto results = idx_new.GetVectorByIds(ids_ds);
221-
REQUIRE(results.has_value());
230+
if (!results.has_value()) {
231+
throw std::runtime_error("GetVectorByIds returned no value");
232+
}
222233
auto xb = (float*)train_ds_copy->GetTensor();
223234
auto res_rows = results.value()->GetRows();
224235
auto res_dim = results.value()->GetDim();
225236
auto res_data = (float*)results.value()->GetTensor();
226-
REQUIRE(res_rows == nq);
227-
REQUIRE(res_dim == dim);
237+
if (res_rows != nq) {
238+
throw std::runtime_error("res_rows mismatch: " + std::to_string(res_rows));
239+
}
240+
if (res_dim != dim) {
241+
throw std::runtime_error("res_dim mismatch: " + std::to_string(res_dim));
242+
}
228243
for (int i = 0; i < nq; ++i) {
229244
const auto id = ids_ds->GetIds()[i];
230245
for (int j = 0; j < dim; ++j) {
231-
REQUIRE(res_data[i * dim + j] == xb[id * dim + j]);
246+
if (res_data[i * dim + j] != xb[id * dim + j]) {
247+
throw std::runtime_error("data mismatch at i=" + std::to_string(i) + " j=" + std::to_string(j));
248+
}
232249
}
233250
}
234251
};
@@ -238,7 +255,7 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") {
238255
retrieve_task_list.push_back(std::async(std::launch::async, [&] { return retrieve_task(); }));
239256
}
240257
for (auto& task : retrieve_task_list) {
241-
task.wait();
258+
REQUIRE_NOTHROW(task.get());
242259
}
243260
}
244261
}

tests/ut/test_sparse.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
// or implied. See the License for the specific language governing permissions and limitations under the License.
1111

1212
#include <future>
13+
#include <stdexcept>
1314
#include <string>
1415
#include <thread>
1516

@@ -510,7 +511,11 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") {
510511
for (auto i = 0; i < nq; ++i) {
511512
for (auto j = 0; j < k; ++j) {
512513
auto base = ids[i * k + j] / nb;
513-
REQUIRE(base == expected_id_base);
514+
if (base != expected_id_base) {
515+
throw std::runtime_error("id base mismatch at i=" + std::to_string(i) + " j=" + std::to_string(j) +
516+
": got " + std::to_string(base) + " expected " +
517+
std::to_string(expected_id_base));
518+
}
514519
}
515520
}
516521
};
@@ -538,7 +543,9 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") {
538543
test_time) {
539544
auto doc_ds = doc_vector_gen(nb, dim);
540545
auto res = idx.Add(doc_ds, json);
541-
REQUIRE(res == knowhere::Status::success);
546+
if (res != knowhere::Status::success) {
547+
throw std::runtime_error("Add failed with status " + std::to_string(static_cast<int>(res)));
548+
}
542549
}
543550
};
544551

@@ -547,7 +554,9 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") {
547554
while (std::chrono::duration_cast<std::chrono::seconds>(std::chrono::steady_clock::now() - start).count() <
548555
test_time) {
549556
auto results = idx.Search(query_ds, json, nullptr);
550-
REQUIRE(results.has_value());
557+
if (!results.has_value()) {
558+
throw std::runtime_error("Search returned no value");
559+
}
551560
check_result(*results.value());
552561
}
553562
};
@@ -559,7 +568,7 @@ TEST_CASE("Test Mem Sparse Index CC", "[float metrics]") {
559568
}
560569
task_list.push_back(std::async(std::launch::async, add_task));
561570
for (auto& task : task_list) {
562-
task.wait();
571+
REQUIRE_NOTHROW(task.get());
563572
}
564573
}
565574

0 commit comments

Comments
 (0)