Skip to content

Commit 1f1ce67

Browse files
committed
cast inplace for mean
1 parent 95dc7ed commit 1f1ce67

2 files changed

Lines changed: 19 additions & 20 deletions

File tree

src/core/column/mean.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,20 +25,20 @@
2525
namespace dt {
2626

2727

28-
template <typename T_IN, typename T_OUT, bool IS_GROUPED>
29-
class Mean_ColumnImpl : public ReduceUnary_ColumnImpl<T_IN, T_OUT, IS_GROUPED> {
28+
template <typename T_IN, bool IS_GROUPED>
29+
class Mean_ColumnImpl : public ReduceUnary_ColumnImpl<T_IN, T_IN, IS_GROUPED> {
3030
public:
31-
using ReduceUnary_ColumnImpl<T_IN, T_OUT, IS_GROUPED>::ReduceUnary_ColumnImpl;
31+
using ReduceUnary_ColumnImpl<T_IN, T_IN, IS_GROUPED>::ReduceUnary_ColumnImpl;
3232

33-
bool get_element(size_t i, T_OUT* out) const override {
33+
bool get_element(size_t i, T_IN* out) const override {
3434
T_IN value;
3535
size_t i0, i1;
3636
this->gby_.get_group(i, &i0, &i1);
3737

3838
if (IS_GROUPED){
3939
bool is_valid = this->col_.get_element(i, &value);
4040
if (!is_valid) return false;
41-
*out = static_cast<T_OUT>(value);
41+
*out = value;
4242
return true;
4343
} else {
4444
double sum = 0;
@@ -51,7 +51,7 @@ class Mean_ColumnImpl : public ReduceUnary_ColumnImpl<T_IN, T_OUT, IS_GROUPED> {
5151
}
5252
}
5353
if (!count) return false;
54-
*out = static_cast<T_OUT>(sum / static_cast<double>(count));
54+
*out = static_cast<T_IN>(sum / static_cast<double>(count));
5555
return true;
5656
}
5757

src/core/expr/fexpr_mean.cc

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,42 +80,41 @@ class FExpr_Mean : public FExpr_Func {
8080
));
8181
case SType::BOOL:
8282
case SType::INT8:
83-
return make<int8_t, double>(std::move(col), gby, is_grouped);
8483
case SType::INT16:
85-
return make<int16_t, double>(std::move(col), gby, is_grouped);
86-
case SType::INT32:
87-
return make<int32_t, double>(std::move(col), gby, is_grouped);
84+
case SType::INT32:
8885
case SType::INT64:
89-
return make<int64_t, double>(std::move(col), gby, is_grouped);
9086
case SType::FLOAT64:
91-
return make<double, double>(std::move(col), gby, is_grouped);
87+
return make<double>(std::move(col), SType::FLOAT64, gby, is_grouped);
9288
case SType::FLOAT32:
93-
return make<float, float>(std::move(col), gby, is_grouped);
89+
return make<float>(std::move(col), SType::FLOAT32, gby, is_grouped);
90+
9491
case SType::DATE32: {
95-
Column coli = make<double, double>(std::move(col), gby, is_grouped);
92+
Column coli = make<double>(std::move(col), SType::FLOAT64, gby, is_grouped);
9693
coli.cast_inplace(SType::DATE32);
9794
return coli;
9895
}
9996
case SType::TIME64: {
100-
Column coli = make<double, double>(std::move(col), gby, is_grouped);
97+
Column coli = make<double>(std::move(col), SType::FLOAT64, gby, is_grouped);
10198
coli.cast_inplace(SType::TIME64);
10299
return coli;
103-
}
100+
}
101+
104102
default:
105103
throw TypeError()
106104
<< "Invalid column of type `" << stype << "` in " << repr();
107105
}
108106
}
109107

110108

111-
template <typename T_IN, typename T_OUT>
112-
Column make(Column &&col, const Groupby& gby, bool is_grouped) const {
109+
template <typename T_IN>
110+
Column make(Column &&col, SType stype, const Groupby& gby, bool is_grouped) const {
111+
col.cast_inplace(stype);
113112
if (is_grouped) {
114-
return Column(new Latent_ColumnImpl(new Mean_ColumnImpl<T_IN, T_OUT, true>(
113+
return Column(new Latent_ColumnImpl(new Mean_ColumnImpl<T_IN, true>(
115114
std::move(col), gby
116115
)));
117116
} else {
118-
return Column(new Latent_ColumnImpl(new Mean_ColumnImpl<T_IN, T_OUT, false>(
117+
return Column(new Latent_ColumnImpl(new Mean_ColumnImpl<T_IN, false>(
119118
std::move(col), gby
120119
)));
121120
}

0 commit comments

Comments
 (0)