Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,16 @@ object CollationTypeCoercion extends SQLConfHelper {
None
}

case elementAt: ElementAt =>
findCollationContext(elementAt.left) match {
case Some(MapType(_, valueType, _)) =>
mergeWinner(elementAt.dataType, valueType)
case Some(ArrayType(elementType, _)) =>
mergeWinner(elementAt.dataType, elementType)
case _ =>
None
}

case struct: CreateNamedStruct =>
val childrenContexts = struct.valExprs.map(findCollationContext)
if (childrenContexts.isEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.parser

import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.util.SchemaUtils

/**
* Handler for parameter substitution across different Spark SQL contexts.
Expand Down Expand Up @@ -107,14 +109,37 @@ object ParameterHandler {
* @param expr The expression to convert (must be a Literal)
* @return SQL string representation
*/
private def convertToSql(expr: Expression): String = expr match {
case lit: Literal => lit.sql
case other =>
throw new IllegalArgumentException(
s"ParameterHandler only accepts resolved Literal expressions. " +
s"Received: ${other.getClass.getSimpleName}. " +
s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " +
s"before being passed to the pre-parser.")
private def convertToSql(expr: Expression): String = {
// Converts an expression to its SQL representation. If the expression's type contains collated
// types, strips collations from nested literals and wraps the whole expression in
// CAST to preserve the collation with implicit strength. Without this, Literal.sql
// produces `'value' COLLATE collationName` which re-parses with explicit strength.
def toSqlWithImplicitCollation(e: Expression): String = {
if (!DataTypeUtils.hasNonDefaultStringCharOrVarcharType(e.dataType)) {
e.sql
} else {
val stripped = e.transform {
case lit: Literal
if DataTypeUtils.hasNonDefaultStringCharOrVarcharType(lit.dataType) =>
Literal.create(
lit.value, SchemaUtils.replaceCollatedStringWithString(lit.dataType))
}
s"CAST(${stripped.sql} AS ${e.dataType.sql})"
}
}

expr match {
case lit: Literal if lit.value == null =>
lit.sql
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The null branch isn't really a bypass: for a typed null like Literal(null, StringType(X)), lit.sql itself emits CAST(NULL AS STRING COLLATE X) (see Literal.sql's case _ if value == null => s"CAST(NULL AS ${dataType.sql})"). That CAST's child is a NullType literal, so the Cast base case in CollationTypeCoercion returns Default strength — which is the intended behavior. For an untyped Literal(null, NullType), lit.sql returns "NULL" → NullType literal → also Default.

Consider adding a one-liner above this case noting why delegating to lit.sql already yields Default strength — it saves the next reader the round-trip through Literal.sql + collationStrengthBaseCases.

case lit: Literal =>
toSqlWithImplicitCollation(lit)
case other =>
throw new IllegalArgumentException(
s"ParameterHandler only accepts resolved Literal expressions. " +
s"Received: ${other.getClass.getSimpleName}. " +
s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " +
s"before being passed to the pre-parser.")
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,17 @@ object DataTypeUtils {
}
}

/**
* Returns true if the given data type contains any STRING/CHAR/VARCHAR with explicit collation
* (including explicit `UTF8_BINARY`), recursively checking nested types.
*/
def hasNonDefaultStringCharOrVarcharType(dataType: DataType): Boolean = {
dataType.existsRecursively {
case st: StringType => !isDefaultStringCharOrVarcharType(st)
case _ => false
}
}

/**
* Recursively replaces all STRING, CHAR and VARCHAR types that do not have an explicit collation
* with the same type but with explicit `UTF8_BINARY` collation.
Expand Down
Loading