Skip to content

Commit 2e8a3e8

Browse files
xxlaykxxlriggs
authored andcommitted
DX-67209 updated aes_encrypt/decrypt (apache#29)
* DX-67209 updated aes_encrypt/decrypt
1 parent 981845f commit 2e8a3e8

6 files changed

+161
-109
lines changed

cpp/src/gandiva/encrypt_utils.cc

+22-6
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,24 @@
1616
// under the License.
1717

1818
#include "gandiva/encrypt_utils.h"
19+
#include <string.h>
1920

2021
#include <stdexcept>
2122

2223
namespace gandiva {
2324
GANDIVA_EXPORT
24-
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
25-
unsigned char* cipher) {
25+
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
26+
int32_t key_len, unsigned char* cipher) {
2627
int32_t cipher_len = 0;
2728
int32_t len = 0;
2829
EVP_CIPHER_CTX* en_ctx = EVP_CIPHER_CTX_new();
30+
const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len);
2931

3032
if (!en_ctx) {
3133
throw std::runtime_error("could not create a new evp cipher ctx for encryption");
3234
}
3335

34-
if (!EVP_EncryptInit_ex(en_ctx, EVP_aes_128_ecb(), nullptr,
36+
if (!EVP_EncryptInit_ex(en_ctx, cipher_algo, nullptr,
3537
reinterpret_cast<const unsigned char*>(key), nullptr)) {
3638
throw std::runtime_error("could not initialize evp cipher ctx for encryption");
3739
}
@@ -55,17 +57,18 @@ int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* ke
5557
}
5658

5759
GANDIVA_EXPORT
58-
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
59-
unsigned char* plaintext) {
60+
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
61+
int32_t key_len, unsigned char* plaintext) {
6062
int32_t plaintext_len = 0;
6163
int32_t len = 0;
6264
EVP_CIPHER_CTX* de_ctx = EVP_CIPHER_CTX_new();
65+
const EVP_CIPHER* cipher_algo = get_cipher_algo(key_len);
6366

6467
if (!de_ctx) {
6568
throw std::runtime_error("could not create a new evp cipher ctx for decryption");
6669
}
6770

68-
if (!EVP_DecryptInit_ex(de_ctx, EVP_aes_128_ecb(), nullptr,
71+
if (!EVP_DecryptInit_ex(de_ctx, cipher_algo, nullptr,
6972
reinterpret_cast<const unsigned char*>(key), nullptr)) {
7073
throw std::runtime_error("could not initialize evp cipher ctx for decryption");
7174
}
@@ -87,4 +90,17 @@ int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char*
8790
EVP_CIPHER_CTX_free(de_ctx);
8891
return plaintext_len;
8992
}
93+
94+
const EVP_CIPHER* get_cipher_algo(int32_t key_length){
95+
switch (key_length) {
96+
case 16:
97+
return EVP_aes_128_ecb();
98+
case 24:
99+
return EVP_aes_192_ecb();
100+
case 32:
101+
return EVP_aes_256_ecb();
102+
default:
103+
throw std::runtime_error("unsupported key length");
104+
}
105+
}
90106
} // namespace gandiva

cpp/src/gandiva/encrypt_utils.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,15 @@ namespace gandiva {
2828
**/
2929
GANDIVA_EXPORT
3030
int32_t aes_encrypt(const char* plaintext, int32_t plaintext_len, const char* key,
31-
unsigned char* cipher);
31+
int32_t key_len, unsigned char* cipher);
3232

3333
/**
3434
* Decrypt data using aes algorithm
3535
**/
3636
GANDIVA_EXPORT
3737
int32_t aes_decrypt(const char* ciphertext, int32_t ciphertext_len, const char* key,
38-
unsigned char* plaintext);
38+
int32_t key_len, unsigned char* plaintext);
39+
40+
const EVP_CIPHER* get_cipher_algo(int32_t key_length);
3941

4042
} // namespace gandiva

cpp/src/gandiva/encrypt_utils_test.cc

+35-79
Original file line numberDiff line numberDiff line change
@@ -20,36 +20,38 @@
2020
#include <gtest/gtest.h>
2121

2222
TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) {
23-
// 8 bytes key
24-
auto* key = "1234abcd";
23+
// 16 bytes key
24+
auto* key = "12345678abcdefgh";
2525
auto* to_encrypt = "some test string";
2626

27+
auto key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
2728
auto to_encrypt_len =
2829
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
2930
unsigned char cipher_1[64];
3031

31-
int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_1);
32+
int32_t cipher_1_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_1);
3233

3334
unsigned char decrypted_1[64];
3435
int32_t decrypted_1_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_1),
35-
cipher_1_len, key, decrypted_1);
36+
cipher_1_len, key, key_len, decrypted_1);
3637

3738
EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
3839
std::string(reinterpret_cast<const char*>(decrypted_1), decrypted_1_len));
3940

40-
// 16 bytes key
41-
key = "12345678abcdefgh";
41+
// 24 bytes key
42+
key = "12345678abcdefgh12345678";
4243
to_encrypt = "some\ntest\nstring";
4344

45+
key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
4446
to_encrypt_len =
4547
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
4648
unsigned char cipher_2[64];
4749

48-
int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_2);
50+
int32_t cipher_2_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_2);
4951

5052
unsigned char decrypted_2[64];
5153
int32_t decrypted_2_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_2),
52-
cipher_2_len, key, decrypted_2);
54+
cipher_2_len, key, key_len, decrypted_2);
5355

5456
EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
5557
std::string(reinterpret_cast<const char*>(decrypted_2), decrypted_2_len));
@@ -58,97 +60,51 @@ TEST(TestShaEncryptUtils, TestAesEncryptDecrypt) {
5860
key = "12345678abcdefgh12345678abcdefgh";
5961
to_encrypt = "New\ntest\nstring";
6062

63+
key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
6164
to_encrypt_len =
6265
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
6366
unsigned char cipher_3[64];
6467

65-
int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_3);
68+
int32_t cipher_3_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_3);
6669

6770
unsigned char decrypted_3[64];
6871
int32_t decrypted_3_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_3),
69-
cipher_3_len, key, decrypted_3);
72+
cipher_3_len, key, key_len, decrypted_3);
7073

7174
EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
7275
std::string(reinterpret_cast<const char*>(decrypted_3), decrypted_3_len));
7376

74-
// 64 bytes key
77+
// check exception
78+
char cipher[64] = "JBB7oJAQuqhDCx01fvBRi8PcljW1+nbnOSMk+R0Sz7E==";
79+
int32_t cipher_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(cipher)));
80+
unsigned char plain_text[64];
81+
7582
key = "12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh";
7683
to_encrypt = "New\ntest\nstring";
7784

85+
key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
7886
to_encrypt_len =
7987
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
8088
unsigned char cipher_4[64];
89+
ASSERT_THROW({
90+
gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_4);
91+
}, std::runtime_error);
8192

82-
int32_t cipher_4_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_4);
83-
84-
unsigned char decrypted_4[64];
85-
int32_t decrypted_4_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_4),
86-
cipher_4_len, key, decrypted_4);
87-
88-
EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
89-
std::string(reinterpret_cast<const char*>(decrypted_4), decrypted_4_len));
90-
91-
// 128 bytes key
92-
key =
93-
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
94-
"345678abcdefgh12345678abcdefgh12345678abcdefgh";
95-
to_encrypt = "A much more longer string then the previous one, but without newline";
96-
97-
to_encrypt_len =
98-
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
99-
unsigned char cipher_5[128];
100-
101-
int32_t cipher_5_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_5);
102-
103-
unsigned char decrypted_5[128];
104-
int32_t decrypted_5_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_5),
105-
cipher_5_len, key, decrypted_5);
106-
107-
EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
108-
std::string(reinterpret_cast<const char*>(decrypted_5), decrypted_5_len));
109-
110-
// 192 bytes key
111-
key =
112-
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
113-
"345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234"
114-
"5678abcdefgh12345678abcdefgh";
115-
to_encrypt =
116-
"A much more longer string then the previous one, but with \nnewline, pretty cool, "
117-
"right?";
118-
119-
to_encrypt_len =
120-
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
121-
unsigned char cipher_6[256];
122-
123-
int32_t cipher_6_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_6);
124-
125-
unsigned char decrypted_6[256];
126-
int32_t decrypted_6_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_6),
127-
cipher_6_len, key, decrypted_6);
93+
ASSERT_THROW({
94+
gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text);
95+
}, std::runtime_error);
12896

129-
EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
130-
std::string(reinterpret_cast<const char*>(decrypted_6), decrypted_6_len));
131-
132-
// 256 bytes key
133-
key =
134-
"12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12"
135-
"345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh1234"
136-
"5678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh12345678abcdefgh123456"
137-
"78abcdefgh";
138-
to_encrypt =
139-
"A much more longer string then the previous one, but with \nnewline, pretty cool, "
140-
"right?";
97+
key = "12345678";
98+
to_encrypt = "New\ntest\nstring";
14199

100+
key_len = static_cast<int32_t>(strlen(reinterpret_cast<const char*>(key)));
142101
to_encrypt_len =
143102
static_cast<int32_t>(strlen(reinterpret_cast<const char*>(to_encrypt)));
144-
unsigned char cipher_7[256];
145-
146-
int32_t cipher_7_len = gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, cipher_7);
147-
148-
unsigned char decrypted_7[256];
149-
int32_t decrypted_7_len = gandiva::aes_decrypt(reinterpret_cast<const char*>(cipher_7),
150-
cipher_7_len, key, decrypted_7);
151-
152-
EXPECT_EQ(std::string(reinterpret_cast<const char*>(to_encrypt), to_encrypt_len),
153-
std::string(reinterpret_cast<const char*>(decrypted_7), decrypted_7_len));
103+
unsigned char cipher_5[64];
104+
ASSERT_THROW({
105+
gandiva::aes_encrypt(to_encrypt, to_encrypt_len, key, key_len, cipher_5);
106+
}, std::runtime_error);
107+
ASSERT_THROW({
108+
gandiva::aes_decrypt(cipher, cipher_len, key, key_len, plain_text);
109+
}, std::runtime_error);
154110
}

cpp/src/gandiva/gdv_function_stubs.cc

+21-5
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,6 @@ CAST_NUMERIC_FROM_VARBINARY(double, arrow::DoubleType, FLOAT8)
306306
#undef GDV_FN_CAST_VARCHAR_INTEGER
307307
#undef GDV_FN_CAST_VARCHAR_REAL
308308

309-
static constexpr int64_t kAesBlockSize = 16; // bytes
310-
311309
GANDIVA_EXPORT
312310
const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_len,
313311
const char* key_data, int32_t key_data_len,
@@ -318,6 +316,15 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l
318316
return "";
319317
}
320318

319+
int64_t kAesBlockSize = 0;
320+
if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) {
321+
kAesBlockSize = static_cast<int64_t>(key_data_len);
322+
} else {
323+
gdv_fn_context_set_error_msg(context, "invalid key length");
324+
*out_len = 0;
325+
return nullptr;
326+
}
327+
321328
*out_len =
322329
static_cast<int32_t>(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize));
323330
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
@@ -329,7 +336,7 @@ const char* gdv_fn_aes_encrypt(int64_t context, const char* data, int32_t data_l
329336
}
330337

331338
try {
332-
*out_len = gandiva::aes_encrypt(data, data_len, key_data,
339+
*out_len = gandiva::aes_encrypt(data, data_len, key_data, key_data_len,
333340
reinterpret_cast<unsigned char*>(ret));
334341
} catch (const std::runtime_error& e) {
335342
gdv_fn_context_set_error_msg(context, e.what());
@@ -349,6 +356,15 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l
349356
return "";
350357
}
351358

359+
int64_t kAesBlockSize = 0;
360+
if (key_data_len == 16 || key_data_len == 24 || key_data_len == 32) {
361+
kAesBlockSize = static_cast<int64_t>(key_data_len);
362+
} else {
363+
gdv_fn_context_set_error_msg(context, "invalid key length");
364+
*out_len = 0;
365+
return nullptr;
366+
}
367+
352368
*out_len =
353369
static_cast<int32_t>(arrow::bit_util::RoundUpToPowerOf2(data_len, kAesBlockSize));
354370
char* ret = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(context, *out_len));
@@ -360,13 +376,13 @@ const char* gdv_fn_aes_decrypt(int64_t context, const char* data, int32_t data_l
360376
}
361377

362378
try {
363-
*out_len = gandiva::aes_decrypt(data, data_len, key_data,
379+
*out_len = gandiva::aes_decrypt(data, data_len, key_data, key_data_len,
364380
reinterpret_cast<unsigned char*>(ret));
365381
} catch (const std::runtime_error& e) {
366382
gdv_fn_context_set_error_msg(context, e.what());
367383
return nullptr;
368384
}
369-
385+
ret[*out_len] = '\0';
370386
return ret;
371387
}
372388

cpp/src/gandiva/gdv_function_stubs_test.cc

+70
Original file line numberDiff line numberDiff line change
@@ -1345,4 +1345,74 @@ TEST(TestGdvFnStubs, TestMask) {
13451345
EXPECT_EQ(std::string(result, out_len), expected);
13461346
}
13471347

1348+
TEST(TestGdvFnStubs, TestAesEncryptDecrypt16) {
1349+
gandiva::ExecutionContext ctx;
1350+
std::string key16 = "12345678abcdefgh";
1351+
auto key16_len = static_cast<int32_t>(key16.length());
1352+
int32_t cipher_len = 0;
1353+
int32_t decrypted_len = 0;
1354+
std::string data = "test string";
1355+
auto data_len = static_cast<int32_t>(data.length());
1356+
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
1357+
1358+
const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key16.c_str(), key16_len, &cipher_len);
1359+
const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key16.c_str(), key16_len, &decrypted_len);
1360+
1361+
EXPECT_EQ(data, std::string(reinterpret_cast<const char*>(decrypted_value), decrypted_len));
1362+
}
1363+
1364+
TEST(TestGdvFnStubs, TestAesEncryptDecrypt24) {
1365+
gandiva::ExecutionContext ctx;
1366+
std::string key24 = "12345678abcdefgh12345678";
1367+
auto key24_len = static_cast<int32_t>(key24.length());
1368+
int32_t cipher_len = 0;
1369+
int32_t decrypted_len = 0;
1370+
std::string data = "test string";
1371+
auto data_len = static_cast<int32_t>(data.length());
1372+
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
1373+
1374+
const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key24.c_str(), key24_len, &cipher_len);
1375+
1376+
const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key24.c_str(), key24_len, &decrypted_len);
1377+
1378+
EXPECT_EQ(data, std::string(reinterpret_cast<const char*>(decrypted_value), decrypted_len));
1379+
}
1380+
1381+
TEST(TestGdvFnStubs, TestAesEncryptDecrypt32) {
1382+
gandiva::ExecutionContext ctx;
1383+
std::string key32 = "12345678abcdefgh12345678abcdefgh";
1384+
auto key32_len = static_cast<int32_t>(key32.length());
1385+
int32_t cipher_len = 0;
1386+
int32_t decrypted_len = 0;
1387+
std::string data = "test string";
1388+
auto data_len = static_cast<int32_t>(data.length());
1389+
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
1390+
1391+
const char* cipher = gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key32.c_str(), key32_len, &cipher_len);
1392+
1393+
const char* decrypted_value = gdv_fn_aes_decrypt(ctx_ptr, cipher, cipher_len, key32.c_str(), key32_len, &decrypted_len);
1394+
1395+
EXPECT_EQ(data, std::string(reinterpret_cast<const char*>(decrypted_value), decrypted_len));
1396+
}
1397+
1398+
TEST(TestGdvFnStubs, TestAesEncryptDecryptValidation) {
1399+
gandiva::ExecutionContext ctx;
1400+
std::string key33 = "12345678abcdefgh12345678abcdefghb";
1401+
auto key33_len = static_cast<int32_t>(key33.length());
1402+
int32_t decrypted_len = 0;
1403+
std::string data = "test string";
1404+
auto data_len = static_cast<int32_t>(data.length());
1405+
int64_t ctx_ptr = reinterpret_cast<int64_t>(&ctx);
1406+
std::string cipher = "12345678abcdefgh12345678abcdefghb";
1407+
auto cipher_len = static_cast<int32_t>(cipher.length());
1408+
1409+
gdv_fn_aes_encrypt(ctx_ptr, data.c_str(), data_len, key33.c_str(), key33_len, &cipher_len);
1410+
EXPECT_THAT(ctx.get_error(),
1411+
::testing::HasSubstr("invalid key length"));
1412+
ctx.Reset();
1413+
1414+
gdv_fn_aes_decrypt(ctx_ptr, cipher.c_str(), cipher_len, key33.c_str(), key33_len, &decrypted_len); EXPECT_THAT(ctx.get_error(),
1415+
::testing::HasSubstr("invalid key length"));
1416+
ctx.Reset();
1417+
}
13481418
} // namespace gandiva

0 commit comments

Comments
 (0)