@@ -22,10 +22,11 @@ package org.apache.comet.serde
2222import scala .jdk .CollectionConverters ._
2323
2424import org .apache .spark .sql .catalyst .expressions ._
25- import org .apache .spark .sql .types .{ArrayType , MapType }
25+ import org .apache .spark .sql .types .{ArrayType , DataType , DataTypes , DecimalType , MapType , StructType }
26+ import org .apache .spark .sql .types .DataTypes .{BinaryType , BooleanType , ByteType , DateType , DoubleType , FloatType , IntegerType , LongType , ShortType , StringType , TimestampNTZType , TimestampType }
2627
2728import org .apache .comet .CometSparkSessionExtensions .withInfo
28- import org .apache .comet .serde .QueryPlanSerde .{exprToProtoInternal , optExprWithInfo , scalarFunctionExprToProto , scalarFunctionExprToProtoWithReturnType }
29+ import org .apache .comet .serde .QueryPlanSerde .{exprToProtoInternal , optExprWithInfo , scalarFunctionExprToProto , scalarFunctionExprToProtoWithReturnType , serializeDataType }
2930
3031object CometMapKeys extends CometExpressionSerde [MapKeys ] {
3132
@@ -94,26 +95,38 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] {
9495}
9596
9697object CometCreateMap extends CometExpressionSerde [CreateMap ] {
98+
99+ override def getSupportLevel (expr : CreateMap ): SupportLevel = {
100+ Compatible (None )
101+ }
102+
97103 override def convert (
98104 expr : CreateMap ,
99105 inputs : Seq [Attribute ],
100106 binding : Boolean ): Option [ExprOuterClass .Expr ] = {
101- val keysProtoExpr = expr.keys.map(exprToProtoInternal(_, inputs, binding))
102- val valuesProtoExpr = expr.values.map(exprToProtoInternal(_, inputs, binding))
103- if (keysProtoExpr.forall(_.isDefined) && valuesProtoExpr.forall(_.isDefined)) {
104- val createMapProtoExpr = ExprOuterClass .CreateMap
105- .newBuilder()
106- .addAllValues(keysProtoExpr.map(_.get).asJava)
107- .addAllValues(valuesProtoExpr.map(_.get).asJava)
108- .build()
109- Some (
110- ExprOuterClass .Expr
111- .newBuilder()
112- .setCreateMap(createMapProtoExpr)
113- .build())
114- } else {
115- withInfo(expr, expr.children: _* )
116- None
107+ val keysArray = CreateArray (expr.keys)
108+ val valuesArray = CreateArray (expr.values)
109+ val keysExprProto = exprToProtoInternal(keysArray, inputs, binding)
110+ val valuesExprProto = exprToProtoInternal(valuesArray, inputs, binding)
111+ val createMapExprProto =
112+ scalarFunctionExprToProtoWithReturnType(
113+ " map_from_arrays" ,
114+ expr.dataType,
115+ false ,
116+ keysExprProto,
117+ valuesExprProto)
118+ optExprWithInfo(createMapExprProto, expr, expr.children: _* )
119+ }
120+ }
121+
122+ sealed trait MapBase {
123+
124+ def containsBinary (dataType : DataType ): Boolean = {
125+ dataType match {
126+ case BinaryType => true
127+ case StructType (fields) => fields.exists(field => containsBinary(field.dataType))
128+ case ArrayType (elementType, _) => containsBinary(elementType)
129+ case _ => false
117130 }
118131 }
119132}
0 commit comments