diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala index 75619c9c5ce39..8a83b576d7253 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala @@ -288,6 +288,16 @@ object CollationTypeCoercion extends SQLConfHelper { None } + case elementAt: ElementAt => + findCollationContext(elementAt.left) match { + case Some(MapType(_, valueType, _)) => + mergeWinner(elementAt.dataType, valueType) + case Some(ArrayType(elementType, _)) => + mergeWinner(elementAt.dataType, elementType) + case _ => + None + } + case struct: CreateNamedStruct => val childrenContexts = struct.valExprs.map(findCollationContext) if (childrenContexts.isEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParameterHandler.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParameterHandler.scala index 715dccdc10737..2a55b1474d9b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParameterHandler.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParameterHandler.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.util.SchemaUtils /** * Handler for parameter substitution across different Spark SQL contexts. @@ -107,14 +109,37 @@ object ParameterHandler { * @param expr The expression to convert (must be a Literal) * @return SQL string representation */ - private def convertToSql(expr: Expression): String = expr match { - case lit: Literal => lit.sql - case other => - throw new IllegalArgumentException( - s"ParameterHandler only accepts resolved Literal expressions. " + - s"Received: ${other.getClass.getSimpleName}. " + - s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " + - s"before being passed to the pre-parser.") + private def convertToSql(expr: Expression): String = { + // Converts an expression to its SQL representation. If the expression's type contains collated + // types, strips collations from nested literals and wraps the whole expression in + // CAST to preserve the collation with implicit strength. Without this, Literal.sql + // produces `'value' COLLATE collationName` which re-parses with explicit strength. + def toSqlWithImplicitCollation(e: Expression): String = { + if (!DataTypeUtils.hasNonDefaultStringCharOrVarcharType(e.dataType)) { + e.sql + } else { + val stripped = e.transform { + case lit: Literal + if DataTypeUtils.hasNonDefaultStringCharOrVarcharType(lit.dataType) => + Literal.create( + lit.value, SchemaUtils.replaceCollatedStringWithString(lit.dataType)) + } + s"CAST(${stripped.sql} AS ${e.dataType.sql})" + } + } + + expr match { + case lit: Literal if lit.value == null => + lit.sql + case lit: Literal => + toSqlWithImplicitCollation(lit) + case other => + throw new IllegalArgumentException( + s"ParameterHandler only accepts resolved Literal expressions. " + + s"Received: ${other.getClass.getSimpleName}. " + + s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " + + s"before being passed to the pre-parser.") + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala index d2d9f8e446263..f7ab8b06baf5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala @@ -302,6 +302,17 @@ object DataTypeUtils { } } + /** + * Returns true if the given data type contains any STRING/CHAR/VARCHAR with explicit collation + * (including explicit `UTF8_BINARY`), recursively checking nested types. + */ + def hasNonDefaultStringCharOrVarcharType(dataType: DataType): Boolean = { + dataType.existsRecursively { + case st: StringType => !isDefaultStringCharOrVarcharType(st) + case _ => false + } + } + /** * Recursively replaces all STRING, CHAR and VARCHAR types that do not have an explicit collation * with the same type but with explicit `UTF8_BINARY` collation. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala index 1e9dcdf5854b4..bc04d23436a69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala @@ -47,6 +47,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { private val collationNonPreservingSources = Seq("orc", "csv", "json", "text") private val allFileBasedDataSources = collationPreservingSources ++ collationNonPreservingSources private val fullyQualifiedPrefix = s"${CollationFactory.CATALOG}.${CollationFactory.SCHEMA}." + private val collations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI") @inline private def isSortMergeForced: Boolean = { @@ -2724,4 +2725,395 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + + test("execute immediate parameter with explicit COLLATE has implicit strength") { + collations.foreach { collation => + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE + | 'SELECT COLLATION(? || "world" COLLATE $collation)' + | USING 'hello' COLLATE UNICODE""".stripMargin), + Row(s"$fullyQualifiedPrefix$collation")) + + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE + | 'SELECT COLLATION(? || "world" COLLATE $collation)' + | USING 'hello' COLLATE UTF8_LCASE""".stripMargin), + Row(s"$fullyQualifiedPrefix$collation")) + + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE 'SELECT ? = "HELLO" COLLATE $collation' + | USING 'hello' COLLATE UTF8_LCASE""".stripMargin), + Row(collation == "UTF8_LCASE" || collation == "UNICODE_CI")) + } + } + + test("execute immediate parameter without explicit COLLATE") { + checkAnswer( + sql( + """EXECUTE IMMEDIATE 'SELECT COLLATION(? || "world")' + | USING 'hello'""".stripMargin), + Row(s"${fullyQualifiedPrefix}UTF8_BINARY")) + + collations.foreach { collation => + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE + | 'SELECT COLLATION(? || "world" COLLATE $collation)' + | USING 'hello'""".stripMargin), + Row(s"$fullyQualifiedPrefix$collation")) + } + } + + test("execute immediate parameter implicit vs column collation") { + withTable("t") { + sql( + """CREATE TABLE t ( + | lcase_col STRING COLLATE UTF8_LCASE, + | unicode_col STRING COLLATE UNICODE + |) USING parquet""".stripMargin) + sql("INSERT INTO t VALUES ('hello', 'hello')") + + checkAnswer( + sql( + """EXECUTE IMMEDIATE + | 'SELECT ? = lcase_col FROM t' + | USING 'hello' COLLATE UTF8_LCASE""".stripMargin), + Row(true)) + + checkError( + exception = intercept[AnalysisException] { + sql( + """EXECUTE IMMEDIATE + | 'SELECT ? = unicode_col FROM t' + | USING 'hello' COLLATE UTF8_LCASE""".stripMargin) + }, + condition = "INDETERMINATE_COLLATION_IN_EXPRESSION", + parameters = Map("expr" -> + "\"(CAST(hello AS STRING COLLATE UTF8_LCASE) = unicode_col)\""), + queryContext = Array( + ExpectedContext("EXECUTE IMMEDIATE", "", 7, 21, "? = unicode_col"))) + } + } + + test("execute immediate complex type parameter collation and strength") { + withTable("t") { + sql( + """CREATE TABLE t ( + | lcase_col STRING COLLATE UTF8_LCASE, + | unicode_col STRING COLLATE UNICODE + |) USING parquet""".stripMargin) + sql("INSERT INTO t VALUES ('hello', 'hello')") + + checkAnswer( + sql( + """EXECUTE IMMEDIATE + | 'SELECT ?[0] = lcase_col FROM t' + | USING ARRAY('hello' COLLATE UTF8_LCASE)""".stripMargin), + Row(true)) + + checkError( + exception = intercept[AnalysisException] { + sql( + """EXECUTE IMMEDIATE + | 'SELECT ?[0] = unicode_col FROM t' + | USING ARRAY('hello' COLLATE UTF8_LCASE)""".stripMargin) + }, + condition = "INDETERMINATE_COLLATION_IN_EXPRESSION", + parameters = Map("expr" -> + "\"(array(hello)[0] = unicode_col)\""), + queryContext = Array( + ExpectedContext("EXECUTE IMMEDIATE", "", 7, 24, "?[0] = unicode_col"))) + + checkAnswer( + sql( + """EXECUTE IMMEDIATE + | 'SELECT element_at(?, 1) = lcase_col FROM t' + | USING ARRAY('hello' COLLATE UTF8_LCASE)""".stripMargin), + Row(true)) + + checkError( + exception = intercept[AnalysisException] { + sql( + """EXECUTE IMMEDIATE + | 'SELECT element_at(?, 1) = unicode_col FROM t' + | USING ARRAY('hello' COLLATE UTF8_LCASE)""".stripMargin) + }, + condition = "INDETERMINATE_COLLATION_IN_EXPRESSION", + parameters = Map("expr" -> + "\"(element_at(array(hello), 1) = unicode_col)\""), + queryContext = Array( + ExpectedContext("EXECUTE IMMEDIATE", "", 7, 36, + "element_at(?, 1) = unicode_col"))) + + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE + | 'SELECT element_at(?, "key") = lcase_col FROM t' + | USING MAP('key', 'hello' COLLATE UTF8_LCASE)""".stripMargin), + Row(true)) + + checkError( + exception = intercept[AnalysisException] { + sql( + s"""EXECUTE IMMEDIATE + | 'SELECT element_at(?, "key") = unicode_col FROM t' + | USING MAP('key', 'hello' COLLATE UTF8_LCASE)""".stripMargin) + }, + condition = "INDETERMINATE_COLLATION_IN_EXPRESSION", + parameters = Map("expr" -> + "\"(element_at(map(key, hello), key) = unicode_col)\""), + queryContext = Array( + ExpectedContext( + "EXECUTE IMMEDIATE", "", 7, 40, + """element_at(?, "key") = unicode_col"""))) + + checkAnswer( + sql( + """EXECUTE IMMEDIATE + | 'SELECT ?.f1 = lcase_col FROM t' + | USING NAMED_STRUCT('f1', 'hello' COLLATE UTF8_LCASE)""".stripMargin), + Row(true)) + + checkError( + exception = intercept[AnalysisException] { + sql( + """EXECUTE IMMEDIATE + | 'SELECT ?.f1 = unicode_col FROM t' + | USING NAMED_STRUCT('f1', 'hello' COLLATE UTF8_LCASE)""".stripMargin) + }, + condition = "INDETERMINATE_COLLATION_IN_EXPRESSION", + parameters = Map("expr" -> + "\"(named_struct(f1, hello).f1 = unicode_col)\""), + queryContext = Array( + ExpectedContext("EXECUTE IMMEDIATE", "", 7, 24, "?.f1 = unicode_col"))) + } + } + + test("execute immediate complex type parameter with explicit COLLATE") { + collations.foreach { collation => + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE 'SELECT COLLATION(?[0])' + | USING ARRAY('hello' COLLATE $collation)""".stripMargin), + Row(s"$fullyQualifiedPrefix$collation")) + + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE 'SELECT COLLATION(element_at(?, "value"))' + | USING MAP('value', 'hello' COLLATE $collation)""".stripMargin), + Row(s"$fullyQualifiedPrefix$collation")) + + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE 'SELECT COLLATION(?.f1)' + | USING NAMED_STRUCT('f1', 'hello' COLLATE $collation)""".stripMargin), + Row(s"$fullyQualifiedPrefix$collation")) + } + } + + test("execute immediate variable parameter preserves collation") { + collations.foreach { collation => + withSessionVariable("v1") { + sql(s"DECLARE VARIABLE v1 STRING COLLATE $collation DEFAULT 'hello'") + checkAnswer( + sql("EXECUTE IMMEDIATE 'SELECT COLLATION(?)' USING v1"), + Row(s"$fullyQualifiedPrefix$collation")) + } + } + } + + test("execute immediate variable parameter has implicit strength") { + collations.foreach { collation => + withSessionVariable("v1") { + sql("DECLARE VARIABLE v1 STRING COLLATE UTF8_LCASE DEFAULT 'hello'") + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE 'SELECT ? = "HELLO" COLLATE $collation' + | USING v1""".stripMargin), + Row(collation == "UTF8_LCASE" || collation == "UNICODE_CI")) + } + } + } + + test("execute immediate variable parameter implicit vs column collation") { + withTable("t") { + sql( + """CREATE TABLE t ( + | lcase_col STRING COLLATE UTF8_LCASE, + | unicode_col STRING COLLATE UNICODE + |) USING parquet""".stripMargin) + sql("INSERT INTO t VALUES ('hello', 'hello')") + + withSessionVariable("v1") { + sql("DECLARE VARIABLE v1 STRING COLLATE UTF8_LCASE DEFAULT 'hello'") + checkAnswer( + sql("EXECUTE IMMEDIATE 'SELECT ? = lcase_col FROM t' USING v1"), + Row(true)) + + checkError( + exception = intercept[AnalysisException] { + sql("EXECUTE IMMEDIATE 'SELECT ? = unicode_col FROM t' USING v1") + }, + condition = "INDETERMINATE_COLLATION_IN_EXPRESSION", + parameters = Map("expr" -> + "\"(CAST(hello AS STRING COLLATE UTF8_LCASE) = unicode_col)\""), + queryContext = Array( + ExpectedContext("EXECUTE IMMEDIATE", "", 7, 21, "? = unicode_col"))) + } + } + } + + test("execute immediate two parameters with different collations") { + checkError( + exception = intercept[AnalysisException] { + sql( + """EXECUTE IMMEDIATE 'SELECT ? = ?' + | USING 'hello' COLLATE UTF8_LCASE, 'hello' COLLATE UNICODE""".stripMargin) + }, + condition = "INDETERMINATE_COLLATION_IN_EXPRESSION", + parameters = Map("expr" -> + "\"(CAST(hello AS STRING COLLATE UTF8_LCASE) = CAST(hello AS STRING COLLATE UNICODE))\""), + queryContext = Array( + ExpectedContext("EXECUTE IMMEDIATE", "", 7, 11, "? = ?"))) + + withSessionVariable("v1", "v2") { + sql("DECLARE VARIABLE v1 STRING COLLATE UTF8_LCASE DEFAULT 'hello'") + sql("DECLARE VARIABLE v2 STRING COLLATE UNICODE DEFAULT 'hello'") + + checkError( + exception = intercept[AnalysisException] { + sql("EXECUTE IMMEDIATE 'SELECT ? = ?' USING v1, v2") + }, + condition = "INDETERMINATE_COLLATION_IN_EXPRESSION", + parameters = Map("expr" -> + "\"(CAST(hello AS STRING COLLATE UTF8_LCASE) = CAST(hello AS STRING COLLATE UNICODE))\""), + queryContext = Array( + ExpectedContext("EXECUTE IMMEDIATE", "", 7, 11, "? = ?"))) + } + + withSessionVariable("v1") { + sql("DECLARE VARIABLE v1 STRING COLLATE UNICODE DEFAULT 'hello'") + + checkError( + exception = intercept[AnalysisException] { + sql( + """EXECUTE IMMEDIATE 'SELECT ? = ?' + | USING v1, 'hello' COLLATE UTF8_LCASE""".stripMargin) + }, + condition = "INDETERMINATE_COLLATION_IN_EXPRESSION", + parameters = Map("expr" -> + "\"(CAST(hello AS STRING COLLATE UNICODE) = CAST(hello AS STRING COLLATE UTF8_LCASE))\""), + queryContext = Array( + ExpectedContext("EXECUTE IMMEDIATE", "", 7, 11, "? = ?"))) + } + } + + test("execute immediate null parameter with collation") { + checkAnswer( + sql( + """EXECUTE IMMEDIATE 'SELECT COLLATION(COALESCE(?, "hello"))' + | USING NULL""".stripMargin), + Row(s"${fullyQualifiedPrefix}UTF8_BINARY")) + + checkAnswer( + sql( + """EXECUTE IMMEDIATE 'SELECT COALESCE(?, "hello") = "hello"' + | USING NULL""".stripMargin), + Row(true)) + + withSessionVariable("v1") { + sql("DECLARE VARIABLE v1 STRING COLLATE UTF8_LCASE") + checkAnswer( + sql("EXECUTE IMMEDIATE 'SELECT ?, COLLATION(?)' USING v1, v1"), + Row(null, s"${fullyQualifiedPrefix}UTF8_LCASE")) + } + + checkAnswer( + sql( + """EXECUTE IMMEDIATE 'SELECT COLLATION(COALESCE(?, "hello"))' + | USING CAST(NULL AS STRING COLLATE UNICODE)""".stripMargin), + Row("null")) + + withTable("t") { + sql( + """CREATE TABLE t ( + | lcase_col STRING COLLATE UTF8_LCASE, + | unicode_col STRING COLLATE UNICODE + |) USING parquet""".stripMargin) + sql("INSERT INTO t VALUES ('hello', 'hello')") + + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE + | 'SELECT COALESCE(?, lcase_col) FROM t' + | USING CAST(NULL AS STRING COLLATE UTF8_LCASE)""".stripMargin), + Row("hello")) + + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE + | 'SELECT COALESCE(?, unicode_col) FROM t' + | USING CAST(NULL AS STRING COLLATE UTF8_LCASE)""".stripMargin), + Row("hello")) + + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE + | 'SELECT COALESCE(?, unicode_col) FROM t' + | USING NULL""".stripMargin), + Row("hello")) + } + } + + test("execute immediate named parameter collation strength") { + collations.foreach { collation => + checkAnswer( + sql( + s"""EXECUTE IMMEDIATE + | 'SELECT COLLATION(:p || "world" COLLATE $collation)' + | USING 'hello' COLLATE UNICODE AS p""".stripMargin), + Row(s"$fullyQualifiedPrefix$collation")) + } + + checkAnswer( + sql( + """EXECUTE IMMEDIATE + | 'SELECT :p = "HELLO" COLLATE UTF8_LCASE' + | USING 'hello' COLLATE UNICODE AS p""".stripMargin), + Row(true)) + } + + test("parameterized query vs column collation") { + withTable("t") { + sql( + """CREATE TABLE t ( + | binary_col STRING, + | unicode_col STRING COLLATE UNICODE + |) USING parquet""".stripMargin) + sql("INSERT INTO t VALUES ('hello', 'hello')") + + checkAnswer( + spark.sql( + "SELECT :p = binary_col FROM t", + Map("p" -> "hello")), + Row(true)) + + checkAnswer( + spark.sql( + "SELECT :p = unicode_col FROM t", + Map("p" -> "hello")), + Row(true)) + } + } + + test("parameterized query collation strength") { + checkAnswer( + spark.sql( + "SELECT :p = 'HELLO', :p = 'HELLO' COLLATE UTF8_LCASE", + Map("p" -> "hello")), + Row(false, true)) + } }