From 8384c8cd23de78ac94066ca500124bd31152ff9c Mon Sep 17 00:00:00 2001 From: ccfelius Date: Tue, 19 Nov 2024 12:12:39 +0100 Subject: [PATCH] Adding HMAC calculation for column keys --- src/core/crypto/CMakeLists.txt | 2 +- src/core/crypto/crypto_primitives.cpp | 35 ++++ .../function_data/encrypt_function_data.cpp | 1 - .../functions/scalar/encrypt_to_etype.cpp | 191 ++++++++++++------ src/core/functions/secrets/authentication.cpp | 16 +- .../core/crypto/crypto_primitives.hpp | 8 + src/include/simple_encryption_state.hpp | 4 + src/simple_encryption_state.cpp | 3 + test/sql/secrets/secrets_encryption.test | 8 +- 9 files changed, 185 insertions(+), 83 deletions(-) diff --git a/src/core/crypto/CMakeLists.txt b/src/core/crypto/CMakeLists.txt index db95d8f..6bae7f1 100644 --- a/src/core/crypto/CMakeLists.txt +++ b/src/core/crypto/CMakeLists.txt @@ -2,4 +2,4 @@ set(EXTENSION_SOURCES ${EXTENSION_SOURCES} ${CMAKE_CURRENT_SOURCE_DIR}/crypto_primitives.cpp PARENT_SCOPE -) \ No newline at end of file +) diff --git a/src/core/crypto/crypto_primitives.cpp b/src/core/crypto/crypto_primitives.cpp index 63a1e6f..885d46b 100644 --- a/src/core/crypto/crypto_primitives.cpp +++ b/src/core/crypto/crypto_primitives.cpp @@ -4,10 +4,15 @@ #include "duckdb/common/common.hpp" #include +// todo; use httplib for windows compatibility +//#define CPPHTTPLIB_OPENSSL_SUPPORT +//#include "duckdb/third_party/httplib/httplib.hpp" + // OpenSSL functions #include #include #include +#include namespace duckdb { @@ -190,4 +195,34 @@ extern "C" { DUCKDB_EXTENSION_API AESStateSSLFactory *CreateSSLFactory() { return new AESStateSSLFactory(); }; +} + +namespace simple_encryption { + +namespace core { + +std::string CalculateHMAC(const std::string &secret, const std::string &message, const uint32_t length) { + const EVP_MD *algorithm = EVP_sha256(); // Replace with EVP_sha1(), EVP_md5(), etc., if needed. + unsigned char key_buffer[32]; + + // Output buffer and length + unsigned char hmacResult[EVP_MAX_MD_SIZE]; + unsigned int hmacLength = 0; + + // Compute the HMAC + HMAC(algorithm, + secret.data(), secret.size(), // Key + reinterpret_cast(message.data()), message.size(), // Message + hmacResult, &hmacLength); + + // Copy the desired number of bytes + memcpy(key_buffer, hmacResult, length); + + // convert to string + std::string result_key(reinterpret_cast(key_buffer), length); + + return result_key; +} + +} } \ No newline at end of file diff --git a/src/core/functions/function_data/encrypt_function_data.cpp b/src/core/functions/function_data/encrypt_function_data.cpp index f4ed10f..ae52820 100644 --- a/src/core/functions/function_data/encrypt_function_data.cpp +++ b/src/core/functions/function_data/encrypt_function_data.cpp @@ -23,7 +23,6 @@ EncryptFunctionData::EncryptBind(ClientContext &context, vector> &arguments) { // here, implement bound statements? - // do something return make_uniq(context); } } // namespace core diff --git a/src/core/functions/scalar/encrypt_to_etype.cpp b/src/core/functions/scalar/encrypt_to_etype.cpp index 8b8bc00..af95e33 100644 --- a/src/core/functions/scalar/encrypt_to_etype.cpp +++ b/src/core/functions/scalar/encrypt_to_etype.cpp @@ -1,25 +1,28 @@ #define DUCKDB_EXTENSION_MAIN #include "duckdb.hpp" -#include "duckdb/common/exception.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/extension_util.hpp" +#include "duckdb/main/connection_manager.hpp" +#include "duckdb/main/secret/secret_manager.hpp" +#include "duckdb/function/scalar_function.hpp" #include "duckdb/common/types.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/types/blob.hpp" #include "duckdb/common/encryption_state.hpp" -#include "duckdb/function/scalar_function.hpp" -#include "duckdb/main/extension_util.hpp" +#include "duckdb/common/vector_operations/generic_executor.hpp" +#include "duckdb/planner/expression/bound_function_expression.hpp" #include "mbedtls_wrapper.hpp" + #include -#include "duckdb/common/types/blob.hpp" -#include "duckdb/main/connection_manager.hpp" -#include "simple_encryption/core/functions/scalar/encrypt.hpp" -#include "simple_encryption/core/functions/scalar.hpp" + #include "simple_encryption_state.hpp" -#include "duckdb/main/client_context.hpp" -#include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" #include "simple_encryption/core/types.hpp" +#include "simple_encryption/core/crypto/crypto_primitives.hpp" +#include "simple_encryption/core/functions/scalar.hpp" #include "simple_encryption/core/functions/secrets.hpp" -#include "duckdb/common/vector_operations/generic_executor.hpp" -#include "duckdb/main/secret/secret_manager.hpp" +#include "simple_encryption/core/functions/scalar/encrypt.hpp" +#include "simple_encryption/core/functions/function_data/encrypt_function_data.hpp" namespace simple_encryption { @@ -115,9 +118,35 @@ GetSimpleEncryptionState(ExpressionState &state) { "simple_encryption"); } +const KeyValueSecret* GetSecret(ExpressionState &state) { + + // todo; we can also put the secret in the state + auto &info = GetEncryptionBindInfo(state); + auto &secret_manager = SecretManager::Get(info.context); + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(info.context); + auto secret_match = secret_manager.LookupSecret(transaction, "encryption", "encryption"); + + if (!secret_match.HasMatch()) { + throw InvalidInputException("No 'encryption' secret found. Please create a secret with 'CREATE SECRET' first."); + } + + auto &secret = secret_match.GetSecret(); + if (secret.GetType() != "encryption") { + throw InvalidInputException("Invalid secret type. Expected 'encryption', got '%s'", secret.GetType()); + } + + const auto *kv_secret = dynamic_cast(&secret); + if (!kv_secret) { + throw InvalidInputException("Invalid secret format for 'encryption' secret."); + } + + return kv_secret; +} + std::string GetKeyFromSecret(ExpressionState &state) { + // todo; we can also put the secret in the state auto &info = GetEncryptionBindInfo(state); auto &secret_manager = SecretManager::Get(info.context); auto transaction = CatalogTransaction::GetSystemCatalogTransaction(info.context); @@ -142,18 +171,46 @@ std::string GetKeyFromSecret(ExpressionState &state) { throw InvalidInputException("'token' not found in 'encryption' secret."); } -// -// // Parse optional label parameter -// std::string label = ""; // Default to fetching all emails if no label is provided -// if (input.named_parameters.find("mail_label") != input.named_parameters.end()) { -// label = input.named_parameters.at("mail_label").GetValue(); -// } + return token_value.ToString(); +} + +uint32_t GetLengthFromSecret(ExpressionState &state){ - std::string token = token_value.ToString(); + // todo; we can also put the secret in the state + auto &info = GetEncryptionBindInfo(state); + auto &secret_manager = SecretManager::Get(info.context); + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(info.context); + auto secret_match = secret_manager.LookupSecret(transaction, "encryption", "encryption"); + + if (!secret_match.HasMatch()) { + throw InvalidInputException("No 'encryption' secret found. Please create a secret with 'CREATE SECRET' first."); + } + + auto &secret = secret_match.GetSecret(); + if (secret.GetType() != "encryption") { + throw InvalidInputException("Invalid secret type. Expected 'encryption', got '%s'", secret.GetType()); + } - return token; + const auto *kv_secret = dynamic_cast(&secret); + if (!kv_secret) { + throw InvalidInputException("Invalid secret format for 'encryption' secret."); + } + + Value length_value; + if (!kv_secret->TryGetValue("length", length_value)) { + throw InvalidInputException("'length' not found in 'encryption' secret."); + } + + return length_value.GetValue(); } +std::string GenerateColumnKey(ExpressionState &state, const std::string &message){ + // Get the encryption key from DuckDB Secrets Manager + auto secret = GetKeyFromSecret(state); + size_t length = GetLengthFromSecret(state); + + return CalculateHMAC(secret, message, length); +} bool HasSpace(shared_ptr simple_encryption_state, uint64_t size) { @@ -190,12 +247,19 @@ LogicalType CreateEVARtypeStruct() { template void EncryptToEtype(LogicalType result_struct, Vector &input_vector, - const string key_t, uint64_t size, ExpressionState &state, + const string message_t, uint64_t size, ExpressionState &state, Vector &result) { + // this now happens for every chunk, maybe we should already put it in the bind auto simple_encryption_state = GetSimpleEncryptionState(state); auto encryption_state = GetEncryptionState(state); + // calculate column key if no key set yet + if (!simple_encryption_state->key_flag){ + simple_encryption_state->key = GenerateColumnKey(state, message_t); + simple_encryption_state->key_flag = true; + } + // Reset the reference of the result vector Vector struct_vector(result_struct, size); result.ReferenceAndSetType(struct_vector); @@ -213,10 +277,6 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector, using ENCRYPTED_TYPE = StructTypeTernary; using PLAINTEXT_TYPE = PrimitiveType; - encryption_state->InitializeEncryption( - reinterpret_cast(simple_encryption_state->iv), 16, - reinterpret_cast(&key_t)); - GenericExecutor::ExecuteUnary( input_vector, result, size, [&](PLAINTEXT_TYPE input) { @@ -226,7 +286,7 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector, encryption_state->InitializeEncryption( reinterpret_cast(simple_encryption_state->iv), 16, - reinterpret_cast(&key_t)); + reinterpret_cast(&simple_encryption_state->key)); T encrypted_data = ProcessAndCastEncrypt(encryption_state, result, input.val, @@ -235,30 +295,33 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector, return ENCRYPTED_TYPE{simple_encryption_state->iv[0], simple_encryption_state->iv[1], encrypted_data}; }); - - encryption_state->Finalize(simple_encryption_state->buffer_p, 0, nullptr, NULL); } + template -void DecryptFromEtype(Vector &input_vector, const string key_t, uint64_t size, +void DecryptFromEtype(Vector &input_vector, const string message_t, uint64_t size, ExpressionState &state, Vector &result) { auto simple_encryption_state = GetSimpleEncryptionState(state); auto encryption_state = GetEncryptionState(state); - uint64_t iv[2]; - iv[0] = iv[1] = 0; + // calculate column key if no key set yet + if (!simple_encryption_state->key_flag){ + simple_encryption_state->key = GenerateColumnKey(state, message_t); + simple_encryption_state->key_flag = true; + } using ENCRYPTED_TYPE = StructTypeTernary; using PLAINTEXT_TYPE = PrimitiveType; GenericExecutor::ExecuteUnary( input_vector, result, size, [&](ENCRYPTED_TYPE input) { - iv[0] = input.a_val; - iv[1] = input.b_val; + simple_encryption_state->iv[0] = input.a_val; + simple_encryption_state->iv[1] = input.b_val; encryption_state->InitializeDecryption( - reinterpret_cast(iv), 16, &key_t); + reinterpret_cast(simple_encryption_state->iv), 16, + reinterpret_cast(&simple_encryption_state->key)); T decrypted_data = ProcessAndCastDecrypt(encryption_state, result, input.c_val, @@ -275,51 +338,45 @@ static void EncryptDataToEtype(DataChunk &args, ExpressionState &state, auto vector_type = input_vector.GetType(); auto size = args.size(); - // Get the encryption key from DuckDB Secrets Manager - auto encryption_key = GetKeyFromSecret(state); - - // Check if a key is already present in the state - // if not, generate a new key - // Get the encryption key from client input - auto &key_vector = args.data[1]; - D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - const string key_t = - ConstantVector::GetData(key_vector)[0].GetString(); + auto &message_vector = args.data[1]; + D_ASSERT(message_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + const string message_t = + ConstantVector::GetData(message_vector)[0].GetString(); if (vector_type.IsNumeric()) { switch (vector_type.id()) { case LogicalTypeId::TINYINT: case LogicalTypeId::UTINYINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, message_t, size, state, result); case LogicalTypeId::SMALLINT: case LogicalTypeId::USMALLINT: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - key_t, size, state, result); + message_t, size, state, result); case LogicalTypeId::INTEGER: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - key_t, size, state, result); + message_t, size, state, result); case LogicalTypeId::UINTEGER: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - key_t, size, state, result); + message_t, size, state, result); case LogicalTypeId::BIGINT: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - key_t, size, state, result); + message_t, size, state, result); case LogicalTypeId::UBIGINT: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - key_t, size, state, result); + message_t, size, state, result); case LogicalTypeId::FLOAT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, message_t, size, state, result); case LogicalTypeId::DOUBLE: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_t, + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, message_t, size, state, result); default: throw NotImplementedException("Unsupported numeric type for encryption"); } } else if (vector_type.id() == LogicalTypeId::VARCHAR) { - return EncryptToEtype(CreateEVARtypeStruct(), input_vector, key_t, + return EncryptToEtype(CreateEVARtypeStruct(), input_vector, message_t, size, state, result); } else if (vector_type.IsNested()) { throw NotImplementedException( @@ -336,12 +393,12 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, auto size = args.size(); auto &input_vector = args.data[0]; - auto &key_vector = args.data[1]; - D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + auto &message_vector = args.data[1]; + D_ASSERT(message_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - // Fetch the encryption key as a constant string - const string key_t = - ConstantVector::GetData(key_vector)[0].GetString(); + // Fetch the message as a constant string + const string message_t = + ConstantVector::GetData(message_vector)[0].GetString(); auto &children = StructVector::GetEntries(input_vector); // get type of vector containing encrypted values @@ -351,32 +408,32 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, switch (vector_type.id()) { case LogicalTypeId::TINYINT: case LogicalTypeId::UTINYINT: - return DecryptFromEtype(input_vector, key_t, size, state, result); + return DecryptFromEtype(input_vector, message_t, size, state, result); case LogicalTypeId::SMALLINT: case LogicalTypeId::USMALLINT: - return DecryptFromEtype(input_vector, key_t, size, state, + return DecryptFromEtype(input_vector, message_t, size, state, result); case LogicalTypeId::INTEGER: - return DecryptFromEtype(input_vector, key_t, size, state, + return DecryptFromEtype(input_vector, message_t, size, state, result); case LogicalTypeId::UINTEGER: - return DecryptFromEtype(input_vector, key_t, size, state, + return DecryptFromEtype(input_vector, message_t, size, state, result); case LogicalTypeId::BIGINT: - return DecryptFromEtype(input_vector, key_t, size, state, + return DecryptFromEtype(input_vector, message_t, size, state, result); case LogicalTypeId::UBIGINT: - return DecryptFromEtype(input_vector, key_t, size, state, + return DecryptFromEtype(input_vector, message_t, size, state, result); case LogicalTypeId::FLOAT: - return DecryptFromEtype(input_vector, key_t, size, state, result); + return DecryptFromEtype(input_vector, message_t, size, state, result); case LogicalTypeId::DOUBLE: - return DecryptFromEtype(input_vector, key_t, size, state, result); + return DecryptFromEtype(input_vector, message_t, size, state, result); default: throw NotImplementedException("Unsupported numeric type for decryption"); } } else if (vector_type.id() == LogicalTypeId::VARCHAR) { - return EncryptToEtype(CreateEVARtypeStruct(), input_vector, key_t, + return EncryptToEtype(CreateEVARtypeStruct(), input_vector, message_t, size, state, result); } else if (vector_type.IsNested()) { throw NotImplementedException( diff --git a/src/core/functions/secrets/authentication.cpp b/src/core/functions/secrets/authentication.cpp index 814a90b..148f95b 100644 --- a/src/core/functions/secrets/authentication.cpp +++ b/src/core/functions/secrets/authentication.cpp @@ -67,7 +67,7 @@ static void AddSecretParameter(const std::string &key, const CreateSecretInput & static void RegisterCommonSecretParameters(CreateSecretFunction &function) { - function.named_parameters["key_value"] = LogicalType::VARCHAR; + function.named_parameters["master_key"] = LogicalType::VARCHAR; function.named_parameters["key_name"] = LogicalType::VARCHAR; function.named_parameters["length"] = LogicalType::INTEGER; } @@ -90,20 +90,16 @@ static unique_ptr CreateKeyEncryptionKey(ClientContext &context, Cre auto length = input.options["length"].GetValue(); if (!CheckKeySize(length)){ - throw InvalidInputException("Invalid size for encryption key: '%d', expected: 16, 24, or 32", length); + throw InvalidInputException("Invalid size for encryption key: '%d', only a length of 16 bytes is supported", length); } - - // get the results from the user input - auto password = input.options["key_value"].GetValue(); + // get the other results from the user input + auto master_key = input.options["master_key"].GetValue(); auto key_name = input.options["key_name"].GetValue(); - // todo: generate key from user input - // get token from user input - std::string token = "0123456789112345"; - // Store the token in the secret - result->secret_map["token"] = Value(token); + result->secret_map["token"] = Value(master_key); + result->secret_map["length"] = Value(to_string(length)); // Hide (redact) sensitive information RedactSensitiveKeys(*result); diff --git a/src/include/simple_encryption/core/crypto/crypto_primitives.hpp b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp index ff6a3ca..b451a1a 100644 --- a/src/include/simple_encryption/core/crypto/crypto_primitives.hpp +++ b/src/include/simple_encryption/core/crypto/crypto_primitives.hpp @@ -45,6 +45,14 @@ class DUCKDB_EXTENSION_API AESStateSSL : public duckdb::EncryptionState { } // namespace duckdb +namespace simple_encryption { +namespace core { + +std::string CalculateHMAC(const std::string &secret, const std::string &message, uint32_t length); + +} +} + extern "C" { class DUCKDB_EXTENSION_API AESStateSSLFactory : public duckdb::EncryptionUtil { diff --git a/src/include/simple_encryption_state.hpp b/src/include/simple_encryption_state.hpp index 110131a..0ae08bb 100644 --- a/src/include/simple_encryption_state.hpp +++ b/src/include/simple_encryption_state.hpp @@ -21,6 +21,10 @@ class SimpleEncryptionState : public ClientContextState { bool is_initialized = false; uint64_t iv[2]; + // todo; key can also be 24 or 32 (resize or always allocate 32) + std::string key; + bool key_flag = false; + // encryption buffer uint8_t *buffer_p; }; diff --git a/src/simple_encryption_state.cpp b/src/simple_encryption_state.cpp index ff82216..4a2aa0b 100644 --- a/src/simple_encryption_state.cpp +++ b/src/simple_encryption_state.cpp @@ -33,6 +33,9 @@ SimpleEncryptionState::SimpleEncryptionState(shared_ptr context) uint8_t encryption_buffer[MAX_BUFFER_SIZE]; buffer_p = encryption_buffer; + // clear the iv + iv[0] = iv[1] = 0; + // Create a new table containing encryption metadata (nonce, tag) auto query = new_conn->Query( "CREATE TABLE IF NOT EXISTS __simple_encryption_internal (" diff --git a/test/sql/secrets/secrets_encryption.test b/test/sql/secrets/secrets_encryption.test index 10e7eb3..b806b53 100644 --- a/test/sql/secrets/secrets_encryption.test +++ b/test/sql/secrets/secrets_encryption.test @@ -16,7 +16,7 @@ statement ok CREATE SECRET test_key ( TYPE ENCRYPTION, KEY_NAME 'key_1', - KEY_VALUE '0123456789112345', + MASTER_KEY '0123456789112345', LENGTH 16 ); @@ -25,11 +25,11 @@ statement error CREATE SECRET test_wrong_length ( TYPE ENCRYPTION, KEY_NAME 'key_1', - KEY_VALUE '0123456789112345', + MASTER_KEY '0123456789112345', LENGTH 99 ); ---- -Invalid Input Error: Invalid size for encryption key: '99', expected: 16, 24, or 32 +Invalid Input Error: Invalid size for encryption key: '99', only a length of 16 bytes is supported statement ok -SELECT encrypt(11, '0123456789112345'); \ No newline at end of file +SELECT encrypt(11, 'random_message'); \ No newline at end of file