Skip to content

Commit 82d674e

Browse files
feat: support Spark expression json_array_length (apache#4365)
1 parent 577a793 commit 82d674e

13 files changed

Lines changed: 405 additions & 4 deletions

File tree

native/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/spark-expr/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ num = { workspace = true }
3535
regex = { workspace = true }
3636
# preserve_order: needed for get_json_object to match Spark's JSON key ordering
3737
serde_json = { version = "1.0", features = ["preserve_order"] }
38+
serde = { version = "1.0", features = ["derive"] }
3839
datafusion-comet-common = { workspace = true }
3940
datafusion-comet-jni-bridge = { workspace = true }
4041
jni = "0.22.4"

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
use crate::hash_funcs::*;
19+
use crate::json_funcs::JsonArrayLength;
1920
use crate::map_funcs::spark_map_sort;
2021
use crate::math_funcs::abs::abs;
2122
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
@@ -239,6 +240,7 @@ fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
239240
Arc::new(ScalarUDF::new_from_impl(SparkNextDay::default())),
240241
Arc::new(ScalarUDF::new_from_impl(SparkSecondsToTimestamp::default())),
241242
Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())),
243+
Arc::new(ScalarUDF::new_from_impl(JsonArrayLength::default())),
242244
]
243245
}
244246

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, ArrayRef, Int32Builder, OffsetSizeTrait};
19+
use arrow::datatypes::DataType;
20+
use datafusion::common::cast::as_generic_string_array;
21+
use datafusion::common::{exec_err, Result, ScalarValue};
22+
use datafusion::logical_expr::{
23+
ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
24+
};
25+
26+
use std::any::Any;
27+
28+
use serde::de::{IgnoredAny, SeqAccess, Visitor};
29+
use serde::Deserializer;
30+
use std::fmt;
31+
use std::sync::Arc;
32+
33+
#[derive(Debug, PartialEq, Eq, Hash)]
34+
pub struct JsonArrayLength {
35+
signature: Signature,
36+
}
37+
38+
impl Default for JsonArrayLength {
39+
fn default() -> Self {
40+
Self::new()
41+
}
42+
}
43+
44+
impl JsonArrayLength {
45+
pub fn new() -> Self {
46+
Self {
47+
signature: Signature::variadic(
48+
vec![DataType::Utf8, DataType::LargeUtf8],
49+
Volatility::Immutable,
50+
),
51+
}
52+
}
53+
}
54+
55+
impl ScalarUDFImpl for JsonArrayLength {
56+
fn as_any(&self) -> &dyn Any {
57+
self
58+
}
59+
60+
fn name(&self) -> &str {
61+
"json_array_length"
62+
}
63+
64+
fn signature(&self) -> &Signature {
65+
&self.signature
66+
}
67+
68+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
69+
Ok(DataType::Int32)
70+
}
71+
72+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
73+
spark_json_array_length(&args.args)
74+
}
75+
}
76+
77+
fn spark_json_array_length(args: &[ColumnarValue]) -> Result<ColumnarValue> {
78+
if args.len() != 1 {
79+
return exec_err!("json_array_length function takes exactly one argument");
80+
}
81+
match &args[0] {
82+
ColumnarValue::Array(array) => {
83+
let result = spark_json_array_length_array(array)?;
84+
Ok(ColumnarValue::Array(result))
85+
}
86+
ColumnarValue::Scalar(scalar) => {
87+
let result = spark_json_array_length_scalar(scalar)?;
88+
Ok(ColumnarValue::Scalar(result))
89+
}
90+
}
91+
}
92+
93+
fn spark_json_array_length_array(array: &ArrayRef) -> Result<ArrayRef> {
94+
match array.data_type() {
95+
DataType::Utf8 => spark_json_array_length_array_inner::<i32>(array),
96+
DataType::LargeUtf8 => spark_json_array_length_array_inner::<i64>(array),
97+
other => {
98+
exec_err!("Unsupported data type {other:?} for function `json_array_length`")
99+
}
100+
}
101+
}
102+
103+
fn spark_json_array_length_scalar(scalar: &ScalarValue) -> Result<ScalarValue> {
104+
match scalar {
105+
ScalarValue::Utf8(value) => spark_json_array_length_scalar_inner(value),
106+
ScalarValue::LargeUtf8(value) => spark_json_array_length_scalar_inner(value),
107+
other => {
108+
exec_err!("Unsupported data type {other:?} for function `json_array_length`")
109+
}
110+
}
111+
}
112+
113+
fn spark_json_array_length_scalar_inner(json_str: &Option<String>) -> Result<ScalarValue> {
114+
let array_length = json_str
115+
.clone()
116+
.and_then(|json_str| get_json_array_length(&json_str));
117+
Ok(ScalarValue::Int32(array_length))
118+
}
119+
120+
fn spark_json_array_length_array_inner<T: OffsetSizeTrait>(array: &ArrayRef) -> Result<ArrayRef> {
121+
let str_array = as_generic_string_array::<T>(array)?;
122+
let mut builder = Int32Builder::with_capacity(str_array.len());
123+
for row_idx in 0..str_array.len() {
124+
if str_array.is_null(row_idx) {
125+
builder.append_null();
126+
} else {
127+
let json_str = str_array.value(row_idx);
128+
if let Some(json_array_length) = get_json_array_length(json_str) {
129+
builder.append_value(json_array_length);
130+
} else {
131+
builder.append_null()
132+
}
133+
}
134+
}
135+
Ok(Arc::new(builder.finish()))
136+
}
137+
138+
struct ArrayItemCounter;
139+
140+
impl<'de> Visitor<'de> for ArrayItemCounter {
141+
type Value = i32;
142+
143+
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144+
f.write_str("a JSON array")
145+
}
146+
147+
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
148+
let mut len = 0i32;
149+
while seq.next_element::<IgnoredAny>()?.is_some() {
150+
len += 1;
151+
}
152+
Ok(len)
153+
}
154+
}
155+
156+
fn get_json_array_length(json: &str) -> Option<i32> {
157+
let mut deserializer = serde_json::Deserializer::from_str(json);
158+
deserializer.deserialize_seq(ArrayItemCounter).ok()
159+
}

native/spark-expr/src/json_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
// under the License.
1717

1818
mod from_json;
19+
mod json_array_length;
1920
mod to_json;
2021

2122
pub use from_json::FromJson;
23+
pub use json_array_length::JsonArrayLength;
2224
pub use to_json::ToJson;

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
269269
private val conversionExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
270270
classOf[Cast] -> CometCast)
271271

272+
private val jsonExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
273+
classOf[LengthOfJsonArray] -> CometLengthOfJsonArray)
274+
272275
private[comet] val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
273276
// TODO PromotePrecision
274277
classOf[Alias] -> CometAlias,
@@ -295,7 +298,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
295298
mathExpressions ++ hashExpressions ++ stringExpressions ++
296299
conditionalExpressions ++ mapExpressions ++ predicateExpressions ++
297300
structExpressions ++ bitwiseExpressions ++ miscExpressions ++ arrayExpressions ++
298-
temporalExpressions ++ conversionExpressions ++ urlExpressions
301+
temporalExpressions ++ conversionExpressions ++ urlExpressions ++ jsonExpressions
299302

300303
/**
301304
* Mapping of Spark aggregate expression class to Comet expression handler.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.comet.serde
21+
22+
import org.apache.spark.sql.catalyst.expressions.LengthOfJsonArray
23+
24+
object CometLengthOfJsonArray
25+
extends CometScalarFunction[LengthOfJsonArray]("json_array_length") {
26+
27+
private val IncompatibleReason: String =
28+
"Spark's lenient JSON parser allows single quotes, unescaped controls, " +
29+
"and trailing content, " +
30+
"while Comet's serde_json requires strict JSON."
31+
32+
override def getIncompatibleReasons(): Seq[String] = Seq(IncompatibleReason)
33+
34+
override def getSupportLevel(expr: LengthOfJsonArray): SupportLevel = Incompatible(
35+
Some(IncompatibleReason))
36+
}

spark/src/main/spark-4.0/org/apache/comet/shims/CometExprShim.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.shims
2121

2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
24-
import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator
24+
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, StructsToJsonEvaluator}
2525
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
2626
import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator
2727
import org.apache.spark.sql.internal.SQLConf
@@ -160,6 +160,20 @@ trait CometExprShim extends CommonStringExprs with CometExprShim4x {
160160
case _ => None
161161
}
162162

163+
case s: StaticInvoke =>
164+
(s.staticObject, s.functionName, s.arguments) match {
165+
case (cls, "lengthOfJsonArray", Seq(child)) if cls == classOf[JsonExpressionUtils] =>
166+
val lengthOfJsonArray = LengthOfJsonArray(child)
167+
val exprProto = exprToProtoInternal(lengthOfJsonArray, inputs, binding)
168+
if (exprProto.isEmpty) {
169+
lengthOfJsonArray
170+
.getTagValue(CometExplainInfo.FALLBACK_REASONS)
171+
.foreach(reasons => s.setTagValue(CometExplainInfo.FALLBACK_REASONS, reasons))
172+
}
173+
exprProto
174+
case _ => None
175+
}
176+
163177
case ms: MapSort =>
164178
val keyType = ms.dataType.asInstanceOf[MapType].keyType
165179
if (!supportedScalarSortElementType(keyType)) {

spark/src/main/spark-4.1/org/apache/comet/shims/CometExprShim.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.shims
2121

2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
24-
import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator
24+
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, StructsToJsonEvaluator}
2525
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
2626
import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator
2727
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -191,6 +191,20 @@ trait CometExprShim extends CommonStringExprs with CometExprShim4x {
191191
case _ => None
192192
}
193193

194+
case s: StaticInvoke =>
195+
(s.staticObject, s.functionName, s.arguments) match {
196+
case (cls, "lengthOfJsonArray", Seq(child)) if cls == classOf[JsonExpressionUtils] =>
197+
val lengthOfJsonArray = LengthOfJsonArray(child)
198+
val exprProto = exprToProtoInternal(lengthOfJsonArray, inputs, binding)
199+
if (exprProto.isEmpty) {
200+
lengthOfJsonArray
201+
.getTagValue(CometExplainInfo.FALLBACK_REASONS)
202+
.foreach(reasons => s.setTagValue(CometExplainInfo.FALLBACK_REASONS, reasons))
203+
}
204+
exprProto
205+
case _ => None
206+
}
207+
194208
case ms: MapSort =>
195209
val keyType = ms.dataType.asInstanceOf[MapType].keyType
196210
if (!supportedScalarSortElementType(keyType)) {

spark/src/main/spark-4.2/org/apache/comet/shims/CometExprShim.scala

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.shims
2121

2222
import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
24-
import org.apache.spark.sql.catalyst.expressions.json.StructsToJsonEvaluator
24+
import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionUtils, StructsToJsonEvaluator}
2525
import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
2626
import org.apache.spark.sql.catalyst.expressions.url.ParseUrlEvaluator
2727
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -191,6 +191,20 @@ trait CometExprShim extends CommonStringExprs with CometExprShim4x {
191191
case _ => None
192192
}
193193

194+
case s: StaticInvoke =>
195+
(s.staticObject, s.functionName, s.arguments) match {
196+
case (cls, "lengthOfJsonArray", Seq(child)) if cls == classOf[JsonExpressionUtils] =>
197+
val lengthOfJsonArray = LengthOfJsonArray(child)
198+
val exprProto = exprToProtoInternal(lengthOfJsonArray, inputs, binding)
199+
if (exprProto.isEmpty) {
200+
lengthOfJsonArray
201+
.getTagValue(CometExplainInfo.FALLBACK_REASONS)
202+
.foreach(reasons => s.setTagValue(CometExplainInfo.FALLBACK_REASONS, reasons))
203+
}
204+
exprProto
205+
case _ => None
206+
}
207+
194208
case ms: MapSort =>
195209
val keyType = ms.dataType.asInstanceOf[MapType].keyType
196210
if (!supportedScalarSortElementType(keyType)) {

0 commit comments

Comments
 (0)