Skip to content

Commit 68372fc

Browse files
committed
Fix decimal floor/ceil (#10365)
1 parent c6dc5fc commit 68372fc

3 files changed

Lines changed: 298 additions & 73 deletions

File tree

dbms/src/Functions/FunctionsRound.h

Lines changed: 80 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -199,77 +199,6 @@ enum class RoundingMode
199199
#endif
200200
};
201201

202-
/** Rounding functions for decimal values
203-
*/
204-
205-
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
206-
struct DecimalRoundingComputation
207-
{
208-
static_assert(IsDecimal<T>);
209-
static const size_t data_count = 1;
210-
static size_t prepare(size_t scale) { return scale; }
211-
// compute need decimal_scale to interpret decimals
212-
static inline void compute(
213-
const T * __restrict in,
214-
size_t scale,
215-
OutputType * __restrict out,
216-
ScaleType decimal_scale)
217-
{
218-
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
219-
Float64 val = in->template toFloat<Float64>(decimal_scale);
220-
221-
if constexpr (scale_mode == ScaleMode::Positive)
222-
{
223-
val = val * scale;
224-
}
225-
else if constexpr (scale_mode == ScaleMode::Negative)
226-
{
227-
val = val / scale;
228-
}
229-
230-
if constexpr (rounding_mode == RoundingMode::Round)
231-
{
232-
val = round(val);
233-
}
234-
else if constexpr (rounding_mode == RoundingMode::Floor)
235-
{
236-
val = floor(val);
237-
}
238-
else if constexpr (rounding_mode == RoundingMode::Ceil)
239-
{
240-
val = ceil(val);
241-
}
242-
else if constexpr (rounding_mode == RoundingMode::Trunc)
243-
{
244-
val = trunc(val);
245-
}
246-
247-
248-
if constexpr (scale_mode == ScaleMode::Positive)
249-
{
250-
val = val / scale;
251-
}
252-
else if constexpr (scale_mode == ScaleMode::Negative)
253-
{
254-
val = val * scale;
255-
}
256-
257-
if constexpr (std::is_same_v<T, OutputType>)
258-
{
259-
*out = ToDecimal<Float64, T>(val, decimal_scale);
260-
}
261-
else if constexpr (std::is_same_v<OutputType, Int64>)
262-
{
263-
*out = static_cast<Int64>(val);
264-
}
265-
else
266-
{
267-
; // never arrived here
268-
}
269-
}
270-
};
271-
272-
273202
/** Rounding functions for integer values.
274203
*/
275204
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
@@ -327,12 +256,90 @@ struct IntegerRoundingComputation
327256
}
328257
}
329258

330-
static ALWAYS_INLINE void compute(const T * __restrict in, size_t scale, T * __restrict out)
259+
static ALWAYS_INLINE void compute(const T * __restrict in, T scale, T * __restrict out)
331260
{
332261
*out = compute(*in, scale);
333262
}
334263
};
335264

265+
/** Rounding functions for decimal values
266+
*/
267+
268+
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
269+
struct DecimalRoundingComputation
270+
{
271+
static_assert(IsDecimal<T>);
272+
static const size_t data_count = 1;
273+
static size_t prepare(size_t scale) { return scale; }
274+
// compute need decimal_scale to interpret decimals
275+
static inline void compute(
276+
const T * __restrict in,
277+
size_t scale,
278+
OutputType * __restrict out,
279+
ScaleType decimal_scale)
280+
{
281+
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
282+
using NativeType = T::NativeType;
283+
// Currently, we only use DecimalRoundingComputation for floor/ceil.
284+
// As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
285+
// So, we only handle ScaleMode::Zero here.
286+
if constexpr (scale_mode == ScaleMode::Zero)
287+
{
288+
using Op = IntegerRoundingComputation<NativeType, rounding_mode, ScaleMode::Negative>;
289+
auto scale_factor = intExp10OfSize<NativeType>(decimal_scale);
290+
291+
if constexpr (std::is_same_v<T, OutputType>)
292+
{
293+
Op::compute(&in->value, scale_factor, &out->value);
294+
}
295+
else if constexpr (std::is_same_v<OutputType, Int64>)
296+
{
297+
try
298+
{
299+
if constexpr (rounding_mode == RoundingMode::Floor)
300+
{
301+
auto x = in->value;
302+
if (x < 0)
303+
x -= scale_factor - 1;
304+
*out = static_cast<Int64>(x / scale_factor);
305+
}
306+
else if constexpr (rounding_mode == RoundingMode::Ceil)
307+
{
308+
auto x = in->value;
309+
if (x >= 0)
310+
x += scale_factor - 1;
311+
*out = static_cast<Int64>(x / scale_factor);
312+
}
313+
else
314+
{
315+
throw Exception(
316+
"Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation",
317+
ErrorCodes::LOGICAL_ERROR);
318+
}
319+
}
320+
catch (const std::overflow_error & e)
321+
{
322+
throw Exception(
323+
"Logical error: unexpected Type of DecimalRoundingComputation for INT result",
324+
ErrorCodes::LOGICAL_ERROR);
325+
}
326+
}
327+
else
328+
{
329+
throw Exception(
330+
"Logical error: unexpected OutputType of DecimalRoundingComputation",
331+
ErrorCodes::LOGICAL_ERROR);
332+
}
333+
}
334+
else
335+
{
336+
throw Exception(
337+
"Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
338+
+ toString(scale),
339+
ErrorCodes::LOGICAL_ERROR);
340+
}
341+
}
342+
};
336343

337344
#if __SSE4_1__
338345

@@ -540,7 +547,7 @@ struct IntegerRoundingImpl
540547

541548
while (p_in < end_in)
542549
{
543-
Op::compute(p_in, scale, p_out);
550+
Op::compute(p_in, static_cast<T>(scale), p_out);
544551
++p_in;
545552
++p_out;
546553
}

libs/libcommon/include/common/intExp.h

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,202 @@ inline uint64_t intExp10(int x)
5656

5757
return table[x];
5858
}
59+
60+
constexpr int64_t exp10_i64_table[]
61+
= {1LL,
62+
10LL,
63+
100LL,
64+
1000LL,
65+
10000LL,
66+
100000LL,
67+
1000000LL,
68+
10000000LL,
69+
100000000LL,
70+
1000000000LL,
71+
10000000000LL,
72+
100000000000LL,
73+
1000000000000LL,
74+
10000000000000LL,
75+
100000000000000LL,
76+
1000000000000000LL,
77+
10000000000000000LL,
78+
100000000000000000LL,
79+
1000000000000000000LL};
80+
81+
constexpr Int128 exp10_i128_table[]
82+
= {static_cast<Int128>(1LL),
83+
static_cast<Int128>(10LL),
84+
static_cast<Int128>(100LL),
85+
static_cast<Int128>(1000LL),
86+
static_cast<Int128>(10000LL),
87+
static_cast<Int128>(100000LL),
88+
static_cast<Int128>(1000000LL),
89+
static_cast<Int128>(10000000LL),
90+
static_cast<Int128>(100000000LL),
91+
static_cast<Int128>(1000000000LL),
92+
static_cast<Int128>(10000000000LL),
93+
static_cast<Int128>(100000000000LL),
94+
static_cast<Int128>(1000000000000LL),
95+
static_cast<Int128>(10000000000000LL),
96+
static_cast<Int128>(100000000000000LL),
97+
static_cast<Int128>(1000000000000000LL),
98+
static_cast<Int128>(10000000000000000LL),
99+
static_cast<Int128>(100000000000000000LL),
100+
static_cast<Int128>(1000000000000000000LL),
101+
static_cast<Int128>(1000000000000000000LL) * 10LL,
102+
static_cast<Int128>(1000000000000000000LL) * 100LL,
103+
static_cast<Int128>(1000000000000000000LL) * 1000LL,
104+
static_cast<Int128>(1000000000000000000LL) * 10000LL,
105+
static_cast<Int128>(1000000000000000000LL) * 100000LL,
106+
static_cast<Int128>(1000000000000000000LL) * 1000000LL,
107+
static_cast<Int128>(1000000000000000000LL) * 10000000LL,
108+
static_cast<Int128>(1000000000000000000LL) * 100000000LL,
109+
static_cast<Int128>(1000000000000000000LL) * 1000000000LL,
110+
static_cast<Int128>(1000000000000000000LL) * 10000000000LL,
111+
static_cast<Int128>(1000000000000000000LL) * 100000000000LL,
112+
static_cast<Int128>(1000000000000000000LL) * 1000000000000LL,
113+
static_cast<Int128>(1000000000000000000LL) * 10000000000000LL,
114+
static_cast<Int128>(1000000000000000000LL) * 100000000000000LL,
115+
static_cast<Int128>(1000000000000000000LL) * 1000000000000000LL,
116+
static_cast<Int128>(1000000000000000000LL) * 10000000000000000LL,
117+
static_cast<Int128>(1000000000000000000LL) * 100000000000000000LL,
118+
static_cast<Int128>(1000000000000000000LL) * 100000000000000000LL * 10LL,
119+
static_cast<Int128>(1000000000000000000LL) * 100000000000000000LL * 100LL,
120+
static_cast<Int128>(1000000000000000000LL) * 100000000000000000LL * 1000LL};
121+
122+
constexpr Int256 i10e18{1000000000000000000ll};
123+
constexpr Int256 exp10_i256_table[] = {
124+
static_cast<Int256>(1ll),
125+
static_cast<Int256>(10ll),
126+
static_cast<Int256>(100ll),
127+
static_cast<Int256>(1000ll),
128+
static_cast<Int256>(10000ll),
129+
static_cast<Int256>(100000ll),
130+
static_cast<Int256>(1000000ll),
131+
static_cast<Int256>(10000000ll),
132+
static_cast<Int256>(100000000ll),
133+
static_cast<Int256>(1000000000ll),
134+
static_cast<Int256>(10000000000ll),
135+
static_cast<Int256>(100000000000ll),
136+
static_cast<Int256>(1000000000000ll),
137+
static_cast<Int256>(10000000000000ll),
138+
static_cast<Int256>(100000000000000ll),
139+
static_cast<Int256>(1000000000000000ll),
140+
static_cast<Int256>(10000000000000000ll),
141+
static_cast<Int256>(100000000000000000ll),
142+
i10e18,
143+
i10e18 * 10ll,
144+
i10e18 * 100ll,
145+
i10e18 * 1000ll,
146+
i10e18 * 10000ll,
147+
i10e18 * 100000ll,
148+
i10e18 * 1000000ll,
149+
i10e18 * 10000000ll,
150+
i10e18 * 100000000ll,
151+
i10e18 * 1000000000ll,
152+
i10e18 * 10000000000ll,
153+
i10e18 * 100000000000ll,
154+
i10e18 * 1000000000000ll,
155+
i10e18 * 10000000000000ll,
156+
i10e18 * 100000000000000ll,
157+
i10e18 * 1000000000000000ll,
158+
i10e18 * 10000000000000000ll,
159+
i10e18 * 100000000000000000ll,
160+
i10e18 * 100000000000000000ll * 10ll,
161+
i10e18 * 100000000000000000ll * 100ll,
162+
i10e18 * 100000000000000000ll * 1000ll,
163+
i10e18 * 100000000000000000ll * 10000ll,
164+
i10e18 * 100000000000000000ll * 100000ll,
165+
i10e18 * 100000000000000000ll * 1000000ll,
166+
i10e18 * 100000000000000000ll * 10000000ll,
167+
i10e18 * 100000000000000000ll * 100000000ll,
168+
i10e18 * 100000000000000000ll * 1000000000ll,
169+
i10e18 * 100000000000000000ll * 10000000000ll,
170+
i10e18 * 100000000000000000ll * 100000000000ll,
171+
i10e18 * 100000000000000000ll * 1000000000000ll,
172+
i10e18 * 100000000000000000ll * 10000000000000ll,
173+
i10e18 * 100000000000000000ll * 100000000000000ll,
174+
i10e18 * 100000000000000000ll * 1000000000000000ll,
175+
i10e18 * 100000000000000000ll * 10000000000000000ll,
176+
i10e18 * 100000000000000000ll * 100000000000000000ll,
177+
i10e18 * 100000000000000000ll * 100000000000000000ll * 10ll,
178+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100ll,
179+
i10e18 * 100000000000000000ll * 100000000000000000ll * 1000ll,
180+
i10e18 * 100000000000000000ll * 100000000000000000ll * 10000ll,
181+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000ll,
182+
i10e18 * 100000000000000000ll * 100000000000000000ll * 1000000ll,
183+
i10e18 * 100000000000000000ll * 100000000000000000ll * 10000000ll,
184+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000ll,
185+
i10e18 * 100000000000000000ll * 100000000000000000ll * 1000000000ll,
186+
i10e18 * 100000000000000000ll * 100000000000000000ll * 10000000000ll,
187+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000ll,
188+
i10e18 * 100000000000000000ll * 100000000000000000ll * 1000000000000ll,
189+
i10e18 * 100000000000000000ll * 100000000000000000ll * 10000000000000ll,
190+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000ll,
191+
i10e18 * 100000000000000000ll * 100000000000000000ll * 1000000000000000ll,
192+
i10e18 * 100000000000000000ll * 100000000000000000ll * 10000000000000000ll,
193+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll,
194+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 10ll,
195+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 100ll,
196+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 1000ll,
197+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 10000ll,
198+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 100000ll,
199+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 1000000ll,
200+
i10e18 * 100000000000000000ll * 100000000000000000ll * 100000000000000000ll * 10000000ll,
201+
};
202+
203+
constexpr int exp10_i32(int x)
204+
{
205+
if (x < 0)
206+
return 0;
207+
if (x > 9)
208+
return std::numeric_limits<int>::max();
209+
210+
constexpr int exp10_i32_table[10] = {1, 10, 100, 1000, 10000, 100000, 1000000, 10000000, 100000000, 1000000000};
211+
return exp10_i32_table[x];
212+
}
213+
214+
constexpr int64_t exp10_i64(int x)
215+
{
216+
if (x < 0)
217+
return 0;
218+
if (x > 18)
219+
return std::numeric_limits<int64_t>::max();
220+
221+
return exp10_i64_table[x];
222+
}
223+
224+
constexpr Int128 exp10_i128(int x)
225+
{
226+
if (x < 0)
227+
return 0;
228+
if (x > 38)
229+
return std::numeric_limits<Int128>::max();
230+
231+
return exp10_i128_table[x];
232+
}
233+
234+
constexpr Int256 exp10_i256(int x)
235+
{
236+
if (x < 0)
237+
return 0;
238+
if (x > 76)
239+
return std::numeric_limits<Int256>::max();
240+
241+
return exp10_i256_table[x];
242+
}
243+
244+
245+
/// intExp10 returning the type T.
246+
template <typename T>
247+
T intExp10OfSize(int x)
248+
{
249+
if constexpr (sizeof(T) <= 4)
250+
return static_cast<T>(exp10_i32(x));
251+
else if constexpr (sizeof(T) <= 8)
252+
return exp10_i64(x);
253+
else if constexpr (sizeof(T) <= 16)
254+
return exp10_i128(x);
255+
else
256+
return exp10_i256(x);
257+
}

0 commit comments

Comments
 (0)