Skip to content

Commit e984a38

Browse files
authored
[Enhancement](doris-future) Support REGR_ARGX, REGR_ARGY, REGR_COUNT and REGR_R2 aggregate functions (#61352)
Issue Number: #38974, #38976 Related PR: #55940 Problem Summary: This PR completes the remaining statistical regression aggregate functions (`REGR_*`) by adding support for `REGR_AVGX`, `REGR_AVGY`, `REGR_COUNT`, and `REGR_R2`, based on the unified Regr aggregate function approach introduced in #55940.
1 parent df8e0a9 commit e984a38

34 files changed

Lines changed: 2041 additions & 608 deletions
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "exprs/aggregate/aggregate_function_regr.h"
19+
20+
#include "core/data_type/data_type.h"
21+
#include "core/data_type/define_primitive_type.h"
22+
#include "exprs/aggregate/aggregate_function.h"
23+
#include "exprs/aggregate/aggregate_function_simple_factory.h"
24+
#include "exprs/aggregate/factory_helpers.h"
25+
#include "exprs/aggregate/helpers.h"
26+
27+
namespace doris {
28+
29+
template <RegrFunctionKind kind>
30+
AggregateFunctionPtr create_aggregate_function_regr(const std::string& name,
31+
const DataTypes& argument_types,
32+
const DataTypePtr& result_type,
33+
const bool result_is_nullable,
34+
const AggregateFunctionAttr& attr) {
35+
assert_arity_range(name, argument_types, 2, 2);
36+
DCHECK(argument_types[0]->get_primitive_type() == TYPE_DOUBLE);
37+
DCHECK(argument_types[1]->get_primitive_type() == TYPE_DOUBLE);
38+
if constexpr (kind == RegrFunctionKind::regr_count) {
39+
DCHECK(result_type->get_primitive_type() == TYPE_BIGINT);
40+
DCHECK(!result_is_nullable);
41+
} else {
42+
DCHECK(result_type->get_primitive_type() == TYPE_DOUBLE);
43+
DCHECK(result_is_nullable);
44+
}
45+
46+
bool y_nullable_input = argument_types[0]->is_nullable();
47+
bool x_nullable_input = argument_types[1]->is_nullable();
48+
if (y_nullable_input) {
49+
if (x_nullable_input) {
50+
return creator_without_type::create_ignore_nullable<
51+
AggregateFunctionRegr<TYPE_DOUBLE, kind, true, true>>(argument_types,
52+
result_is_nullable, attr);
53+
} else {
54+
return creator_without_type::create_ignore_nullable<
55+
AggregateFunctionRegr<TYPE_DOUBLE, kind, true, false>>(
56+
argument_types, result_is_nullable, attr);
57+
}
58+
} else {
59+
if (x_nullable_input) {
60+
return creator_without_type::create_ignore_nullable<
61+
AggregateFunctionRegr<TYPE_DOUBLE, kind, false, true>>(
62+
argument_types, result_is_nullable, attr);
63+
} else {
64+
return creator_without_type::create_ignore_nullable<
65+
AggregateFunctionRegr<TYPE_DOUBLE, kind, false, false>>(
66+
argument_types, result_is_nullable, attr);
67+
}
68+
}
69+
}
70+
71+
void register_aggregate_function_regr(AggregateFunctionSimpleFactory& factory) {
72+
factory.register_function_both(RegrTraits<RegrFunctionKind::regr_avgx>::name,
73+
create_aggregate_function_regr<RegrFunctionKind::regr_avgx>);
74+
factory.register_function_both(RegrTraits<RegrFunctionKind::regr_avgy>::name,
75+
create_aggregate_function_regr<RegrFunctionKind::regr_avgy>);
76+
factory.register_function_both(RegrTraits<RegrFunctionKind::regr_count>::name,
77+
create_aggregate_function_regr<RegrFunctionKind::regr_count>);
78+
factory.register_function_both(RegrTraits<RegrFunctionKind::regr_slope>::name,
79+
create_aggregate_function_regr<RegrFunctionKind::regr_slope>);
80+
factory.register_function_both(
81+
RegrTraits<RegrFunctionKind::regr_intercept>::name,
82+
create_aggregate_function_regr<RegrFunctionKind::regr_intercept>);
83+
factory.register_function_both(RegrTraits<RegrFunctionKind::regr_sxx>::name,
84+
create_aggregate_function_regr<RegrFunctionKind::regr_sxx>);
85+
factory.register_function_both(RegrTraits<RegrFunctionKind::regr_syy>::name,
86+
create_aggregate_function_regr<RegrFunctionKind::regr_syy>);
87+
factory.register_function_both(RegrTraits<RegrFunctionKind::regr_sxy>::name,
88+
create_aggregate_function_regr<RegrFunctionKind::regr_sxy>);
89+
factory.register_function_both(RegrTraits<RegrFunctionKind::regr_r2>::name,
90+
create_aggregate_function_regr<RegrFunctionKind::regr_r2>);
91+
}
92+
93+
} // namespace doris

0 commit comments

Comments
 (0)