6767 .cache ()
6868)
6969
70+ ratings_with_strings = (
71+ spark .createDataFrame (
72+ [
73+ ("user0" , "item1" , 4 , 4 ),
74+ ("user0" , "item3" , 1 , 1 ),
75+ ("user0" , "item4" , 5 , 5 ),
76+ ("user0" , "item5" , 3 , 3 ),
77+ ("user0" , "item7" , 3 , 3 ),
78+ ("user0" , "item9" , 3 , 3 ),
79+ ("user0" , "item10" , 3 , 3 ),
80+ ("user1" , "item1" , 4 , 4 ),
81+ ("user1" , "item2" , 5 , 5 ),
82+ ("user1" , "item3" , 1 , 1 ),
83+ ("user1" , "item6" , 4 , 4 ),
84+ ("user1" , "item7" , 5 , 5 ),
85+ ("user1" , "item8" , 1 , 1 ),
86+ ("user1" , "item10" , 3 , 3 ),
87+ ("user2" , "item1" , 4 , 4 ),
88+ ("user2" , "item2" , 1 , 1 ),
89+ ("user2" , "item3" , 1 , 1 ),
90+ ("user2" , "item4" , 5 , 5 ),
91+ ("user2" , "item5" , 3 , 3 ),
92+ ("user2" , "item6" , 4 , 4 ),
93+ ("user2" , "item8" , 1 , 1 ),
94+ ("user2" , "item9" , 5 , 5 ),
95+ ("user2" , "item10" , 3 , 3 ),
96+ ("user3" , "item2" , 5 , 5 ),
97+ ("user3" , "item3" , 1 , 1 ),
98+ ("user3" , "item4" , 5 , 5 ),
99+ ("user3" , "item5" , 3 , 3 ),
100+ ("user3" , "item6" , 4 , 4 ),
101+ ("user3" , "item7" , 5 , 5 ),
102+ ("user3" , "item8" , 1 , 1 ),
103+ ("user3" , "item9" , 5 , 5 ),
104+ ("user3" , "item10" , 3 , 3 ),
105+ ],
106+ ["originalCustomerID" , "newCategoryID" , "rating" , "notTime" ],
107+ )
108+ .coalesce (1 )
109+ .cache ()
110+ )
111+
70112
71113class RankingSpec (unittest .TestCase ):
72114 @staticmethod
73- def adapter_evaluator (algo ):
115+ def adapter_evaluator (algo , data ):
74116 recommendation_indexer = RecommendationIndexer (
75117 userInputCol = USER_ID ,
76118 userOutputCol = USER_ID_INDEX ,
@@ -80,7 +122,7 @@ def adapter_evaluator(algo):
80122
81123 adapter = RankingAdapter (mode = "allUsers" , k = 5 , recommender = algo )
82124 pipeline = Pipeline (stages = [recommendation_indexer , adapter ])
83- output = pipeline .fit (ratings ).transform (ratings )
125+ output = pipeline .fit (data ).transform (data )
84126 print (str (output .take (1 )) + "\n " )
85127
86128 metrics = ["ndcgAt" , "fcp" , "mrr" ]
@@ -91,13 +133,17 @@ def adapter_evaluator(algo):
91133 + str (RankingEvaluator (k = 3 , metricName = metric ).evaluate (output )),
92134 )
93135
94- # def test_adapter_evaluator_als(self):
95- # als = ALS(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
96- # self.adapter_evaluator(als)
97- #
98- # def test_adapter_evaluator_sar(self):
99- # sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
100- # self.adapter_evaluator(sar)
136+ def test_adapter_evaluator_als (self ):
137+ als = ALS (userCol = USER_ID_INDEX , itemCol = ITEM_ID_INDEX , ratingCol = RATING_ID )
138+ self .adapter_evaluator (als , ratings )
139+
140+ def test_adapter_evaluator_sar (self ):
141+ sar = SAR (userCol = USER_ID_INDEX , itemCol = ITEM_ID_INDEX , ratingCol = RATING_ID )
142+ self .adapter_evaluator (sar , ratings )
143+
144+ def test_adapter_evaluator_sar_with_strings (self ):
145+ sar = SAR (userCol = USER_ID_INDEX , itemCol = ITEM_ID_INDEX , ratingCol = RATING_ID )
146+ self .adapter_evaluator (sar , ratings_with_strings )
101147
102148 def test_all_tiny (self ):
103149 customer_index = StringIndexer (inputCol = USER_ID , outputCol = USER_ID_INDEX )
0 commit comments