Skip to content

Commit cbc6d6b

Browse files
committed
temp
1 parent e42a561 commit cbc6d6b

File tree

4 files changed

+440
-8
lines changed

4 files changed

+440
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,16 @@ object CollationTypeCoercion extends SQLConfHelper {
288288
None
289289
}
290290

291+
case elementAt: ElementAt =>
292+
findCollationContext(elementAt.left) match {
293+
case Some(MapType(_, valueType, _)) =>
294+
mergeWinner(elementAt.dataType, valueType)
295+
case Some(ArrayType(elementType, _)) =>
296+
mergeWinner(elementAt.dataType, elementType)
297+
case _ =>
298+
None
299+
}
300+
291301
case struct: CreateNamedStruct =>
292302
val childrenContexts = struct.valExprs.map(findCollationContext)
293303
if (childrenContexts.isEmpty) {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParameterHandler.scala

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package org.apache.spark.sql.catalyst.parser
1818

1919
import org.apache.spark.sql.catalyst.expressions.{Expression, Literal}
20+
import org.apache.spark.sql.catalyst.types.DataTypeUtils
21+
import org.apache.spark.sql.util.SchemaUtils
2022

2123
/**
2224
* Handler for parameter substitution across different Spark SQL contexts.
@@ -107,14 +109,37 @@ object ParameterHandler {
107109
* @param expr The expression to convert (must be a Literal)
108110
* @return SQL string representation
109111
*/
110-
private def convertToSql(expr: Expression): String = expr match {
111-
case lit: Literal => lit.sql
112-
case other =>
113-
throw new IllegalArgumentException(
114-
s"ParameterHandler only accepts resolved Literal expressions. " +
115-
s"Received: ${other.getClass.getSimpleName}. " +
116-
s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " +
117-
s"before being passed to the pre-parser.")
112+
private def convertToSql(expr: Expression): String = {
113+
// Converts an expression to its SQL representation. If the expression's type contains collated
114+
// types, strips collations from nested literals and wraps the whole expression in
115+
// CAST to preserve the collation with implicit strength. Without this, Literal.sql
116+
// produces `'value' COLLATE collationName` which re-parses with explicit strength.
117+
def toSqlWithImplicitCollation(e: Expression): String = {
118+
if (!DataTypeUtils.hasNonDefaultStringCharOrVarcharType(e.dataType)) {
119+
e.sql
120+
} else {
121+
val stripped = e.transform {
122+
case lit: Literal
123+
if DataTypeUtils.hasNonDefaultStringCharOrVarcharType(lit.dataType) =>
124+
Literal.create(
125+
lit.value, SchemaUtils.replaceCollatedStringWithString(lit.dataType))
126+
}
127+
s"CAST(${stripped.sql} AS ${e.dataType.sql})"
128+
}
129+
}
130+
131+
expr match {
132+
case lit: Literal if lit.value == null =>
133+
lit.sql
134+
case lit: Literal =>
135+
toSqlWithImplicitCollation(lit)
136+
case other =>
137+
throw new IllegalArgumentException(
138+
s"ParameterHandler only accepts resolved Literal expressions. " +
139+
s"Received: ${other.getClass.getSimpleName}. " +
140+
s"All parameters must be resolved using SparkSession.resolveAndValidateParameters " +
141+
s"before being passed to the pre-parser.")
142+
}
118143
}
119144

120145
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,17 @@ object DataTypeUtils {
302302
}
303303
}
304304

305+
/**
306+
* Returns true if the given data type contains any STRING/CHAR/VARCHAR with explicit collation
307+
* (including explicit `UTF8_BINARY`), recursively checking nested types.
308+
*/
309+
def hasNonDefaultStringCharOrVarcharType(dataType: DataType): Boolean = {
310+
dataType.existsRecursively {
311+
case st: StringType => !isDefaultStringCharOrVarcharType(st)
312+
case _ => false
313+
}
314+
}
315+
305316
/**
306317
* Recursively replaces all STRING, CHAR and VARCHAR types that do not have an explicit collation
307318
* with the same type but with explicit `UTF8_BINARY` collation.

0 commit comments

Comments
 (0)