Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/src/main/python/synapse/ml/recommendation/SARModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@
class SARModel(_SARModel):
def recommendForAllUsers(self, numItems):
return self._call_java("recommendForAllUsers", numItems)

def recommendForUserSubset(self, dataset, numItems):
if dataset.schema[self.getUserCol()].dataType == StringType():
dataset = dataset.withColumn(self.getUserCol(), dataset[self.getUserCol()].cast("int"))
return self._call_java("recommendForUserSubset", dataset, numItems)
Comment thread
dciborow marked this conversation as resolved.
Outdated
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, I
import org.apache.spark.mllib.linalg
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseMatrix}
import org.apache.spark.sql.functions.{col, collect_list, sum, udf, _}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.{DataFrame, Dataset}

import java.text.SimpleDateFormat
Expand Down Expand Up @@ -106,8 +106,22 @@ class SAR(override val uid: String) extends Estimator[SARModel]
(0 to numItems.value).map(i => map.getOrElse(i, 0.0).toFloat).toArray
})

dataset
.withColumn(C.AffinityCol, (dataset.columns.contains(getTimeCol), dataset.columns.contains(getRatingCol)) match {
val userColType = dataset.schema(getUserCol).dataType
val itemColType = dataset.schema(getItemCol).dataType

val castedDataset = (userColType, itemColType) match {
case (StringType, StringType) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
.withColumn(getItemCol, col(getItemCol).cast("int"))
case (StringType, _) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
case (_, StringType) =>
dataset.withColumn(getItemCol, col(getItemCol).cast("int"))
case _ => dataset
}

castedDataset
.withColumn(C.AffinityCol, (castedDataset.columns.contains(getTimeCol), castedDataset.columns.contains(getRatingCol)) match {
case (true, true) => blendWeights(timeDecay(col(getTimeCol)), col(getRatingCol))
case (true, false) => timeDecay(col(getTimeCol))
case (false, true) => col(getRatingCol)
Expand Down Expand Up @@ -197,7 +211,21 @@ class SAR(override val uid: String) extends Estimator[SARModel]
})
})

dataset
val userColType = dataset.schema(getUserCol).dataType
val itemColType = dataset.schema(getItemCol).dataType

val castedDataset = (userColType, itemColType) match {
case (StringType, StringType) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
.withColumn(getItemCol, col(getItemCol).cast("int"))
case (StringType, _) =>
dataset.withColumn(getUserCol, col(getUserCol).cast("int"))
case (_, StringType) =>
dataset.withColumn(getItemCol, col(getItemCol).cast("int"))
case _ => dataset
}

castedDataset
.select(col(getItemCol), col(getUserCol))
.groupBy(getItemCol).agg(collect_list(getUserCol) as "collect_list")
.withColumn(C.FeaturesCol, createItemFeaturesVector(col("collect_list")))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,52 @@
.cache()
)

ratings_with_strings = (
spark.createDataFrame(
[
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3),
],
["originalCustomerID", "newCategoryID", "rating", "notTime"],
)
.coalesce(1)
.cache()
)


class RankingSpec(unittest.TestCase):
@staticmethod
def adapter_evaluator(algo):
def adapter_evaluator(algo, data):
recommendation_indexer = RecommendationIndexer(
userInputCol=USER_ID,
userOutputCol=USER_ID_INDEX,
Expand All @@ -80,7 +122,7 @@ def adapter_evaluator(algo):

adapter = RankingAdapter(mode="allUsers", k=5, recommender=algo)
pipeline = Pipeline(stages=[recommendation_indexer, adapter])
output = pipeline.fit(ratings).transform(ratings)
output = pipeline.fit(data).transform(data)
print(str(output.take(1)) + "\n")

metrics = ["ndcgAt", "fcp", "mrr"]
Expand All @@ -91,13 +133,17 @@ def adapter_evaluator(algo):
+ str(RankingEvaluator(k=3, metricName=metric).evaluate(output)),
)

# def test_adapter_evaluator_als(self):
# als = ALS(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
# self.adapter_evaluator(als)
#
# def test_adapter_evaluator_sar(self):
# sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
# self.adapter_evaluator(sar)
def test_adapter_evaluator_als(self):
als = ALS(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(als, ratings)

def test_adapter_evaluator_sar(self):
sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(sar, ratings)

def test_adapter_evaluator_sar_with_strings(self):
sar = SAR(userCol=USER_ID_INDEX, itemCol=ITEM_ID_INDEX, ratingCol=RATING_ID)
self.adapter_evaluator(sar, ratings_with_strings)

def test_all_tiny(self):
customer_index = StringIndexer(inputCol=USER_ID, outputCol=USER_ID_INDEX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ class SARSpec extends RankingTestBase with EstimatorFuzzing[SAR] {
new TestObject(new SAR()
.setUserCol(recommendationIndexer.getUserOutputCol)
.setItemCol(recommendationIndexer.getItemOutputCol)
.setRatingCol(ratingCol), transformedDf)
.setRatingCol(ratingCol), transformedDf),
new TestObject(new SAR()
.setUserCol(recommendationIndexer.getUserOutputCol)
.setItemCol(recommendationIndexer.getItemOutputCol)
.setRatingCol(ratingCol), transformedDfWithStrings)
)
}

Expand Down Expand Up @@ -62,6 +66,41 @@ class SARSpec extends RankingTestBase with EstimatorFuzzing[SAR] {
assert(recs.count == 2)
}

test("SAR with string userCol and itemCol") {

val algo = sar
.setSupportThreshold(1)
.setSimilarityFunction("jacccard")
.setActivityTimeFormat("EEE MMM dd HH:mm:ss Z yyyy")

val adapter: RankingAdapter = new RankingAdapter()
.setK(5)
.setRecommender(algo)

val recopipeline = new Pipeline()
.setStages(Array(recommendationIndexer, adapter))
.fit(ratingsWithStrings)

val output = recopipeline.transform(ratingsWithStrings)

val evaluator: RankingEvaluator = new RankingEvaluator()
.setK(5)
.setNItems(10)

assert(evaluator.setMetricName("ndcgAt").evaluate(output) === 0.602819875812812)
assert(evaluator.setMetricName("fcp").evaluate(output) === 0.05 ||
evaluator.setMetricName("fcp").evaluate(output) === 0.1)
assert(evaluator.setMetricName("mrr").evaluate(output) === 1.0)

val users: DataFrame = spark
.createDataFrame(Seq(("user0","item0"),("user1","item1")))
.toDF(userColIndex, itemColIndex)

val recs = recopipeline.stages(1).asInstanceOf[RankingAdapterModel].getRecommenderModel
.asInstanceOf[SARModel].recommendForUserSubset(users, 10)
assert(recs.count == 2)
}

lazy val testFile: String = getClass.getResource("/demoUsage.csv.gz").getPath
lazy val simCount1: String = getClass.getResource("/sim_count1.csv.gz").getPath
lazy val simLift1: String = getClass.getResource("/sim_lift1.csv.gz").getPath
Expand Down Expand Up @@ -115,7 +154,12 @@ class SARModelSpec extends RankingTestBase with TransformerFuzzing[SARModel] {
.setUserCol(recommendationIndexer.getUserOutputCol)
.setItemCol(recommendationIndexer.getItemOutputCol)
.setRatingCol(ratingCol)
.fit(transformedDf), transformedDf)
.fit(transformedDf), transformedDf),
new TestObject(new SAR()
.setUserCol(recommendationIndexer.getUserOutputCol)
.setItemCol(recommendationIndexer.getItemOutputCol)
.setRatingCol(ratingCol)
.fit(transformedDfWithStrings), transformedDfWithStrings)
)
}

Expand Down
119 changes: 119 additions & 0 deletions docs/Quick Examples/estimators/core/_Recommendation.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,43 @@ ratings = (spark.createDataFrame([
.dropDuplicates()
.cache())

ratings_with_strings = (spark.createDataFrame([
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3)
], ["originalCustomerID", "newCategoryID", "rating", "notTime"])
.coalesce(1)
.cache())

Comment thread
dciborow marked this conversation as resolved.
Outdated
recommendationIndexer = (RecommendationIndexer()
.setUserInputCol("customerIDOrg")
.setUserOutputCol("customerID")
Expand Down Expand Up @@ -275,6 +312,43 @@ ratings = (spark.createDataFrame([
.dropDuplicates()
.cache())

ratings_with_strings = (spark.createDataFrame([
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3)
], ["originalCustomerID", "newCategoryID", "rating", "notTime"])
.coalesce(1)
.cache())

Comment thread
dciborow marked this conversation as resolved.
Outdated
recommendationIndexer = (RecommendationIndexer()
.setUserInputCol("customerIDOrg")
.setUserOutputCol("customerID")
Expand All @@ -298,6 +372,10 @@ adapter = (RankingAdapter()
res1 = recommendationIndexer.fit(ratings).transform(ratings).cache()

adapter.fit(res1).transform(res1).show()

res2 = recommendationIndexer.fit(ratings_with_strings).transform(ratings_with_strings).cache()

adapter.fit(res2).transform(res2).show()
Comment thread
dciborow marked this conversation as resolved.
Outdated
```

</TabItem>
Expand Down Expand Up @@ -344,6 +422,43 @@ val ratings = (Seq(
.dropDuplicates()
.cache())

val ratings_with_strings = (Seq(
("user0", "item1", 4, 4),
("user0", "item3", 1, 1),
("user0", "item4", 5, 5),
("user0", "item5", 3, 3),
("user0", "item7", 3, 3),
("user0", "item9", 3, 3),
("user0", "item10", 3, 3),
("user1", "item1", 4, 4),
("user1", "item2", 5, 5),
("user1", "item3", 1, 1),
("user1", "item6", 4, 4),
("user1", "item7", 5, 5),
("user1", "item8", 1, 1),
("user1", "item10", 3, 3),
("user2", "item1", 4, 4),
("user2", "item2", 1, 1),
("user2", "item3", 1, 1),
("user2", "item4", 5, 5),
("user2", "item5", 3, 3),
("user2", "item6", 4, 4),
("user2", "item8", 1, 1),
("user2", "item9", 5, 5),
("user2", "item10", 3, 3),
("user3", "item2", 5, 5),
("user3", "item3", 1, 1),
("user3", "item4", 5, 5),
("user3", "item5", 3, 3),
("user3", "item6", 4, 4),
("user3", "item7", 5, 5),
("user3", "item8", 1, 1),
("user3", "item9", 5, 5),
("user3", "item10", 3, 3))
.toDF("originalCustomerID", "newCategoryID", "rating", "notTime")
.coalesce(1)
.cache())

Comment thread
dciborow marked this conversation as resolved.
Outdated
val recommendationIndexer = (new RecommendationIndexer()
.setUserInputCol("customerIDOrg")
.setUserOutputCol("customerID")
Expand All @@ -367,6 +482,10 @@ val adapter = (new RankingAdapter()
val res1 = recommendationIndexer.fit(ratings).transform(ratings).cache()

adapter.fit(res1).transform(res1).show()

val res2 = recommendationIndexer.fit(ratings_with_strings).transform(ratings_with_strings).cache()

adapter.fit(res2).transform(res2).show()
Comment thread
dciborow marked this conversation as resolved.
Outdated
```

</TabItem>
Expand Down