1010// or implied. See the License for the specific language governing permissions and limitations under the License.
1111
1212#include < future>
13+ #include < stdexcept>
1314
1415#include " catch2/catch_approx.hpp"
1516#include " catch2/catch_test_macros.hpp"
@@ -84,18 +85,26 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
8485
8586 auto retrieve_task = [&]() {
8687 auto results = idx_new.GetVectorByIds (ids_ds);
87- REQUIRE (results.has_value ());
88+ if (!results.has_value ()) {
89+ throw std::runtime_error (" GetVectorByIds returned no value" );
90+ }
8891 auto xb = (uint8_t *)train_ds->GetTensor ();
8992 auto res_rows = results.value ()->GetRows ();
9093 auto res_dim = results.value ()->GetDim ();
9194 auto res_data = (uint8_t *)results.value ()->GetTensor ();
92- REQUIRE (res_rows == nq);
93- REQUIRE (res_dim == dim);
95+ if (res_rows != nq) {
96+ throw std::runtime_error (" res_rows mismatch: " + std::to_string (res_rows));
97+ }
98+ if (res_dim != dim) {
99+ throw std::runtime_error (" res_dim mismatch: " + std::to_string (res_dim));
100+ }
94101 const auto data_bytes = dim / 8 ;
95102 for (int i = 0 ; i < nq; ++i) {
96103 auto id = ids_ds->GetIds ()[i];
97104 for (int j = 0 ; j < data_bytes; ++j) {
98- REQUIRE (res_data[i * data_bytes + j] == xb[id * data_bytes + j]);
105+ if (res_data[i * data_bytes + j] != xb[id * data_bytes + j]) {
106+ throw std::runtime_error (" data mismatch at i=" + std::to_string (i) + " j=" + std::to_string (j));
107+ }
99108 }
100109 }
101110 };
@@ -105,7 +114,7 @@ TEST_CASE("Test Binary Get Vector By Ids", "[Binary GetVectorByIds]") {
105114 retrieve_task_list.push_back (std::async (std::launch::async, [&] { return retrieve_task (); }));
106115 }
107116 for (auto & task : retrieve_task_list) {
108- task.wait ( );
117+ REQUIRE_NOTHROW ( task.get () );
109118 }
110119 };
111120}
@@ -218,17 +227,25 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") {
218227
219228 auto retrieve_task = [&]() {
220229 auto results = idx_new.GetVectorByIds (ids_ds);
221- REQUIRE (results.has_value ());
230+ if (!results.has_value ()) {
231+ throw std::runtime_error (" GetVectorByIds returned no value" );
232+ }
222233 auto xb = (float *)train_ds_copy->GetTensor ();
223234 auto res_rows = results.value ()->GetRows ();
224235 auto res_dim = results.value ()->GetDim ();
225236 auto res_data = (float *)results.value ()->GetTensor ();
226- REQUIRE (res_rows == nq);
227- REQUIRE (res_dim == dim);
237+ if (res_rows != nq) {
238+ throw std::runtime_error (" res_rows mismatch: " + std::to_string (res_rows));
239+ }
240+ if (res_dim != dim) {
241+ throw std::runtime_error (" res_dim mismatch: " + std::to_string (res_dim));
242+ }
228243 for (int i = 0 ; i < nq; ++i) {
229244 const auto id = ids_ds->GetIds ()[i];
230245 for (int j = 0 ; j < dim; ++j) {
231- REQUIRE (res_data[i * dim + j] == xb[id * dim + j]);
246+ if (res_data[i * dim + j] != xb[id * dim + j]) {
247+ throw std::runtime_error (" data mismatch at i=" + std::to_string (i) + " j=" + std::to_string (j));
248+ }
232249 }
233250 }
234251 };
@@ -238,7 +255,7 @@ TEST_CASE("Test Float Get Vector By Ids", "[Float GetVectorByIds]") {
238255 retrieve_task_list.push_back (std::async (std::launch::async, [&] { return retrieve_task (); }));
239256 }
240257 for (auto & task : retrieve_task_list) {
241- task.wait ( );
258+ REQUIRE_NOTHROW ( task.get () );
242259 }
243260 }
244261}
0 commit comments