Skip to content

Commit c0d5c91

Browse files
author
Kazantsev Maksim
committed
work
1 parent a3136ed commit c0d5c91

4 files changed

Lines changed: 59 additions & 27 deletions

File tree

native/core/src/execution/jni_api.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ use std::collections::HashMap;
6767
use std::path::PathBuf;
6868
use std::time::{Duration, Instant};
6969
use std::{sync::Arc, task::Poll};
70+
use datafusion_spark::function::map::map_from_arrays::MapFromArrays;
7071
use tokio::runtime::Runtime;
7172

7273
use crate::execution::memory_pools::{
@@ -339,6 +340,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) {
339340
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default()));
340341
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default()));
341342
session_ctx.register_udf(ScalarUDF::new_from_impl(SparkHex::default()));
343+
session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromArrays::default()));
342344
}
343345

344346
/// Prepares arrow arrays for output.

native/core/src/execution/planner.rs

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,6 @@ use num::{BigInt, ToPrimitive};
134134
use object_store::path::Path;
135135
use std::cmp::max;
136136
use std::{collections::HashMap, sync::Arc};
137-
use datafusion_functions_nested::map::map;
138137
use url::Url;
139138

140139
// For clippy error on type_complexity.
@@ -677,14 +676,6 @@ impl PhysicalPlanner {
677676
ExprStruct::MonotonicallyIncreasingId(_) => Ok(Arc::new(
678677
MonotonicallyIncreasingId::from_partition_id(self.partition),
679678
)),
680-
ExprStruct::CreateMap(expr) => {
681-
let keys = expr.keys.iter().map(|expr| self.create_expr(expr, Arc::clone(&input_schema)))
682-
.collect::<Vec<_>>();
683-
let values = expr.values.iter().map(|expr| self.create_expr(expr, Arc::clone(&input_schema)))
684-
.collect::<Result<Vec<_>, _>>()?;
685-
let create_map = map(keys, values);
686-
Ok(Arc::new(create_map))
687-
},
688679
expr => Err(GeneralError(format!("Not implemented: {expr:?}"))),
689680
}
690681
}

spark/src/main/scala/org/apache/comet/serde/maps.scala

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@ package org.apache.comet.serde
2222
import scala.jdk.CollectionConverters._
2323

2424
import org.apache.spark.sql.catalyst.expressions._
25-
import org.apache.spark.sql.types.{ArrayType, MapType}
25+
import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, MapType, StructType}
26+
import org.apache.spark.sql.types.DataTypes.{BinaryType, BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType, TimestampType}
2627

2728
import org.apache.comet.CometSparkSessionExtensions.withInfo
28-
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType}
29+
import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType}
2930

3031
object CometMapKeys extends CometExpressionSerde[MapKeys] {
3132

@@ -94,26 +95,38 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] {
9495
}
9596

9697
object CometCreateMap extends CometExpressionSerde[CreateMap] {
98+
99+
override def getSupportLevel(expr: CreateMap): SupportLevel = {
100+
Compatible(None)
101+
}
102+
97103
override def convert(
98104
expr: CreateMap,
99105
inputs: Seq[Attribute],
100106
binding: Boolean): Option[ExprOuterClass.Expr] = {
101-
val keysProtoExpr = expr.keys.map(exprToProtoInternal(_, inputs, binding))
102-
val valuesProtoExpr = expr.values.map(exprToProtoInternal(_, inputs, binding))
103-
if (keysProtoExpr.forall(_.isDefined) && valuesProtoExpr.forall(_.isDefined)) {
104-
val createMapProtoExpr = ExprOuterClass.CreateMap
105-
.newBuilder()
106-
.addAllValues(keysProtoExpr.map(_.get).asJava)
107-
.addAllValues(valuesProtoExpr.map(_.get).asJava)
108-
.build()
109-
Some(
110-
ExprOuterClass.Expr
111-
.newBuilder()
112-
.setCreateMap(createMapProtoExpr)
113-
.build())
114-
} else {
115-
withInfo(expr, expr.children: _*)
116-
None
107+
val keysArray = CreateArray(expr.keys)
108+
val valuesArray = CreateArray(expr.values)
109+
val keysExprProto = exprToProtoInternal(keysArray, inputs, binding)
110+
val valuesExprProto = exprToProtoInternal(valuesArray, inputs, binding)
111+
val createMapExprProto =
112+
scalarFunctionExprToProtoWithReturnType(
113+
"map_from_arrays",
114+
expr.dataType,
115+
false,
116+
keysExprProto,
117+
valuesExprProto)
118+
optExprWithInfo(createMapExprProto, expr, expr.children: _*)
119+
}
120+
}
121+
122+
sealed trait MapBase {
123+
124+
def containsBinary(dataType: DataType): Boolean = {
125+
dataType match {
126+
case BinaryType => true
127+
case StructType(fields) => fields.exists(field => containsBinary(field.dataType))
128+
case ArrayType(elementType, _) => containsBinary(elementType)
129+
case _ => false
117130
}
118131
}
119132
}

spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
2525
import org.apache.spark.sql.CometTestBase
2626
import org.apache.spark.sql.functions._
2727
import org.apache.spark.sql.internal.SQLConf
28+
import org.apache.spark.sql.types.BinaryType
2829

2930
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions}
3031

@@ -125,4 +126,29 @@ class CometMapExpressionSuite extends CometTestBase {
125126
}
126127
}
127128

129+
test("create_map") {
130+
withTempDir { dir =>
131+
val path = new Path(dir.toURI.toString, "test.parquet")
132+
val filename = path.toString
133+
val random = new Random(42)
134+
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
135+
val schemaGenOptions =
136+
SchemaGenOptions(generateArray = false, generateStruct = false, generateMap = false)
137+
val dataGenOptions = DataGenOptions(allowNull = false, generateNegativeZero = false)
138+
ParquetGenerator.makeParquetFile(
139+
random,
140+
spark,
141+
filename,
142+
100,
143+
schemaGenOptions,
144+
dataGenOptions)
145+
}
146+
val df = spark.read.parquet(filename)
147+
df.createOrReplaceTempView("t1")
148+
for (fieldName <- df.schema.filter(_.dataType != BinaryType).map(_.name)) {
149+
checkSparkAnswerAndOperator(spark.sql(s"SELECT map($fieldName, $fieldName) FROM t1"))
150+
}
151+
}
152+
}
153+
128154
}

0 commit comments

Comments
 (0)