Skip to content

Commit 39b9f19

Browse files
jmorlockJan Morlockbenfred
authored
check user and item matrix for nan entries also in bpr gpu version (#731)
Co-authored-by: Jan Morlock <jan.morlock@1und1.de> Co-authored-by: Ben Frederickson <github@benfrederickson.com>
1 parent 7c36141 commit 39b9f19

4 files changed

Lines changed: 14 additions & 5 deletions

File tree

implicit/cpu/matrix_factorization_base.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from scipy.sparse import csr_matrix, lil_matrix
77

8-
from ..recommender_base import ModelFitError, RecommenderBase
8+
from ..recommender_base import RecommenderBase
99
from .topk import topk
1010

1111

@@ -247,10 +247,7 @@ def item_norms(self):
247247
return self._item_norms
248248

249249
def _check_fit_errors(self):
250-
is_nan = np.any(np.isnan(self.user_factors), axis=None)
251-
is_nan |= np.any(np.isnan(self.item_factors), axis=None)
252-
if is_nan:
253-
raise ModelFitError("NaN encountered in factors")
250+
self._check_factors(self.user_factors, self.item_factors)
254251

255252

256253
def _filter_items_from_sparse_matrix(items, query_items):

implicit/gpu/bpr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def fit(self, user_items, show_progress=True, callback=None):
160160
if callback:
161161
callback(_epoch, time.time() - s, correct, skipped)
162162

163+
self._check_fit_errors()
164+
163165
def to_cpu(self) -> implicit.cpu.bpr.BayesianPersonalizedRanking:
164166
"""Converts this model to an equivalent version running on the cpu"""
165167
ret = implicit.cpu.bpr.BayesianPersonalizedRanking(

implicit/gpu/matrix_factorization_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ def similar_items(
201201

202202
similar_items.__doc__ = RecommenderBase.similar_items.__doc__
203203

204+
def _check_fit_errors(self):
205+
self._check_factors(self.user_factors.to_numpy(), self.item_factors.to_numpy())
206+
204207
def recalculate_user(self, userid, user_items):
205208
raise NotImplementedError("recalculate_user is not supported with this model")
206209

implicit/recommender_base.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,10 @@ def rank_items(self, userid, user_items, selected_items, recalculate_user=False)
214214
items=selected_items,
215215
filter_already_liked_items=False,
216216
)
217+
218+
@staticmethod
219+
def _check_factors(user_factors, item_factors):
220+
is_nan = np.any(np.isnan(user_factors), axis=None)
221+
is_nan |= np.any(np.isnan(item_factors), axis=None)
222+
if is_nan:
223+
raise ModelFitError("NaN encountered in factors")

0 commit comments

Comments
 (0)