Skip to content

Commit c2fdb05

Browse files
chore: Adding Spark34 support (#2052) (#2116)
* chore: bump to spark 3.4.1 --------- Co-authored-by: Jessica Wang <jessiwang@microsoft.com> Co-authored-by: Scott Votaw <svotaw@gmail.com> Co-authored-by: Brendan Walsh <37676373+BrendanWalsh@users.noreply.github.com> Co-authored-by: JessicaXYWang <108437381+JessicaXYWang@users.noreply.github.com> fixes Co-authored-by: Keerthi Yanda <98137159+KeerthiYandaOS@users.noreply.github.com>
1 parent 903dc6b commit c2fdb05

File tree

46 files changed

+288
-262
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+288
-262
lines changed

build.sbt

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@ import org.apache.commons.io.FileUtils
33
import sbt.ExclusionRule
44

55
import java.io.File
6-
import java.net.URL
76
import scala.xml.transform.{RewriteRule, RuleTransformer}
87
import scala.xml.{Node => XmlNode, NodeSeq => XmlNodeSeq, _}
98

109
val condaEnvName = "synapseml"
11-
val sparkVersion = "3.2.3"
10+
val sparkVersion = "3.4.1"
1211
name := "synapseml"
1312
ThisBuild / organization := "com.microsoft.azure"
14-
ThisBuild / scalaVersion := "2.12.15"
13+
ThisBuild / scalaVersion := "2.12.17"
1514

1615
val scalaMajorVersion = 2.12
1716

@@ -21,23 +20,24 @@ val excludes = Seq(
2120
)
2221

2322
val coreDependencies = Seq(
24-
"org.apache.spark" %% "spark-core" % sparkVersion % "compile",
23+
// Excluding protobuf-java, as spark-core is bringing the older version transitively.
24+
"org.apache.spark" %% "spark-core" % sparkVersion % "compile" exclude("com.google.protobuf", "protobuf-java"),
2525
"org.apache.spark" %% "spark-mllib" % sparkVersion % "compile",
26-
"org.apache.spark" %% "spark-avro" % sparkVersion % "provided",
26+
"org.apache.spark" %% "spark-avro" % sparkVersion % "compile",
2727
"org.apache.spark" %% "spark-tags" % sparkVersion % "test",
2828
"com.globalmentor" % "hadoop-bare-naked-local-fs" % "0.1.0" % "test",
2929
"org.scalatest" %% "scalatest" % "3.2.14" % "test")
3030
val extraDependencies = Seq(
31+
"commons-lang" % "commons-lang" % "2.6",
3132
"org.scalactic" %% "scalactic" % "3.2.14",
3233
"io.spray" %% "spray-json" % "1.3.5",
3334
"com.jcraft" % "jsch" % "0.1.54",
3435
"org.apache.httpcomponents.client5" % "httpclient5" % "5.1.3",
3536
"org.apache.httpcomponents" % "httpmime" % "4.5.13",
36-
"com.linkedin.isolation-forest" %% "isolation-forest_3.2.0" % "2.0.8",
37-
// Although breeze 1.2 is already provided by Spark, this is needed for Azure Synapse Spark 3.2 pools.
38-
// Otherwise a NoSuchMethodError will be thrown by interpretability code. This problem only happens
39-
// to Azure Synapse Spark 3.2 pools.
40-
"org.scalanlp" %% "breeze" % "1.2"
37+
"com.linkedin.isolation-forest" %% "isolation-forest_3.4.1" % "3.0.3",
38+
// Although breeze 2.1.0 is already provided by Spark, this is needed for Azure Synapse Spark 3.4 pools.
39+
// Otherwise a NoSuchMethodError will be thrown by interpretability code.
40+
"org.scalanlp" %% "breeze" % "2.1.0"
4141
).map(d => d excludeAll (excludes: _*))
4242
val dependencies = coreDependencies ++ extraDependencies
4343

@@ -70,7 +70,7 @@ pomPostProcess := pomPostFunc
7070

7171
val getDatasetsTask = TaskKey[Unit]("getDatasets", "download datasets used for testing")
7272
val datasetName = "datasets-2023-04-03.tgz"
73-
val datasetUrl = new URI(s"https://mmlspark.blob.core.windows.net/installers/$datasetName").toURL()
73+
val datasetUrl = new URI(s"https://mmlspark.blob.core.windows.net/installers/$datasetName").toURL
7474
val datasetDir = settingKey[File]("The directory that holds the dataset")
7575
ThisBuild / datasetDir := {
7676
join((Compile / packageBin / artifactPath).value.getParentFile,

core/src/main/scala/com/microsoft/azure/synapse/ml/codegen/PyCodegen.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ object PyCodegen {
7070
// There's `Already borrowed` error found in transformers 4.16.2 when using tokenizers
7171
s"""extras_require={"extras": [
7272
| "cmake",
73-
| "horovod==0.25.0",
73+
| "horovod==0.28.1",
7474
| "pytorch_lightning>=1.5.0,<1.5.10",
75-
| "torch==1.11.0",
76-
| "torchvision>=0.12.0",
77-
| "transformers==4.15.0",
75+
| "torch==1.13.1",
76+
| "torchvision>=0.14.1",
77+
| "transformers==4.32.1",
7878
| "petastorm>=0.12.0",
7979
| "huggingface-hub>=0.8.1",
8080
|]},

core/src/main/scala/com/microsoft/azure/synapse/ml/core/env/PackageUtils.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ object PackageUtils {
1818

1919
val PackageName = s"synapseml_$ScalaVersionSuffix"
2020
val PackageMavenCoordinate = s"$PackageGroup:$PackageName:${BuildInfo.version}"
21-
private val AvroCoordinate = "org.apache.spark:spark-avro_2.12:3.3.1"
21+
private val AvroCoordinate = "org.apache.spark:spark-avro_2.12:3.4.1"
2222
val PackageRepository: String = SparkMLRepository
2323

2424
// If testing onnx package with snapshots repo, make sure to switch to using

core/src/main/scala/com/microsoft/azure/synapse/ml/exploratory/DistributionBalanceMeasure.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
package com.microsoft.azure.synapse.ml.exploratory
55

6-
import breeze.stats.distributions.ChiSquared
6+
import breeze.stats.distributions.{ChiSquared, RandBasis}
77
import com.microsoft.azure.synapse.ml.codegen.Wrappable
88
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
99
import com.microsoft.azure.synapse.ml.logging.{FeatureNames, SynapseMLLogging}
@@ -261,6 +261,7 @@ private[exploratory] case class DistributionMetrics(numFeatures: Int,
261261

262262
// Calculates left-tailed p-value from degrees of freedom and chi-squared test statistic
263263
def chiSquaredPValue: Column = {
264+
implicit val rand: RandBasis = RandBasis.mt0
264265
val degOfFreedom = numFeatures - 1
265266
val scoreCol = chiSquaredTestStatistic
266267
val chiSqPValueUdf = udf(

core/src/main/scala/com/microsoft/azure/synapse/ml/io/binary/BinaryFileFormat.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ private[ml] class HadoopFileReader(file: PartitionedFile,
194194

195195
private val iterator = {
196196
val fileSplit = new FileSplit(
197-
new Path(new URI(file.filePath)),
197+
new Path(new URI(file.filePath.toString())),
198198
file.start,
199199
file.length,
200200
Array.empty)

core/src/main/scala/com/microsoft/azure/synapse/ml/nn/BallTree.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
package com.microsoft.azure.synapse.ml.nn
55

6-
import breeze.linalg.functions.euclideanDistance
76
import breeze.linalg.{DenseVector, norm, _}
87
import com.microsoft.azure.synapse.ml.core.env.StreamUtilities.using
98

core/src/main/scala/org/apache/spark/ml/recommendation/RecommendationHelper.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,17 +199,20 @@ object SparkHelpers {
199199

200200
def flatten(ratings: Dataset[_], num: Int, dstOutputColumn: String, srcOutputColumn: String): DataFrame = {
201201
import ratings.sparkSession.implicits._
202-
203-
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
204-
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
205-
.toDF("id", "recommendations")
202+
import org.apache.spark.sql.functions.{collect_top_k, struct}
206203

207204
val arrayType = ArrayType(
208205
new StructType()
209206
.add(dstOutputColumn, IntegerType)
210-
.add("rating", FloatType)
207+
.add(Constants.RatingCol, FloatType)
211208
)
212-
recs.select(col("id").as(srcOutputColumn), col("recommendations").cast(arrayType))
209+
210+
ratings.toDF(srcOutputColumn, dstOutputColumn, Constants.RatingCol).groupBy(srcOutputColumn)
211+
.agg(collect_top_k(struct(Constants.RatingCol, dstOutputColumn), num, false))
212+
.as[(Int, Seq[(Float, Int)])]
213+
.map(t => (t._1, t._2.map(p => (p._2, p._1))))
214+
.toDF(srcOutputColumn, Constants.Recommendations)
215+
.withColumn(Constants.Recommendations, col(Constants.Recommendations).cast(arrayType))
213216
}
214217
}
215218

core/src/main/scala/org/apache/spark/ml/source/image/PatchedImageFileFormat.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class PatchedImageFileFormat extends ImageFileFormat with Serializable with Logg
9898
Iterator(emptyUnsafeRow)
9999
} else {
100100
val origin = file.filePath
101-
val path = new Path(origin)
101+
val path = new Path(origin.toString())
102102
val fs = path.getFileSystem(broadcastedHadoopConf.value.value)
103103
val stream = fs.open(path)
104104
val bytes = try {
@@ -107,11 +107,12 @@ class PatchedImageFileFormat extends ImageFileFormat with Serializable with Logg
107107
IOUtils.close(stream)
108108
}
109109

110-
val resultOpt = catchFlakiness(5)(ImageSchema.decode(origin, bytes)) //scalastyle:ignore magic.number
110+
val resultOpt = catchFlakiness(5)( //scalastyle:ignore magic.number
111+
ImageSchema.decode(origin.toString(), bytes))
111112
val filteredResult = if (imageSourceOptions.dropInvalid) {
112113
resultOpt.toIterator
113114
} else {
114-
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin)))
115+
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin.toString())))
115116
}
116117

117118
if (requiredSchema.isEmpty) {

core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/HTTPSinkV2.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import org.apache.spark.sql.internal.connector.{SimpleTableProvider, SupportsStr
1515
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister}
1616
import org.apache.spark.sql.types._
1717
import org.apache.spark.sql.util.CaseInsensitiveStringMap
18+
import org.sparkproject.dmg.pmml.False
1819

1920
import java.util
2021
import scala.collection.JavaConverters._
@@ -107,8 +108,12 @@ private[streaming] class HTTPDataWriter(val partitionId: Int,
107108
val replyColIndex: Int,
108109
val name: String)
109110
extends DataWriter[InternalRow] with Logging {
110-
logInfo(s"Creating writer on PID:$partitionId")
111-
HTTPSourceStateHolder.getServer(name).commit(epochId - 1, partitionId)
111+
logDebug(s"Creating writer on parition:$partitionId epoch $epochId")
112+
113+
val server = HTTPSourceStateHolder.getServer(name)
114+
if (server.isContinuous) {
115+
server.commit(epochId - 1, partitionId)
116+
}
112117

113118
private val ids: mutable.ListBuffer[(String, Int)] = new mutable.ListBuffer[(String, Int)]()
114119

core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/HTTPSourceV2.scala

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ class HTTPSourceTable(options: CaseInsensitiveStringMap)
6868
override def readSchema(): StructType = HTTPSourceV2.Schema
6969

7070
override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
71-
logInfo("Creating Microbatch reader")
7271
new HTTPMicroBatchReader(continuous = false, options = options)
7372
}
7473

@@ -136,8 +135,8 @@ private[streaming] object DriverServiceUtils {
136135
host: String,
137136
handler: HttpHandler): HttpServer = {
138137
val port: Int = StreamUtilities.using(new ServerSocket(0))(_.getLocalPort).get
139-
val server = HttpServer.create(new InetSocketAddress(host, port), 100) //scalastyle:ignore magic.number
140-
server.setExecutor(Executors.newFixedThreadPool(100)) //scalastyle:ignore magic.number
138+
val server = HttpServer.create(new InetSocketAddress(host, port), 100) //scalastyle:ignore magic.number
139+
server.setExecutor(Executors.newFixedThreadPool(100)) //scalastyle:ignore magic.number
141140
server.createContext(s"/$path", handler)
142141
server.start()
143142
server
@@ -208,10 +207,10 @@ private[streaming] class HTTPMicroBatchReader(continuous: Boolean, options: Case
208207

209208
val numPartitions: Int = options.getInt(HTTPSourceV2.NumPartitions, 2)
210209
val host: String = options.get(HTTPSourceV2.Host, "localhost")
211-
val port: Int = options.getInt(HTTPSourceV2.Port, 8888) //scalastyle:ignore magic.number
210+
val port: Int = options.getInt(HTTPSourceV2.Port, 8888) //scalastyle:ignore magic.number
212211
val path: String = options.get(HTTPSourceV2.Path)
213212
val name: String = options.get(HTTPSourceV2.NAME)
214-
val epochLength: Long = options.getLong(HTTPSourceV2.EpochLength, 30000) //scalastyle:ignore magic.number
213+
val epochLength: Long = options.getLong(HTTPSourceV2.EpochLength, 30000) //scalastyle:ignore magic.number
215214

216215
val forwardingOptions: collection.Map[String, String] = options.asCaseSensitiveMap().asScala
217216
.filter { case (k, _) => k.startsWith("forwarding") }
@@ -270,8 +269,9 @@ private[streaming] class HTTPMicroBatchReader(continuous: Boolean, options: Case
270269

271270
val config = WorkerServiceConfig(host, port, path, forwardingOptions,
272271
DriverServiceUtils.getDriverHost, driverService.getAddress.getPort, epochLength)
272+
273273
Range(0, numPartitions).map { i =>
274-
HTTPInputPartition(continuous, name, config, startMap(i), endMap.map(_ (i)), i)
274+
HTTPInputPartition(continuous, name, config, startMap(i), endMap.map(_(i)), i)
275275
: InputPartition
276276
}.toArray
277277
}
@@ -318,7 +318,7 @@ private[streaming] class HTTPContinuousReader(options: CaseInsensitiveStringMap)
318318
}
319319

320320
override def planInputPartitions(start: Offset): Array[InputPartition] =
321-
planInputPartitions(start, null) //scalastyle:ignore null
321+
planInputPartitions(start, null) //scalastyle:ignore null
322322

323323
override def createContinuousReaderFactory(): ContinuousPartitionReaderFactory = {
324324
HTTPSourceReaderFactory
@@ -332,7 +332,13 @@ private[streaming] case class HTTPInputPartition(continuous: Boolean,
332332
endValue: Option[Long],
333333
partitionIndex: Int
334334
)
335-
extends InputPartition
335+
extends InputPartition {
336+
if (!HTTPSourceStateHolder.hasServer(name)) {
337+
val client = HTTPSourceStateHolder.getOrCreateClient(name)
338+
HTTPSourceStateHolder.getOrCreateServer(name, startValue - 1, partitionIndex, continuous, client, config)
339+
}
340+
341+
}
336342

337343
object HTTPSourceStateHolder {
338344

@@ -381,6 +387,10 @@ object HTTPSourceStateHolder {
381387
HTTPSourceStateHolder.Servers(name)
382388
}
383389

390+
private[streaming] def hasServer(name: String): Boolean = {
391+
HTTPSourceStateHolder.Servers.contains(name)
392+
}
393+
384394
private[streaming] def getOrCreateServer(name: String,
385395
epoch: Long,
386396
partitionId: Int,
@@ -487,10 +497,10 @@ private[streaming] class WorkerServer(val name: String,
487497

488498
def registerPartition(localEpoch: Epoch, partitionId: PID): Unit = synchronized {
489499
if (!registeredPartitions.contains(partitionId)) {
490-
logInfo(s"registering $partitionId localEpoch:$localEpoch globalEpoch:$epoch")
500+
logDebug(s"registering $partitionId localEpoch:$localEpoch globalEpoch:$epoch")
491501
registeredPartitions.update(partitionId, localEpoch)
492502
} else {
493-
logInfo(s"re-registering $partitionId localEpoch:$localEpoch globalEpoch:$epoch")
503+
logDebug(s"re-registering $partitionId localEpoch:$localEpoch globalEpoch:$epoch")
494504
val previousEpoch = registeredPartitions(partitionId)
495505
registeredPartitions.update(partitionId, localEpoch)
496506
//there has been a failed partition and we need to rehydrate the queue
@@ -514,14 +524,16 @@ private[streaming] class WorkerServer(val name: String,
514524
@GuardedBy("this")
515525
private val historyQueues = new mutable.HashMap[(Epoch, PID), mutable.ListBuffer[CachedRequest]]
516526

527+
@GuardedBy("this")
517528
private[streaming] val recoveredPartitions = new mutable.HashMap[(Epoch, PID), LinkedBlockingQueue[CachedRequest]]
518529

519530
private class PublicHandler extends HttpHandler {
520531
override def handle(request: HttpExchange): Unit = {
521-
logDebug(s"handling epoch: $epoch")
532+
logDebug(s"handling request epoch: $epoch")
522533
val uuid = UUID.randomUUID().toString
523534
val cReq = new CachedRequest(request, uuid)
524535
requestQueues(epoch).put(cReq)
536+
logDebug(s"handled request epoch: $epoch")
525537
}
526538
}
527539

@@ -540,6 +552,7 @@ private[streaming] class WorkerServer(val name: String,
540552
None
541553
}
542554
.foreach { request =>
555+
logDebug(s"Replying to request")
543556
HTTPServerUtils.respond(request.e, data)
544557
request.e.close()
545558
routingTable.remove(id)
@@ -582,7 +595,7 @@ private[streaming] class WorkerServer(val name: String,
582595
}
583596
try {
584597
val server = HttpServer.create(new InetSocketAddress(InetAddress.getByName(host), startingPort),
585-
100) //scalastyle:ignore magic.number
598+
100) //scalastyle:ignore magic.number
586599
(server, startingPort)
587600
} catch {
588601
case _: java.net.BindException =>
@@ -624,22 +637,24 @@ private[streaming] class WorkerServer(val name: String,
624637
}
625638

626639
timeout.map {
627-
case Left(0L) => Option(queue.poll())
628-
case Right(t) =>
629-
Option(queue.poll(t, TimeUnit.MILLISECONDS)).orElse {
630-
synchronized {
631-
//If the queue times out then we move to the next epoch
632-
epoch += 1
633-
val lbq = new LinkedBlockingQueue[CachedRequest]()
634-
requestQueues.update(epoch, lbq)
635-
epochStart = System.currentTimeMillis()
640+
case Left(0L) => Option(queue.poll())
641+
case Right(t) =>
642+
val polled = queue.poll(t, TimeUnit.MILLISECONDS)
643+
Option(polled).orElse {
644+
synchronized {
645+
//If the queue times out then we move to the next epoch
646+
epoch += 1
647+
val lbq = new LinkedBlockingQueue[CachedRequest]()
648+
requestQueues.update(epoch, lbq)
649+
epochStart = System.currentTimeMillis()
650+
}
636651
None
637652
}
638-
}
639-
case _ => throw new IllegalArgumentException("Should not hit this path")
640-
}
641-
.orElse(Some(Some(queue.take())))
642-
.flatten
653+
654+
case _ => throw new IllegalArgumentException("Should not hit this path")
655+
}
656+
.orElse(Some(Some(queue.take())))
657+
.flatten
643658
}
644659
}
645660

@@ -650,7 +665,8 @@ private[streaming] class WorkerServer(val name: String,
650665
if (TaskContext.get().attemptNumber() == 0) {
651666
// If the request has never been materialized add it to the cache, otherwise we are in a retry and
652667
// should not modify the history
653-
historyQueues.getOrElseUpdate((localEpoch, partitionIndex), new ListBuffer[CachedRequest]())
668+
historyQueues
669+
.getOrElseUpdate((localEpoch, partitionIndex), new ListBuffer[CachedRequest]())
654670
.append(request)
655671
}
656672
InternalRow(
@@ -702,7 +718,6 @@ private[streaming] class HTTPInputPartitionReader(continuous: Boolean,
702718
val endEpoch: Option[Long],
703719
val partitionIndex: Int)
704720
extends ContinuousPartitionReader[InternalRow] with Logging {
705-
706721
val client: WorkerClient = HTTPSourceStateHolder.getOrCreateClient(name)
707722
val server: WorkerServer = HTTPSourceStateHolder.getOrCreateServer(
708723
name, startEpoch, partitionIndex, continuous, client, config)

0 commit comments

Comments
 (0)