Skip to content

Commit d3bbdb2

Browse files
authored
Fix overflow warning in nearest neighbors code (#755)
During the unittest we were seeing a warning saying ``` tests/knn_test.py::BM25Test::test_rank_items_batch tests/knn_test.py::BM25Test::test_similar_items_filter tests/knn_test.py::TFIDFTest::test_rank_items_batch tests/knn_test.py::TFIDFTest::test_similar_items_filter tests/knn_test.py::CosineTest::test_rank_items_batch tests/knn_test.py::CosineTest::test_similar_items_filter /home/ben/code/implicit/implicit/utils.py:134: RuntimeWarning: overflow encountered in cast output_scores[i] = batch_scores[:N] ``` This is because the `_batch_call` was generating output in float32, but the KNN models were returning float64 results. Fix.
1 parent 3584470 commit d3bbdb2

2 files changed

Lines changed: 10 additions & 4 deletions

File tree

implicit/nearest_neighbours.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def recommend(
6262
userid,
6363
user_items=user_items,
6464
N=N,
65+
score_dtype=np.float64,
6566
filter_already_liked_items=filter_already_liked_items,
6667
filter_items=filter_items,
6768
recalculate_user=recalculate_user,
@@ -115,7 +116,12 @@ def similar_items(
115116

116117
if not np.isscalar(itemid):
117118
return _batch_call(
118-
self.similar_items, itemid, N=N, filter_items=filter_items, items=items
119+
self.similar_items,
120+
itemid,
121+
N=N,
122+
score_dtype=np.float64,
123+
filter_items=filter_items,
124+
items=items,
119125
)
120126

121127
if filter_items is not None and items is not None:

implicit/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ def augment_inner_product_matrix(factors):
103103
return max_norm, np.append(factors, extra_dimension.reshape(norms.shape[0], 1), axis=1)
104104

105105

106-
def _batch_call(func, ids, *args, N=10, **kwargs):
106+
def _batch_call(func, ids, *args, N=10, id_dtype=np.int32, score_dtype=np.float32, **kwargs):
107107
# we're running in batch mode, just loop over each item and call the scalar version of the
108108
# function
109-
output_ids = np.zeros((len(ids), N), dtype=np.int32)
110-
output_scores = np.zeros((len(ids), N), dtype=np.float32)
109+
output_ids = np.zeros((len(ids), N), dtype=id_dtype)
110+
output_scores = np.zeros((len(ids), N), dtype=score_dtype)
111111

112112
user_items = kwargs.pop("user_items") if "user_items" in kwargs else None
113113
item_users = kwargs.pop("item_users") if "item_users" in kwargs else None

0 commit comments

Comments
 (0)