Skip to content

Commit 1e75bfc

Browse files
HappenLeeCopilot
andauthored
[feature](be/fe) Add exponential_moving_average aggregate function (#63499)
Problem Summary: Doris lacked an exponential moving average aggregate function. This implements `exponential_moving_average(half_decay, value, time)` ported from ClickHouse's ExponentialMovingAverage aggregate. Algorithm: - State stores (value, time, half_decay) as doubles. - On add(): scale existing value to the new timestamp via `2^(-dt/half_decay)`, then accumulate. Store half_decay in state so it is available during merge(). - On merge(): advance both states to the later timestamp, then sum. This is commutative and associative, so row order doesn't matter. - Result: `value * (1 - 2^(-1/half_decay))` (normalised by sum of weights). Changes: - BE: `aggregate_function_ema.h/cpp` — state struct + function class, registered in `aggregate_function_simple_factory.cpp`. - FE: `ExponentialMovingAverage.java` — 3-arg DOUBLE signature, visitor method added to `AggregateFunctionVisitor`, registered in `BuiltinAggregateFunctions`. - Regression test: `query_p0/aggregate/exponential_moving_average/`. ### Release note New aggregate function `exponential_moving_average(half_decay, value, time)` computes the exponential moving average over a stream of (value, time) pairs with the given half-decay parameter. ### Check List (For Author) - Test: Regression test added (query_p0/aggregate/exponential_moving_average) - Behavior changed: No (new function) - Does this need documentation: [docs](apache/doris-website#3726) --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 895102b commit 1e75bfc

8 files changed

Lines changed: 524 additions & 0 deletions

File tree

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "exprs/aggregate/aggregate_function_ema.h"
19+
20+
#include "exprs/aggregate/aggregate_function_simple_factory.h"
21+
#include "exprs/aggregate/helpers.h"
22+
23+
namespace doris {
24+
25+
void register_aggregate_function_ema(AggregateFunctionSimpleFactory& factory) {
26+
factory.register_function_both(
27+
"exponential_moving_average",
28+
creator_without_type::creator<AggregateFunctionExponentialMovingAverage>);
29+
}
30+
31+
} // namespace doris
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
// This file is adapted from
19+
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionExponentialMovingAverage.cpp
20+
21+
#pragma once
22+
23+
#include <cmath>
24+
#include <memory>
25+
26+
#include "core/assert_cast.h"
27+
#include "core/column/column_vector.h"
28+
#include "core/data_type/data_type_number.h"
29+
#include "core/types.h"
30+
#include "exprs/aggregate/aggregate_function.h"
31+
32+
namespace doris {
33+
class Arena;
34+
class BufferReadable;
35+
class BufferWritable;
36+
class IColumn;
37+
38+
/**
39+
* Exponentially smoothed moving average over time.
40+
*
41+
* Each value corresponds to a timeunit index. The half_decay parameter is the
42+
* time lag at which exponential weights decay by one-half.
43+
*
44+
* State is a (value, time) pair representing the exponentially accumulated sum
45+
* at a reference time. To get the average, divide by sumWeights(half_decay).
46+
*
47+
* Formula:
48+
* scale(dt, x) = 2^(-dt/x)
49+
* sumWeights(x) = 1 / (1 - 2^(-1/x))
50+
* add(v, t): merge current state with point (v, t)
51+
* merge(a, b): move both to the later time, then sum values
52+
* get(): value / sumWeights(half_decay)
53+
*
54+
* Usage: exponential_moving_average(half_decay, value, timeunit)
55+
* - half_decay: constant double, the half-life period in timeunit units
56+
* - value: numeric column to average
57+
* - timeunit: numeric time index (not raw timestamp; use intDiv if needed)
58+
* Returns DOUBLE.
59+
*/
60+
struct ExponentialMovingAverageData {
61+
double value = 0.0;
62+
double time = 0.0;
63+
double half_decay = 0.0;
64+
65+
static double scale(double time_passed, double hd) { return std::exp2(-time_passed / hd); }
66+
67+
static double sum_weights(double hd) { return 1.0 / (1.0 - std::exp2(-1.0 / hd)); }
68+
69+
void add(double new_value, double current_time, double hd) {
70+
half_decay = hd;
71+
ExponentialMovingAverageData other;
72+
other.value = new_value;
73+
other.time = current_time;
74+
merge_point(other, hd);
75+
}
76+
77+
void merge_point(const ExponentialMovingAverageData& other, double hd) {
78+
if (time > other.time) {
79+
value = value + other.value * scale(time - other.time, hd);
80+
} else if (time < other.time) {
81+
value = other.value + value * scale(other.time - time, hd);
82+
time = other.time;
83+
} else {
84+
value = value + other.value;
85+
}
86+
}
87+
88+
void merge(const ExponentialMovingAverageData& rhs) {
89+
double hd = half_decay != 0.0 ? half_decay : rhs.half_decay;
90+
if (hd == 0.0) {
91+
return;
92+
}
93+
half_decay = hd;
94+
merge_point(rhs, hd);
95+
}
96+
97+
double get() const {
98+
if (half_decay == 0.0) {
99+
return 0.0;
100+
}
101+
return value / sum_weights(half_decay);
102+
}
103+
104+
void write(BufferWritable& buf) const {
105+
buf.write_binary(value);
106+
buf.write_binary(time);
107+
buf.write_binary(half_decay);
108+
}
109+
110+
void read(BufferReadable& buf) {
111+
buf.read_binary(value);
112+
buf.read_binary(time);
113+
buf.read_binary(half_decay);
114+
}
115+
116+
void reset() {
117+
value = 0.0;
118+
time = 0.0;
119+
half_decay = 0.0;
120+
}
121+
};
122+
123+
class AggregateFunctionExponentialMovingAverage final
124+
: public IAggregateFunctionDataHelper<ExponentialMovingAverageData,
125+
AggregateFunctionExponentialMovingAverage>,
126+
MultiExpression,
127+
NullableAggregateFunction {
128+
public:
129+
AggregateFunctionExponentialMovingAverage(const DataTypes& argument_types_)
130+
: IAggregateFunctionDataHelper<ExponentialMovingAverageData,
131+
AggregateFunctionExponentialMovingAverage>(
132+
argument_types_) {}
133+
134+
String get_name() const override { return "exponential_moving_average"; }
135+
136+
DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
137+
138+
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
139+
140+
void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
141+
Arena&) const override {
142+
const double half_decay =
143+
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0])
144+
.get_data()[row_num];
145+
const double new_value =
146+
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1])
147+
.get_data()[row_num];
148+
const double current_time =
149+
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2])
150+
.get_data()[row_num];
151+
this->data(place).add(new_value, current_time, half_decay);
152+
}
153+
154+
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
155+
Arena&) const override {
156+
this->data(place).merge(this->data(rhs));
157+
}
158+
159+
void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
160+
this->data(place).write(buf);
161+
}
162+
163+
void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
164+
Arena&) const override {
165+
this->data(place).read(buf);
166+
}
167+
168+
void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
169+
assert_cast<ColumnFloat64&>(to).get_data().push_back(this->data(place).get());
170+
}
171+
};
172+
173+
} // namespace doris

be/src/exprs/aggregate/aggregate_function_simple_factory.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ void register_aggregate_function_percentile_reservoir(AggregateFunctionSimpleFac
7979
void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory& factory);
8080
void register_aggregate_function_bool_union(AggregateFunctionSimpleFactory& factory);
8181
void register_aggregate_function_sem(AggregateFunctionSimpleFactory& factory);
82+
void register_aggregate_function_ema(AggregateFunctionSimpleFactory& factory);
8283

8384
AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
8485
static std::once_flag oc;
@@ -136,6 +137,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
136137
register_aggregate_function_ai_agg(instance);
137138
register_aggregate_function_bool_union(instance);
138139
register_aggregate_function_sem(instance);
140+
register_aggregate_function_ema(instance);
139141
// Register foreach and foreachv2 functions
140142
register_aggregate_function_combinator_foreach(instance);
141143
register_aggregate_function_combinator_foreachv2(instance);

fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinAggregateFunctions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
3939
import org.apache.doris.nereids.trees.expressions.functions.agg.Covar;
4040
import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp;
41+
import org.apache.doris.nereids.trees.expressions.functions.agg.ExponentialMovingAverage;
4142
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect;
4243
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayUnion;
4344
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
@@ -139,6 +140,7 @@ private BuiltinAggregateFunctions() {
139140
agg(CollectSet.class, "collect_set", "group_uniq_array"),
140141
agg(Corr.class, "corr"),
141142
agg(CorrWelford.class, "corr_welford"),
143+
agg(ExponentialMovingAverage.class, "exponential_moving_average"),
142144
agg(Count.class, "count"),
143145
agg(CountByEnum.class, "count_by_enum"),
144146
agg(Covar.class, "covar", "covar_pop"),
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
package org.apache.doris.nereids.trees.expressions.functions.agg;
19+
20+
import org.apache.doris.catalog.FunctionSignature;
21+
import org.apache.doris.nereids.exceptions.AnalysisException;
22+
import org.apache.doris.nereids.trees.expressions.Expression;
23+
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
24+
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
25+
import org.apache.doris.nereids.types.DoubleType;
26+
27+
import com.google.common.base.Preconditions;
28+
import com.google.common.collect.ImmutableList;
29+
30+
import java.util.List;
31+
32+
/**
33+
* Exponential Moving Average aggregate function.
34+
*
35+
* <p>Computes the exponentially smoothed moving average over time-indexed values.
36+
* The half_decay parameter controls the half-life period: the time after which the
37+
* exponential weight of a past value decays by a factor of 1/2.
38+
*
39+
* <p>Signature: {@code exponential_moving_average(half_decay DOUBLE, value DOUBLE,
40+
* timeunit DOUBLE) -> DOUBLE}
41+
*
42+
* <p>The timeunit argument is a numeric time index, not a raw timestamp. For
43+
* timestamp columns use {@code intDiv(toUnixTimestamp(ts), interval_seconds)}.
44+
*/
45+
public class ExponentialMovingAverage extends NullableAggregateFunction
46+
implements ExplicitlyCastableSignature {
47+
48+
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
49+
FunctionSignature.ret(DoubleType.INSTANCE)
50+
.args(DoubleType.INSTANCE, DoubleType.INSTANCE, DoubleType.INSTANCE)
51+
);
52+
53+
/**
54+
* Constructor with 3 arguments: (half_decay, value, timeunit).
55+
*/
56+
public ExponentialMovingAverage(Expression halfDecay, Expression value, Expression timeunit) {
57+
this(false, halfDecay, value, timeunit);
58+
}
59+
60+
/**
61+
* Constructor with distinct flag and 3 arguments.
62+
*/
63+
public ExponentialMovingAverage(boolean distinct, Expression halfDecay,
64+
Expression value, Expression timeunit) {
65+
this(distinct, false, halfDecay, value, timeunit);
66+
}
67+
68+
/**
69+
* Full constructor.
70+
*/
71+
public ExponentialMovingAverage(boolean distinct, boolean alwaysNullable,
72+
Expression halfDecay, Expression value, Expression timeunit) {
73+
super("exponential_moving_average", distinct, alwaysNullable, halfDecay, value, timeunit);
74+
}
75+
76+
/** Constructor for withChildren and reuse signature. */
77+
private ExponentialMovingAverage(NullableAggregateFunctionParams functionParams) {
78+
super(functionParams);
79+
}
80+
81+
@Override
82+
public void checkLegalityBeforeTypeCoercion() {
83+
if (!getArgument(0).isConstant()) {
84+
throw new AnalysisException("The half_decay argument of "
85+
+ getName() + " must be a constant");
86+
}
87+
if (!getArgumentType(0).isNumericType()) {
88+
throw new AnalysisException("The half_decay argument of "
89+
+ getName() + " must be numeric");
90+
}
91+
if (!getArgumentType(1).isNumericType()) {
92+
throw new AnalysisException("The value argument of "
93+
+ getName() + " must be numeric");
94+
}
95+
if (!getArgumentType(2).isNumericType()) {
96+
throw new AnalysisException("The timeunit argument of "
97+
+ getName() + " must be numeric");
98+
}
99+
}
100+
101+
@Override
102+
public ExponentialMovingAverage withDistinctAndChildren(boolean distinct,
103+
List<Expression> children) {
104+
Preconditions.checkArgument(children.size() == 3);
105+
return new ExponentialMovingAverage(getFunctionParams(distinct, children));
106+
}
107+
108+
@Override
109+
public ExponentialMovingAverage withAlwaysNullable(boolean alwaysNullable) {
110+
return new ExponentialMovingAverage(getAlwaysNullableFunctionParams(alwaysNullable));
111+
}
112+
113+
@Override
114+
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
115+
return visitor.visitExponentialMovingAverage(this, context);
116+
}
117+
118+
@Override
119+
public List<FunctionSignature> getSignatures() {
120+
return SIGNATURES;
121+
}
122+
}

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
4040
import org.apache.doris.nereids.trees.expressions.functions.agg.Covar;
4141
import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp;
42+
import org.apache.doris.nereids.trees.expressions.functions.agg.ExponentialMovingAverage;
4243
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect;
4344
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayUnion;
4445
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
@@ -133,6 +134,10 @@ default R visitBitmapAgg(BitmapAgg bitmapAgg, C context) {
133134
return visitAggregateFunction(bitmapAgg, context);
134135
}
135136

137+
default R visitExponentialMovingAverage(ExponentialMovingAverage ema, C context) {
138+
return visitNullableAggregateFunction(ema, context);
139+
}
140+
136141
default R visitBitmapIntersect(BitmapIntersect bitmapIntersect, C context) {
137142
return visitAggregateFunction(bitmapIntersect, context);
138143
}

0 commit comments

Comments
 (0)