Skip to content

Commit 26b0e85

Browse files
authored
feat: add native support for substring_index expression (#4286)
1 parent 76adfa2 commit 26b0e85

4 files changed

Lines changed: 138 additions & 2 deletions

File tree

docs/source/contributor-guide/spark_expressions_support.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@
563563
- [x] startswith
564564
- [x] substr
565565
- [x] substring
566-
- [ ] substring_index
566+
- [x] substring_index
567567
- [ ] to_binary
568568
- [ ] to_char
569569
- [ ] to_number

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim {
198198
classOf[Left] -> CometLeft,
199199
classOf[Right] -> CometRight,
200200
classOf[Substring] -> CometSubstring,
201+
classOf[SubstringIndex] -> CometSubstringIndex,
201202
classOf[Upper] -> CometUpper)
202203

203204
private val bitwiseExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(

spark/src/main/scala/org/apache/comet/serde/strings.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ package org.apache.comet.serde
2121

2222
import java.util.Locale
2323

24-
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper}
24+
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, SubstringIndex, Upper}
2525
import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType}
2626
import org.apache.spark.unsafe.types.UTF8String
2727

@@ -129,6 +129,22 @@ object CometSubstring extends CometExpressionSerde[Substring] {
129129
}
130130
}
131131

132+
object CometSubstringIndex extends CometExpressionSerde[SubstringIndex] {
133+
134+
override def convert(
135+
expr: SubstringIndex,
136+
inputs: Seq[Attribute],
137+
binding: Boolean): Option[ExprOuterClass.Expr] = {
138+
val strExpr = exprToProtoInternal(expr.strExpr, inputs, binding)
139+
val delimExpr = exprToProtoInternal(expr.delimExpr, inputs, binding)
140+
val countCast = Cast(expr.countExpr, LongType)
141+
val countExpr = exprToProtoInternal(countCast, inputs, binding)
142+
val optExpr =
143+
scalarFunctionExprToProto("substring_index", strExpr, delimExpr, countExpr)
144+
optExprWithInfo(optExpr, expr, expr.strExpr, expr.delimExpr, expr.countExpr)
145+
}
146+
}
147+
132148
object CometLeft extends CometExpressionSerde[Left] {
133149

134150
override def getUnsupportedReasons(): Seq[String] = Seq(
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
-- Licensed to the Apache Software Foundation (ASF) under one
2+
-- or more contributor license agreements. See the NOTICE file
3+
-- distributed with this work for additional information
4+
-- regarding copyright ownership. The ASF licenses this file
5+
-- to you under the Apache License, Version 2.0 (the
6+
-- "License"); you may not use this file except in compliance
7+
-- with the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing,
12+
-- software distributed under the License is distributed on an
13+
-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
-- KIND, either express or implied. See the License for the
15+
-- specific language governing permissions and limitations
16+
-- under the License.
17+
18+
-- ConfigMatrix: parquet.enable.dictionary=false,true
19+
20+
statement
21+
CREATE TABLE test_substring_index(s string, delim string, cnt int) USING parquet
22+
23+
statement
24+
INSERT INTO test_substring_index VALUES
25+
('www.apache.org', '.', 1),
26+
('www.apache.org', '.', 2),
27+
('www.apache.org', '.', 3),
28+
('www.apache.org', '.', -1),
29+
('www.apache.org', '.', -2),
30+
('www.apache.org', '.', -3),
31+
('www.apache.org', '.', 0),
32+
('hello', '.', 1),
33+
('', '.', 1),
34+
('www.apache.org', '', 1),
35+
(NULL, '.', 1),
36+
('www.apache.org', NULL, 1),
37+
('www.apache.org', '.', NULL)
38+
39+
-- all columns
40+
query
41+
SELECT substring_index(s, delim, cnt) FROM test_substring_index
42+
43+
-- literal arguments
44+
query
45+
SELECT substring_index('www.apache.org', '.', 1),
46+
substring_index('www.apache.org', '.', 2),
47+
substring_index('www.apache.org', '.', -1),
48+
substring_index('www.apache.org', '.', -2),
49+
substring_index('www.apache.org', '.', 0)
50+
51+
-- NULL literal arguments
52+
query
53+
SELECT substring_index(NULL, '.', 1),
54+
substring_index('www.apache.org', NULL, 1),
55+
substring_index('www.apache.org', '.', NULL)
56+
57+
-- column string, literal delimiter and count
58+
query
59+
SELECT substring_index(s, '.', 1) FROM test_substring_index
60+
61+
-- literal string, column delimiter and count
62+
query
63+
SELECT substring_index('www.apache.org', delim, cnt) FROM test_substring_index
64+
65+
-- count exceeds number of delimiters (returns full string)
66+
query
67+
SELECT substring_index('www.apache.org', '.', 10),
68+
substring_index('www.apache.org', '.', -10)
69+
70+
-- multi-character delimiter
71+
query
72+
SELECT substring_index('one::two::three', '::', 1),
73+
substring_index('one::two::three', '::', 2),
74+
substring_index('one::two::three', '::', -1),
75+
substring_index('one::two::three', '::', -2)
76+
77+
-- delimiter not found
78+
query
79+
SELECT substring_index('hello world', 'xyz', 1),
80+
substring_index('hello world', 'xyz', -1)
81+
82+
-- empty string input
83+
query
84+
SELECT substring_index('', '.', 1),
85+
substring_index('', '.', -1)
86+
87+
-- empty delimiter
88+
query
89+
SELECT substring_index('www.apache.org', '', 1),
90+
substring_index('www.apache.org', '', -1)
91+
92+
-- multibyte UTF-8 characters
93+
query
94+
SELECT substring_index('a.b.c', '.', 2),
95+
substring_index('中文.测试.数据', '.', 1),
96+
substring_index('中文.测试.数据', '.', -1),
97+
substring_index('中文.测试.数据', '.', 2)
98+
99+
-- delimiter at start of string
100+
query
101+
SELECT substring_index('.www.apache.org', '.', 1),
102+
substring_index('.www.apache.org', '.', 2),
103+
substring_index('.www.apache.org', '.', -1)
104+
105+
-- delimiter at end of string
106+
query
107+
SELECT substring_index('www.apache.org.', '.', -1),
108+
substring_index('www.apache.org.', '.', 3),
109+
substring_index('www.apache.org.', '.', -2)
110+
111+
-- delimiter equals the full string
112+
query
113+
SELECT substring_index('abc', 'abc', 1),
114+
substring_index('abc', 'abc', -1)
115+
116+
-- large count values
117+
query
118+
SELECT substring_index('www.apache.org', '.', 2147483647),
119+
substring_index('www.apache.org', '.', -2147483647)

0 commit comments

Comments
 (0)