Skip to content

Commit 66f0afd

Browse files
BrendanWalshCopilot
andcommitted
Eliminate RDD usage across SynapseML for Spark 4.0 compatibility
Replace SparkContext/RDD APIs with DataFrame/SparkSession equivalents per SPARK-48909 pattern. This enables SynapseML to work in environments where RDDs are restricted (e.g., Databricks Unity Catalog shared mode) and improves forward-compatibility with Spark 4.0+. Key changes: - ComplexParamsSerializer/Serializer: Replace sc.parallelize().saveAsTextFile() with spark.createDataFrame().write.text() for metadata serialization - LightGBMBooster: Replace sc.parallelize() with Seq().toDS() for model I/O - Lambda: Replace SparkContext.getOrCreate() + sc.emptyRDD with SparkSession - ONNXModel: Replace sc.binaryFiles() with Hadoop FileSystem API - ONNXHub: Replace SparkContext.hadoopConfiguration with SparkSession - StratifiedRepartition: Replace RDD keyBy/sampleByKeyExact/RangePartitioner with DataFrame-based oversampling and round-robin partitioning - Repartition: Replace .rdd.repartition() with DataFrame repartition - ClusterUtil: Replace .rdd.mapPartitionsWithIndex with spark_partition_id() - VectorOps: Replace sparkContext.parallelize with spark.range() - SyntheticEstimator: Replace df.rdd.zipWithIndex with monotonically_increasing_id - TuneHyperparameters: Replace MLUtils.kFold(df.rdd) with DataFrame-based k-fold - VowpalWabbitBase: Refactor prepareDataSet to return (DataFrame, Int) tuple - LightGBMBase/Ranker: Simplify partition management without .rdd.getNumPartitions - DistributedHTTPSource: Replace sparkContext.parallelize with spark.range() Remaining RDD usage (no DataFrame API alternatives): - Barrier execution: VowpalWabbitBaseLearner, LightGBMBase (df.rdd.barrier()) - MLlib evaluators: ComputeModelStatistics, RankingEvaluator (require RDD input) - MLlib linalg: SARModel (CoordinateMatrix requires RDD) - Streaming internals: HTTPSource, DistributedHTTPSource, HTTPSourceV2 Closes #2401 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 895752c commit 66f0afd

25 files changed

Lines changed: 152 additions & 129 deletions

File tree

core/src/main/scala/com/microsoft/azure/synapse/ml/automl/TuneHyperparameters.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import org.apache.spark.ml.classification.ClassificationModel
1717
import org.apache.spark.ml.param._
1818
import org.apache.spark.ml.regression.RegressionModel
1919
import org.apache.spark.ml.util._
20-
import org.apache.spark.mllib.util.MLUtils
2120
import org.apache.spark.sql._
2221
import org.apache.spark.sql.types.StructType
2322

@@ -149,9 +148,21 @@ class TuneHyperparameters(override val uid: String) extends Estimator[TuneHyperp
149148
override def fit(dataset: Dataset[_]): TuneHyperparametersModel = { //scalastyle:ignore cyclomatic.complexity
150149
logFit({
151150
val sparkSession = dataset.sparkSession
152-
val splits = MLUtils.kFold(dataset.toDF.rdd, getNumFolds, getSeed)
151+
import org.apache.spark.sql.functions.{rand, lit}
152+
val df = dataset.toDF
153+
val nFolds = getNumFolds
154+
// DataFrame-based k-fold splitting: assign each row to a fold using hash of random value
155+
val dfWithFold = df.withColumn("_kfold_rand", rand(getSeed))
156+
val splits = (0 until nFolds).map { fold =>
157+
val training = dfWithFold
158+
.filter((dfWithFold("_kfold_rand") * lit(nFolds)).cast("int") =!= lit(fold))
159+
.drop("_kfold_rand")
160+
val validation = dfWithFold
161+
.filter((dfWithFold("_kfold_rand") * lit(nFolds)).cast("int") === lit(fold))
162+
.drop("_kfold_rand")
163+
(training, validation)
164+
}.toArray
153165
val hyperParams = getParamSpace.paramMaps
154-
val schema = dataset.schema
155166
val executionContext = getExecutionContext
156167
val (evaluationMetricColumnName, operator): (String, Ordering[Double]) =
157168
EvaluationUtils.getMetricWithOperator(getModels.head, getEvaluationMetric)
@@ -163,8 +174,8 @@ class TuneHyperparameters(override val uid: String) extends Estimator[TuneHyperp
163174
val numModels = getModels.length
164175

165176
val metrics = splits.zipWithIndex.map { case ((training, validation), _) =>
166-
val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
167-
val validationDataset = sparkSession.createDataFrame(validation, schema).cache()
177+
val trainingDataset = training.cache()
178+
val validationDataset = validation.cache()
168179

169180
val modelParams = ListBuffer[ParamMap]()
170181
for (n <- 0 until getNumRuns) {

core/src/main/scala/com/microsoft/azure/synapse/ml/causal/SyntheticEstimator.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,14 +218,8 @@ object SyntheticEstimator {
218218
}
219219

220220
private[causal] def assignRowIndex(df: DataFrame, colName: String): DataFrame = {
221-
df.sparkSession.createDataFrame(
222-
df.rdd.zipWithIndex.map(element =>
223-
Row.fromSeq(Seq(element._2) ++ element._1.toSeq)
224-
),
225-
StructType(
226-
Array(StructField(colName, LongType, nullable = false)) ++ df.schema.fields
227-
)
228-
)
221+
df.withColumn(colName, monotonically_increasing_id())
222+
.select(col(colName) +: df.columns.map(col): _*)
229223
}
230224

231225
private[causal] def createIndex(data: DataFrame, inputCol: String, indexCol: String): DataFrame = {

core/src/main/scala/com/microsoft/azure/synapse/ml/causal/linalg/VectorOps.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,8 @@ object DVectorOps extends VectorOps[DVector] {
9696
def make(size: Long, value: => Double): DVector = {
9797
val spark = SparkSession.active
9898
import spark.implicits._
99-
val data = 0L until size
10099
spark
101-
.sparkContext
102-
.parallelize(data)
100+
.range(0, size)
103101
.toDF("i")
104102
.withColumn("value", lit(value))
105103
.as[VectorEntry]

core/src/main/scala/com/microsoft/azure/synapse/ml/core/utils/ClusterUtil.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ package com.microsoft.azure.synapse.ml.core.utils
55

66
import java.net.InetAddress
77
import org.apache.http.conn.util.InetAddressUtils
8-
import org.apache.spark.SparkContext
98
import org.apache.spark.injections.BlockManagerUtils
10-
import org.apache.spark.sql.functions.typedLit
119
import org.apache.spark.sql.{Column, DataFrame, SparkSession}
1210
import org.slf4j.Logger
1311

@@ -20,7 +18,7 @@ object ClusterUtil {
2018
* @return The number of tasks per executor.
2119
*/
2220
def getNumTasksPerExecutor(spark: SparkSession, log: Logger): Int = {
23-
val confTaskCpus = getTaskCpus(spark.sparkContext, log)
21+
val confTaskCpus = getTaskCpus(spark, log)
2422
try {
2523
val confCores = spark.sparkContext.getConf.get("spark.executor.cores").toInt
2624
val tasksPerExec = confCores / confTaskCpus
@@ -44,13 +42,18 @@ object ClusterUtil {
4442
* @return The number of rows per partition (where partitionId is the array index).
4543
*/
4644
def getNumRowsPerPartition(df: DataFrame, labelCol: Column): Array[Long] = {
47-
val indexedRowCounts: Array[(Int, Long)] = df
48-
.select(typedLit(0.toByte))
49-
.rdd
50-
.mapPartitionsWithIndex({case (i,rows) => Iterator((i,rows.size.toLong))}, true)
45+
import org.apache.spark.sql.functions.{spark_partition_id, count, lit}
46+
val partitionCounts = df
47+
.select(spark_partition_id().as("partId"))
48+
.groupBy("partId")
49+
.agg(count(lit(1)).as("cnt"))
5150
.collect()
52-
// Get an array where the index is implicitly the partition id
53-
indexedRowCounts.sortBy(pair => pair._1).map(pair => pair._2)
51+
val maxPartId = if (partitionCounts.isEmpty) 0 else partitionCounts.map(_.getInt(0)).max + 1
52+
val result = Array.fill[Long](maxPartId)(0L)
53+
partitionCounts.foreach { row =>
54+
result(row.getInt(0)) = row.getLong(1)
55+
}
56+
result
5457
}
5558

5659
/** Get number of default cores from sparkSession(required) or master(optional) for 1 executor.
@@ -104,9 +107,9 @@ object ClusterUtil {
104107
}
105108
}
106109

107-
def getTaskCpus(sparkContext: SparkContext, log: Logger): Int = {
110+
def getTaskCpus(spark: SparkSession, log: Logger): Int = {
108111
try {
109-
val taskCpusConfig = sparkContext.getConf.getOption("spark.task.cpus")
112+
val taskCpusConfig = spark.sparkContext.getConf.getOption("spark.task.cpus")
110113
if (taskCpusConfig.isEmpty) {
111114
log.info("ClusterUtils did not detect spark.task.cpus config set, using default 1 instead")
112115
}

core/src/main/scala/com/microsoft/azure/synapse/ml/recommendation/RankingEvaluator.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ class RankingEvaluator(override val uid: String)
130130
/** @group setParam */
131131
def setPredictionCol(value: String): this.type = set(predictionCol, value)
132132

133+
// Note: RankingMetrics from MLlib requires RDD input - .rdd conversion is necessary
133134
def getMetrics(dataset: Dataset[_]): AdvancedRankingMetrics = {
134135
val predictionAndLabels = dataset
135136
.select(getPredictionCol, getLabelCol)

core/src/main/scala/com/microsoft/azure/synapse/ml/recommendation/SARModel.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class SARModel(override val uid: String) extends Model[SARModel]
104104
dstOutputColumn: String,
105105
num: Int): DataFrame = {
106106

107+
// Note: CoordinateMatrix from MLlib requires RDD input - .rdd conversion is necessary
107108
def dfToRDDMatrxEntry(dataframe: DataFrame) = {
108109
dataframe.rdd
109110
.flatMap(row =>

core/src/main/scala/com/microsoft/azure/synapse/ml/stages/Lambda.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package com.microsoft.azure.synapse.ml.stages
66
import com.microsoft.azure.synapse.ml.codegen.Wrappable
77
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
88
import com.microsoft.azure.synapse.ml.param.UDFParam
9-
import org.apache.spark.SparkContext
109
import org.apache.spark.injections.UDFUtils
1110
import org.apache.spark.ml.param.ParamMap
1211
import org.apache.spark.ml.util.Identifiable
@@ -59,8 +58,8 @@ class Lambda(val uid: String) extends Transformer with Wrappable with ComplexPar
5958

6059
def transformSchema(schema: StructType): StructType = {
6160
if (get(transformSchemaFunc).isEmpty) {
62-
val sc = SparkContext.getOrCreate()
63-
val df = SparkSession.builder().getOrCreate().createDataFrame(sc.emptyRDD[Row], schema)
61+
val spark = SparkSession.builder().getOrCreate()
62+
val df = spark.createDataFrame(java.util.Collections.emptyList[Row](), schema)
6463
transform(df).schema
6564
} else {
6665
getTransformSchema(schema)

core/src/main/scala/com/microsoft/azure/synapse/ml/stages/Repartition.scala

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
88
import org.apache.spark.ml.Transformer
99
import org.apache.spark.ml.param._
1010
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
11-
import org.apache.spark.rdd.RDD
1211
import org.apache.spark.sql.types._
13-
import org.apache.spark.sql.{DataFrame, Dataset, Row}
12+
import org.apache.spark.sql.{DataFrame, Dataset}
1413

1514
object Repartition extends DefaultParamsReadable[Repartition]
1615

@@ -50,12 +49,8 @@ class Repartition(val uid: String) extends Transformer with Wrappable with Defau
5049
logTransform[DataFrame]({
5150
if (getDisable)
5251
dataset.toDF
53-
else if (getN < dataset.rdd.getNumPartitions)
54-
dataset.coalesce(getN).toDF()
5552
else
56-
dataset.sqlContext.createDataFrame(
57-
dataset.rdd.repartition(getN).asInstanceOf[RDD[Row]],
58-
dataset.schema)
53+
dataset.repartition(getN).toDF()
5954
}, dataset.columns.length)
6055
}
6156

core/src/main/scala/com/microsoft/azure/synapse/ml/stages/StratifiedRepartition.scala

Lines changed: 52 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ package com.microsoft.azure.synapse.ml.stages
66
import com.microsoft.azure.synapse.ml.codegen.Wrappable
77
import com.microsoft.azure.synapse.ml.core.contracts.HasLabelCol
88
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
9-
import org.apache.spark.RangePartitioner
109
import org.apache.spark.ml.Transformer
1110
import org.apache.spark.ml.param._
1211
import org.apache.spark.ml.param.shared.HasSeed
1312
import org.apache.spark.ml.util._
1413
import org.apache.spark.sql.types._
15-
import org.apache.spark.sql.{DataFrame, Dataset}
14+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
15+
import org.apache.spark.sql.expressions.Window
16+
import org.apache.spark.sql.functions.{col, lit, max => sqlMax, rand, row_number, spark_partition_id}
1617

1718
/** Constants for <code>StratifiedRepartition</code>. */
1819
object SPConstants {
@@ -49,32 +50,58 @@ class StratifiedRepartition(val uid: String) extends Transformer with Wrappable
4950
*/
5051
override def transform(dataset: Dataset[_]): DataFrame = {
5152
logTransform[DataFrame]({
52-
// Count unique values in label column
53-
val distinctLabelCounts = dataset.select(getLabelCol).groupBy(getLabelCol).count().collect()
54-
val labelToCount = distinctLabelCounts.map(row => (row.getInt(0), row.getLong(1)))
55-
val labelToFraction =
56-
getMode match {
57-
case SPConstants.Equal => getEqualLabelCount(labelToCount, dataset)
58-
case SPConstants.Mixed =>
59-
val equalLabelToCount = getEqualLabelCount(labelToCount, dataset)
60-
val normalizedRatio = equalLabelToCount.map { case (label, count) => count }.sum / labelToCount.length
61-
labelToCount.map { case (label, count) => (label, count / normalizedRatio) }.toMap
62-
case SPConstants.Original => labelToCount.map { case (label, count) => (label, 1.0) }.toMap
63-
case _ => throw new Exception(s"Unknown mode specified to StratifiedRepartition: $getMode")
64-
}
65-
val labelColIndex = dataset.schema.fieldIndex(getLabelCol)
66-
val spdata = dataset.toDF().rdd.keyBy(row => row.getInt(labelColIndex))
67-
.sampleByKeyExact(true, labelToFraction, getSeed)
68-
.mapPartitions(keyToRow => keyToRow.zipWithIndex.map { case ((key, row), index) => (index, row) })
69-
val rangePartitioner = new RangePartitioner(dataset.rdd.getNumPartitions, spdata)
70-
val rspdata = spdata.partitionBy(rangePartitioner).mapPartitions(keyToRow =>
71-
keyToRow.map { case (key, row) => row }).persist()
72-
dataset.sqlContext.createDataFrame(rspdata, dataset.schema)
53+
val df = dataset.toDF()
54+
val labelToFraction = computeLabelFractions(df)
55+
val sampled = stratifiedSample(df, labelToFraction)
56+
val numPartitions = getNumPartitions(df)
57+
roundRobinRepartition(sampled, numPartitions)
7358
}, dataset.columns.length)
7459
}
7560

76-
private def getEqualLabelCount(labelToCount: Array[(Int, Long)], dataset: Dataset[_]): Map[Int, Double] = {
77-
val maxLabelCount = Math.max(labelToCount.map { case (label, count) => count }.max, dataset.rdd.getNumPartitions)
61+
private def computeLabelFractions(df: DataFrame): Map[Int, Double] = {
62+
val distinctLabelCounts = df.select(getLabelCol).groupBy(getLabelCol).count().collect()
63+
val labelToCount = distinctLabelCounts.map(row => (row.getInt(0), row.getLong(1)))
64+
getMode match {
65+
case SPConstants.Equal => getEqualLabelCount(labelToCount, df)
66+
case SPConstants.Mixed =>
67+
val equalLabelToCount = getEqualLabelCount(labelToCount, df)
68+
val normalizedRatio = equalLabelToCount.map { case (_, count) => count }.sum / labelToCount.length
69+
labelToCount.map { case (label, count) => (label, count / normalizedRatio) }.toMap
70+
case SPConstants.Original => labelToCount.map { case (label, _) => (label, 1.0) }.toMap
71+
case _ => throw new Exception(s"Unknown mode specified to StratifiedRepartition: $getMode")
72+
}
73+
}
74+
75+
private def stratifiedSample(df: DataFrame, labelToFraction: Map[Int, Double]): DataFrame = {
76+
val spark = df.sparkSession
77+
val emptyDF = spark.createDataFrame(java.util.Collections.emptyList[Row](), df.schema)
78+
val labelDFs = labelToFraction.map { case (label, fraction) =>
79+
val labelData = df.filter(col(getLabelCol) === lit(label))
80+
val wholeReplicates = math.floor(fraction).toInt
81+
val fractionalPart = fraction - wholeReplicates
82+
val wholePart = if (wholeReplicates > 0) {
83+
(1 to wholeReplicates).map(_ => labelData).reduce(_ union _)
84+
} else emptyDF
85+
val fracPart = if (fractionalPart > 0) {
86+
labelData.sample(withReplacement = false, fractionalPart, getSeed)
87+
} else emptyDF
88+
wholePart.union(fracPart)
89+
}
90+
labelDFs.reduce(_ union _)
91+
}
92+
93+
private def getNumPartitions(df: DataFrame): Int =
94+
df.select(spark_partition_id().as("_pid")).agg(sqlMax("_pid")).head().getInt(0) + 1
95+
96+
private def roundRobinRepartition(df: DataFrame, numPartitions: Int): DataFrame = {
97+
val windowSpec = Window.partitionBy(col(getLabelCol)).orderBy(rand(getSeed))
98+
val withPartition = df.withColumn("_rr_idx", row_number().over(windowSpec) % lit(numPartitions))
99+
withPartition.repartitionByRange(numPartitions, col("_rr_idx")).drop("_rr_idx")
100+
}
101+
102+
private def getEqualLabelCount(labelToCount: Array[(Int, Long)], df: DataFrame): Map[Int, Double] = {
103+
val numPartitions = getNumPartitions(df)
104+
val maxLabelCount = Math.max(labelToCount.map { case (_, count) => count }.max, numPartitions)
78105
labelToCount.map { case (label, count) => (label, maxLabelCount.toDouble / count) }.toMap
79106
}
80107

core/src/main/scala/com/microsoft/azure/synapse/ml/train/ComputeModelStatistics.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ class ComputeModelStatistics(override val uid: String) extends Transformer
266266
.drop(Array(predictionColumnName, labelColumnName))
267267
}
268268

269+
// Note: MLlib metrics (BinaryClassificationMetrics, MulticlassMetrics, RegressionMetrics)
270+
// require RDD input. These .rdd conversions are necessary until Spark provides
271+
// DataFrame-based equivalents for all metrics.
269272
private def selectAndCastToRDD(dataset: Dataset[_],
270273
predictionColumnName: String,
271274
labelColumnName: String): RDD[(Double, Double)] = {

0 commit comments

Comments
 (0)