Skip to content

Commit 89f517c

Browse files
committed
fix(lmf): correct four bugs in lmf_update negative sampling
Bug A: item_vectors.shape[1] returned n_factors+2, not n_items. Fix: use shape[0]. Bug B: RNGVector range was [0, nnz-1] and i = indices[index] only samples from already-interacted items (popularity-biased, never zero-interaction items). Fix: sample i directly from [0, n_items). Bug C: outer negative-sample loop and inner factor loops all used as the loop variable. Each inner loop left _ == n_factors, so the outer loop ran at most once regardless of neg_prop. Fix: use f for inner factor loops. Bug D: a single RNG seeded with nnz-1 was shared by the user-update pass (needs item IDs) and item-update pass (needs user IDs). Fix: two separate RNGVector instances with correct ranges.
1 parent 7a80fb7 commit 89f517c

1 file changed

Lines changed: 30 additions & 18 deletions

File tree

implicit/cpu/lmf.pyx

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -174,20 +174,24 @@ class LogisticMatrixFactorization(MatrixFactorizationBase):
174174

175175
# initialize RNG's, one per thread. Also pass the seeds for each thread's RNG
176176
cdef long[:] rng_seeds = rs.integers(0, 2**31, size=num_threads, dtype="long")
177-
cdef RNGVector rng = RNGVector(num_threads, len(user_items.data) - 1, rng_seeds)
177+
cdef long[:] rng_seeds2 = rs.integers(0, 2**31, size=num_threads, dtype="long")
178+
# Separate RNG per update direction: user update samples item IDs [0, items),
179+
# item update samples user IDs [0, users).
180+
cdef RNGVector user_neg_rng = RNGVector(num_threads, items - 1, rng_seeds)
181+
cdef RNGVector item_neg_rng = RNGVector(num_threads, users - 1, rng_seeds2)
178182

179183
log.debug("Running %i LMF training epochs", self.iterations)
180184
with tqdm(total=self.iterations, disable=not show_progress) as progress:
181185
for epoch in range(self.iterations):
182186
s = time.time()
183187
# user update
184-
lmf_update(rng, user_vec_deriv_sum,
188+
lmf_update(user_neg_rng, user_vec_deriv_sum,
185189
self.user_factors, self.item_factors,
186190
user_items.indices, user_items.indptr, user_items.data,
187191
self.learning_rate, self.regularization, self.neg_prop, num_threads)
188192
self.user_factors[:, -2] = 1.0
189193
# item update
190-
lmf_update(rng, item_vec_deriv_sum,
194+
lmf_update(item_neg_rng, item_vec_deriv_sum,
191195
self.item_factors, self.user_factors,
192196
item_users.indices, item_users.indptr, item_users.data,
193197
self.learning_rate, self.regularization, self.neg_prop, num_threads)
@@ -235,7 +239,9 @@ def lmf_update(RNGVector rng, floating[:, :] deriv_sum_sq,
235239
integral num_threads):
236240

237241
cdef integral n_users = user_vectors.shape[0]
238-
cdef integral n_items = item_vectors.shape[1]
242+
# item_vectors rows = number of opposite-side entities (items during user update,
243+
# users during item update). shape[1] was wrong — that gives n_factors+2.
244+
cdef integral n_items = item_vectors.shape[0]
239245
cdef integral n_factors = user_vectors.shape[1]
240246

241247
cdef integral u, i, it, c, _, index, f
@@ -272,21 +278,27 @@ def lmf_update(RNGVector rng, floating[:, :] deriv_sum_sq,
272278
deriv[_] = deriv[_] - z * item_vectors[i, _]
273279

274280
# Negative(Sampled) Item Indices exp(y_ui) / (1 + exp(y_ui)) * y_i
275-
for _ in range(min(n_items, user_seen_item * neg_prop)):
276-
index = rng.generate(thread_id)
277-
i = indices[index]
278-
exp_r = 0
279-
for _ in range(n_factors):
280-
exp_r = exp_r + (user_vectors[u, _] * item_vectors[i, _])
281-
z = sigmoid(exp_r)
282-
283-
for _ in range(n_factors):
284-
deriv[_] = deriv[_] - z * item_vectors[i, _]
285-
for _ in range(n_factors):
286-
deriv[_] -= reg * user_vectors[u, _]
287-
deriv_sum_sq[u, _] += deriv[_] * deriv[_]
281+
# Sample uniformly from [0, n_items); reject any item the user has
282+
# actually interacted with. Guard against users who have seen every
283+
# item (no valid negative exists).
284+
if user_seen_item < n_items:
285+
for c in range(user_seen_item * neg_prop):
286+
i = rng.generate(thread_id)
287+
# indices[indptr[u]:indptr[u+1]] is sorted (guaranteed by fit()),
288+
# so binary_search gives O(log k) rejection per sample.
289+
while binary_search(&indices[indptr[u]], &indices[indptr[u + 1]], i):
290+
i = rng.generate(thread_id)
291+
exp_r = 0
292+
for f in range(n_factors):
293+
exp_r = exp_r + (user_vectors[u, f] * item_vectors[i, f])
294+
z = sigmoid(exp_r)
295+
for f in range(n_factors):
296+
deriv[f] = deriv[f] - z * item_vectors[i, f]
297+
for f in range(n_factors):
298+
deriv[f] -= reg * user_vectors[u, f]
299+
deriv_sum_sq[u, f] += deriv[f] * deriv[f]
288300

289301
# a small constant is added for numerical stability
290-
user_vectors[u, _] += (lr / (sqrt(1e-6 + deriv_sum_sq[u, _]))) * deriv[_]
302+
user_vectors[u, f] += (lr / (sqrt(1e-6 + deriv_sum_sq[u, f]))) * deriv[f]
291303
finally:
292304
free(deriv)

0 commit comments

Comments
 (0)