Skip to content

Commit baf139c

Browse files
authored
feat: implement truncate max for literals (#585)
1 parent 066bee0 commit baf139c

File tree

3 files changed

+357
-3
lines changed

3 files changed

+357
-3
lines changed

src/iceberg/test/truncate_util_test.cc

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <gtest/gtest.h>
2323

2424
#include "iceberg/expression/literal.h"
25+
#include "iceberg/test/matchers.h"
2526

2627
namespace iceberg {
2728

@@ -50,4 +51,141 @@ TEST(TruncateUtilTest, TruncateLiteral) {
5051
Literal::Binary(std::vector<uint8_t>(expected.begin(), expected.end())));
5152
}
5253

54+
TEST(TruncateUtilTest, TruncateBinaryMax) {
55+
std::vector<uint8_t> test1{1, 1, 2};
56+
std::vector<uint8_t> test2{1, 1, 0xFF, 2};
57+
std::vector<uint8_t> test3{0xFF, 0xFF, 0xFF, 2};
58+
std::vector<uint8_t> test4{1, 1, 0};
59+
std::vector<uint8_t> expected_output{1, 2};
60+
61+
// Test1: truncate {1, 1, 2} to 2 bytes -> {1, 2}
62+
ICEBERG_UNWRAP_OR_FAIL(auto result1,
63+
TruncateUtils::TruncateLiteralMax(Literal::Binary(test1), 2));
64+
EXPECT_EQ(result1, Literal::Binary(expected_output));
65+
66+
// Test2: truncate {1, 1, 0xFF, 2} to 2 bytes -> {1, 2}
67+
ICEBERG_UNWRAP_OR_FAIL(auto result2,
68+
TruncateUtils::TruncateLiteralMax(Literal::Binary(test2), 2));
69+
EXPECT_EQ(result2, Literal::Binary(expected_output));
70+
71+
// Test2b: truncate {1, 1, 0xFF, 2} to 3 bytes -> {1, 2}
72+
ICEBERG_UNWRAP_OR_FAIL(auto result2b,
73+
TruncateUtils::TruncateLiteralMax(Literal::Binary(test2), 3));
74+
EXPECT_EQ(result2b, Literal::Binary(expected_output));
75+
76+
// Test3: no truncation needed when length >= input size
77+
ICEBERG_UNWRAP_OR_FAIL(auto result3,
78+
TruncateUtils::TruncateLiteralMax(Literal::Binary(test3), 5));
79+
EXPECT_EQ(result3, Literal::Binary(test3));
80+
81+
// Test3b: cannot truncate when first bytes are all 0xFF
82+
EXPECT_THAT(TruncateUtils::TruncateLiteralMax(Literal::Binary(test3), 2),
83+
IsError(ErrorKind::kInvalidArgument));
84+
85+
// Test4: truncate {1, 1, 0} to 2 bytes -> {1, 2}
86+
ICEBERG_UNWRAP_OR_FAIL(auto result4,
87+
TruncateUtils::TruncateLiteralMax(Literal::Binary(test4), 2));
88+
EXPECT_EQ(result4, Literal::Binary(expected_output));
89+
}
90+
91+
TEST(TruncateUtilTest, TruncateStringMax) {
92+
// Test1: Japanese characters "イロハニホヘト"
93+
std::string test1 =
94+
"\xE3\x82\xA4\xE3\x83\xAD\xE3\x83\x8F\xE3\x83\x8B\xE3\x83\x9B\xE3\x83\x98\xE3\x83"
95+
"\x88";
96+
std::string test1_2_expected = "\xE3\x82\xA4\xE3\x83\xAE"; // "イヮ"
97+
std::string test1_3_expected = "\xE3\x82\xA4\xE3\x83\xAD\xE3\x83\x90"; // "イロバ"
98+
99+
ICEBERG_UNWRAP_OR_FAIL(auto result1_2,
100+
TruncateUtils::TruncateLiteralMax(Literal::String(test1), 2));
101+
EXPECT_EQ(result1_2, Literal::String(test1_2_expected));
102+
103+
ICEBERG_UNWRAP_OR_FAIL(auto result1_3,
104+
TruncateUtils::TruncateLiteralMax(Literal::String(test1), 3));
105+
EXPECT_EQ(result1_3, Literal::String(test1_3_expected));
106+
107+
// No truncation needed when length >= input size
108+
ICEBERG_UNWRAP_OR_FAIL(auto result1_7,
109+
TruncateUtils::TruncateLiteralMax(Literal::String(test1), 7));
110+
EXPECT_EQ(result1_7, Literal::String(test1));
111+
112+
ICEBERG_UNWRAP_OR_FAIL(auto result1_8,
113+
TruncateUtils::TruncateLiteralMax(Literal::String(test1), 8));
114+
EXPECT_EQ(result1_8, Literal::String(test1));
115+
116+
// Test2: Mixed characters "щщаεはчωいにπάほхεろへσκζ"
117+
std::string test2 =
118+
"\xD1\x89\xD1\x89\xD0\xB0\xCE\xB5\xE3\x81\xAF\xD1\x87\xCF\x89\xE3\x81\x84\xE3\x81"
119+
"\xAB\xCF\x80\xCE\xAC\xE3\x81\xBB\xD1\x85\xCE\xB5\xE3\x82\x8D\xE3\x81\xB8\xCF\x83"
120+
"\xCE\xBA\xCE\xB6";
121+
std::string test2_7_expected =
122+
"\xD1\x89\xD1\x89\xD0\xB0\xCE\xB5\xE3\x81\xAF\xD1\x87\xCF\x8A"; // "щщаεはчϊ"
123+
124+
ICEBERG_UNWRAP_OR_FAIL(auto result2_7,
125+
TruncateUtils::TruncateLiteralMax(Literal::String(test2), 7));
126+
EXPECT_EQ(result2_7, Literal::String(test2_7_expected));
127+
128+
// Test3: String with max 3-byte UTF-8 character "aनि\uFFFF\uFFFF"
129+
std::string test3 = "a\xE0\xA4\xA8\xE0\xA4\xBF\xEF\xBF\xBF\xEF\xBF\xBF";
130+
std::string test3_3_expected = "a\xE0\xA4\xA8\xE0\xA5\x80"; // "aनी"
131+
132+
ICEBERG_UNWRAP_OR_FAIL(auto result3_3,
133+
TruncateUtils::TruncateLiteralMax(Literal::String(test3), 3));
134+
EXPECT_EQ(result3_3, Literal::String(test3_3_expected));
135+
136+
// Test4: Max 3-byte UTF-8 character "\uFFFF\uFFFF"
137+
std::string test4 = "\xEF\xBF\xBF\xEF\xBF\xBF";
138+
std::string test4_1_expected = "\xF0\x90\x80\x80"; // U+10000 (first 4-byte UTF-8 char)
139+
140+
ICEBERG_UNWRAP_OR_FAIL(auto result4_1,
141+
TruncateUtils::TruncateLiteralMax(Literal::String(test4), 1));
142+
EXPECT_EQ(result4_1, Literal::String(test4_1_expected));
143+
144+
// Test5: Max 4-byte UTF-8 characters "\uDBFF\uDFFF\uDBFF\uDFFF"
145+
std::string test5 = "\xF4\x8F\xBF\xBF\xF4\x8F\xBF\xBF"; // U+10FFFF U+10FFFF
146+
EXPECT_THAT(TruncateUtils::TruncateLiteralMax(Literal::String(test5), 1),
147+
IsError(ErrorKind::kInvalidArgument));
148+
149+
// Test6: 4-byte UTF-8 character "\uD800\uDFFF\uD800\uDFFF"
150+
std::string test6 = "\xF0\x90\x8F\xBF\xF0\x90\x8F\xBF"; // U+103FF U+103FF
151+
std::string test6_1_expected = "\xF0\x90\x90\x80"; // U+10400
152+
153+
ICEBERG_UNWRAP_OR_FAIL(auto result6_1,
154+
TruncateUtils::TruncateLiteralMax(Literal::String(test6), 1));
155+
EXPECT_EQ(result6_1, Literal::String(test6_1_expected));
156+
157+
// Test7: Emoji "\uD83D\uDE02\uD83D\uDE02\uD83D\uDE02"
158+
std::string test7 = "\xF0\x9F\x98\x82\xF0\x9F\x98\x82\xF0\x9F\x98\x82"; // 😂😂😂
159+
std::string test7_2_expected = "\xF0\x9F\x98\x82\xF0\x9F\x98\x83"; // 😂😃
160+
std::string test7_1_expected = "\xF0\x9F\x98\x83"; // 😃
161+
162+
ICEBERG_UNWRAP_OR_FAIL(auto result7_2,
163+
TruncateUtils::TruncateLiteralMax(Literal::String(test7), 2));
164+
EXPECT_EQ(result7_2, Literal::String(test7_2_expected));
165+
166+
ICEBERG_UNWRAP_OR_FAIL(auto result7_1,
167+
TruncateUtils::TruncateLiteralMax(Literal::String(test7), 1));
168+
EXPECT_EQ(result7_1, Literal::String(test7_1_expected));
169+
170+
// Test8: Overflow case "a\uDBFF\uDFFFc"
171+
std::string test8 =
172+
"a\xF4\x8F\xBF\xBF"
173+
"c"; // a U+10FFFF c
174+
std::string test8_2_expected = "b";
175+
176+
ICEBERG_UNWRAP_OR_FAIL(auto result8_2,
177+
TruncateUtils::TruncateLiteralMax(Literal::String(test8), 2));
178+
EXPECT_EQ(result8_2, Literal::String(test8_2_expected));
179+
180+
// Test9: Skip surrogate range "a" + (char)(Character.MIN_SURROGATE - 1) + "b"
181+
std::string test9 =
182+
"a\xED\x9F\xBF"
183+
"b"; // a U+D7FF b
184+
std::string test9_2_expected = "a\xEE\x80\x80"; // a U+E000
185+
186+
ICEBERG_UNWRAP_OR_FAIL(auto result9_2,
187+
TruncateUtils::TruncateLiteralMax(Literal::String(test9), 2));
188+
EXPECT_EQ(result9_2, Literal::String(test9_2_expected));
189+
}
190+
53191
} // namespace iceberg

src/iceberg/util/truncate_util.cc

Lines changed: 192 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,105 @@
2929
namespace iceberg {
3030

3131
namespace {
32-
template <TypeId type_id>
33-
Literal TruncateLiteralImpl(const Literal& literal, int32_t width) {
34-
std::unreachable();
32+
constexpr uint32_t kUtf8MaxCodePoint = 0x10FFFF;
33+
constexpr uint32_t kUtf8MinSurrogate = 0xD800;
34+
constexpr uint32_t kUtf8MaxSurrogate = 0xDFFF;
35+
36+
std::optional<uint32_t> DecodeUtf8CodePoint(std::string_view source) {
37+
if (source.empty()) {
38+
return std::nullopt;
39+
}
40+
41+
auto byte0 = static_cast<uint8_t>(source[0]);
42+
43+
// 1-byte sequence (ASCII): 0xxxxxxx
44+
if (byte0 < 0x80) {
45+
return byte0;
46+
}
47+
48+
const auto size = source.size();
49+
50+
// 2-byte sequence: 110xxxxx 10xxxxxx
51+
if ((byte0 & 0xE0) == 0xC0) {
52+
if (size < 2) {
53+
return std::nullopt;
54+
}
55+
auto byte1 = static_cast<uint8_t>(source[1]);
56+
if ((byte1 & 0xC0) != 0x80) {
57+
return std::nullopt;
58+
}
59+
uint32_t code_point = ((byte0 & 0x1F) << 6) | (byte1 & 0x3F);
60+
// Check for overlong encoding
61+
if (code_point < 0x80) {
62+
return std::nullopt;
63+
}
64+
return code_point;
65+
}
66+
67+
// 3-byte sequence: 1110xxxx 10xxxxxx 10xxxxxx
68+
if ((byte0 & 0xF0) == 0xE0) {
69+
if (size < 3) {
70+
return std::nullopt;
71+
}
72+
auto byte1 = static_cast<uint8_t>(source[1]);
73+
auto byte2 = static_cast<uint8_t>(source[2]);
74+
if ((byte1 & 0xC0) != 0x80 || (byte2 & 0xC0) != 0x80) {
75+
return std::nullopt;
76+
}
77+
uint32_t code_point = ((byte0 & 0x0F) << 12) | ((byte1 & 0x3F) << 6) | (byte2 & 0x3F);
78+
// Check for overlong encoding and surrogate pairs
79+
if (code_point < 0x800 ||
80+
(code_point >= kUtf8MinSurrogate && code_point <= kUtf8MaxSurrogate)) {
81+
return std::nullopt;
82+
}
83+
return code_point;
84+
}
85+
86+
// 4-byte sequence: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
87+
if ((byte0 & 0xF8) == 0xF0) {
88+
if (size < 4) {
89+
return std::nullopt;
90+
}
91+
auto byte1 = static_cast<uint8_t>(source[1]);
92+
auto byte2 = static_cast<uint8_t>(source[2]);
93+
auto byte3 = static_cast<uint8_t>(source[3]);
94+
if ((byte1 & 0xC0) != 0x80 || (byte2 & 0xC0) != 0x80 || (byte3 & 0xC0) != 0x80) {
95+
return std::nullopt;
96+
}
97+
uint32_t code_point = ((byte0 & 0x07) << 18) | ((byte1 & 0x3F) << 12) |
98+
((byte2 & 0x3F) << 6) | (byte3 & 0x3F);
99+
// Check for overlong encoding and valid Unicode range
100+
if (code_point < 0x10000 || code_point > kUtf8MaxCodePoint) {
101+
return std::nullopt;
102+
}
103+
return code_point;
104+
}
105+
106+
// Invalid UTF-8 start byte
107+
return std::nullopt;
35108
}
36109

110+
void AppendUtf8CodePoint(uint32_t code_point, std::string& target) {
111+
if (code_point <= 0x7F) {
112+
target.push_back(static_cast<char>(code_point));
113+
} else if (code_point <= 0x7FF) {
114+
target.push_back(static_cast<char>(0xC0 | (code_point >> 6)));
115+
target.push_back(static_cast<char>(0x80 | (code_point & 0x3F)));
116+
} else if (code_point <= 0xFFFF) {
117+
target.push_back(static_cast<char>(0xE0 | (code_point >> 12)));
118+
target.push_back(static_cast<char>(0x80 | ((code_point >> 6) & 0x3F)));
119+
target.push_back(static_cast<char>(0x80 | (code_point & 0x3F)));
120+
} else {
121+
target.push_back(static_cast<char>(0xF0 | (code_point >> 18)));
122+
target.push_back(static_cast<char>(0x80 | ((code_point >> 12) & 0x3F)));
123+
target.push_back(static_cast<char>(0x80 | ((code_point >> 6) & 0x3F)));
124+
target.push_back(static_cast<char>(0x80 | (code_point & 0x3F)));
125+
}
126+
}
127+
128+
template <TypeId type_id>
129+
Literal TruncateLiteralImpl(const Literal& literal, int32_t width) = delete;
130+
37131
template <>
38132
Literal TruncateLiteralImpl<TypeId::kInt>(const Literal& literal, int32_t width) {
39133
int32_t v = std::get<int32_t>(literal.value());
@@ -72,8 +166,80 @@ Literal TruncateLiteralImpl<TypeId::kBinary>(const Literal& literal, int32_t wid
72166
return Literal::Binary(std::vector<uint8_t>(data.begin(), data.begin() + width));
73167
}
74168

169+
template <TypeId type_id>
170+
Result<Literal> TruncateLiteralMaxImpl(const Literal& literal, int32_t width) = delete;
171+
172+
template <>
173+
Result<Literal> TruncateLiteralMaxImpl<TypeId::kString>(const Literal& literal,
174+
int32_t width) {
175+
const auto& str = std::get<std::string>(literal.value());
176+
ICEBERG_ASSIGN_OR_RAISE(std::string truncated,
177+
TruncateUtils::TruncateUTF8Max(str, width));
178+
return Literal::String(std::move(truncated));
179+
}
180+
181+
template <>
182+
Result<Literal> TruncateLiteralMaxImpl<TypeId::kBinary>(const Literal& literal,
183+
int32_t width) {
184+
const auto& data = std::get<std::vector<uint8_t>>(literal.value());
185+
if (static_cast<int32_t>(data.size()) <= width) {
186+
return literal;
187+
}
188+
189+
std::vector<uint8_t> truncated(data.begin(), data.begin() + width);
190+
for (auto it = truncated.rbegin(); it != truncated.rend(); ++it) {
191+
if (*it < 0xFF) {
192+
++(*it);
193+
truncated.resize(truncated.size() - std::distance(truncated.rbegin(), it));
194+
return Literal::Binary(std::move(truncated));
195+
}
196+
}
197+
return InvalidArgument("Cannot truncate upper bound for binary: all bytes are 0xFF");
198+
}
199+
75200
} // namespace
76201

202+
Result<std::string> TruncateUtils::TruncateUTF8Max(const std::string& source, size_t L) {
203+
std::string truncated = TruncateUTF8(source, L);
204+
if (truncated == source) {
205+
return truncated;
206+
}
207+
208+
// Try incrementing code points from the end
209+
size_t last_cp_start = truncated.size();
210+
while (last_cp_start > 0) {
211+
size_t cp_start = last_cp_start;
212+
// Find the start of the previous code point
213+
do {
214+
--cp_start;
215+
} while (cp_start > 0 && (static_cast<uint8_t>(truncated[cp_start]) & 0xC0) == 0x80);
216+
217+
auto code_point_opt = DecodeUtf8CodePoint(
218+
std::string_view(truncated.data() + cp_start, last_cp_start - cp_start));
219+
if (!code_point_opt.has_value()) {
220+
return InvalidArgument("Invalid UTF-8 in string literal");
221+
}
222+
uint32_t code_point = code_point_opt.value();
223+
224+
// Try to increment the code point
225+
if (code_point < kUtf8MaxCodePoint) {
226+
uint32_t next_code_point = code_point + 1;
227+
// Skip surrogate range
228+
if (next_code_point >= kUtf8MinSurrogate && next_code_point <= kUtf8MaxSurrogate) {
229+
next_code_point = kUtf8MaxSurrogate + 1;
230+
}
231+
if (next_code_point <= kUtf8MaxCodePoint) {
232+
truncated.resize(cp_start);
233+
AppendUtf8CodePoint(next_code_point, truncated);
234+
return truncated;
235+
}
236+
}
237+
last_cp_start = cp_start;
238+
}
239+
return InvalidArgument(
240+
"Cannot truncate upper bound for string: all code points are 0x10FFFF");
241+
}
242+
77243
Decimal TruncateUtils::TruncateDecimal(const Decimal& decimal, int32_t width) {
78244
return decimal - (((decimal % width) + width) % width);
79245
}
@@ -104,4 +270,27 @@ Result<Literal> TruncateUtils::TruncateLiteral(const Literal& literal, int32_t w
104270
}
105271
}
106272

273+
#define DISPATCH_TRUNCATE_LITERAL_MAX(TYPE_ID) \
274+
case TYPE_ID: \
275+
return TruncateLiteralMaxImpl<TYPE_ID>(literal, width);
276+
277+
Result<Literal> TruncateUtils::TruncateLiteralMax(const Literal& literal, int32_t width) {
278+
if (literal.IsNull()) [[unlikely]] {
279+
// Return null as is
280+
return literal;
281+
}
282+
283+
if (literal.IsAboveMax() || literal.IsBelowMin()) [[unlikely]] {
284+
return NotSupported("Cannot truncate {}", literal.ToString());
285+
}
286+
287+
switch (literal.type()->type_id()) {
288+
DISPATCH_TRUNCATE_LITERAL_MAX(TypeId::kString);
289+
DISPATCH_TRUNCATE_LITERAL_MAX(TypeId::kBinary);
290+
default:
291+
return NotSupported("Truncate max is not supported for type: {}",
292+
literal.type()->ToString());
293+
}
294+
}
295+
107296
} // namespace iceberg

0 commit comments

Comments
 (0)