Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions datafusion/functions/src/regex/regexpinstr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@ where
)
.map(|(value, regex, start, nth, flags, subexp)| match regex {
None => Ok(None),
Some("") => Ok(Some(0)),
Some(regex) => get_index(
value,
regex,
Expand Down Expand Up @@ -395,11 +394,8 @@ where
{
let value = match value {
None => return Ok(None),
Some("") => return Ok(Some(0)),
Some(value) => value,
};
let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?;
// println!("get_index: value = {}, pattern = {}, start = {}, n = {}, subexpr = {}, flags = {:?}", value, pattern, start, n, subexpr, flags);
if start < 1 {
return Err(ArrowError::ComputeError(
"regexp_instr() requires start to be 1-based".to_string(),
Expand All @@ -412,8 +408,22 @@ where
));
}

// --- Simplified byte_start_offset calculation ---
let total_chars = value.chars().count() as i64;
if pattern.is_empty() {
compile_and_cache_regex(pattern, flags, regex_cache)?;
if subexpr > 0 {
return Ok(Some(0));
}

let match_position = start.saturating_add(n).saturating_sub(1);
return Ok(Some(if match_position <= total_chars + 1 {
match_position
} else {
0
}));
}

let pattern: &Regex = compile_and_cache_regex(pattern, flags, regex_cache)?;
let byte_start_offset: usize = if start > total_chars {
// If start is beyond the total characters, it means we start searching
// after the string effectively. No matches possible.
Expand All @@ -426,7 +436,6 @@ where
.map(|(idx, _)| idx)
.unwrap_or(0) // Should not happen if start is valid and <= total_chars
};
// --- End simplified calculation ---

let search_slice = &value[byte_start_offset..];

Expand All @@ -452,6 +461,7 @@ mod tests {
test_case_sensitive_regexp_instr_scalar_start();
test_case_sensitive_regexp_instr_scalar_nth();
test_case_sensitive_regexp_instr_scalar_subexp();
test_regexp_instr_empty_pattern_global_flag();

test_case_sensitive_regexp_instr_array::<GenericStringArray<i32>>();
test_case_sensitive_regexp_instr_array::<GenericStringArray<i64>>();
Expand Down Expand Up @@ -492,7 +502,7 @@ mod tests {
fn test_case_sensitive_regexp_instr_nulls() {
let v = "";
let r = "";
let expected = 0;
let expected = 1;
let regex_sv = ScalarValue::Utf8(Some(r.to_string()));
let re = regexp_instr_with_scalar_values(&[v.to_string().into(), regex_sv]);
// let res_exp = re.unwrap();
Expand All @@ -511,10 +521,11 @@ mod tests {
"no match here",
"abc",
"ДатаФусион数据融合📊🔥",
"abc",
];
let regex = ["o", "d", "123", "z", "gg", "📊"];
let regex = ["o", "d", "123", "z", "gg", "📊", ""];

let expected: Vec<i64> = vec![5, 4, 4, 0, 0, 15];
let expected: Vec<i64> = vec![5, 4, 4, 0, 0, 15, 1];

izip!(values.iter(), regex.iter())
.enumerate()
Expand Down Expand Up @@ -762,6 +773,22 @@ mod tests {
});
}

fn test_regexp_instr_empty_pattern_global_flag() {
let args = [
ScalarValue::Utf8(Some("abc".to_string())),
ScalarValue::Utf8(Some("".to_string())),
ScalarValue::Int64(Some(1)),
ScalarValue::Int64(Some(1)),
ScalarValue::Utf8(Some("g".to_string())),
];
let err = regexp_instr_with_scalar_values(&args)
.expect_err("global flag should be rejected for empty patterns");
assert!(
err.to_string().contains("does not support the global flag"),
"{err}"
);
}

fn test_case_sensitive_regexp_instr_array<A>()
where
A: From<Vec<&'static str>> + Array + 'static,
Expand All @@ -772,10 +799,11 @@ mod tests {
"xyz123xyz",
"no match here",
"",
"abc",
]);
let regex = A::from(vec!["o", "d", "123", "z", "gg"]);
let regex = A::from(vec!["o", "d", "123", "z", "gg", ""]);

let expected = Int64Array::from(vec![5, 4, 4, 0, 0]);
let expected = Int64Array::from(vec![5, 4, 4, 0, 0, 1]);
let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
assert_eq!(re.as_ref(), &expected);
}
Expand All @@ -784,10 +812,10 @@ mod tests {
where
A: From<Vec<&'static str>> + Array + 'static,
{
let values = A::from(vec!["abcabcabc", "abcabcabc", ""]);
let regex = A::from(vec!["abc", "abc", "gg"]);
let start = Int64Array::from(vec![4, 5, 5]);
let expected = Int64Array::from(vec![4, 7, 0]);
let values = A::from(vec!["abcabcabc", "abcabcabc", "", "abc"]);
let regex = A::from(vec!["abc", "abc", "gg", ""]);
let start = Int64Array::from(vec![4, 5, 5, 2]);
let expected = Int64Array::from(vec![4, 7, 0, 2]);

let re = regexp_instr_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
.unwrap();
Expand All @@ -798,11 +826,17 @@ mod tests {
where
A: From<Vec<&'static str>> + Array + 'static,
{
let values = A::from(vec!["abcabcabc", "abcabcabc", "abcabcabc", "abcabcabc"]);
let regex = A::from(vec!["abc", "abc", "abc", "abc"]);
let start = Int64Array::from(vec![1, 1, 1, 1]);
let nth = Int64Array::from(vec![1, 2, 3, 4]);
let expected = Int64Array::from(vec![1, 4, 7, 0]);
let values = A::from(vec![
"abcabcabc",
"abcabcabc",
"abcabcabc",
"abcabcabc",
"abc",
]);
let regex = A::from(vec!["abc", "abc", "abc", "abc", ""]);
let start = Int64Array::from(vec![1, 1, 1, 1, 2]);
let nth = Int64Array::from(vec![1, 2, 3, 4, 3]);
let expected = Int64Array::from(vec![1, 4, 7, 0, 4]);

let re = regexp_instr_func(&[
Arc::new(values),
Expand Down
20 changes: 20 additions & 0 deletions datafusion/sqllogictest/test_files/regexp/regexp_instr.slt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ SELECT regexp_instr('123123123123123', '(12)3');
----
1

query I
SELECT regexp_instr('abc', '');
----
1

query I
SELECT regexp_instr('', '');
----
1

query I
SELECT regexp_instr('abc', '', 2, 3);
----
4

query I
SELECT regexp_instr('abc', '', 5);
----
0

query I
SELECT regexp_instr('123123123123', '123', 1);
----
Expand Down