From 75c5461c150583d6cfb2dc79ccc2308f0272a396 Mon Sep 17 00:00:00 2001 From: ccfelius Date: Mon, 11 Nov 2024 13:57:42 +0100 Subject: [PATCH] added encryption and decryption into struct for int and varchar --- .../function_data/encrypt_function_data.cpp | 2 + .../functions/scalar/encrypt_to_etype.cpp | 206 ++++++++++++++++-- .../function_data/encrypt_function_data.hpp | 1 + 3 files changed, 185 insertions(+), 24 deletions(-) diff --git a/src/core/functions/function_data/encrypt_function_data.cpp b/src/core/functions/function_data/encrypt_function_data.cpp index b37a93b..f417d6d 100644 --- a/src/core/functions/function_data/encrypt_function_data.cpp +++ b/src/core/functions/function_data/encrypt_function_data.cpp @@ -19,6 +19,8 @@ bool EncryptFunctionData::Equals(const FunctionData &other_p) const { unique_ptr EncryptFunctionData::EncryptBind(ClientContext &context, ScalarFunction &bound_function, vector> &arguments) { + // here, implement bound statements? + // do something return make_uniq(context); } diff --git a/src/core/functions/scalar/encrypt_to_etype.cpp b/src/core/functions/scalar/encrypt_to_etype.cpp index 407ba7d..4b9fa77 100644 --- a/src/core/functions/scalar/encrypt_to_etype.cpp +++ b/src/core/functions/scalar/encrypt_to_etype.cpp @@ -35,6 +35,38 @@ int32_t EncryptValueInt32(EncryptionState *encryption_state, Vector &result, int return encrypted_data; } +string_t EncryptValueToString(shared_ptr encryption_state, Vector &result, string_t value, uint8_t *buffer_p) { + + // first encrypt the bytes of the string into a temp buffer_p + auto input_data = data_ptr_t(value.GetData()); + auto value_size = value.GetSize(); + encryption_state->Process(input_data, value_size, buffer_p, value_size); + + // Convert the encrypted data to Base64 + auto encrypted_data = string_t(reinterpret_cast(buffer_p), value_size); + size_t base64_size = Blob::ToBase64Size(encrypted_data); + + // convert to Base64 into a newly allocated string in the result vector + string_t base64_data = StringVector::EmptyString(result, base64_size); + Blob::ToBase64(encrypted_data, base64_data.GetDataWriteable()); + + return base64_data; +} + +string_t DecryptValueToString(shared_ptr encryption_state, Vector &result, string_t base64_data, uint8_t *buffer_p) { + + // first encrypt the bytes of the string into a temp buffer_p + size_t encrypted_size = Blob::FromBase64Size(base64_data); + size_t decrypted_size = encrypted_size; + Blob::FromBase64(base64_data, reinterpret_cast(buffer_p), encrypted_size); + D_ASSERT(encrypted_size <= base64_data.GetSize()); + + string_t decrypted_data = StringVector::EmptyString(result, decrypted_size); + encryption_state->Process(buffer_p, encrypted_size, reinterpret_cast(decrypted_data.GetDataWriteable()), decrypted_size); + + return decrypted_data; +} + static void EncryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = (BoundFunctionExpression &)state.expr; @@ -62,11 +94,6 @@ static void EncryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect // reset the reference of the result vector result.ReferenceAndSetType(struct_vector); -// // Get the child vectors from the struct vector -// auto &children = StructVector::GetEntries(result); -// children[0] = make_uniq(LogicalType::INTEGER); -// children[1] = make_uniq(LogicalType::INTEGER); - // TODO: put this in the state of the extension uint8_t encryption_buffer[MAX_BUFFER_SIZE]; uint8_t *buffer_p = encryption_buffer; @@ -75,6 +102,7 @@ static void EncryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect auto encryption_state = simple_encryption_state->encryption_state; // this can be an int64_t, we have 12 bytes available... + // this needs to be in the state btw, because it needs to keep increasing PER vector int32_t nonce_count = 0; // TODO: construct nonce based on immutable ROW_ID + hash(col_name) @@ -102,7 +130,7 @@ static void EncryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect }); } -static void DecryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vector &result) { +static void EncryptDataChunkStructString(DataChunk &args, ExpressionState &state, Vector &result) { auto &func_expr = (BoundFunctionExpression &)state.expr; auto &info = (EncryptFunctionData &)*func_expr.bind_info; @@ -113,13 +141,65 @@ static void DecryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect "simple_encryption"); auto &input_vector = args.data[0]; - // D_assert to check whether it is a struct? -// D_ASSERT(input_vector.GetType() == LogicalType::STRUCT); - // get children of input data - auto &children = StructVector::GetEntries(input_vector); - auto &nonce_vector = children[0]; - auto &value_vector = children[1]; + auto &key_vector = args.data[1]; + + D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + + // Fetch the encryption key as a constant string + const string key_t = + ConstantVector::GetData(key_vector)[0].GetString(); + + // create struct_type + LogicalType result_struct = LogicalType::STRUCT( + {{"nonce", LogicalType::INTEGER}, {"value", LogicalType::VARCHAR}}); + Vector struct_vector(result_struct, args.size()); + // reset the reference of the result vector + result.ReferenceAndSetType(struct_vector); + + // TODO: put this in the state of the extension + uint8_t encryption_buffer[MAX_BUFFER_SIZE]; + uint8_t *buffer_p = encryption_buffer; + + unsigned char iv[16]; + auto encryption_state = simple_encryption_state->encryption_state; + + // this can be an int64_t, we have 12 bytes available... + // this needs to be in the state btw, because it needs to keep increasing PER vector + int32_t nonce_count = 0; + + // TODO: construct nonce based on immutable ROW_ID + hash(col_name) + memcpy(iv, "12345678901", 12); + iv[12] = iv[13] = iv[14] = iv[15] = 0x00; + + encryption_state->InitializeEncryption(iv, 16, &key_t); + using ENCRYPTED_TYPE = StructTypeBinary; + using PLAINTEXT_TYPE = PrimitiveType; + + GenericExecutor::ExecuteUnary(input_vector, result, args.size(), [&](PLAINTEXT_TYPE input) { + + // set the nonce + nonce_count++; + memcpy(iv, &nonce_count, sizeof(int32_t)); + + // Encrypt data + string_t encrypted_data = EncryptValueToString(encryption_state, result, input.val, buffer_p); + + return ENCRYPTED_TYPE {nonce_count, encrypted_data}; + }); +} + +static void DecryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vector &result) { + + auto &func_expr = (BoundFunctionExpression &)state.expr; + auto &info = (EncryptFunctionData &)*func_expr.bind_info; + + // refactor this into GetSimpleEncryptionState(info.context); + auto simple_encryption_state = + info.context.registered_state->Get( + "simple_encryption"); + + auto &input_vector = args.data[0]; auto &key_vector = args.data[1]; D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); @@ -144,26 +224,94 @@ static void DecryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect // Decrypt data int32_t decrypted_data; - // also, template this function - BinaryExecutor::Execute( - reinterpret_cast(nonce_vector), - reinterpret_cast(value_vector), result, args.size(), - [&](int32_t nonce, int32_t input) { - // Set the nonce - memcpy(iv, &nonce, sizeof(int32_t)); + using ENCRYPTED_TYPE = StructTypeBinary; + using PLAINTEXT_TYPE = PrimitiveType; + + GenericExecutor::ExecuteUnary( + input_vector, result, args.size(), [&](ENCRYPTED_TYPE input) { + auto nonce = input.a_val; + auto value = input.b_val; + // Set the nonce + memcpy(iv, &nonce, sizeof(int32_t)); + + encryption_state->Process( + reinterpret_cast(&value), sizeof(int32_t), + reinterpret_cast(&decrypted_data), + sizeof(int32_t)); + + return decrypted_data; + }); +} + + static void DecryptDataChunkStructString(DataChunk &args, ExpressionState &state, Vector &result) { + + auto &func_expr = (BoundFunctionExpression &)state.expr; + auto &info = (EncryptFunctionData &)*func_expr.bind_info; + + // refactor this into GetSimpleEncryptionState(info.context); + auto simple_encryption_state = + info.context.registered_state->Get( + "simple_encryption"); + + auto &input_vector = args.data[0]; + auto &key_vector = args.data[1]; + D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + + // Fetch the encryption key as a constant string + const string key_t = + ConstantVector::GetData(key_vector)[0].GetString(); + + // maybe convert vector to unified format? Like they do in other scalar fucntions + + // TODO: put this in the state of the extension + uint8_t encryption_buffer[MAX_BUFFER_SIZE]; + uint8_t *buffer_p = encryption_buffer; + + unsigned char iv[16]; + auto encryption_state = simple_encryption_state->encryption_state; + + // TODO: construct nonce based on immutable ROW_ID + hash(col_name) + memcpy(iv, "12345678901", 12); + iv[12] = iv[13] = iv[14] = iv[15] = 0x00; + + encryption_state->InitializeDecryption(iv, 16, &key_t); // Decrypt data - encryption_state->Process(reinterpret_cast(&input), sizeof(int32_t), reinterpret_cast(&decrypted_data), sizeof(int32_t)); + int32_t decrypted_data; - return decrypted_data; - }); + using ENCRYPTED_TYPE = StructTypeBinary; + using PLAINTEXT_TYPE = PrimitiveType; + + GenericExecutor::ExecuteUnary( + input_vector, result, args.size(), [&](ENCRYPTED_TYPE input) { + auto nonce = input.a_val; + auto value = input.b_val; + + // Set the nonce + memcpy(iv, &nonce, sizeof(int32_t)); + + // Decrypt data + string_t decrypted_data = + DecryptValueToString(encryption_state, result, value, buffer_p); + return decrypted_data; + }); } ScalarFunctionSet GetEncryptionStructFunction() { ScalarFunctionSet set("encrypt_etypes"); - set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, EncryptionTypes::E_INT(), EncryptDataChunkStruct, +// set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, EncryptionTypes::E_INT(), EncryptDataChunkStruct, +// EncryptFunctionData::EncryptBind)); + + // Function to Encrypt INTEGERS + set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalType::STRUCT( + {{"nonce", LogicalType::INTEGER}, {"value", LogicalType::INTEGER}}), EncryptDataChunkStruct, + EncryptFunctionData::EncryptBind)); + + // Function to encrypt VARCHAR + set.AddFunction(ScalarFunction({LogicalTypeId::VARCHAR, LogicalType::VARCHAR}, LogicalType::STRUCT( + {{"nonce", LogicalType::INTEGER}, {"value", LogicalType::VARCHAR}}), EncryptDataChunkStructString, EncryptFunctionData::EncryptBind)); return set; @@ -172,9 +320,19 @@ ScalarFunctionSet GetEncryptionStructFunction() { ScalarFunctionSet GetDecryptionStructFunction() { ScalarFunctionSet set("decrypt_etypes"); - set.AddFunction(ScalarFunction({EncryptionTypes::E_INT(), LogicalType::VARCHAR}, LogicalTypeId::INTEGER, DecryptDataChunkStruct, + // Why is E_INT not working? +// set.AddFunction(ScalarFunction({EncryptionTypes::E_INT(), LogicalType::VARCHAR}, LogicalTypeId::INTEGER, DecryptDataChunkStruct, +// EncryptFunctionData::EncryptBind)); + + // try with input struct? + set.AddFunction(ScalarFunction({LogicalType::STRUCT( +{{"nonce", LogicalType::INTEGER}, {"value", LogicalType::INTEGER}}), LogicalType::VARCHAR}, LogicalTypeId::INTEGER, DecryptDataChunkStruct, EncryptFunctionData::EncryptBind)); + set.AddFunction(ScalarFunction({LogicalType::STRUCT( + {{"nonce", LogicalType::INTEGER}, {"value", LogicalType::VARCHAR}}), LogicalType::VARCHAR}, LogicalTypeId::VARCHAR, DecryptDataChunkStructString, + EncryptFunctionData::EncryptBind)); + return set; } diff --git a/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp b/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp index b50bfe0..d61f8a1 100644 --- a/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp +++ b/src/include/simple_encryption/core/functions/function_data/encrypt_function_data.hpp @@ -10,6 +10,7 @@ struct EncryptFunctionData: FunctionData { // Save the ClientContext ClientContext &context; +// BoundStatement relation; EncryptFunctionData(ClientContext &context) : context(context) {}