Skip to content

Commit 113fd2d

Browse files
authored
Add levenshtein and hamming_distance functions (#60412)
Related Issue: #48203 Related PR: #57144 (reference) Problem Summary: support levenshtein (Hive) and hamming_distance (Trino/Presto).
1 parent 2570dd8 commit 113fd2d

11 files changed

Lines changed: 1225 additions & 2 deletions

File tree

Lines changed: 320 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,320 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include <algorithm>
19+
#include <vector>
20+
21+
#include "common/status.h"
22+
#include "core/column/column_nullable.h"
23+
#include "core/column/column_string.h"
24+
#include "core/data_type/data_type_number.h"
25+
#include "core/string_ref.h"
26+
#include "exprs/function/simple_function_factory.h"
27+
#include "util/simd/vstring_function.h"
28+
29+
namespace doris {
30+
31+
class FunctionHammingDistance : public IFunction {
32+
public:
33+
using ResultDataType = DataTypeInt64;
34+
using ResultPaddedPODArray = PaddedPODArray<Int64>;
35+
using ResultColumnType = ColumnVector<ResultDataType::PType>;
36+
37+
static constexpr auto name = "hamming_distance";
38+
39+
static FunctionPtr create() { return std::make_shared<FunctionHammingDistance>(); }
40+
41+
String get_name() const override { return name; }
42+
size_t get_number_of_arguments() const override { return 2; }
43+
44+
DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
45+
const bool has_nullable = std::ranges::any_of(
46+
arguments, [](const DataTypePtr& type) { return type->is_nullable(); });
47+
if (has_nullable) {
48+
return make_nullable(std::make_shared<ResultDataType>());
49+
}
50+
return std::make_shared<ResultDataType>();
51+
}
52+
53+
bool use_default_implementation_for_nulls() const override { return false; }
54+
55+
Status execute_impl(FunctionContext* /*context*/, Block& block, const ColumnNumbers& arguments,
56+
uint32_t result, size_t input_rows_count) const override {
57+
const auto& [left_col, left_const] =
58+
unpack_if_const(block.get_by_position(arguments[0]).column);
59+
const auto& [right_col, right_const] =
60+
unpack_if_const(block.get_by_position(arguments[1]).column);
61+
62+
const auto* left_nullable = check_and_get_column<ColumnNullable>(left_col.get());
63+
const auto* right_nullable = check_and_get_column<ColumnNullable>(right_col.get());
64+
65+
const IColumn* left_nested =
66+
left_nullable ? &left_nullable->get_nested_column() : left_col.get();
67+
const IColumn* right_nested =
68+
right_nullable ? &right_nullable->get_nested_column() : right_col.get();
69+
70+
const auto* left_str_col = assert_cast<const ColumnString*>(left_nested);
71+
const auto* right_str_col = assert_cast<const ColumnString*>(right_nested);
72+
73+
auto res_col = ResultColumnType::create(input_rows_count);
74+
auto& res_data = res_col->get_data();
75+
76+
const NullMap* left_null_map =
77+
left_nullable ? &left_nullable->get_null_map_data() : nullptr;
78+
const NullMap* right_null_map =
79+
right_nullable ? &right_nullable->get_null_map_data() : nullptr;
80+
const bool has_nullable = left_null_map != nullptr || right_null_map != nullptr;
81+
82+
if (!has_nullable) {
83+
if (left_const) {
84+
RETURN_IF_ERROR(scalar_vector(left_str_col->get_data_at(0).trim_tail_padding_zero(),
85+
*right_str_col, res_data));
86+
} else if (right_const) {
87+
RETURN_IF_ERROR(vector_scalar(
88+
*left_str_col, right_str_col->get_data_at(0).trim_tail_padding_zero(),
89+
res_data));
90+
} else {
91+
RETURN_IF_ERROR(vector_vector(*left_str_col, *right_str_col, res_data));
92+
}
93+
block.replace_by_position(result, std::move(res_col));
94+
return Status::OK();
95+
}
96+
97+
auto null_col = ColumnUInt8::create(input_rows_count, 0);
98+
auto& null_map = null_col->get_data();
99+
if (left_const) {
100+
if (left_null_map && (*left_null_map)[0]) {
101+
std::fill(null_map.begin(), null_map.end(), 1);
102+
block.replace_by_position(
103+
result, ColumnNullable::create(std::move(res_col), std::move(null_col)));
104+
return Status::OK();
105+
}
106+
107+
const auto left = left_str_col->get_data_at(0).trim_tail_padding_zero();
108+
RETURN_IF_ERROR(scalar_vector_nullable(left, *right_str_col, right_null_map, res_data,
109+
null_map));
110+
} else if (right_const) {
111+
if (right_null_map && (*right_null_map)[0]) {
112+
std::fill(null_map.begin(), null_map.end(), 1);
113+
block.replace_by_position(
114+
result, ColumnNullable::create(std::move(res_col), std::move(null_col)));
115+
return Status::OK();
116+
}
117+
118+
RETURN_IF_ERROR(vector_scalar_nullable(
119+
*left_str_col, right_str_col->get_data_at(0).trim_tail_padding_zero(),
120+
left_null_map, res_data, null_map));
121+
} else {
122+
for (size_t i = 0; i < input_rows_count; ++i) {
123+
const bool left_is_null = left_null_map && (*left_null_map)[i];
124+
const bool right_is_null = right_null_map && (*right_null_map)[i];
125+
if (left_is_null || right_is_null) {
126+
null_map[i] = 1;
127+
res_data[i] = 0;
128+
continue;
129+
}
130+
131+
RETURN_IF_ERROR(hamming_distance(
132+
left_str_col->get_data_at(i).trim_tail_padding_zero(),
133+
right_str_col->get_data_at(i).trim_tail_padding_zero(), res_data[i], i));
134+
}
135+
}
136+
137+
block.replace_by_position(result,
138+
ColumnNullable::create(std::move(res_col), std::move(null_col)));
139+
return Status::OK();
140+
}
141+
142+
private:
143+
static Status vector_vector(const ColumnString& lcol, const ColumnString& rcol,
144+
ResultPaddedPODArray& res) {
145+
DCHECK_EQ(lcol.size(), rcol.size());
146+
147+
const size_t size = lcol.size();
148+
res.resize(size);
149+
std::vector<size_t> left_offsets;
150+
std::vector<size_t> right_offsets;
151+
for (size_t i = 0; i < size; ++i) {
152+
const auto left = lcol.get_data_at(i).trim_tail_padding_zero();
153+
const auto right = rcol.get_data_at(i).trim_tail_padding_zero();
154+
RETURN_IF_ERROR(hamming_distance_with_offsets(
155+
left, left_offsets, false, simd::VStringFunctions::is_ascii(left), right,
156+
right_offsets, false, simd::VStringFunctions::is_ascii(right), res[i], i));
157+
}
158+
return Status::OK();
159+
}
160+
161+
static Status vector_scalar(const ColumnString& lcol, const StringRef& rdata,
162+
ResultPaddedPODArray& res) {
163+
const size_t size = lcol.size();
164+
res.resize(size);
165+
const bool right_ascii = simd::VStringFunctions::is_ascii(rdata);
166+
std::vector<size_t> right_offsets;
167+
simd::VStringFunctions::get_utf8_char_offsets(rdata, right_offsets);
168+
std::vector<size_t> left_offsets;
169+
for (size_t i = 0; i < size; ++i) {
170+
const auto left = lcol.get_data_at(i).trim_tail_padding_zero();
171+
RETURN_IF_ERROR(hamming_distance_with_offsets(
172+
left, left_offsets, false, simd::VStringFunctions::is_ascii(left), rdata,
173+
right_offsets, true, right_ascii, res[i], i));
174+
}
175+
return Status::OK();
176+
}
177+
178+
static Status scalar_vector(const StringRef& ldata, const ColumnString& rcol,
179+
ResultPaddedPODArray& res) {
180+
const size_t size = rcol.size();
181+
res.resize(size);
182+
const bool left_ascii = simd::VStringFunctions::is_ascii(ldata);
183+
std::vector<size_t> left_offsets;
184+
simd::VStringFunctions::get_utf8_char_offsets(ldata, left_offsets);
185+
std::vector<size_t> right_offsets;
186+
for (size_t i = 0; i < size; ++i) {
187+
const auto right = rcol.get_data_at(i).trim_tail_padding_zero();
188+
RETURN_IF_ERROR(hamming_distance_with_offsets(
189+
ldata, left_offsets, true, left_ascii, right, right_offsets, false,
190+
simd::VStringFunctions::is_ascii(right), res[i], i));
191+
}
192+
return Status::OK();
193+
}
194+
195+
static Status vector_scalar_nullable(const ColumnString& lcol, const StringRef& rdata,
196+
const NullMap* left_null_map, ResultPaddedPODArray& res,
197+
NullMap& null_map) {
198+
const size_t size = lcol.size();
199+
res.resize(size);
200+
const bool right_ascii = simd::VStringFunctions::is_ascii(rdata);
201+
std::vector<size_t> right_offsets;
202+
simd::VStringFunctions::get_utf8_char_offsets(rdata, right_offsets);
203+
std::vector<size_t> left_offsets;
204+
for (size_t i = 0; i < size; ++i) {
205+
if (left_null_map && (*left_null_map)[i]) {
206+
null_map[i] = 1;
207+
res[i] = 0;
208+
continue;
209+
}
210+
211+
const auto left = lcol.get_data_at(i).trim_tail_padding_zero();
212+
RETURN_IF_ERROR(hamming_distance_with_offsets(
213+
left, left_offsets, false, simd::VStringFunctions::is_ascii(left), rdata,
214+
right_offsets, true, right_ascii, res[i], i));
215+
}
216+
return Status::OK();
217+
}
218+
219+
static Status scalar_vector_nullable(const StringRef& ldata, const ColumnString& rcol,
220+
const NullMap* right_null_map, ResultPaddedPODArray& res,
221+
NullMap& null_map) {
222+
const size_t size = rcol.size();
223+
res.resize(size);
224+
const bool left_ascii = simd::VStringFunctions::is_ascii(ldata);
225+
std::vector<size_t> left_offsets;
226+
simd::VStringFunctions::get_utf8_char_offsets(ldata, left_offsets);
227+
std::vector<size_t> right_offsets;
228+
for (size_t i = 0; i < size; ++i) {
229+
if (right_null_map && (*right_null_map)[i]) {
230+
null_map[i] = 1;
231+
res[i] = 0;
232+
continue;
233+
}
234+
235+
const auto right = rcol.get_data_at(i).trim_tail_padding_zero();
236+
RETURN_IF_ERROR(hamming_distance_with_offsets(
237+
ldata, left_offsets, true, left_ascii, right, right_offsets, false,
238+
simd::VStringFunctions::is_ascii(right), res[i], i));
239+
}
240+
return Status::OK();
241+
}
242+
243+
static Status hamming_distance_ascii(const StringRef& left, const StringRef& right,
244+
Int64& result, size_t row) {
245+
if (left.size != right.size) {
246+
return Status::InvalidArgument(
247+
"hamming_distance requires strings of the same length at row {}", row);
248+
}
249+
250+
Int64 distance = 0;
251+
for (size_t i = 0; i < left.size; ++i) {
252+
distance += static_cast<Int64>(left.data[i] != right.data[i]);
253+
}
254+
result = distance;
255+
return Status::OK();
256+
}
257+
258+
static Status hamming_distance_utf8(const StringRef& left,
259+
const std::vector<size_t>& left_offsets,
260+
const StringRef& right,
261+
const std::vector<size_t>& right_offsets, Int64& result,
262+
size_t row) {
263+
if (left_offsets.size() != right_offsets.size()) {
264+
return Status::InvalidArgument(
265+
"hamming_distance requires strings of the same length at row {}", row);
266+
}
267+
268+
Int64 distance = 0;
269+
const size_t len = left_offsets.size();
270+
for (size_t i = 0; i + 1 < len; ++i) {
271+
const size_t left_off = left_offsets[i];
272+
const size_t left_next = left_offsets[i + 1];
273+
const size_t right_off = right_offsets[i];
274+
const size_t right_next = right_offsets[i + 1];
275+
distance += static_cast<Int64>(!simd::VStringFunctions::utf8_char_equal(
276+
left, left_off, left_next, right, right_off, right_next));
277+
}
278+
if (len > 0) {
279+
const size_t left_off = left_offsets[len - 1];
280+
const size_t right_off = right_offsets[len - 1];
281+
distance += static_cast<Int64>(!simd::VStringFunctions::utf8_char_equal(
282+
left, left_off, left.size, right, right_off, right.size));
283+
}
284+
285+
result = distance;
286+
return Status::OK();
287+
}
288+
289+
static Status hamming_distance_with_offsets(
290+
const StringRef& left, std::vector<size_t>& left_offsets, bool left_offsets_ready,
291+
bool left_ascii, const StringRef& right, std::vector<size_t>& right_offsets,
292+
bool right_offsets_ready, bool right_ascii, Int64& result, size_t row) {
293+
if (left_ascii && right_ascii) {
294+
return hamming_distance_ascii(left, right, result, row);
295+
}
296+
297+
if (!left_offsets_ready) {
298+
simd::VStringFunctions::get_utf8_char_offsets(left, left_offsets);
299+
}
300+
if (!right_offsets_ready) {
301+
simd::VStringFunctions::get_utf8_char_offsets(right, right_offsets);
302+
}
303+
return hamming_distance_utf8(left, left_offsets, right, right_offsets, result, row);
304+
}
305+
306+
static Status hamming_distance(const StringRef& left, const StringRef& right, Int64& result,
307+
size_t row) {
308+
std::vector<size_t> left_offsets;
309+
std::vector<size_t> right_offsets;
310+
return hamming_distance_with_offsets(
311+
left, left_offsets, false, simd::VStringFunctions::is_ascii(left), right,
312+
right_offsets, false, simd::VStringFunctions::is_ascii(right), result, row);
313+
}
314+
};
315+
316+
void register_function_hamming_distance(SimpleFunctionFactory& factory) {
317+
factory.register_function<FunctionHammingDistance>();
318+
}
319+
320+
} // namespace doris

0 commit comments

Comments
 (0)