Skip to content

Commit 33d8b77

Browse files
committed
functions: implement aggregate function entropy
1 parent 1cde086 commit 33d8b77

File tree

11 files changed

+771
-10
lines changed

11 files changed

+771
-10
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
// This file is copied from
18+
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionEntropy.cpp
19+
// and modified by Doris
20+
21+
#include "vec/aggregate_functions/aggregate_function_entropy.h"
22+
23+
#include "runtime/define_primitive_type.h"
24+
#include "vec/aggregate_functions/aggregate_function_simple_factory.h"
25+
#include "vec/aggregate_functions/helpers.h"
26+
27+
namespace doris::vectorized {
28+
#include "common/compile_check_begin.h"
29+
30+
AggregateFunctionPtr create_aggregate_function_entropy(const std::string& name,
31+
const DataTypes& argument_types,
32+
const DataTypePtr& result_type,
33+
const bool result_is_nullable,
34+
const AggregateFunctionAttr& attr) {
35+
if (argument_types.size() == 1) {
36+
auto res = creator_with_type_list<
37+
TYPE_BOOLEAN, TYPE_TINYINT, TYPE_SMALLINT, TYPE_INT, TYPE_BIGINT, TYPE_LARGEINT,
38+
TYPE_DECIMAL32, TYPE_DECIMAL64, TYPE_DECIMAL128I, TYPE_DECIMAL256, TYPE_DECIMALV2,
39+
TYPE_FLOAT, TYPE_DOUBLE, TYPE_DATE, TYPE_DATETIME, TYPE_DATEV2, TYPE_DATETIMEV2,
40+
TYPE_TIME, TYPE_TIMEV2, TYPE_TIMESTAMPTZ>::
41+
create<AggregateFunctionEntropy, AggregateFunctionEntropySingleNumericData>(
42+
argument_types, result_is_nullable, attr);
43+
if (res) {
44+
return res;
45+
}
46+
47+
auto type = argument_types[0]->get_primitive_type();
48+
if (is_string_type(type) || is_varbinary(type) || type == TYPE_JSONB) {
49+
res = creator_without_type::create<
50+
AggregateFunctionEntropy<AggregateFunctionEntropySingleStringData>>(
51+
argument_types, result_is_nullable, attr);
52+
return res;
53+
}
54+
}
55+
56+
return creator_without_type::create<
57+
AggregateFunctionEntropy<AggregateFunctionEntropyGenericData>>(
58+
argument_types, result_is_nullable, attr);
59+
}
60+
61+
void register_aggregate_function_entropy(AggregateFunctionSimpleFactory& factory) {
62+
factory.register_function_both("entropy", create_aggregate_function_entropy);
63+
}
64+
65+
} // namespace doris::vectorized
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
// This file is copied from
18+
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionEntropy.cpp
19+
// and modified by Doris
20+
21+
#pragma once
22+
23+
#include "vec/aggregate_functions/aggregate_function.h"
24+
#include "vec/columns/column.h"
25+
#include "vec/columns/column_decimal.h"
26+
#include "vec/common/assert_cast.h"
27+
#include "vec/common/hash_table/hash.h"
28+
#include "vec/common/hash_table/phmap_fwd_decl.h"
29+
#include "vec/common/string_ref.h"
30+
#include "vec/common/uint128.h"
31+
#include "vec/core/types.h"
32+
#include "vec/data_types/data_type.h"
33+
#include "vec/data_types/data_type_number.h"
34+
35+
namespace doris::vectorized {
36+
#include "common/compile_check_begin.h"
37+
38+
class Arena;
39+
class BufferReadable;
40+
class BufferWritable;
41+
template <PrimitiveType>
42+
class ColumnVector;
43+
44+
/** Calculates Shannon Entropy, using HashMap and computing empirical distribution function.
45+
* Entropy is measured in bits (base-2 logarithm is used).
46+
*/
47+
template <typename Value, typename Hash = HashCRC32<Value>>
48+
struct AggregateFunctionEntropyData {
49+
using Container = flat_hash_map<Value, uint64_t, Hash>;
50+
using Self = AggregateFunctionEntropyData<Value, Hash>;
51+
Container frequency_map;
52+
53+
void clear() { frequency_map.clear(); }
54+
55+
void add(const Value& elem) { ++frequency_map[elem]; }
56+
57+
void merge(const Self& rhs) {
58+
frequency_map.reserve(frequency_map.size() + rhs.frequency_map.size());
59+
for (const auto& [elem, count] : rhs.frequency_map) {
60+
frequency_map[elem] += count;
61+
}
62+
}
63+
64+
void write(BufferWritable& buf) const {
65+
buf.write_var_uint(frequency_map.size());
66+
for (const auto& [elem, count] : frequency_map) {
67+
buf.write_binary(elem);
68+
buf.write_binary(count);
69+
}
70+
}
71+
72+
void read(BufferReadable& buf) {
73+
uint64_t new_size = 0;
74+
buf.read_var_uint(new_size);
75+
frequency_map.reserve(frequency_map.size() + new_size);
76+
77+
Value elem;
78+
uint64_t count;
79+
for (size_t i = 0; i < new_size; ++i) {
80+
buf.read_binary(elem);
81+
buf.read_binary(count);
82+
frequency_map[elem] += count;
83+
}
84+
}
85+
86+
Float64 get_result() const {
87+
Float64 entropy = 0;
88+
uint64_t total_count = 0;
89+
for (const auto& [_, count] : frequency_map) {
90+
total_count += count;
91+
}
92+
for (const auto& [_, count] : frequency_map) {
93+
Float64 p = static_cast<Float64>(count) / static_cast<Float64>(total_count);
94+
entropy -= p * std::log2(p);
95+
}
96+
return entropy;
97+
}
98+
99+
static String get_name() { return "entropy"; }
100+
};
101+
102+
template <PrimitiveType T>
103+
struct AggregateFunctionEntropySingleNumericData
104+
: public AggregateFunctionEntropyData<typename PrimitiveTypeTraits<T>::CppType> {
105+
using Base = AggregateFunctionEntropyData<typename PrimitiveTypeTraits<T>::CppType>;
106+
107+
void add(const IColumn** columns, size_t /* columns_num */, size_t row_num, Arena&) {
108+
const auto& vec = assert_cast<const typename PrimitiveTypeTraits<T>::ColumnType&,
109+
TypeCheckOnRelease::DISABLE>(*columns[0])
110+
.get_data();
111+
Base::add(vec[row_num]);
112+
}
113+
};
114+
115+
struct AggregateFunctionEntropySingleStringData
116+
: public AggregateFunctionEntropyData<UInt128, UInt128TrivialHash> {
117+
using Base = AggregateFunctionEntropyData<UInt128, UInt128TrivialHash>;
118+
119+
void add(const IColumn** columns, size_t /* columns_num */, size_t row_num, Arena&) {
120+
auto key = columns[0]->get_data_at(row_num);
121+
auto hash_value = XXH_INLINE_XXH128(key.data, key.size, 0);
122+
Base::add(UInt128 {hash_value.high64, hash_value.low64});
123+
}
124+
};
125+
126+
struct AggregateFunctionEntropyGenericData
127+
: public AggregateFunctionEntropyData<UInt128, UInt128TrivialHash> {
128+
using Base = AggregateFunctionEntropyData<UInt128, UInt128TrivialHash>;
129+
130+
void add(const IColumn** columns, size_t columns_num, size_t row_num, Arena& arena) {
131+
const char* begin = nullptr;
132+
StringRef key(begin, 0);
133+
for (size_t i = 0; i < columns_num; ++i) {
134+
auto cur_ref = columns[i]->serialize_value_into_arena(row_num, arena, begin);
135+
key.data = cur_ref.data - key.size;
136+
key.size += cur_ref.size;
137+
}
138+
auto hash_value = XXH_INLINE_XXH128(key.data, key.size, 0);
139+
Base::add(UInt128 {hash_value.high64, hash_value.low64});
140+
}
141+
};
142+
143+
template <typename Data>
144+
class AggregateFunctionEntropy
145+
: public IAggregateFunctionDataHelper<Data, AggregateFunctionEntropy<Data>>,
146+
VarargsExpression,
147+
NullableAggregateFunction {
148+
private:
149+
size_t arguments_num;
150+
151+
public:
152+
AggregateFunctionEntropy(const DataTypes& arguments)
153+
: IAggregateFunctionDataHelper<Data, AggregateFunctionEntropy<Data>>(arguments),
154+
arguments_num(arguments.size()) {}
155+
156+
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
157+
Arena& arena) const override {
158+
this->data(place).add(columns, arguments_num, row_num, arena);
159+
}
160+
161+
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
162+
Arena& arena) const override {
163+
this->data(place).merge(this->data(rhs));
164+
}
165+
166+
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
167+
this->data(place).write(buf);
168+
}
169+
170+
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
171+
Arena&) const override {
172+
this->data(place).read(buf);
173+
}
174+
175+
void insert_result_into(ConstAggregateDataPtr place, IColumn& to) const override {
176+
auto& column = assert_cast<ColumnFloat64&>(to);
177+
column.get_data().push_back(this->data(place).get_result());
178+
}
179+
180+
void reset(AggregateDataPtr place) const override { this->data(place).clear(); }
181+
182+
String get_name() const override { return Data::get_name(); }
183+
184+
DataTypePtr get_return_type() const override {
185+
return std::make_shared<DataTypeNumber<TYPE_DOUBLE>>();
186+
}
187+
};
188+
189+
} // namespace doris::vectorized
190+
191+
#include "common/compile_check_end.h"

be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ void register_aggregate_function_percentile_reservoir(AggregateFunctionSimpleFac
7979
void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory& factory);
8080
void register_aggregate_function_bool_union(AggregateFunctionSimpleFactory& factory);
8181
void register_aggregate_function_sem(AggregateFunctionSimpleFactory& factory);
82+
void register_aggregate_function_entropy(AggregateFunctionSimpleFactory& factory);
8283

8384
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
8485
static std::once_flag oc;
@@ -135,6 +136,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
135136
register_aggregate_function_ai_agg(instance);
136137
register_aggregate_function_bool_union(instance);
137138
register_aggregate_function_sem(instance);
139+
register_aggregate_function_entropy(instance);
138140
// Register foreach and foreachv2 functions
139141
register_aggregate_function_combinator_foreach(instance);
140142
register_aggregate_function_combinator_foreachv2(instance);

be/test/testutil/column_helper.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,15 @@ struct ColumnHelper {
127127
return block;
128128
}
129129

130-
template <typename DataType>
131-
static Block create_block(const std::vector<typename DataType::FieldType>& data1,
132-
const std::vector<typename DataType::FieldType>& data2) {
133-
auto column1 = create_column<DataType>(data1);
134-
auto column2 = create_column<DataType>(data2);
135-
auto data_type = std::make_shared<DataType>();
136-
Block block({ColumnWithTypeAndName(column1, data_type, "column1"),
137-
ColumnWithTypeAndName(column2, data_type, "column2")});
130+
template <typename DataType1, typename DataType2 = DataType1>
131+
static Block create_block(const std::vector<typename DataType1::FieldType>& data1,
132+
const std::vector<typename DataType2::FieldType>& data2) {
133+
auto column1 = create_column<DataType1>(data1);
134+
auto column2 = create_column<DataType2>(data2);
135+
auto data_type1 = std::make_shared<DataType1>();
136+
auto data_type2 = std::make_shared<DataType2>();
137+
Block block({ColumnWithTypeAndName(column1, data_type1, "column1"),
138+
ColumnWithTypeAndName(column2, data_type2, "column2")});
138139
return block;
139140
}
140141

0 commit comments

Comments
 (0)