Skip to content

Commit a19dbc5

Browse files
authored
fix regexp_count should count empty-pattern matches (#22311)
## Which issue does this PR close? - Closes #22267 ## Rationale for this change `regexp_count` did not handle empty regular-expression patterns correctly. An empty pattern should be counted as valid matches instead of returning `0`. This also affects calls that use the `start` argument and certain flag combinations. ## What changes are included in this PR? - Fix `regexp_count` so empty-pattern matches are counted correctly. - Adjust `start` handling so character offsets are computed correctly. - Update unit tests and sqllogictest coverage for empty patterns, `start`, and flags. - Update expected results to match the corrected behavior. ## Are these changes tested? - Yes. Rust unit tests were updated. - Yes. Sqllogictest coverage was added/updated. - I also ran: - `cargo test -p datafusion-sqllogictest --test sqllogictests table_functions` - `cargo test -p datafusion-sqllogictest --test sqllogictests scalar` ## Are there any user-facing changes? - Yes. `regexp_count` now returns correct counts for empty patterns, so results may differ from the previous behavior.
1 parent bb7d486 commit a19dbc5

2 files changed

Lines changed: 110 additions & 19 deletions

File tree

datafusion/functions/src/regex/regexpcount.rs

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,10 @@ where
268268
S: StringArrayType<'a>,
269269
{
270270
let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 {
271-
(Some(regex_array.value(0)), true)
271+
(
272+
(!regex_array.is_null(0)).then(|| regex_array.value(0)),
273+
true,
274+
)
272275
} else {
273276
(None, false)
274277
};
@@ -300,7 +303,7 @@ where
300303
match (is_regex_scalar, is_start_scalar, is_flags_scalar) {
301304
(true, true, true) => {
302305
let regex = match regex_scalar {
303-
None | Some("") => {
306+
None => {
304307
return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
305308
}
306309
Some(regex) => regex,
@@ -317,7 +320,7 @@ where
317320
}
318321
(true, true, false) => {
319322
let regex = match regex_scalar {
320-
None | Some("") => {
323+
None => {
321324
return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
322325
}
323326
Some(regex) => regex,
@@ -346,7 +349,7 @@ where
346349
}
347350
(true, false, true) => {
348351
let regex = match regex_scalar {
349-
None | Some("") => {
352+
None => {
350353
return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
351354
}
352355
Some(regex) => regex,
@@ -366,7 +369,7 @@ where
366369
}
367370
(true, false, false) => {
368371
let regex = match regex_scalar {
369-
None | Some("") => {
372+
None => {
370373
return Ok(Arc::new(Int64Array::from(vec![0; values.len()])));
371374
}
372375
Some(regex) => regex,
@@ -411,7 +414,7 @@ where
411414
.zip(regex_array.iter())
412415
.map(|(value, regex)| {
413416
let regex = match regex {
414-
None | Some("") => return Ok(0),
417+
None => return Ok(0),
415418
Some(regex) => regex,
416419
};
417420

@@ -447,7 +450,7 @@ where
447450
izip!(values.iter(), regex_array.iter(), flags_array.iter())
448451
.map(|(value, regex, flags)| {
449452
let regex = match regex {
450-
None | Some("") => return Ok(0),
453+
None => return Ok(0),
451454
Some(regex) => regex,
452455
};
453456

@@ -481,7 +484,7 @@ where
481484
izip!(values.iter(), regex_array.iter(), start_array.iter())
482485
.map(|(value, regex, start)| {
483486
let regex = match regex {
484-
None | Some("") => return Ok(0),
487+
None => return Ok(0),
485488
Some(regex) => regex,
486489
};
487490

@@ -531,7 +534,7 @@ where
531534
)
532535
.map(|(value, regex, start, flags)| {
533536
let regex = match regex {
534-
None | Some("") => return Ok(0),
537+
None => return Ok(0),
535538
Some(regex) => regex,
536539
};
537540

@@ -551,7 +554,7 @@ fn count_matches(
551554
start: Option<i64>,
552555
) -> Result<i64, ArrowError> {
553556
let value = match value {
554-
None | Some("") => return Ok(0),
557+
None => return Ok(0),
555558
Some(value) => value,
556559
};
557560

@@ -562,12 +565,23 @@ fn count_matches(
562565
));
563566
}
564567

568+
let char_len = value.chars().count();
569+
let start_index = (start as usize).saturating_sub(1);
570+
571+
if start_index > char_len {
572+
return Ok(0);
573+
}
574+
565575
// Find the byte offset for the start position (1-based character index)
566-
let byte_offset = value
567-
.char_indices()
568-
.nth((start as usize).saturating_sub(1))
569-
.map(|(idx, _)| idx)
570-
.unwrap_or(value.len());
576+
let byte_offset = if start_index == char_len {
577+
value.len()
578+
} else {
579+
value
580+
.char_indices()
581+
.nth(start_index)
582+
.map(|(idx, _)| idx)
583+
.unwrap_or(value.len())
584+
};
571585

572586
// Use string slicing instead of collecting chars into a new String
573587
let find_slice = &value[byte_offset..];
@@ -589,6 +603,7 @@ mod tests {
589603
#[test]
590604
fn test_regexp_count() {
591605
test_case_sensitive_regexp_count_scalar();
606+
test_case_sensitive_regexp_count_empty_pattern_scalar();
592607
test_case_sensitive_regexp_count_scalar_start();
593608
test_case_insensitive_regexp_count_scalar_flags();
594609
test_case_sensitive_regexp_count_start_scalar_complex();
@@ -675,6 +690,57 @@ mod tests {
675690
});
676691
}
677692

693+
fn test_case_sensitive_regexp_count_empty_pattern_scalar() {
694+
let values = ["", "abc", "abc"];
695+
let start_positions = [1, 1, 2];
696+
let expected: Vec<i64> = vec![1, 4, 3];
697+
698+
values
699+
.iter()
700+
.zip(start_positions.iter())
701+
.enumerate()
702+
.for_each(|(pos, (&value, &start))| {
703+
let expected = expected.get(pos).cloned();
704+
let start_sv = ScalarValue::Int64(Some(start));
705+
706+
let re = regexp_count_with_scalar_values(&[
707+
ScalarValue::Utf8(Some(value.to_string())),
708+
ScalarValue::Utf8(Some("".to_string())),
709+
start_sv.clone(),
710+
]);
711+
match re {
712+
Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
713+
assert_eq!(v, expected, "regexp_count scalar test failed");
714+
}
715+
_ => panic!("Unexpected result"),
716+
}
717+
718+
let re = regexp_count_with_scalar_values(&[
719+
ScalarValue::LargeUtf8(Some(value.to_string())),
720+
ScalarValue::LargeUtf8(Some("".to_string())),
721+
start_sv.clone(),
722+
]);
723+
match re {
724+
Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
725+
assert_eq!(v, expected, "regexp_count scalar test failed");
726+
}
727+
_ => panic!("Unexpected result"),
728+
}
729+
730+
let re = regexp_count_with_scalar_values(&[
731+
ScalarValue::Utf8View(Some(value.to_string())),
732+
ScalarValue::Utf8View(Some("".to_string())),
733+
start_sv,
734+
]);
735+
match re {
736+
Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
737+
assert_eq!(v, expected, "regexp_count scalar test failed");
738+
}
739+
_ => panic!("Unexpected result"),
740+
}
741+
});
742+
}
743+
678744
fn test_case_sensitive_regexp_count_scalar_start() {
679745
let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
680746
let regex = "abc";
@@ -792,7 +858,7 @@ mod tests {
792858
let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]);
793859
let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
794860

795-
let expected = Int64Array::from(vec![0, 1, 2, 2, 2]);
861+
let expected = Int64Array::from(vec![1, 1, 2, 2, 2]);
796862

797863
let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
798864
assert_eq!(re.as_ref(), &expected);
@@ -806,7 +872,7 @@ mod tests {
806872
let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
807873
let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
808874

809-
let expected = Int64Array::from(vec![0, 0, 1, 1, 0]);
875+
let expected = Int64Array::from(vec![1, 0, 1, 1, 0]);
810876

811877
let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
812878
.unwrap();
@@ -822,7 +888,7 @@ mod tests {
822888
let start = Int64Array::from(vec![1]);
823889
let flags = A::from(vec!["", "i", "", "", "i"]);
824890

825-
let expected = Int64Array::from(vec![0, 1, 2, 2, 3]);
891+
let expected = Int64Array::from(vec![1, 1, 2, 2, 3]);
826892

827893
let re = regexp_count_func(&[
828894
Arc::new(values),
@@ -910,7 +976,7 @@ mod tests {
910976
let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
911977
let flags = A::from(vec!["", "i", "", "", "i"]);
912978

913-
let expected = Int64Array::from(vec![0, 1, 1, 1, 1]);
979+
let expected = Int64Array::from(vec![1, 1, 1, 1, 1]);
914980

915981
let re = regexp_count_func(&[
916982
Arc::new(values),

datafusion/sqllogictest/test_files/regexp/regexp_count.slt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,31 @@ SELECT regexp_count('ABCABCABCABC', 'Abc', 1, 'i');
5151
----
5252
4
5353

54+
query I
55+
SELECT regexp_count('abc', '');
56+
----
57+
4
58+
59+
query I
60+
SELECT regexp_count('abc', '', 2);
61+
----
62+
3
63+
64+
query I
65+
SELECT regexp_count('abc', '', 1, 'i');
66+
----
67+
4
68+
69+
query I
70+
SELECT regexp_count('abc', '', 4);
71+
----
72+
1
73+
74+
query I
75+
SELECT regexp_count('abc', '', 5);
76+
----
77+
0
78+
5479
statement error
5580
External error: query failed: DataFusion error: Arrow error: Compute error: regexp_count() requires start to be 1 based
5681
SELECT regexp_count('123123123123', '123', 0);

0 commit comments

Comments
 (0)