diff --git a/CMakeLists.txt b/CMakeLists.txt index 5552115..e697be3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,9 +4,10 @@ cmake_minimum_required(VERSION 3.5) set(TARGET_NAME simple_encryption) set(CMAKE_CXX_STANDARD 11) -# DuckDB's extension distribution supports vcpkg. As such, dependencies can be added in ./vcpkg.json and then -# used in cmake with find_package. Feel free to remove or replace with other dependencies. -# Note that it should also be removed from vcpkg.json to prevent needlessly installing it.. +# DuckDB's extension distribution supports vcpkg. As such, dependencies can be +# added in ./vcpkg.json and then used in cmake with find_package. Feel free to +# remove or replace with other dependencies. Note that it should also be removed +# from vcpkg.json to prevent needlessly installing it.. find_package(OpenSSL REQUIRED) set(EXTENSION_NAME ${TARGET_NAME}_extension) @@ -18,19 +19,19 @@ add_subdirectory(src) include_directories(../duckdb/third_party/httplib/include) # by now do this manually, later fix this -set(EXTENSION_SOURCES src/simple_encryption_extension.cpp - src/simple_encryption_extension.cpp - src/simple_encryption_state.cpp - src/core/module.cpp - src/core/types.cpp - src/core/functions/scalar/encrypt.cpp - src/core/functions/scalar/encrypt_to_etype.cpp - src/core/functions/function_data/encrypt_function_data.cpp - src/core/functions/cast/varchar_cast.cpp - src/core/functions/table/encrypt_table.cpp - src/core/utils/simple_encryption_utils.cpp - src/core/crypto/crypto_primitives.cpp -) +set(EXTENSION_SOURCES + src/simple_encryption_extension.cpp + src/simple_encryption_extension.cpp + src/simple_encryption_state.cpp + src/core/module.cpp + src/core/types.cpp + src/core/functions/scalar/encrypt.cpp + src/core/functions/scalar/encrypt_to_etype.cpp + src/core/functions/function_data/encrypt_function_data.cpp + src/core/functions/cast/varchar_cast.cpp + src/core/functions/table/encrypt_table.cpp + src/core/utils/simple_encryption_utils.cpp + src/core/crypto/crypto_primitives.cpp) build_static_extension(${TARGET_NAME} ${EXTENSION_SOURCES}) build_loadable_extension(${TARGET_NAME} " " ${EXTENSION_SOURCES}) diff --git a/src/core/crypto/crypto_primitives.cpp b/src/core/crypto/crypto_primitives.cpp index e26c6ea..50c1f0a 100644 --- a/src/core/crypto/crypto_primitives.cpp +++ b/src/core/crypto/crypto_primitives.cpp @@ -15,8 +15,10 @@ void sha256(const char *in, size_t in_len, hash_bytes &out) { duckdb_mbedtls::MbedTlsWrapper::ComputeSha256Hash(in, in_len, (char *)out); } -void hmac256(const std::string &message, const char *secret, size_t secret_len, hash_bytes &out) { - duckdb_mbedtls::MbedTlsWrapper::Hmac256(secret, secret_len, message.data(), message.size(), (char *)out); +void hmac256(const std::string &message, const char *secret, size_t secret_len, + hash_bytes &out) { + duckdb_mbedtls::MbedTlsWrapper::Hmac256(secret, secret_len, message.data(), + message.size(), (char *)out); } void hmac256(std::string message, hash_bytes secret, hash_bytes &out) { @@ -33,44 +35,45 @@ void hex256(hash_bytes &in, hash_str &out) { } } -const EVP_CIPHER *GetCipher(const string &key, AESStateSSL::Algorithm algorithm) { - - switch(algorithm) { - case AESStateSSL::GCM: - switch (key.size()) { - case 16: - return EVP_aes_128_gcm(); - case 24: - return EVP_aes_192_gcm(); - case 32: - return EVP_aes_256_gcm(); - default: - throw InternalException("Invalid AES key length"); - } +const EVP_CIPHER *GetCipher(const string &key, + AESStateSSL::Algorithm algorithm) { + + switch (algorithm) { + case AESStateSSL::GCM: + switch (key.size()) { + case 16: + return EVP_aes_128_gcm(); + case 24: + return EVP_aes_192_gcm(); + case 32: + return EVP_aes_256_gcm(); + default: + throw InternalException("Invalid AES key length"); + } - case AESStateSSL::CTR: - switch (key.size()) { - case 16: - return EVP_aes_128_ctr(); - case 24: - return EVP_aes_192_ctr(); - case 32: - return EVP_aes_256_ctr(); - default: - throw InternalException("Invalid AES key length"); - } - case AESStateSSL::OCB: - // For now, we only support GCM ciphers - switch (key.size()) { - case 16: - return EVP_aes_128_ocb(); - case 24: - return EVP_aes_192_ocb(); - case 32: - return EVP_aes_256_ocb(); - default: - throw InternalException("Invalid AES key length"); - } + case AESStateSSL::CTR: + switch (key.size()) { + case 16: + return EVP_aes_128_ctr(); + case 24: + return EVP_aes_192_ctr(); + case 32: + return EVP_aes_256_ctr(); + default: + throw InternalException("Invalid AES key length"); + } + case AESStateSSL::OCB: + // For now, we only support GCM ciphers + switch (key.size()) { + case 16: + return EVP_aes_128_ocb(); + case 24: + return EVP_aes_192_ocb(); + case 32: + return EVP_aes_256_ocb(); + default: + throw InternalException("Invalid AES key length"); + } } } @@ -85,9 +88,7 @@ AESStateSSL::~AESStateSSL() { EVP_CIPHER_CTX_free(context); } -bool AESStateSSL::IsOpenSSL() { - return ssl; -} +bool AESStateSSL::IsOpenSSL() { return ssl; } void AESStateSSL::SetEncryptionAlgorithm(string_t s_algorithm) { @@ -107,36 +108,44 @@ void AESStateSSL::GenerateRandomData(data_ptr_t data, idx_t len) { RAND_bytes(data, len); } -void AESStateSSL::InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const string *key) { - // somewhere here or earlier we should set the encryption algorithm (maybe manually) +void AESStateSSL::InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, + const string *key) { + // somewhere here or earlier we should set the encryption algorithm (maybe + // manually) mode = ENCRYPT; - if (1 != EVP_EncryptInit_ex(context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) { + if (1 != EVP_EncryptInit_ex(context, GetCipher(*key, algorithm), NULL, + const_data_ptr_cast(key->data()), iv)) { throw InternalException("EncryptInit failed"); } } -void AESStateSSL::InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, const string *key) { +void AESStateSSL::InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, + const string *key) { mode = DECRYPT; - if (1 != EVP_DecryptInit_ex(context, GetCipher(*key, algorithm), NULL, const_data_ptr_cast(key->data()), iv)) { + if (1 != EVP_DecryptInit_ex(context, GetCipher(*key, algorithm), NULL, + const_data_ptr_cast(key->data()), iv)) { throw InternalException("DecryptInit failed"); } } -size_t AESStateSSL::Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, idx_t out_len) { +size_t AESStateSSL::Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, + idx_t out_len) { switch (mode) { case ENCRYPT: - if (1 != EVP_EncryptUpdate(context, data_ptr_cast(out), reinterpret_cast(&out_len), + if (1 != EVP_EncryptUpdate(context, data_ptr_cast(out), + reinterpret_cast(&out_len), const_data_ptr_cast(in), (int)in_len)) { throw InternalException("Encryption failed at OpenSSL EVP_EncryptUpdate"); } break; case DECRYPT: - if (1 != EVP_DecryptUpdate(context, data_ptr_cast(out), reinterpret_cast(&out_len), + if (1 != EVP_DecryptUpdate(context, data_ptr_cast(out), + reinterpret_cast(&out_len), const_data_ptr_cast(in), (int)in_len)) { throw InternalException("Decryption failed at OpenSSL EVP_DecryptUpdate"); @@ -151,13 +160,15 @@ size_t AESStateSSL::Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, i return out_len; } -size_t AESStateSSL::Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) { +size_t AESStateSSL::Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, + idx_t tag_len) { auto text_len = out_len; switch (mode) { case ENCRYPT: - if (1 != EVP_EncryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast(&out_len))) { + if (1 != EVP_EncryptFinal_ex(context, data_ptr_cast(out) + out_len, + reinterpret_cast(&out_len))) { throw InternalException("EncryptFinal failed"); } @@ -166,31 +177,32 @@ size_t AESStateSSL::Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_ } // The computed tag is written at the end of a chunk for OCB and GCM - if (1 != EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_GET_TAG, tag_len, - tag)) { + if (1 != EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_GET_TAG, tag_len, tag)) { throw InternalException("Calculating the tag failed"); } return text_len; case DECRYPT: - if (algorithm != CTR){ + if (algorithm != CTR) { // Set expected tag value - if (!EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_SET_TAG, tag_len, - tag)) { + if (!EVP_CIPHER_CTX_ctrl(context, EVP_CTRL_GCM_SET_TAG, tag_len, tag)) { throw InternalException("Finalizing tag failed"); } } - // EVP_DecryptFinal() will return an error code if final block is not correctly formatted. - int ret = EVP_DecryptFinal_ex(context, data_ptr_cast(out) + out_len, reinterpret_cast(&out_len)); + // EVP_DecryptFinal() will return an error code if final block is not + // correctly formatted. + int ret = EVP_DecryptFinal_ex(context, data_ptr_cast(out) + out_len, + reinterpret_cast(&out_len)); text_len += out_len; if (ret > 0) { // success return text_len; } - throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?"); + throw InvalidInputException("Computed AES tag differs from read AES tag, " + "are you using the right key?"); } } diff --git a/src/core/functions/cast/varchar_cast.cpp b/src/core/functions/cast/varchar_cast.cpp index 43aca83..c1ba5f8 100644 --- a/src/core/functions/cast/varchar_cast.cpp +++ b/src/core/functions/cast/varchar_cast.cpp @@ -12,4 +12,4 @@ namespace core { // do something } -} \ No newline at end of file +} // namespace simple_encrypt \ No newline at end of file diff --git a/src/core/functions/function_data/encrypt_function_data.cpp b/src/core/functions/function_data/encrypt_function_data.cpp index f417d6d..f4ed10f 100644 --- a/src/core/functions/function_data/encrypt_function_data.cpp +++ b/src/core/functions/function_data/encrypt_function_data.cpp @@ -17,14 +17,14 @@ bool EncryptFunctionData::Equals(const FunctionData &other_p) const { return true; } -unique_ptr EncryptFunctionData::EncryptBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { +unique_ptr +EncryptFunctionData::EncryptBind(ClientContext &context, + ScalarFunction &bound_function, + vector> &arguments) { // here, implement bound statements? // do something return make_uniq(context); } -} -} - - +} // namespace core +} // namespace simple_encryption diff --git a/src/core/functions/scalar/encrypt.cpp b/src/core/functions/scalar/encrypt.cpp index b495408..30dfa46 100644 --- a/src/core/functions/scalar/encrypt.cpp +++ b/src/core/functions/scalar/encrypt.cpp @@ -45,27 +45,38 @@ shared_ptr InitializeCryptoState(ExpressionState &state) { } template -typename std::enable_if::value || std::is_floating_point::value, T>::type -EncryptValue(EncryptionState *encryption_state, Vector &result, T plaintext_data, uint8_t *buffer_p) { - // actually, you can just for process already give the pointer to the result, thus skip buffer +typename std::enable_if< + std::is_integral::value || std::is_floating_point::value, T>::type +EncryptValue(EncryptionState *encryption_state, Vector &result, + T plaintext_data, uint8_t *buffer_p) { + // actually, you can just for process already give the pointer to the result, + // thus skip buffer T encrypted_data; - encryption_state->Process(reinterpret_cast(&plaintext_data), sizeof(T), reinterpret_cast(&encrypted_data), sizeof(T)); + encryption_state->Process( + reinterpret_cast(&plaintext_data), sizeof(T), + reinterpret_cast(&encrypted_data), sizeof(T)); return encrypted_data; } template -typename std::enable_if::value || std::is_floating_point::value, T>::type -DecryptValue(EncryptionState *encryption_state, Vector &result, T encrypted_data, uint8_t *buffer_p) { - // actually, you can just for process already give the pointer to the result, thus skip buffer +typename std::enable_if< + std::is_integral::value || std::is_floating_point::value, T>::type +DecryptValue(EncryptionState *encryption_state, Vector &result, + T encrypted_data, uint8_t *buffer_p) { + // actually, you can just for process already give the pointer to the result, + // thus skip buffer T decrypted_data; - encryption_state->Process(reinterpret_cast(&encrypted_data), sizeof(T), reinterpret_cast(&decrypted_data), sizeof(T)); + encryption_state->Process( + reinterpret_cast(&encrypted_data), sizeof(T), + reinterpret_cast(&decrypted_data), sizeof(T)); return decrypted_data; } // Handle string_t type and convert to Base64 template typename std::enable_if::value, T>::type -EncryptValue(EncryptionState *encryption_state, Vector &result, T value, uint8_t *buffer_p) { +EncryptValue(EncryptionState *encryption_state, Vector &result, T value, + uint8_t *buffer_p) { // first encrypt the bytes of the string into a temp buffer_p auto input_data = data_ptr_t(value.GetData()); @@ -73,7 +84,8 @@ EncryptValue(EncryptionState *encryption_state, Vector &result, T value, uint8_t encryption_state->Process(input_data, value_size, buffer_p, value_size); // Convert the encrypted data to Base64 - auto encrypted_data = string_t(reinterpret_cast(buffer_p), value_size); + auto encrypted_data = + string_t(reinterpret_cast(buffer_p), value_size); size_t base64_size = Blob::ToBase64Size(encrypted_data); // convert to Base64 into a newly allocated string in the result vector @@ -85,23 +97,28 @@ EncryptValue(EncryptionState *encryption_state, Vector &result, T value, uint8_t template typename std::enable_if::value, T>::type -DecryptValue(EncryptionState *encryption_state, Vector &result, T base64_data, uint8_t *buffer_p) { +DecryptValue(EncryptionState *encryption_state, Vector &result, T base64_data, + uint8_t *buffer_p) { // first encrypt the bytes of the string into a temp buffer_p size_t encrypted_size = Blob::FromBase64Size(base64_data); size_t decrypted_size = encrypted_size; - Blob::FromBase64(base64_data, reinterpret_cast(buffer_p), encrypted_size); + Blob::FromBase64(base64_data, reinterpret_cast(buffer_p), + encrypted_size); D_ASSERT(encrypted_size <= base64_data.GetSize()); string_t decrypted_data = StringVector::EmptyString(result, decrypted_size); - encryption_state->Process(buffer_p, encrypted_size, reinterpret_cast(decrypted_data.GetDataWriteable()), decrypted_size); + encryption_state->Process( + buffer_p, encrypted_size, + reinterpret_cast(decrypted_data.GetDataWriteable()), + decrypted_size); return decrypted_data; } - template -void ExecuteEncryptExecutor(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) { +void ExecuteEncryptExecutor(Vector &vector, Vector &result, idx_t size, + ExpressionState &state, const string &key_t) { // TODO: put this in the state of the extension uint8_t encryption_buffer[MAX_BUFFER_SIZE]; @@ -116,15 +133,18 @@ void ExecuteEncryptExecutor(Vector &vector, Vector &result, idx_t size, Expressi UnaryExecutor::Execute(vector, result, size, [&](T input) -> T { encryption_state->InitializeEncryption(iv, 16, &key_t); - return EncryptValue(encryption_state.get(), result, input, buffer_p);; + return EncryptValue(encryption_state.get(), result, input, buffer_p); + ; }); } // Generated code //--------------------------------------------------------------------------------------------- -// Helper function that dispatches the runtime type to the appropriate templated function -void ExecuteEncrypt(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) { +// Helper function that dispatches the runtime type to the appropriate templated +// function +void ExecuteEncrypt(Vector &vector, Vector &result, idx_t size, + ExpressionState &state, const string &key_t) { // Check the vector type and call the correct templated version switch (vector.GetType().id()) { case LogicalTypeId::INTEGER: @@ -143,7 +163,8 @@ void ExecuteEncrypt(Vector &vector, Vector &result, idx_t size, ExpressionState //--------------------------------------------------------------------------------------------- template -void ExecuteDecryptExecutor(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) { +void ExecuteDecryptExecutor(Vector &vector, Vector &result, idx_t size, + ExpressionState &state, const string &key_t) { // TODO: put this in the state of the extension uint8_t encryption_buffer[MAX_BUFFER_SIZE]; @@ -158,15 +179,18 @@ void ExecuteDecryptExecutor(Vector &vector, Vector &result, idx_t size, Expressi UnaryExecutor::Execute(vector, result, size, [&](T input) -> T { encryption_state->InitializeDecryption(iv, 16, &key_t); - return DecryptValue(encryption_state.get(), result, input, buffer_p);; + return DecryptValue(encryption_state.get(), result, input, buffer_p); + ; }); } // Generated code //--------------------------------------------------------------------------------------------- -// Helper function that dispatches the runtime type to the appropriate templated function -void ExecuteDecrypt(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) { +// Helper function that dispatches the runtime type to the appropriate templated +// function +void ExecuteDecrypt(Vector &vector, Vector &result, idx_t size, + ExpressionState &state, const string &key_t) { // Check the vector type and call the correct templated version switch (vector.GetType().id()) { case LogicalTypeId::INTEGER: @@ -184,7 +208,8 @@ void ExecuteDecrypt(Vector &vector, Vector &result, idx_t size, ExpressionState } //--------------------------------------------------------------------------------------------- -static void EncryptData(DataChunk &args, ExpressionState &state, Vector &result) { +static void EncryptData(DataChunk &args, ExpressionState &state, + Vector &result) { auto &value_vector = args.data[0]; @@ -193,13 +218,15 @@ static void EncryptData(DataChunk &args, ExpressionState &state, Vector &result) D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); // Fetch the encryption key as a constant string - const string key_t = ConstantVector::GetData(key_vector)[0].GetString(); + const string key_t = + ConstantVector::GetData(key_vector)[0].GetString(); // can we not pass by reference? ExecuteEncrypt(value_vector, result, args.size(), state, key_t); } -static void DecryptData(DataChunk &args, ExpressionState &state, Vector &result) { +static void DecryptData(DataChunk &args, ExpressionState &state, + Vector &result) { auto &value_vector = args.data[0]; @@ -208,39 +235,45 @@ static void DecryptData(DataChunk &args, ExpressionState &state, Vector &result) D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); // Fetch the encryption key as a constant string - const string key_t = ConstantVector::GetData(key_vector)[0].GetString(); + const string key_t = + ConstantVector::GetData(key_vector)[0].GetString(); // can we not pass by reference? ExecuteDecrypt(value_vector, result, args.size(), state, key_t); } - ScalarFunctionSet GetEncryptionFunction() { - ScalarFunctionSet set("encrypt"); + ScalarFunctionSet set("encrypt_simple"); - set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalTypeId::INTEGER, EncryptData, + set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, + LogicalTypeId::INTEGER, EncryptData, EncryptFunctionData::EncryptBind)); - set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, LogicalTypeId::BIGINT, EncryptData, + set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, + LogicalTypeId::BIGINT, EncryptData, EncryptFunctionData::EncryptBind)); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, EncryptData, + set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::VARCHAR, EncryptData, EncryptFunctionData::EncryptBind)); return set; } ScalarFunctionSet GetDecryptionFunction() { - ScalarFunctionSet set("decrypt"); + ScalarFunctionSet set("decrypt_simple"); // input is column of any type, key is of type VARCHAR, output is of same type - set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalTypeId::INTEGER, DecryptData, + set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, + LogicalTypeId::INTEGER, DecryptData, EncryptFunctionData::EncryptBind)); - set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, LogicalTypeId::BIGINT, DecryptData, + set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, + LogicalTypeId::BIGINT, DecryptData, EncryptFunctionData::EncryptBind)); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, DecryptData, + set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, + LogicalType::VARCHAR, DecryptData, EncryptFunctionData::EncryptBind)); return set; @@ -255,5 +288,5 @@ void CoreScalarFunctions::RegisterEncryptDataScalarFunction( ExtensionUtil::RegisterFunction(db, GetEncryptionFunction()); ExtensionUtil::RegisterFunction(db, GetDecryptionFunction()); } -} -} \ No newline at end of file +} // namespace core +} // namespace simple_encryption \ No newline at end of file diff --git a/src/core/functions/scalar/encrypt_to_etype.cpp b/src/core/functions/scalar/encrypt_to_etype.cpp index 8fcb6f1..ded8766 100644 --- a/src/core/functions/scalar/encrypt_to_etype.cpp +++ b/src/core/functions/scalar/encrypt_to_etype.cpp @@ -29,16 +29,21 @@ namespace simple_encryption { namespace core { template -typename std::enable_if::value || std::is_floating_point::value, T>::type -ProcessAndCastEncrypt(shared_ptr encryption_state, Vector &result, T plaintext_data, uint8_t *buffer_p) { +typename std::enable_if< + std::is_integral::value || std::is_floating_point::value, T>::type +ProcessAndCastEncrypt(shared_ptr encryption_state, + Vector &result, T plaintext_data, uint8_t *buffer_p) { T encrypted_data; - encryption_state->Process(reinterpret_cast(&plaintext_data), sizeof(int32_t), reinterpret_cast(&encrypted_data), sizeof(int32_t)); + encryption_state->Process( + reinterpret_cast(&plaintext_data), sizeof(int32_t), + reinterpret_cast(&encrypted_data), sizeof(int32_t)); return encrypted_data; } template typename std::enable_if::value, T>::type -ProcessAndCastEncrypt(shared_ptr encryption_state, Vector &result, T plaintext_data, uint8_t *buffer_p) { +ProcessAndCastEncrypt(shared_ptr encryption_state, + Vector &result, T plaintext_data, uint8_t *buffer_p) { auto &children = StructVector::GetEntries(result); // take the third vector of the struct @@ -50,7 +55,8 @@ ProcessAndCastEncrypt(shared_ptr encryption_state, Vector &resu encryption_state->Process(input_data, value_size, buffer_p, value_size); // Convert the encrypted data to Base64 - auto encrypted_data = string_t(reinterpret_cast(buffer_p), value_size); + auto encrypted_data = + string_t(reinterpret_cast(buffer_p), value_size); size_t base64_size = Blob::ToBase64Size(encrypted_data); // convert to Base64 into a newly allocated string in the result vector @@ -62,7 +68,8 @@ ProcessAndCastEncrypt(shared_ptr encryption_state, Vector &resu template typename std::enable_if::value, T>::type -ProcessAndCastDecrypt(shared_ptr encryption_state, Vector &result, T base64_data, uint8_t *buffer_p) { +ProcessAndCastDecrypt(shared_ptr encryption_state, + Vector &result, T base64_data, uint8_t *buffer_p) { auto &children = StructVector::GetEntries(result); auto &result_vector = children[2]; @@ -70,24 +77,34 @@ ProcessAndCastDecrypt(shared_ptr encryption_state, Vector &resu // first encrypt the bytes of the string into a temp buffer_p size_t encrypted_size = Blob::FromBase64Size(base64_data); size_t decrypted_size = encrypted_size; - Blob::FromBase64(base64_data, reinterpret_cast(buffer_p), encrypted_size); + Blob::FromBase64(base64_data, reinterpret_cast(buffer_p), + encrypted_size); D_ASSERT(encrypted_size <= base64_data.GetSize()); - string_t decrypted_data = StringVector::EmptyString(*result_vector, decrypted_size); - encryption_state->Process(buffer_p, encrypted_size, reinterpret_cast(decrypted_data.GetDataWriteable()), decrypted_size); + string_t decrypted_data = + StringVector::EmptyString(*result_vector, decrypted_size); + encryption_state->Process( + buffer_p, encrypted_size, + reinterpret_cast(decrypted_data.GetDataWriteable()), + decrypted_size); return decrypted_data; } template -typename std::enable_if::value || std::is_floating_point::value, T>::type -ProcessAndCastDecrypt(shared_ptr encryption_state, Vector &result, T encrypted_data, uint8_t *buffer_p) { +typename std::enable_if< + std::is_integral::value || std::is_floating_point::value, T>::type +ProcessAndCastDecrypt(shared_ptr encryption_state, + Vector &result, T encrypted_data, uint8_t *buffer_p) { T decrypted_data; - encryption_state->Process(reinterpret_cast(&encrypted_data), sizeof(T), reinterpret_cast(&decrypted_data), sizeof(T)); + encryption_state->Process( + reinterpret_cast(&encrypted_data), sizeof(T), + reinterpret_cast(&decrypted_data), sizeof(T)); return decrypted_data; } -shared_ptr GetSimpleEncryptionState(ExpressionState &state){ +shared_ptr +GetSimpleEncryptionState(ExpressionState &state) { auto &func_expr = (BoundFunctionExpression &)state.expr; auto &info = (EncryptFunctionData &)*func_expr.bind_info; @@ -97,24 +114,25 @@ shared_ptr GetSimpleEncryptionState(ExpressionState &stat "simple_encryption"); return simple_encryption_state; - } -bool HasSpace(shared_ptr simple_encryption_state, uint64_t size) { +bool HasSpace(shared_ptr simple_encryption_state, + uint64_t size) { uint32_t max_value = ~0u; - if ((max_value - simple_encryption_state->counter) > size){ + if ((max_value - simple_encryption_state->counter) > size) { return true; } return false; } void SetIV(shared_ptr simple_encryption_state) { - simple_encryption_state->encryption_state->GenerateRandomData(reinterpret_cast(simple_encryption_state->iv), 12); + simple_encryption_state->iv[0] = simple_encryption_state->iv[1] = 0; + simple_encryption_state->encryption_state->GenerateRandomData( + reinterpret_cast(simple_encryption_state->iv), 12); } -shared_ptr GetEncryptionState(ExpressionState &state){ +shared_ptr GetEncryptionState(ExpressionState &state) { return GetSimpleEncryptionState(state)->encryption_state; - } // todo; template @@ -124,7 +142,6 @@ LogicalType CreateEINTtypeStruct() { {"value", LogicalType::INTEGER}}); } - LogicalType CreateEVARtypeStruct() { return LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, {"nonce_lo", LogicalType::UBIGINT}, @@ -132,69 +149,80 @@ LogicalType CreateEVARtypeStruct() { } template -void EncryptToEtype(LogicalType result_struct, Vector &input_vector, const string key_t, uint64_t size, ExpressionState &state, Vector &result){ +void EncryptToEtype(LogicalType result_struct, Vector &input_vector, + const string key_t, uint64_t size, ExpressionState &state, + Vector &result) { auto simple_encryption_state = GetSimpleEncryptionState(state); + auto encryption_state = GetEncryptionState(state); + // Reset the reference of the result vector Vector struct_vector(result_struct, size); result.ReferenceAndSetType(struct_vector); - auto encryption_state = GetEncryptionState(state); - - if (simple_encryption_state->counter == 0 || !HasSpace(simple_encryption_state, size)) { + if ((simple_encryption_state->counter == 0) || (HasSpace(simple_encryption_state, size) == false)) { // generate new random IV and reset counter SetIV(simple_encryption_state); simple_encryption_state->counter = 0; } + auto &children = StructVector::GetEntries(result); + auto &nonce_hi = children[0]; + nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR); + using ENCRYPTED_TYPE = StructTypeTernary; using PLAINTEXT_TYPE = PrimitiveType; - uint8_t encryption_buffer[MAX_BUFFER_SIZE]; - uint8_t *buffer_p = encryption_buffer; - - GenericExecutor::ExecuteUnary(input_vector, result, size, [&](PLAINTEXT_TYPE input) { - - // increment the low part of the nonce - simple_encryption_state->iv[1]++; - simple_encryption_state->counter++; + GenericExecutor::ExecuteUnary( + input_vector, result, size, [&](PLAINTEXT_TYPE input) { + // increment the low part of the nonce + simple_encryption_state->iv[1]++; + simple_encryption_state->counter++; - encryption_state->InitializeEncryption( - reinterpret_cast(simple_encryption_state->iv), 16, reinterpret_cast(&key_t)); + encryption_state->InitializeEncryption( + reinterpret_cast(simple_encryption_state->iv), 16, + reinterpret_cast(&key_t)); - T encrypted_data = ProcessAndCastEncrypt(encryption_state, result, input.val, simple_encryption_state->buffer_p); + T encrypted_data = + ProcessAndCastEncrypt(encryption_state, result, input.val, + simple_encryption_state->buffer_p); - return ENCRYPTED_TYPE {simple_encryption_state->iv[0], simple_encryption_state->iv[1], encrypted_data}; - }); + return ENCRYPTED_TYPE{simple_encryption_state->iv[0], + simple_encryption_state->iv[1], encrypted_data}; + }); } template -void DecryptFromEtype(Vector &input_vector, const string key_t, uint64_t size, ExpressionState &state, Vector &result){ +void DecryptFromEtype(Vector &input_vector, const string key_t, uint64_t size, + ExpressionState &state, Vector &result) { auto simple_encryption_state = GetSimpleEncryptionState(state); + auto encryption_state = GetEncryptionState(state); uint64_t iv[2]; iv[0] = iv[1] = 0; - auto encryption_state = GetEncryptionState(state); - using ENCRYPTED_TYPE = StructTypeTernary; using PLAINTEXT_TYPE = PrimitiveType; - GenericExecutor::ExecuteUnary(input_vector, result, size, [&](ENCRYPTED_TYPE input) { + GenericExecutor::ExecuteUnary( + input_vector, result, size, [&](ENCRYPTED_TYPE input) { + iv[0] = input.a_val; + iv[1] = input.b_val; - iv[0] = input.a_val; - iv[1] = input.b_val; + encryption_state->InitializeDecryption( + reinterpret_cast(iv), 16, &key_t); - encryption_state->InitializeDecryption( - reinterpret_cast(iv), 16, &key_t); - - T decrypted_data = ProcessAndCastDecrypt(encryption_state, result, input.c_val, simple_encryption_state->buffer_p); - return decrypted_data; - }); + T decrypted_data = + ProcessAndCastDecrypt(encryption_state, result, input.c_val, + simple_encryption_state->buffer_p); + return decrypted_data; + }); } -static void EncryptDataToEtype(DataChunk &args, ExpressionState &state, Vector &result) { + +static void EncryptDataToEtype(DataChunk &args, ExpressionState &state, + Vector &result) { auto &input_vector = args.data[0]; auto vector_type = input_vector.GetType(); @@ -207,30 +235,39 @@ static void EncryptDataToEtype(DataChunk &args, ExpressionState &state, Vector & ConstantVector::GetData(key_vector)[0].GetString(); if (vector_type.IsNumeric()) { - switch (vector_type.id()){ - case LogicalTypeId::TINYINT: - case LogicalTypeId::UTINYINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, size, state, result); - case LogicalTypeId::SMALLINT: - case LogicalTypeId::USMALLINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, size, state, result); - case LogicalTypeId::INTEGER: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, size, state, result); - case LogicalTypeId::UINTEGER: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, size, state, result); - case LogicalTypeId::BIGINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, size, state, result); - case LogicalTypeId::UBIGINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, size, state, result); - case LogicalTypeId::FLOAT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, size, state, result); - case LogicalTypeId::DOUBLE: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, size, state, result); - default: - throw NotImplementedException("Unsupported numeric type for encryption"); - } + switch (vector_type.id()) { + case LogicalTypeId::TINYINT: + case LogicalTypeId::UTINYINT: + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, + size, state, result); + case LogicalTypeId::SMALLINT: + case LogicalTypeId::USMALLINT: + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + key_t, size, state, result); + case LogicalTypeId::INTEGER: + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + key_t, size, state, result); + case LogicalTypeId::UINTEGER: + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + key_t, size, state, result); + case LogicalTypeId::BIGINT: + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + key_t, size, state, result); + case LogicalTypeId::UBIGINT: + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, + key_t, size, state, result); + case LogicalTypeId::FLOAT: + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, + size, state, result); + case LogicalTypeId::DOUBLE: + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, + size, state, result); + default: + throw NotImplementedException("Unsupported numeric type for encryption"); + } } else if (vector_type.id() == LogicalTypeId::VARCHAR) { - return EncryptToEtype(CreateEVARtypeStruct(), input_vector, key_t, size, state, result); + return EncryptToEtype(CreateEVARtypeStruct(), input_vector, key_t, + size, state, result); } else if (vector_type.IsNested()) { throw NotImplementedException( "Nested types are not supported for encryption"); @@ -240,7 +277,9 @@ static void EncryptDataToEtype(DataChunk &args, ExpressionState &state, Vector & } } -static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, Vector &result) { + +static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, + Vector &result) { auto size = args.size(); auto &input_vector = args.data[0]; @@ -262,15 +301,20 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, Vector return DecryptFromEtype(input_vector, key_t, size, state, result); case LogicalTypeId::SMALLINT: case LogicalTypeId::USMALLINT: - return DecryptFromEtype(input_vector, key_t, size, state, result); + return DecryptFromEtype(input_vector, key_t, size, state, + result); case LogicalTypeId::INTEGER: - return DecryptFromEtype(input_vector, key_t, size, state, result); + return DecryptFromEtype(input_vector, key_t, size, state, + result); case LogicalTypeId::UINTEGER: - return DecryptFromEtype(input_vector, key_t, size, state, result); + return DecryptFromEtype(input_vector, key_t, size, state, + result); case LogicalTypeId::BIGINT: - return DecryptFromEtype(input_vector, key_t, size, state, result); + return DecryptFromEtype(input_vector, key_t, size, state, + result); case LogicalTypeId::UBIGINT: - return DecryptFromEtype(input_vector, key_t, size, state, result); + return DecryptFromEtype(input_vector, key_t, size, state, + result); case LogicalTypeId::FLOAT: return DecryptFromEtype(input_vector, key_t, size, state, result); case LogicalTypeId::DOUBLE: @@ -291,9 +335,9 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, Vector } ScalarFunctionSet GetEncryptionStructFunction() { - ScalarFunctionSet set("encrypt_etypes"); + ScalarFunctionSet set("encrypt"); - for (auto &type: LogicalType::AllTypes()) { + for (auto &type : LogicalType::AllTypes()) { set.AddFunction( ScalarFunction({type, LogicalType::VARCHAR}, LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, @@ -306,9 +350,9 @@ ScalarFunctionSet GetEncryptionStructFunction() { } ScalarFunctionSet GetDecryptionStructFunction() { - ScalarFunctionSet set("decrypt_etypes"); + ScalarFunctionSet set("decrypt"); - for (auto &type: LogicalType::AllTypes()) { + for (auto &type : LogicalType::AllTypes()) { for (auto &nonce_type_a : LogicalType::Numeric()) { for (auto &nonce_type_b : LogicalType::Numeric()) { set.AddFunction(ScalarFunction( @@ -319,17 +363,16 @@ ScalarFunctionSet GetDecryptionStructFunction() { type, DecryptDataFromEtype, EncryptFunctionData::EncryptBind)); } } - } - // TODO: Fix EINT encryption -// set.AddFunction(ScalarFunction({EncryptionTypes::E_INT(), LogicalType::VARCHAR}, LogicalTypeId::INTEGER, DecryptDataChunkStruct, -// EncryptFunctionData::EncryptBind)); + // TODO: Fix EINT encryption +// set.AddFunction(ScalarFunction({EncryptionTypes::E_INTEGER(), +// LogicalType::VARCHAR}, LogicalTypeId::INTEGER, DecryptDataFromEtype, +// EncryptFunctionData::EncryptBind)); + } return set; } - - //------------------------------------------------------------------------------ // Register functions //------------------------------------------------------------------------------ @@ -339,5 +382,5 @@ void CoreScalarFunctions::RegisterEncryptDataStructScalarFunction( ExtensionUtil::RegisterFunction(db, GetEncryptionStructFunction()); ExtensionUtil::RegisterFunction(db, GetDecryptionStructFunction()); } -} -} +} // namespace core +} // namespace simple_encryption diff --git a/src/core/functions/table/encrypt_table.cpp b/src/core/functions/table/encrypt_table.cpp index f296b34..6286798 100644 --- a/src/core/functions/table/encrypt_table.cpp +++ b/src/core/functions/table/encrypt_table.cpp @@ -12,8 +12,7 @@ namespace simple_encryption { namespace core { void CreateEncryptColumnFunction::CreateEncryptColumnFunc( - ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { -} + ClientContext &context, TableFunctionInput &data_p, DataChunk &output) {} //------------------------------------------------------------------------------ // Register functions @@ -24,5 +23,5 @@ void CoreTableFunctions::RegisterEncryptColumnTableFunction( ExtensionUtil::RegisterFunction(db, CreateEncryptColumnFunction()); } -} -} \ No newline at end of file +} // namespace core +} // namespace simple_encryption \ No newline at end of file diff --git a/src/core/types.cpp b/src/core/types.cpp index 393c862..947365d 100644 --- a/src/core/types.cpp +++ b/src/core/types.cpp @@ -8,22 +8,48 @@ namespace simple_encryption { namespace core { -LogicalType EncryptionTypes::E_INT() { - auto type = LogicalType::STRUCT({{"nonce", LogicalType::INTEGER}, {"value", LogicalType::INTEGER}}); - type.SetAlias("E_INT"); +LogicalType EncryptionTypes::E_INTEGER() { + auto type = LogicalType::STRUCT( + {{"nonce_hi", LogicalType::BIGINT}, {"nonce_lo", LogicalType::BIGINT}, {"value", LogicalType::INTEGER}}); + type.SetAlias("E_INTEGER"); + return type; +} + +LogicalType EncryptionTypes::EA_INTEGER() { + auto type = LogicalType::STRUCT( + {{"value", LogicalType::INTEGER}, {"nonce_hi", LogicalType::BIGINT}, {"nonce_lo", LogicalType::BIGINT}, + {"tag", LogicalType::VARCHAR}}); + type.SetAlias("EA_INTEGER"); + return type; +} + +LogicalType EncryptionTypes::E_UINTEGER() { + auto type = LogicalType::STRUCT( + {{"nonce_hi", LogicalType::BIGINT}, {"nonce_lo", LogicalType::BIGINT}, {"value", LogicalType::UINTEGER}}); + type.SetAlias("E_UINTEGER"); + return type; +} + +LogicalType EncryptionTypes::EA_UINTEGER() { + auto type = LogicalType::STRUCT( + {{"value", LogicalType::UINTEGER}, {"nonce_hi", LogicalType::BIGINT}, {"nonce_lo", LogicalType::BIGINT}, + {"tag", LogicalType::VARCHAR}}); + type.SetAlias("EA_UINTEGER"); return type; } LogicalType EncryptionTypes::E_VARCHAR() { - auto blob_type = LogicalType::STRUCT({{"nonce", LogicalType::INTEGER}, {"value", LogicalType::VARCHAR}}); - blob_type.SetAlias("E_VARCHAR"); - return blob_type; + auto type = LogicalType::STRUCT( + {{"nonce_hi", LogicalType::BIGINT}, {"nonce_lo", LogicalType::BIGINT}, {"value", LogicalType::VARCHAR}}); +type.SetAlias("E_VARCHAR"); + return type; } void EncryptionTypes::Register(DatabaseInstance &db) { - // Encrypted INT - ExtensionUtil::RegisterType(db, "E_INT", EncryptionTypes::E_INT()); + // Supported Numeric Values + ExtensionUtil::RegisterType(db, "E_INTEGER", EncryptionTypes::E_INTEGER()); + ExtensionUtil::RegisterType(db, "E_UINTEGER", EncryptionTypes::E_UINTEGER()); // Encrypted VARCHAR ExtensionUtil::RegisterType(db, "E_VARCHAR", EncryptionTypes::E_VARCHAR()); @@ -31,4 +57,4 @@ void EncryptionTypes::Register(DatabaseInstance &db) { } // namespace core -} // namespace spatial +} // namespace simple_encryption diff --git a/src/core/utils/simple_encryption_utils.cpp b/src/core/utils/simple_encryption_utils.cpp index b20f45e..03e96fb 100644 --- a/src/core/utils/simple_encryption_utils.cpp +++ b/src/core/utils/simple_encryption_utils.cpp @@ -7,8 +7,10 @@ namespace simple_encryption { namespace core { // Get SimpleEncryptionState from ClientContext -shared_ptr GetSimpleEncryptionState(ClientContext &context) { - auto lookup = context.registered_state->Get("simple_encryption"); +shared_ptr +GetSimpleEncryptionState(ClientContext &context) { + auto lookup = + context.registered_state->Get("simple_encryption"); if (!lookup) { throw Exception(ExceptionType::INVALID, "Registered simple encryption state not found"); @@ -16,5 +18,5 @@ shared_ptr GetSimpleEncryptionState(ClientContext &contex return lookup; } -} -} +} // namespace core +} // namespace simple_encryption diff --git a/src/include/simple_encryption/core/crypto/crypto_primitives.hpp b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp index c8c2f08..b9c2cdb 100644 --- a/src/include/simple_encryption/core/crypto/crypto_primitives.hpp +++ b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp @@ -6,7 +6,6 @@ #include #include - typedef struct evp_cipher_ctx_st EVP_CIPHER_CTX; namespace duckdb { @@ -16,7 +15,8 @@ typedef unsigned char hash_str[64]; void sha256(const char *in, size_t in_len, hash_bytes &out); -void hmac256(const std::string &message, const char *secret, size_t secret_len, hash_bytes &out); +void hmac256(const std::string &message, const char *secret, size_t secret_len, + hash_bytes &out); void hmac256(std::string message, hash_bytes secret, hash_bytes &out); @@ -33,10 +33,14 @@ class DUCKDB_EXTENSION_API AESStateSSL : public duckdb::EncryptionState { public: bool IsOpenSSL() override; - void InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, const std::string *key) override; - void InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, const std::string *key) override; - size_t Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, idx_t out_len) override; - size_t Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, idx_t tag_len) override; + void InitializeEncryption(const_data_ptr_t iv, idx_t iv_len, + const std::string *key) override; + void InitializeDecryption(const_data_ptr_t iv, idx_t iv_len, + const std::string *key) override; + size_t Process(const_data_ptr_t in, idx_t in_len, data_ptr_t out, + idx_t out_len) override; + size_t Finalize(data_ptr_t out, idx_t out_len, data_ptr_t tag, + idx_t tag_len) override; void GenerateRandomData(data_ptr_t data, idx_t len) override; // crypto-specific functions @@ -57,14 +61,13 @@ extern "C" { class DUCKDB_EXTENSION_API AESStateSSLFactory : public duckdb::EncryptionUtil { public: - explicit AESStateSSLFactory() { - } + explicit AESStateSSLFactory() {} - duckdb::shared_ptr CreateEncryptionState() const override { + duckdb::shared_ptr + CreateEncryptionState() const override { return duckdb::make_shared_ptr(); } - ~AESStateSSLFactory() override { - } + ~AESStateSSLFactory() override {} }; } \ No newline at end of file diff --git a/src/include/simple_encryption/core/functions/cast.hpp b/src/include/simple_encryption/core/functions/cast.hpp index 9ae24b8..593085b 100644 --- a/src/include/simple_encryption/core/functions/cast.hpp +++ b/src/include/simple_encryption/core/functions/cast.hpp @@ -14,9 +14,7 @@ struct CoreVectorOperations { struct CoreCastFunctions { public: - static void Register(DatabaseInstance &db) { - RegisterVarcharCasts(db); - } + static void Register(DatabaseInstance &db) { RegisterVarcharCasts(db); } private: static void RegisterVarcharCasts(DatabaseInstance &db); @@ -24,4 +22,4 @@ struct CoreCastFunctions { } // namespace core -} // namespace spatial \ No newline at end of file +} // namespace simple_encrypt \ No newline at end of file diff --git a/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp b/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp index d61f8a1..303d885 100644 --- a/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp +++ b/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp @@ -6,17 +6,17 @@ namespace simple_encryption { namespace core { -struct EncryptFunctionData: FunctionData { +struct EncryptFunctionData : FunctionData { // Save the ClientContext ClientContext &context; -// BoundStatement relation; + // BoundStatement relation; - EncryptFunctionData(ClientContext &context) - : context(context) {} + EncryptFunctionData(ClientContext &context) : context(context) {} - static unique_ptr EncryptBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments); + static unique_ptr + EncryptBind(ClientContext &context, ScalarFunction &bound_function, + vector> &arguments); unique_ptr Copy() const override; bool Equals(const FunctionData &other_p) const override; diff --git a/src/include/simple_encryption/core/functions/scalar.hpp b/src/include/simple_encryption/core/functions/scalar.hpp index 04307bb..c15efe3 100644 --- a/src/include/simple_encryption/core/functions/scalar.hpp +++ b/src/include/simple_encryption/core/functions/scalar.hpp @@ -14,10 +14,10 @@ struct CoreScalarFunctions { private: static void RegisterEncryptDataScalarFunction(duckdb::DatabaseInstance &db); - static void RegisterEncryptDataStructScalarFunction(duckdb::DatabaseInstance &db); + static void + RegisterEncryptDataStructScalarFunction(duckdb::DatabaseInstance &db); }; - } // namespace core } // namespace simple_encryption diff --git a/src/include/simple_encryption/core/functions/scalar/encrypt.hpp b/src/include/simple_encryption/core/functions/scalar/encrypt.hpp index 3fa0b6d..92a6601 100644 --- a/src/include/simple_encryption/core/functions/scalar/encrypt.hpp +++ b/src/include/simple_encryption/core/functions/scalar/encrypt.hpp @@ -29,5 +29,5 @@ class SimpleEncryptKeys : public ObjectCacheEntry { unordered_map keys; }; -} -} \ No newline at end of file +} // namespace core +} // namespace simple_encryption \ No newline at end of file diff --git a/src/include/simple_encryption/core/functions/table.hpp b/src/include/simple_encryption/core/functions/table.hpp index 475c156..2607828 100644 --- a/src/include/simple_encryption/core/functions/table.hpp +++ b/src/include/simple_encryption/core/functions/table.hpp @@ -15,7 +15,6 @@ struct CoreTableFunctions { static void RegisterEncryptColumnTableFunction(duckdb::DatabaseInstance &db); }; - } // namespace core } // namespace simple_encryption \ No newline at end of file diff --git a/src/include/simple_encryption/core/functions/table/encrypt_table.hpp b/src/include/simple_encryption/core/functions/table/encrypt_table.hpp index e91ec42..0fee058 100644 --- a/src/include/simple_encryption/core/functions/table/encrypt_table.hpp +++ b/src/include/simple_encryption/core/functions/table/encrypt_table.hpp @@ -14,7 +14,6 @@ class CreateEncryptColumnFunction : public TableFunction { static void CreateEncryptColumnFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output); - }; -} -} \ No newline at end of file +} // namespace core +} // namespace simple_encryption \ No newline at end of file diff --git a/src/include/simple_encryption/core/types.hpp b/src/include/simple_encryption/core/types.hpp index 6be3248..0e2bb8a 100644 --- a/src/include/simple_encryption/core/types.hpp +++ b/src/include/simple_encryption/core/types.hpp @@ -6,9 +6,16 @@ namespace simple_encryption { namespace core { struct EncryptionTypes { - static LogicalType E_INT(); + static LogicalType E_INTEGER(); + static LogicalType E_UINTEGER(); + static LogicalType E_BIGINT(); + static LogicalType E_UBIGINT(); static LogicalType E_VARCHAR(); + // For authenticated encryption + static LogicalType EA_INTEGER(); + static LogicalType EA_UINTEGER(); + static void Register(DatabaseInstance &db); }; diff --git a/src/include/simple_encryption/core/utils/simple_encryption_utils.hpp b/src/include/simple_encryption/core/utils/simple_encryption_utils.hpp index ba74a39..0c8d403 100644 --- a/src/include/simple_encryption/core/utils/simple_encryption_utils.hpp +++ b/src/include/simple_encryption/core/utils/simple_encryption_utils.hpp @@ -8,7 +8,8 @@ namespace simple_encryption { namespace core { // Function to get DuckPGQState from ClientContext -shared_ptr GetSimpleEncryptionState(ClientContext &context); +shared_ptr +GetSimpleEncryptionState(ClientContext &context); } // namespace core } // namespace simple_encryption diff --git a/src/include/simple_encryption_extension.hpp b/src/include/simple_encryption_extension.hpp index e3d58dd..852f308 100644 --- a/src/include/simple_encryption_extension.hpp +++ b/src/include/simple_encryption_extension.hpp @@ -6,9 +6,9 @@ namespace duckdb { class SimpleEncryptionExtension : public Extension { public: - void Load(DuckDB &db) override; - std::string Name() override; - std::string Version() const override; + void Load(DuckDB &db) override; + std::string Name() override; + std::string Version() const override; }; } // namespace duckdb diff --git a/src/include/simple_encryption_extension_callback.hpp b/src/include/simple_encryption_extension_callback.hpp index 41924f5..6108d72 100644 --- a/src/include/simple_encryption_extension_callback.hpp +++ b/src/include/simple_encryption_extension_callback.hpp @@ -8,8 +8,9 @@ namespace duckdb { class SimpleEncryptionExtensionCallback : public ExtensionCallback { void OnConnectionOpened(ClientContext &context) override { - context.registered_state->Insert("simple_encryption", - make_shared_ptr(context.shared_from_this())); + context.registered_state->Insert( + "simple_encryption", + make_shared_ptr(context.shared_from_this())); } }; -} \ No newline at end of file +} // namespace duckdb \ No newline at end of file diff --git a/src/include/simple_encryption_state.hpp b/src/include/simple_encryption_state.hpp index 3ae9656..c17ac07 100644 --- a/src/include/simple_encryption_state.hpp +++ b/src/include/simple_encryption_state.hpp @@ -22,8 +22,6 @@ class SimpleEncryptionState : public ClientContextState { // encryption buffer uint8_t *buffer_p; - }; } // namespace duckdb - diff --git a/src/simple_encryption_extension.cpp b/src/simple_encryption_extension.cpp index 2139cef..a2a5088 100644 --- a/src/simple_encryption_extension.cpp +++ b/src/simple_encryption_extension.cpp @@ -18,38 +18,36 @@ namespace duckdb { static void LoadInternal(DatabaseInstance &instance) { - // register functions in the core module - simple_encryption::core::CoreModule::Register(instance); + // register functions in the core module + simple_encryption::core::CoreModule::Register(instance); - // Register the SimpleEncryptionState for all connections - auto &config = DBConfig::GetConfig(instance); + // Register the SimpleEncryptionState for all connections + auto &config = DBConfig::GetConfig(instance); - // set pointer to OpenSSL encryption state - config.encryption_util = make_shared_ptr(); + // set pointer to OpenSSL encryption state + config.encryption_util = make_shared_ptr(); - // Add extension callback - config.extension_callbacks.push_back(make_uniq()); + // Add extension callback + config.extension_callbacks.push_back( + make_uniq()); - // Register the SimpleEncryptionState for all connections - for (auto &connection : ConnectionManager::Get(instance).GetConnectionList()) { - connection->registered_state->Insert( - "simple_encryption", - make_shared_ptr(connection)); - } + // Register the SimpleEncryptionState for all connections + for (auto &connection : + ConnectionManager::Get(instance).GetConnectionList()) { + connection->registered_state->Insert( + "simple_encryption", + make_shared_ptr(connection)); + } } -void SimpleEncryptionExtension::Load(DuckDB &db) { - LoadInternal(*db.instance); -} -std::string SimpleEncryptionExtension::Name() { - return "simple_encryption"; -} +void SimpleEncryptionExtension::Load(DuckDB &db) { LoadInternal(*db.instance); } +std::string SimpleEncryptionExtension::Name() { return "simple_encryption"; } std::string SimpleEncryptionExtension::Version() const { #ifdef EXT_VERSION_SIMPLE_ENCRYPTION - return EXT_VERSION_SIMPLE_ENCRYPTION; + return EXT_VERSION_SIMPLE_ENCRYPTION; #else - return "V0.0.1"; + return "V0.0.1"; #endif } @@ -57,14 +55,14 @@ std::string SimpleEncryptionExtension::Version() const { extern "C" { - DUCKDB_EXTENSION_API void simple_encryption_init(duckdb::DatabaseInstance &db) { - duckdb::DuckDB db_wrapper(db); - db_wrapper.LoadExtension(); - } +DUCKDB_EXTENSION_API void simple_encryption_init(duckdb::DatabaseInstance &db) { + duckdb::DuckDB db_wrapper(db); + db_wrapper.LoadExtension(); +} - DUCKDB_EXTENSION_API const char *simple_encryption_version() { - return duckdb::DuckDB::LibraryVersion(); - } +DUCKDB_EXTENSION_API const char *simple_encryption_version() { + return duckdb::DuckDB::LibraryVersion(); +} } #ifndef DUCKDB_EXTENSION_MAIN diff --git a/src/simple_encryption_state.cpp b/src/simple_encryption_state.cpp index b9bd18b..64fc1f7 100644 --- a/src/simple_encryption_state.cpp +++ b/src/simple_encryption_state.cpp @@ -14,7 +14,8 @@ shared_ptr GetEncryptionUtil(ClientContext &context_p) { if (config.encryption_util) { return config.encryption_util; } else { - return make_shared_ptr(); + return make_shared_ptr< + duckdb_mbedtls::MbedTlsWrapper::AESGCMStateMBEDTLSFactory>(); } } @@ -32,9 +33,11 @@ SimpleEncryptionState::SimpleEncryptionState(shared_ptr context) buffer_p = encryption_buffer; // Create a new table containing encryption metadata (nonce, tag) - auto query = new_conn->Query("CREATE TABLE IF NOT EXISTS __simple_encryption_internal (" - "nonce VARCHAR, " - "tag VARCHAR)", false); + auto query = new_conn->Query( + "CREATE TABLE IF NOT EXISTS __simple_encryption_internal (" + "nonce VARCHAR, " + "tag VARCHAR)", + false); if (query->HasError()) { throw TransactionException(query->GetError()); @@ -42,6 +45,6 @@ SimpleEncryptionState::SimpleEncryptionState(shared_ptr context) } void SimpleEncryptionState::QueryEnd() { -// clean up - } + // clean up } +} // namespace duckdb diff --git a/test/sql/bulk_encryption.test b/test/sql/bulk_encryption.test new file mode 100644 index 0000000..3b6cb0a --- /dev/null +++ b/test/sql/bulk_encryption.test @@ -0,0 +1,39 @@ +# name: test/sql/bulk_encryption.test +# description: test simple__encryption extension +# group: [simple_encryption] + +# Require statement will ensure this test is run with this extension loaded +require simple_encryption + +# So, when we do more then 2048 values (i.e. vector size) it crashes +statement ok +CREATE TABLE test_1 AS SELECT 1 AS value FROM range(2048); + +statement ok +SELECT encrypt(value, '0123456789112345') AS encrypted_value FROM test_1; + +statement ok +ALTER TABLE test_1 ADD COLUMN encrypted_values STRUCT(nonce_hi UBIGINT, nonce_lo UBIGINT, value INTEGER); + +statement ok +ALTER TABLE test_1 ADD COLUMN decrypted_values INTEGER; + +statement ok +UPDATE test_1 SET encrypted_values = encrypt(value, '0123456789112345'); + +statement ok +UPDATE test_1 SET decrypted_values = decrypt(encrypted_values, '0123456789112345'); + +query I +SELECT decrypted_values FROM test_1 LIMIT 10; +---- +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 \ No newline at end of file diff --git a/test/sql/simple_struct_encryption.test b/test/sql/simple_struct_encryption.test index 905c5d3..36ea26b 100644 --- a/test/sql/simple_struct_encryption.test +++ b/test/sql/simple_struct_encryption.test @@ -23,7 +23,7 @@ statement ok CREATE TABLE test_1 AS SELECT 1 AS value FROM range(10); statement ok -SELECT encrypt_etypes(value, '0123456789112345') AS encrypted_value FROM test_1; +SELECT encrypt(value, '0123456789112345') AS encrypted_value FROM test_1; statement ok ALTER TABLE test_1 ADD COLUMN encrypted_values STRUCT(nonce_hi UBIGINT, nonce_lo UBIGINT, value INTEGER); @@ -32,10 +32,10 @@ statement ok ALTER TABLE test_1 ADD COLUMN decrypted_values INTEGER; statement ok -UPDATE test_1 SET encrypted_values = encrypt_etypes(value, '0123456789112345'); +UPDATE test_1 SET encrypted_values = encrypt(value, '0123456789112345'); statement ok -UPDATE test_1 SET decrypted_values = decrypt_etypes(encrypted_values, '0123456789112345'); +UPDATE test_1 SET decrypted_values = decrypt(encrypted_values, '0123456789112345'); query I SELECT decrypted_values FROM test_1; @@ -51,6 +51,9 @@ SELECT decrypted_values FROM test_1; 1 1 +statement ok +SELECT encrypt_etypes('testtest', '0123456789112345'); + statement ok CREATE TABLE test_varchar AS SELECT CAST('hello' AS VARCHAR) AS value FROM range(10);