Skip to content

Commit 530ba42

Browse files
nsiddharthhieuddo
andauthored
Support hyper-parameter search for next-item recommenders (#643) (#702)
* Support hyper-parameter search for next-item recommenders (#643) GridSearch/RandomSearch could not tune NextItemRecommender models evaluated with NextItemEvaluation: - NextItemEvaluation.evaluate() rejected the search wrapper because it is a Recommender, not a NextItemRecommender. - BaseSearch.fit() scored next-item models with the standard ranking_eval, whose rank()/score() path is incompatible with the session-based score(history_items=...) signature. Fix: - Accept a search wrapper whose .model is a NextItemRecommender. - Route next-item models through next_item_evaluation.ranking_eval during search, using the eval_method's exclude_unknowns/mode. - Delegate transform/score/rank from BaseSearch to the best model so the fitted wrapper evaluates transparently. Add GridSearch/RandomSearch next-item tests. * update get model seed * add seed for reproducible hyperopt search --------- Co-authored-by: hieuddo <hieu.dd.1998@gmail.com>
1 parent a620c82 commit 530ba42

4 files changed

Lines changed: 98 additions & 9 deletions

File tree

cornac/eval_methods/next_item_evaluation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,10 @@ def evaluate(self, model, metrics, user_based, show_validation=True):
335335
-------
336336
res: :obj:`cornac.experiment.Result`
337337
"""
338-
if not isinstance(model, NextItemRecommender):
338+
base_model = getattr(model, "model", None)
339+
if not isinstance(model, NextItemRecommender) and not isinstance(
340+
base_model, NextItemRecommender
341+
):
339342
raise ValueError("model must be a NextItemRecommender but '%s' is provided" % type(model))
340343

341344
if self.train_set is None:

cornac/hyperopt.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
import numpy as np
1818
from itertools import product
1919

20-
from .models import Recommender
20+
from .models import Recommender, NextItemRecommender
2121
from .metrics import RatingMetric, RankingMetric
2222
from .eval_methods import rating_eval, ranking_eval
23+
from .eval_methods.next_item_evaluation import ranking_eval as next_item_ranking_eval
2324
from .utils import get_rng
2425

2526

@@ -147,6 +148,16 @@ def fit(self, train_set, val_set=None):
147148

148149
if isinstance(self.metric, RatingMetric):
149150
score = rating_eval(model, [self.metric], val_set)[0][0]
151+
elif isinstance(model, NextItemRecommender):
152+
score = next_item_ranking_eval(
153+
model,
154+
[self.metric],
155+
train_set,
156+
val_set,
157+
exclude_unknowns=self.eval_method.exclude_unknowns,
158+
mode=self.eval_method.mode,
159+
verbose=False,
160+
)[0][0]
150161
else:
151162
score = ranking_eval(
152163
model,
@@ -171,9 +182,17 @@ def fit(self, train_set, val_set=None):
171182

172183
return self
173184

174-
def score(self, user_idx, item_idx=None):
185+
def transform(self, test_set):
186+
"""Delegate test-set transformation to the best searched model."""
187+
return self.best_model.transform(test_set)
188+
189+
def score(self, user_idx, *args, **kwargs):
175190
"""Scoring using the best searched model"""
176-
return self.best_model.score(user_idx, item_idx)
191+
return self.best_model.score(user_idx, *args, **kwargs)
192+
193+
def rank(self, user_idx, item_indices=None, k=-1, **kwargs):
194+
"""Ranking using the best searched model"""
195+
return self.best_model.rank(user_idx, item_indices, k, **kwargs)
177196

178197

179198
class GridSearch(BaseSearch):
@@ -263,7 +282,7 @@ def _build_param_set(self):
263282
"""Generate searching points"""
264283
param_set = []
265284
keys = [d.name for d in self.space]
266-
rng = get_rng(self.model.seed)
285+
rng = get_rng(getattr(self.model, "seed", None))
267286
while len(param_set) < self.n_trails:
268287
params = [d._sample(rng) for d in self.space]
269288
param_set.append(dict(zip(keys, params)))

docs/source/user/iamadeveloper.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ this, we can use the `cornac.hyperopt` module to perform the searches.
111111
As shown in the above code, we have defined two methods for hyper-parameter search,
112112
``GridSearch`` and ``RandomSearch``.
113113

114+
The same search classes support next-item recommenders when paired with
115+
``NextItemEvaluation``. In both cases, the evaluation method must include a
116+
validation split, which is used to select the best parameter settings.
117+
114118
+------------------------------------------+---------------------------------------------+
115119
| Grid Search | Random Search |
116120
+==========================================+=============================================+
@@ -719,4 +723,4 @@ Cornac.
719723

720724
No matter who you are, you could also consider contributing to Cornac,
721725
with our contributors guide.
722-
View :doc:`/developer/index`.
726+
View :doc:`/developer/index`.

tests/cornac/test_hyperopt.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import numpy.testing as npt
2020

2121
from cornac.data import Reader
22-
from cornac.models import MF, BPR
23-
from cornac.metrics import RMSE, AUC
24-
from cornac.eval_methods import RatioSplit
22+
from cornac.models import MF, BPR, SPop
23+
from cornac.metrics import RMSE, AUC, HitRatio
24+
from cornac.eval_methods import RatioSplit, NextItemEvaluation
2525
from cornac.hyperopt import Discrete, Continuous
2626
from cornac.hyperopt import GridSearch, RandomSearch
2727
from cornac import Experiment
@@ -70,6 +70,69 @@ def test_random_search(self):
7070
user_based=False,
7171
).run()
7272

73+
def test_random_search_next_item_recommender(self):
74+
data = Reader().read("./tests/sequence.txt", fmt="USIT", sep=" ")
75+
eval_method = NextItemEvaluation.from_splits(
76+
train_data=data[:35],
77+
val_data=data[35:50],
78+
test_data=data[50:],
79+
fmt="USIT",
80+
exclude_unknowns=False,
81+
mode="next",
82+
)
83+
metric = HitRatio(k=5)
84+
spop = SPop()
85+
spop.seed = 123 # for reproducible RandomSearch sampling
86+
rs_spop = RandomSearch(
87+
model=spop,
88+
space=[Discrete("use_session_popularity", [False, True])],
89+
metric=metric,
90+
eval_method=eval_method,
91+
n_trails=2,
92+
)
93+
94+
test_result, _ = eval_method.evaluate(
95+
model=rs_spop,
96+
metrics=[metric],
97+
user_based=False,
98+
show_validation=False,
99+
)
100+
101+
self.assertIsNotNone(rs_spop.best_model)
102+
self.assertEqual(rs_spop.best_params, {"use_session_popularity": False})
103+
self.assertAlmostEqual(rs_spop.best_score, 11 / 12)
104+
self.assertTrue(np.isfinite(test_result.metric_avg_results["HitRatio@5"]))
105+
106+
def test_grid_search_next_item_recommender(self):
107+
data = Reader().read("./tests/sequence.txt", fmt="USIT", sep=" ")
108+
eval_method = NextItemEvaluation.from_splits(
109+
train_data=data[:35],
110+
val_data=data[35:50],
111+
test_data=data[50:],
112+
fmt="USIT",
113+
exclude_unknowns=False,
114+
mode="next",
115+
)
116+
metric = HitRatio(k=5)
117+
gs_spop = GridSearch(
118+
model=SPop(),
119+
space=[Discrete("use_session_popularity", [False, True])],
120+
metric=metric,
121+
eval_method=eval_method,
122+
)
123+
124+
test_result, _ = eval_method.evaluate(
125+
model=gs_spop,
126+
metrics=[metric],
127+
user_based=False,
128+
show_validation=False,
129+
)
130+
131+
self.assertIsNotNone(gs_spop.best_model)
132+
self.assertEqual(gs_spop.best_params, {"use_session_popularity": False})
133+
self.assertAlmostEqual(gs_spop.best_score, 11 / 12)
134+
self.assertTrue(np.isfinite(test_result.metric_avg_results["HitRatio@5"]))
135+
73136

74137
if __name__ == "__main__":
75138
unittest.main()

0 commit comments

Comments
 (0)