diff --git a/src/core/functions/scalar/encrypt.cpp b/src/core/functions/scalar/encrypt.cpp index 1618077..0a1440c 100644 --- a/src/core/functions/scalar/encrypt.cpp +++ b/src/core/functions/scalar/encrypt.cpp @@ -21,7 +21,6 @@ #include "duckdb/main/client_context.hpp" #include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp" #include "duckdb/planner/expression/bound_function_expression.hpp" -#include "simple_encryption/core/types.hpp" namespace simple_encryption { @@ -69,15 +68,13 @@ shared_ptr InitializeCryptoState(ExpressionState &state) { if (!encryption_state) { return duckdb_mbedtls::MbedTlsWrapper::AESGCMStateMBEDTLSFactory() - .CreateEncryptionState(); + .CreateEncryptionState(); } return encryption_state->CreateEncryptionState(); } template - -// Fix this now with IsNUMERIC LogicalType::IsNumeric() typename std::enable_if::value || std::is_floating_point::value, T>::type EncryptValue(EncryptionState *encryption_state, Vector &result, T plaintext_data, uint8_t *buffer_p) { // actually, you can just for process already give the pointer to the result, thus skip buffer @@ -132,42 +129,6 @@ DecryptValue(EncryptionState *encryption_state, Vector &result, T base64_data, u return decrypted_data; } -template -void ExecuteEncryptStructExecutor(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) { - - // TODO: put this in the state of the extension - uint8_t encryption_buffer[MAX_BUFFER_SIZE]; - uint8_t *buffer_p = encryption_buffer; - - unsigned char iv[16]; - auto encryption_state = InitializeCryptoState(state); - - // TODO: construct nonce based on immutable ROW_ID + hash(col_name) - memcpy(iv, "12345678901", 12); - iv[12] = iv[13] = iv[14] = iv[15] = 0x00; - - UnaryExecutor::Execute(vector, result, size, [&](T input) -> T { - unsigned char byte_array[sizeof(T)]; - auto data_size = GetSizeOfInput(input); - encryption_state->InitializeEncryption(iv, 16, &key_t); - - // Convert input to byte array for processing - memcpy(byte_array, &input, sizeof(T)); - - // Encrypt data - encryption_state->Process(byte_array, data_size, buffer_p, data_size); - T encrypted_data = ConvertCipherText(buffer_p, data_size, byte_array); - -#if 0 - D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast(name.GetData()), state) == 1); -#endif - - // this does not do anything for CTR and therefore can be skipped - encryption_state->Finalize(buffer_p, 0, nullptr, 0); - uint32_t nonce = 1; - return encrypted_data; - }); -} template void ExecuteEncryptExecutor(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) { @@ -209,25 +170,6 @@ void ExecuteEncrypt(Vector &vector, Vector &result, idx_t size, ExpressionState throw NotImplementedException("Unsupported type for Encryption"); } } - -// Helper function that dispatches the runtime type to the appropriate templated function -void ExecuteEncryptStruct(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) { - // Check the vector type and call the correct templated version - switch (vector.GetType().id()) { - case LogicalTypeId::INTEGER: - ExecuteEncryptStructExecutor(vector, result, size, state, key_t); - break; - case LogicalTypeId::BIGINT: - ExecuteEncryptStructExecutor(vector, result, size, state, key_t); - break; - case LogicalTypeId::VARCHAR: - ExecuteEncryptStructExecutor(vector, result, size, state, key_t); - break; - default: - throw NotImplementedException("Unsupported type for Encryption"); - } -} - //--------------------------------------------------------------------------------------------- template @@ -287,21 +229,6 @@ static void EncryptData(DataChunk &args, ExpressionState &state, Vector &result) ExecuteEncrypt(value_vector, result, args.size(), state, key_t); } -static void EncryptDataStruct(DataChunk &args, ExpressionState &state, Vector &result) { - - auto &value_vector = args.data[0]; - - // Get the encryption key - auto &key_vector = args.data[1]; - D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - - // Fetch the encryption key as a constant string - const string key_t = ConstantVector::GetData(key_vector)[0].GetString(); - - // can we not pass by reference? - ExecuteEncryptStruct(value_vector, result, args.size(), state, key_t); -} - static void DecryptData(DataChunk &args, ExpressionState &state, Vector &result) { auto &value_vector = args.data[0]; @@ -321,10 +248,7 @@ static void DecryptData(DataChunk &args, ExpressionState &state, Vector &result) ScalarFunctionSet GetEncryptionFunction() { ScalarFunctionSet set("encrypt"); -// set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalTypeId::INTEGER, EncryptData, -// EncryptFunctionData::EncryptBind)); - - set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalTypeId::INTEGER, EncryptDataStruct, + set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalTypeId::INTEGER, EncryptData, EncryptFunctionData::EncryptBind)); set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, LogicalTypeId::BIGINT, EncryptData, diff --git a/src/core/functions/scalar/struct_encrypt.cpp b/src/core/functions/scalar/struct_encrypt.cpp index d19eb85..4cbdb58 100644 --- a/src/core/functions/scalar/struct_encrypt.cpp +++ b/src/core/functions/scalar/struct_encrypt.cpp @@ -426,7 +426,7 @@ static void EncryptDataStruct(DataChunk &args, ExpressionState &state, Vector &r auto &value_vector = *children[1]; // do we need to put pointers in result.auxiliary - result.auiliary = &result_vector; + result.auxiliary = &result_vector; // just now execute encrypt per value ExecuteEncryptStruct(plaintext_vector, result, args.size(), state, key_t);