diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala index 8028970193acd..e6f36cf9f7084 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala @@ -474,7 +474,12 @@ trait SQLQueryTestHelper extends SQLConfHelper with Logging { protected def getSparkSettings(comments: Array[String]): Array[(String, String)] = { val settingLines = comments.filter(_.startsWith("--SET ")).map(_.substring(6)) - settingLines.flatMap(_.split(",").map { kv => + // Split on commas that are followed by what looks like a new `key=`. This preserves + // commas inside config values such as + // --SET spark.sql.optimizer.excludedRules=Rule1,Rule2 + // while still supporting the documented multi-setting form + // --SET key1=v1,key2=v2 + settingLines.flatMap(_.split(",(?=[\\w.]+=)").map { kv => val (conf, value) = kv.span(_ != '=') conf.trim -> value.substring(1).trim }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelperSuite.scala new file mode 100644 index 0000000000000..f6642cd9a5c65 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelperSuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite + +class SQLQueryTestHelperSuite extends SparkFunSuite with SQLQueryTestHelper { + + test("getSparkSettings: single key=value") { + val result = getSparkSettings(Array("--SET spark.sql.foo=1")) + assert(result.toSeq === Seq("spark.sql.foo" -> "1")) + } + + test("getSparkSettings: multiple key=value pairs in one --SET (documented form)") { + val result = getSparkSettings(Array("--SET spark.sql.foo=1,spark.sql.bar=2")) + assert(result.toSeq === Seq("spark.sql.foo" -> "1", "spark.sql.bar" -> "2")) + } + + test("getSparkSettings: multiple --SET statements") { + val result = getSparkSettings( + Array("--SET spark.sql.foo=1", "--SET spark.sql.bar=2")) + assert(result.toSeq === Seq("spark.sql.foo" -> "1", "spark.sql.bar" -> "2")) + } + + test("getSparkSettings: value containing commas (e.g. excludedRules list)") { + val excludedRules = + "org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation," + + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" + val result = getSparkSettings( + Array(s"--SET spark.sql.optimizer.excludedRules=$excludedRules")) + assert(result.toSeq === Seq("spark.sql.optimizer.excludedRules" -> excludedRules)) + } + + test("getSparkSettings: mixed -- multiple settings where one value contains commas") { + val excludedRules = + "org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation," + + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" + val result = getSparkSettings( + Array(s"--SET spark.sql.optimizer.excludedRules=$excludedRules,spark.sql.foo=1")) + assert(result.toSeq === Seq( + "spark.sql.optimizer.excludedRules" -> excludedRules, + "spark.sql.foo" -> "1")) + } + + test("getSparkSettings: ignores non --SET comments") { + val result = getSparkSettings( + Array("-- a comment", "--SET spark.sql.foo=1", "-- another")) + assert(result.toSeq === Seq("spark.sql.foo" -> "1")) + } +}