Skip to content

Commit d9aeaa1

Browse files
authored
GH-49614: [C++] Report an error instead of silent truncation in base64_decode on invalid input (#49660)
### Rationale for this change `arrow::util::base64_decode` previously allowed invalid input to be processed, which could result in silently truncated or incorrect output without signaling an error. This can lead to unintended data corruption. ### What changes are included in this PR? - Change `base64_decode` to return `arrow::Result<std::string>` instead of `std::string` - Add validation for: - invalid input length - invalid base64 characters - incorrect padding - Return an error (`Status::Invalid`) for invalid input instead of producing partial output - Update all call sites to handle `Result<std::string>` - Add unit tests covering valid and invalid inputs ### Are these changes tested? Yes. Unit tests have been added to verify: - valid decoding behavior - invalid input length - invalid characters - incorrect padding handling ### Are there any user-facing changes? - The API now returns `arrow::Result<std::string>` instead of `std::string` - Invalid base64 input now results in an error (`Status::Invalid`) instead of returning partial or incorrect output * GitHub Issue: #49614 Authored-by: Aaditya Srinivasan <aadityasri03@gmail.com> Signed-off-by: Sutou Kouhei <kou@clear-code.com>
1 parent b2f2692 commit d9aeaa1

File tree

10 files changed

+173
-33
lines changed

10 files changed

+173
-33
lines changed

cpp/src/arrow/flight/flight_test.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,8 @@ void ParseBasicHeader(const CallHeaders& incoming_headers, std::string& username
620620
std::string& password) {
621621
std::string encoded_credentials =
622622
FindKeyValPrefixInCallHeaders(incoming_headers, kAuthHeader, kBasicPrefix);
623-
std::stringstream decoded_stream(arrow::util::base64_decode(encoded_credentials));
623+
ASSERT_OK_AND_ASSIGN(auto decoded, arrow::util::base64_decode(encoded_credentials));
624+
std::stringstream decoded_stream(decoded);
624625
std::getline(decoded_stream, username, ':');
625626
std::getline(decoded_stream, password, ':');
626627
}

cpp/src/arrow/util/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ add_arrow_test(utility-test
4949
SOURCES
5050
align_util_test.cc
5151
atfork_test.cc
52+
base64_test.cc
5253
byte_size_test.cc
5354
byte_stream_split_test.cc
5455
cache_test.cc

cpp/src/arrow/util/base64.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <string>
2121
#include <string_view>
2222

23+
#include "arrow/result.h"
2324
#include "arrow/util/visibility.h"
2425

2526
namespace arrow {
@@ -29,7 +30,7 @@ ARROW_EXPORT
2930
std::string base64_encode(std::string_view s);
3031

3132
ARROW_EXPORT
32-
std::string base64_decode(std::string_view s);
33+
arrow::Result<std::string> base64_decode(std::string_view s);
3334

3435
} // namespace util
3536
} // namespace arrow

cpp/src/arrow/util/base64_test.cc

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
#include "arrow/util/base64.h"
19+
#include "arrow/testing/gtest_util.h"
20+
21+
namespace arrow {
22+
namespace util {
23+
24+
TEST(Base64DecodeTest, ValidInputs) {
25+
ASSERT_OK_AND_ASSIGN(auto empty, base64_decode(""));
26+
EXPECT_EQ(empty, "");
27+
28+
ASSERT_OK_AND_ASSIGN(auto two_paddings, base64_decode("Zg=="));
29+
EXPECT_EQ(two_paddings, "f");
30+
31+
ASSERT_OK_AND_ASSIGN(auto one_padding, base64_decode("Zm8="));
32+
EXPECT_EQ(one_padding, "fo");
33+
34+
ASSERT_OK_AND_ASSIGN(auto no_padding, base64_decode("Zm9v"));
35+
EXPECT_EQ(no_padding, "foo");
36+
37+
ASSERT_OK_AND_ASSIGN(auto multiblock, base64_decode("SGVsbG8gd29ybGQ="));
38+
EXPECT_EQ(multiblock, "Hello world");
39+
}
40+
41+
TEST(Base64DecodeTest, BinaryOutput) {
42+
// 'A' maps to index 0 — same zero value used for padding slots
43+
// verifies the 'A' bug is not present
44+
ASSERT_OK_AND_ASSIGN(auto all_A, base64_decode("AAAA"));
45+
EXPECT_EQ(all_A, std::string("\x00\x00\x00", 3));
46+
47+
// Arbitrary non-ASCII output bytes
48+
ASSERT_OK_AND_ASSIGN(auto binary, base64_decode("AP8A"));
49+
EXPECT_EQ(binary, std::string("\x00\xff\x00", 3));
50+
}
51+
52+
TEST(Base64DecodeTest, InvalidLength) {
53+
ASSERT_RAISES_WITH_MESSAGE(
54+
Invalid, "Invalid: Invalid base64 input: length is not a multiple of 4",
55+
base64_decode("abc"));
56+
}
57+
58+
TEST(Base64DecodeTest, InvalidCharacters) {
59+
ASSERT_RAISES_WITH_MESSAGE(
60+
Invalid, "Invalid: Invalid base64 input: character is not valid base64 character",
61+
base64_decode("ab$="));
62+
63+
// Non-ASCII byte
64+
std::string non_ascii = std::string("abc") + static_cast<char>(0xFF);
65+
ASSERT_RAISES_WITH_MESSAGE(
66+
Invalid, "Invalid: Invalid base64 input: character is not valid base64 character",
67+
base64_decode(non_ascii));
68+
69+
// Corruption mid-string across multiple blocks
70+
ASSERT_RAISES_WITH_MESSAGE(
71+
Invalid, "Invalid: Invalid base64 input: character is not valid base64 character",
72+
base64_decode("aGVs$G8gd29ybGQ="));
73+
}
74+
75+
TEST(Base64DecodeTest, InvalidPadding) {
76+
// Padding in wrong position within block
77+
ASSERT_RAISES_WITH_MESSAGE(Invalid,
78+
"Invalid: Invalid base64 input: padding in wrong position",
79+
base64_decode("ab=c"));
80+
81+
// 3 padding characters — exceeds maximum of 2
82+
ASSERT_RAISES_WITH_MESSAGE(Invalid,
83+
"Invalid: Invalid base64 input: too many padding characters",
84+
base64_decode("a==="));
85+
86+
// 4 padding characters
87+
ASSERT_RAISES_WITH_MESSAGE(Invalid,
88+
"Invalid: Invalid base64 input: too many padding characters",
89+
base64_decode("===="));
90+
91+
// Padding in non-final block across multiple blocks
92+
ASSERT_RAISES_WITH_MESSAGE(Invalid,
93+
"Invalid: Invalid base64 input: padding in wrong position",
94+
base64_decode("Zm8=Zm8="));
95+
}
96+
97+
} // namespace util
98+
} // namespace arrow

cpp/src/arrow/vendored/base64.cpp

Lines changed: 48 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,6 @@ static const std::string base64_chars =
4040
"abcdefghijklmnopqrstuvwxyz"
4141
"0123456789+/";
4242

43-
44-
static inline bool is_base64(unsigned char c) {
45-
return (isalnum(c) || (c == '+') || (c == '/'));
46-
}
47-
4843
static std::string base64_encode(unsigned char const* bytes_to_encode, unsigned int in_len) {
4944
std::string ret;
5045
int i = 0;
@@ -93,38 +88,65 @@ std::string base64_encode(std::string_view string_to_encode) {
9388
return base64_encode(bytes_to_encode, in_len);
9489
}
9590

96-
std::string base64_decode(std::string_view encoded_string) {
91+
Result<std::string> base64_decode(std::string_view encoded_string) {
9792
size_t in_len = encoded_string.size();
9893
int i = 0;
99-
int j = 0;
100-
int in_ = 0;
94+
std::string_view::size_type in_ = 0;
95+
int padding_count = 0;
96+
int block_padding = 0;
97+
bool padding_started = false;
10198
unsigned char char_array_4[4], char_array_3[3];
10299
std::string ret;
103100

104-
while (in_len-- && ( encoded_string[in_] != '=') && is_base64(encoded_string[in_])) {
105-
char_array_4[i++] = encoded_string[in_]; in_++;
106-
if (i ==4) {
107-
for (i = 0; i <4; i++)
108-
char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff;
101+
if (encoded_string.size() % 4 != 0) {
102+
return Status::Invalid("Invalid base64 input: length is not a multiple of 4");
103+
}
109104

110-
char_array_3[0] = ( char_array_4[0] << 2 ) + ((char_array_4[1] & 0x30) >> 4);
111-
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
112-
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
105+
while (in_len--) {
106+
unsigned char c = encoded_string[in_];
113107

114-
for (i = 0; (i < 3); i++)
115-
ret += char_array_3[i];
116-
i = 0;
108+
if (c == '=') {
109+
padding_started = true;
110+
padding_count++;
111+
112+
if (padding_count > 2) {
113+
return Status::Invalid("Invalid base64 input: too many padding characters");
114+
}
115+
116+
char_array_4[i++] = 0;
117+
} else {
118+
if (padding_started) {
119+
return Status::Invalid("Invalid base64 input: padding in wrong position");
120+
}
121+
122+
if (base64_chars.find(c) == std::string::npos) {
123+
return Status::Invalid("Invalid base64 input: character is not valid base64 character");
124+
}
125+
126+
char_array_4[i++] = c;
117127
}
118-
}
119128

120-
if (i) {
121-
for (j = 0; j < i; j++)
122-
char_array_4[j] = base64_chars.find(char_array_4[j]) & 0xff;
129+
in_++;
130+
131+
if (i == 4) {
132+
for (i = 0; i < 4; i++) {
133+
if (char_array_4[i] != 0) {
134+
char_array_4[i] = base64_chars.find(char_array_4[i]) & 0xff;
135+
}
136+
}
137+
138+
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
139+
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
140+
char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3];
141+
142+
block_padding = padding_count;
123143

124-
char_array_3[0] = (char_array_4[0] << 2) + ((char_array_4[1] & 0x30) >> 4);
125-
char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2);
144+
for (i = 0; i < 3 - block_padding; i++) {
145+
ret += char_array_3[i];
146+
}
126147

127-
for (j = 0; (j < i - 1); j++) ret += char_array_3[j];
148+
i = 0;
149+
}
128150
}
129151

130152
return ret;

cpp/src/gandiva/gdv_function_stubs.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,15 @@ const char* gdv_fn_base64_decode_utf8(int64_t context, const char* in, int32_t i
269269
return "";
270270
}
271271
// use arrow method to decode base64 string
272-
std::string decoded_str = arrow::util::base64_decode(std::string_view(in, in_len));
272+
auto result = arrow::util::base64_decode(std::string_view(in, in_len));
273+
if (!result.ok()) {
274+
gdv_fn_context_set_error_msg(context, result.status().message().c_str());
275+
*out_len = 0;
276+
return "";
277+
}
278+
279+
std::string decoded_str = *result;
280+
273281
*out_len = static_cast<int32_t>(decoded_str.length());
274282
// allocate memory for response
275283
char* ret = reinterpret_cast<char*>(

cpp/src/parquet/arrow/fuzz_internal.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,11 @@ class FuzzDecryptionKeyRetriever : public DecryptionKeyRetriever {
8383
}
8484
// Is it a key generated by MakeEncryptionKey?
8585
if (key_id.starts_with(kInlineKeyPrefix)) {
86-
return SecureString(
86+
PARQUET_ASSIGN_OR_THROW(
87+
auto decoded_key,
8788
::arrow::util::base64_decode(key_id.substr(kInlineKeyPrefix.length())));
89+
90+
return SecureString(std::move(decoded_key));
8891
}
8992
throw ParquetException("Unknown fuzz encryption key_id");
9093
}

cpp/src/parquet/arrow/schema.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,8 @@ Status GetOriginSchema(const std::shared_ptr<const KeyValueMetadata>& metadata,
953953
// The original Arrow schema was serialized using the store_schema option.
954954
// We deserialize it here and use it to inform read options such as
955955
// dictionary-encoded fields.
956-
auto decoded = ::arrow::util::base64_decode(metadata->value(schema_index));
956+
ARROW_ASSIGN_OR_RAISE(auto decoded,
957+
::arrow::util::base64_decode(metadata->value(schema_index)));
957958
auto schema_buf = std::make_shared<Buffer>(decoded);
958959

959960
::arrow::ipc::DictionaryMemo dict_memo;

cpp/src/parquet/encryption/file_key_unwrapper.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "arrow/util/utf8.h"
2121

22+
#include "arrow/util/base64.h"
2223
#include "parquet/encryption/file_key_unwrapper.h"
2324
#include "parquet/encryption/key_metadata.h"
2425

@@ -122,7 +123,10 @@ KeyWithMasterId FileKeyUnwrapper::GetDataEncryptionKey(const KeyMaterial& key_ma
122123
});
123124

124125
// Decrypt the data key
125-
std::string aad = ::arrow::util::base64_decode(encoded_kek_id);
126+
PARQUET_ASSIGN_OR_THROW(auto decoded_kek,
127+
::arrow::util::base64_decode(encoded_kek_id));
128+
129+
std::string aad = std::move(decoded_kek);
126130
data_key = internal::DecryptKeyLocally(encoded_wrapped_dek, kek_bytes, aad);
127131
}
128132

cpp/src/parquet/encryption/key_toolkit_internal.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ std::string EncryptKeyLocally(const SecureString& key_bytes,
5252

5353
SecureString DecryptKeyLocally(const std::string& encoded_encrypted_key,
5454
const SecureString& master_key, const std::string& aad) {
55-
std::string encrypted_key = ::arrow::util::base64_decode(encoded_encrypted_key);
55+
PARQUET_ASSIGN_OR_THROW(auto encrypted_key,
56+
::arrow::util::base64_decode(encoded_encrypted_key));
5657

5758
AesDecryptor key_decryptor(ParquetCipher::AES_GCM_V1,
5859
static_cast<int>(master_key.size()), false,

0 commit comments

Comments
 (0)