-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathcrypto_enc.cpp
More file actions
341 lines (284 loc) · 11.6 KB
/
crypto_enc.cpp
File metadata and controls
341 lines (284 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
#include "crypto_extension.hpp"
#include "crypto_enc.hpp"
#include "duckdb.hpp"
#include "duckdb/common/exception.hpp"
#include "duckdb/common/string_util.hpp"
#include "duckdb/function/scalar_function.hpp"
#include <duckdb/parser/parsed_data/create_scalar_function_info.hpp>
#include <openssl/evp.h>
#include <openssl/err.h>
#include <openssl/rand.h>
#include <algorithm>
#include <cctype>
#include <unordered_map>
#include <functional>
namespace duckdb
{
constexpr int MODE_ENCRYPT = 1;
constexpr int MODE_DECRYPT = 0;
constexpr size_t GCM_TAG_SIZE = 16;
void throwOpensslError(const std::string &prefix)
{
unsigned long err = ERR_get_error();
char err_buf[256];
ERR_error_string_n(err, err_buf, sizeof(err_buf));
std::string error_msg = prefix + ": " + err_buf;
throw InternalException(error_msg);
}
std::string generate_random_string(size_t len)
{
std::string out(len, '\0'); // allocate len bytes
if (RAND_bytes(reinterpret_cast<unsigned char *>(&out[0]), len) != 1)
{
throwOpensslError("RAND_bytes failed");
}
return out;
}
class EvpCipherContext
{
public:
EvpCipherContext(const EVP_CIPHER *cipher_) : cipher(cipher_), ctx(EVP_CIPHER_CTX_new(),
[](EVP_CIPHER_CTX *p)
{
if (p)
{
EVP_CIPHER_CTX_free(p);
}
})
{
}
bool NeedsTag()
{
return (EVP_CIPHER_get_flags(cipher) & EVP_CIPH_FLAG_AEAD_CIPHER) != 0;
}
int IvLength() const
{
return EVP_CIPHER_iv_length(cipher);
}
void SetIv(const unsigned char *iv_)
{
iv = iv_;
}
void SetKey(const unsigned char *key_)
{
key = key_;
}
void SetTag(unsigned char *tag_)
{
in_tag = tag_;
}
std::string GetTag()
{
std::string tag;
tag.resize(GCM_TAG_SIZE);
if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_GET_TAG, GCM_TAG_SIZE, reinterpret_cast<unsigned char *>(&tag[0])))
{
throwOpensslError("Failed to get tag");
}
return tag;
}
void Init(int mode)
{
if (!EVP_CipherInit_ex(ctx.get(), cipher, nullptr, nullptr, nullptr, mode))
{
throwOpensslError("Failed to initialize cipher");
}
if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_SET_IVLEN, IvLength(), nullptr))
{
throwOpensslError("Failed to set iv length");
}
if (!EVP_CipherInit_ex(ctx.get(), cipher, nullptr, key, iv, mode))
{
throwOpensslError("Failed to initialize cipher");
}
if (NeedsTag() && mode == MODE_DECRYPT)
{
if (in_tag == nullptr)
{
throw InvalidInputException("Decryption mode for AEAD cipher requires a tag to be set");
}
// Decryption mode, set expected tag
if (!EVP_CIPHER_CTX_ctrl(ctx.get(), EVP_CTRL_AEAD_SET_TAG, 16, in_tag))
{
throwOpensslError("Failed to set expected tag");
}
}
}
std::string Update(const std::string &data)
{
return Update(reinterpret_cast<const unsigned char *>(data.data()), data.size());
}
std::string Update(const unsigned char *data, size_t size)
{
int outlen = 0;
std::string out;
out.resize(size + EVP_CIPHER_block_size(cipher));
if (EVP_CipherUpdate(ctx.get(), reinterpret_cast<unsigned char *>(&out[0]), &outlen,
data, size) != 1)
{
throwOpensslError("Failed to update cipher");
}
out.resize(outlen);
return out;
}
std::string Finalize()
{
std::string result;
result.resize(EVP_CIPHER_block_size(cipher));
int result_len = 0;
if (EVP_CipherFinal_ex(ctx.get(), reinterpret_cast<unsigned char *>(&result[0]), &result_len) != 1)
{
throwOpensslError("Failed to finalize cipher");
}
result.resize(result_len);
return result;
}
EvpCipherContext(const EvpCipherContext &) = delete;
EvpCipherContext &operator=(const EvpCipherContext &) = delete;
// Enable move operations (needed because copy is deleted)
EvpCipherContext(EvpCipherContext &&) = default;
EvpCipherContext &operator=(EvpCipherContext &&) = default;
private:
// We don't own these
const EVP_CIPHER *cipher = nullptr;
const unsigned char *iv = nullptr;
const unsigned char *key = nullptr;
unsigned char *in_tag = nullptr;
std::unique_ptr<EVP_CIPHER_CTX, void (*)(EVP_CIPHER_CTX *)> ctx;
};
// This is dramatically simplefied wrapper for an encrypted text blob
// It expects to be used with the Plaintext class
// It will do AES AEAD
class CipherText
{
public:
/// @brief
/// @param cipher
/// @param ciphertext - this is the raw ciphertext bytes, it will be iv + encrypted bytes + tag
CipherText(const EVP_CIPHER *cipher, std::string ciphertext_) : ciphertext(ciphertext_), ctx(cipher)
{
}
static CipherText Encrypt(const EVP_CIPHER *cipher, const unsigned char *key, const std::string &plaintext)
{
EvpCipherContext ctx(cipher);
std::string iv = generate_random_string(ctx.IvLength());
ctx.SetKey(key);
ctx.SetIv(reinterpret_cast<const unsigned char *>(iv.data()));
ctx.Init(MODE_ENCRYPT); // 1 for encrypt
std::string encrypted = ctx.Update(plaintext);
encrypted += ctx.Finalize();
std::string ciphertext = iv + encrypted;
if (ctx.NeedsTag())
{
std::string tag = ctx.GetTag();
ciphertext += tag;
}
return CipherText(cipher, ciphertext);
}
const char *GetTag() const
{
// Tag is last 16 bytes
return ciphertext.data() + ciphertext.size() - 16;
}
const char *GetIv() const
{
// IV is first iv_length bytes
return ciphertext.data();
}
const char *GetEncryptedData() const
{
// Encrypted data is between iv and tag
return ciphertext.data() + ctx.IvLength();
}
// std::string Update(const unsigned char *data, size_t size)
std::string Decrypt(const unsigned char *key)
{
ctx.SetKey(key);
ctx.SetIv(reinterpret_cast<const unsigned char *>(GetIv()));
ctx.SetTag(reinterpret_cast<unsigned char *>(const_cast<char *>(GetTag())));
ctx.Init(MODE_DECRYPT);
std::string decrypted = ctx.Update(reinterpret_cast<const unsigned char *>(GetEncryptedData()),
ciphertext.size() - ctx.IvLength() - 16);
decrypted += ctx.Finalize();
return decrypted;
}
std::string &GetValue()
{
return ciphertext;
}
CipherText(const CipherText &) = delete;
CipherText &operator=(const CipherText &) = delete;
// Enable move operations (needed because copy is deleted)
CipherText(CipherText &&) = default;
CipherText &operator=(CipherText &&) = default;
std::string ciphertext;
EvpCipherContext ctx;
std::string tag = "";
};
// // TODO: handle key and iv size checks
// // TODO: handle padding if necessary
// // TODO: handle output buffer size
// // TODO: handle different modes of operation
// evp_ctx.Init(cipher, key, iv);
// evp_ctx.Update(data, data_len);
// evp_ctx.Finalize();
// }
inline void CryptoScalarFun(DataChunk &args, ExpressionState &state, Vector &result, int mode)
{
// This is called with three arguments:
// 1. The cipher name
// 2. The key
// 3. The value
//
auto &hash_function_name_vector = args.data[0];
auto &key_vector = args.data[1];
auto &value_vector = args.data[2];
TernaryExecutor::Execute<string_t, string_t, string_t, string_t>(
hash_function_name_vector, key_vector, value_vector, result, args.size(),
[&](string_t cipher_name, string_t key, string_t value)
{
string algorithm(cipher_name.GetData(), cipher_name.GetSize());
string key_str(key.GetData(), key.GetSize());
string value_str(value.GetData(), value.GetSize());
// TODO: handle iv properly
// TODO: only allow the algorithm to be set once instead of per row
std::string algo_lower = StringUtil::Lower(algorithm);
const EVP_CIPHER *cipher = EVP_get_cipherbyname(algo_lower.c_str());
if (!cipher)
{
throw InvalidInputException("Invalid ciphername '" + algorithm + "'");
}
if (mode == 0)
{ // decrypt
CipherText ct(cipher, value_str);
std::string decrypted = ct.Decrypt(reinterpret_cast<const unsigned char *>(key_str.data()));
return StringVector::AddStringOrBlob(result, string_t(decrypted.data(), decrypted.size()));
}
else if (mode == 1)
{ // encrypt
CipherText ct = CipherText::Encrypt(cipher, reinterpret_cast<const unsigned char *>(key_str.data()), value_str);
return StringVector::AddStringOrBlob(result, string_t(ct.GetValue().data(), ct.GetValue().size()));
}
else
{
throw InternalException("Invalid mode for CryptoScalarFun");
}
});
}
inline void CryptoScalarEncryptFun(DataChunk &args, ExpressionState &state, Vector &result)
{
CryptoScalarFun(args, state, result, MODE_ENCRYPT);
}
inline void CryptoScalarDecryptFun(DataChunk &args, ExpressionState &state, Vector &result)
{
CryptoScalarFun(args, state, result, MODE_DECRYPT);
}
void LoadCipherInternal(ExtensionLoader &loader)
{
// crypto_hash accepts VARCHAR for algorithm name and ANY type for the data to hash
auto crypto_encrypt_scalar_function = ScalarFunction("crypto_encrypt", {LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::ANY}, LogicalType::BLOB, CryptoScalarEncryptFun);
loader.RegisterFunction(crypto_encrypt_scalar_function);
auto crypto_decrypt_scalar_function = ScalarFunction("crypto_decrypt", {LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::ANY}, LogicalType::BLOB, CryptoScalarDecryptFun);
loader.RegisterFunction(crypto_decrypt_scalar_function);
}
};