diff --git a/src/core/functions/scalar/encrypt_to_etype.cpp b/src/core/functions/scalar/encrypt_to_etype.cpp index af95e33..2f5eba5 100644 --- a/src/core/functions/scalar/encrypt_to_etype.cpp +++ b/src/core/functions/scalar/encrypt_to_etype.cpp @@ -204,9 +204,40 @@ uint32_t GetLengthFromSecret(ExpressionState &state){ return length_value.GetValue(); } -std::string GenerateColumnKey(ExpressionState &state, const std::string &message){ +string GetSecretFromKeyName(ExpressionState &state, string key_name){ + + // todo; we can also put the secret in the state + auto &info = GetEncryptionBindInfo(state); + auto &secret_manager = SecretManager::Get(info.context); + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(info.context); + auto secret_entry = secret_manager.GetSecretByName(transaction, key_name); + + if (!secret_entry) { + throw InvalidInputException("No secret found with name '%s'.", key_name); + } + + // Safely access the secret + if (!secret_entry->secret) { + throw InvalidInputException("Secret found, but '%s' contains no actual secret.", key_name); + } + + // Retrieve the secret + auto &secret = *secret_entry->secret; + // Retrieve the key, value secret + const auto *kv_secret = dynamic_cast(&secret); + + Value token_value; + if (!kv_secret->TryGetValue("token", token_value)) { + throw InvalidInputException("'token' not found in 'encryption' secret."); + } + + return token_value.ToString(); +} + +std::string GenerateColumnKey(ExpressionState &state, const std::string &key_name, const std::string &message){ // Get the encryption key from DuckDB Secrets Manager - auto secret = GetKeyFromSecret(state); +// auto secret = GetKeyFromSecret(state); + auto secret = GetSecretFromKeyName(state, key_name); size_t length = GetLengthFromSecret(state); return CalculateHMAC(secret, message, length); @@ -223,11 +254,23 @@ bool HasSpace(shared_ptr simple_encryption_state, void SetIV(shared_ptr simple_encryption_state) { - simple_encryption_state->iv[0] = simple_encryption_state->iv[1] = 0; + simple_encryption_state->iv[1] = 0; simple_encryption_state->encryption_state->GenerateRandomData( reinterpret_cast(simple_encryption_state->iv), 12); } +bool CheckGeneratedKeySize(const uint32_t size){ + + switch(size){ + case 16: + case 24: + case 32: + return true; + default: + return false; + } +} + shared_ptr GetEncryptionState(ExpressionState &state) { return GetSimpleEncryptionState(state)->encryption_state; } @@ -247,7 +290,7 @@ LogicalType CreateEVARtypeStruct() { template void EncryptToEtype(LogicalType result_struct, Vector &input_vector, - const string message_t, uint64_t size, ExpressionState &state, + const string key_name, const string message_t, uint64_t size, ExpressionState &state, Vector &result) { // this now happens for every chunk, maybe we should already put it in the bind @@ -256,10 +299,12 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector, // calculate column key if no key set yet if (!simple_encryption_state->key_flag){ - simple_encryption_state->key = GenerateColumnKey(state, message_t); + simple_encryption_state->key = GenerateColumnKey(state, key_name, message_t); simple_encryption_state->key_flag = true; } + D_ASSERT(CheckGeneratedKeySize(simple_encryption_state->key.size())); + // Reset the reference of the result vector Vector struct_vector(result_struct, size); result.ReferenceAndSetType(struct_vector); @@ -274,16 +319,18 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector, auto &nonce_hi = children[0]; nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR); + auto nonce_lo = simple_encryption_state->iv[1]; + using ENCRYPTED_TYPE = StructTypeTernary; using PLAINTEXT_TYPE = PrimitiveType; + encryption_state->InitializeEncryption( + reinterpret_cast(simple_encryption_state->iv), 16, + reinterpret_cast(&simple_encryption_state->key)); + GenericExecutor::ExecuteUnary( input_vector, result, size, [&](PLAINTEXT_TYPE input) { - // increment the low part of the nonce - simple_encryption_state->iv[1]++; - simple_encryption_state->counter++; - encryption_state->InitializeEncryption( reinterpret_cast(simple_encryption_state->iv), 16, reinterpret_cast(&simple_encryption_state->key)); @@ -292,14 +339,18 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector, ProcessAndCastEncrypt(encryption_state, result, input.val, simple_encryption_state->buffer_p); + nonce_lo = simple_encryption_state->iv[1]; + simple_encryption_state->counter++; + simple_encryption_state->iv[1]++; + return ENCRYPTED_TYPE{simple_encryption_state->iv[0], - simple_encryption_state->iv[1], encrypted_data}; + nonce_lo, encrypted_data}; }); } template -void DecryptFromEtype(Vector &input_vector, const string message_t, uint64_t size, +void DecryptFromEtype(Vector &input_vector, const string key_name, const string message_t, uint64_t size, ExpressionState &state, Vector &result) { auto simple_encryption_state = GetSimpleEncryptionState(state); @@ -307,10 +358,12 @@ void DecryptFromEtype(Vector &input_vector, const string message_t, uint64_t siz // calculate column key if no key set yet if (!simple_encryption_state->key_flag){ - simple_encryption_state->key = GenerateColumnKey(state, message_t); + simple_encryption_state->key = GenerateColumnKey(state, key_name, message_t); simple_encryption_state->key_flag = true; } + D_ASSERT(CheckGeneratedKeySize(simple_encryption_state->key.size())); + using ENCRYPTED_TYPE = StructTypeTernary; using PLAINTEXT_TYPE = PrimitiveType; @@ -320,7 +373,7 @@ void DecryptFromEtype(Vector &input_vector, const string message_t, uint64_t siz simple_encryption_state->iv[1] = input.b_val; encryption_state->InitializeDecryption( - reinterpret_cast(simple_encryption_state->iv), 16, + reinterpret_cast(simple_encryption_state->iv), 12, reinterpret_cast(&simple_encryption_state->key)); T decrypted_data = @@ -338,45 +391,50 @@ static void EncryptDataToEtype(DataChunk &args, ExpressionState &state, auto vector_type = input_vector.GetType(); auto size = args.size(); - // Get the encryption key from client input - auto &message_vector = args.data[1]; + // Get the key_name and message from client input + auto &key_name_vector = args.data[1]; + auto &message_vector = args.data[2]; + D_ASSERT(message_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + D_ASSERT(key_name_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); const string message_t = ConstantVector::GetData(message_vector)[0].GetString(); + const string key_name_t = + ConstantVector::GetData(key_name_vector)[0].GetString(); if (vector_type.IsNumeric()) { switch (vector_type.id()) { case LogicalTypeId::TINYINT: case LogicalTypeId::UTINYINT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, message_t, + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_name_t, message_t, size, state, result); case LogicalTypeId::SMALLINT: case LogicalTypeId::USMALLINT: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - message_t, size, state, result); + key_name_t, message_t, size, state, result); case LogicalTypeId::INTEGER: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - message_t, size, state, result); + key_name_t, message_t, size, state, result); case LogicalTypeId::UINTEGER: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - message_t, size, state, result); + key_name_t, message_t, size, state, result); case LogicalTypeId::BIGINT: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - message_t, size, state, result); + key_name_t, message_t, size, state, result); case LogicalTypeId::UBIGINT: return EncryptToEtype(CreateEINTtypeStruct(), input_vector, - message_t, size, state, result); + key_name_t, message_t, size, state, result); case LogicalTypeId::FLOAT: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, message_t, + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_name_t, message_t, size, state, result); case LogicalTypeId::DOUBLE: - return EncryptToEtype(CreateEINTtypeStruct(), input_vector, message_t, + return EncryptToEtype(CreateEINTtypeStruct(), input_vector, key_name_t, message_t, size, state, result); default: throw NotImplementedException("Unsupported numeric type for encryption"); } } else if (vector_type.id() == LogicalTypeId::VARCHAR) { - return EncryptToEtype(CreateEVARtypeStruct(), input_vector, message_t, + return EncryptToEtype(CreateEVARtypeStruct(), input_vector, key_name_t, message_t, size, state, result); } else if (vector_type.IsNested()) { throw NotImplementedException( @@ -393,10 +451,16 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, auto size = args.size(); auto &input_vector = args.data[0]; - auto &message_vector = args.data[1]; + + // Get the key_name and message from client input + auto &key_name_vector = args.data[1]; + auto &message_vector = args.data[2]; D_ASSERT(message_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); + D_ASSERT(key_name_vector.GetVectorType() == VectorType::CONSTANT_VECTOR); // Fetch the message as a constant string + const string key_name_t = + ConstantVector::GetData(key_name_vector)[0].GetString(); const string message_t = ConstantVector::GetData(message_vector)[0].GetString(); @@ -408,32 +472,32 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state, switch (vector_type.id()) { case LogicalTypeId::TINYINT: case LogicalTypeId::UTINYINT: - return DecryptFromEtype(input_vector, message_t, size, state, result); + return DecryptFromEtype(input_vector, key_name_t, message_t, size, state, result); case LogicalTypeId::SMALLINT: case LogicalTypeId::USMALLINT: - return DecryptFromEtype(input_vector, message_t, size, state, + return DecryptFromEtype(input_vector, key_name_t, message_t, size, state, result); case LogicalTypeId::INTEGER: - return DecryptFromEtype(input_vector, message_t, size, state, + return DecryptFromEtype(input_vector, key_name_t, message_t, size, state, result); case LogicalTypeId::UINTEGER: - return DecryptFromEtype(input_vector, message_t, size, state, + return DecryptFromEtype(input_vector, key_name_t, message_t, size, state, result); case LogicalTypeId::BIGINT: - return DecryptFromEtype(input_vector, message_t, size, state, + return DecryptFromEtype(input_vector, key_name_t, message_t, size, state, result); case LogicalTypeId::UBIGINT: - return DecryptFromEtype(input_vector, message_t, size, state, + return DecryptFromEtype(input_vector, key_name_t, message_t, size, state, result); case LogicalTypeId::FLOAT: - return DecryptFromEtype(input_vector, message_t, size, state, result); + return DecryptFromEtype(input_vector, key_name_t, message_t, size, state, result); case LogicalTypeId::DOUBLE: - return DecryptFromEtype(input_vector, message_t, size, state, result); + return DecryptFromEtype(input_vector, key_name_t, message_t, size, state, result); default: throw NotImplementedException("Unsupported numeric type for decryption"); } } else if (vector_type.id() == LogicalTypeId::VARCHAR) { - return EncryptToEtype(CreateEVARtypeStruct(), input_vector, message_t, + return EncryptToEtype(CreateEVARtypeStruct(), input_vector, key_name_t, message_t, size, state, result); } else if (vector_type.IsNested()) { throw NotImplementedException( @@ -449,7 +513,7 @@ ScalarFunctionSet GetEncryptionStructFunction() { for (auto &type : LogicalType::AllTypes()) { set.AddFunction( - ScalarFunction({type, LogicalType::VARCHAR}, + ScalarFunction({type, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT}, {"nonce_lo", LogicalType::UBIGINT}, {"value", type}}), @@ -469,7 +533,7 @@ ScalarFunctionSet GetDecryptionStructFunction() { {LogicalType::STRUCT({{"nonce_hi", nonce_type_a}, {"nonce_lo", nonce_type_b}, {"value", type}}), - LogicalType::VARCHAR}, + LogicalType::VARCHAR, LogicalType::VARCHAR}, type, DecryptDataFromEtype, EncryptFunctionData::EncryptBind)); } } diff --git a/src/core/functions/secrets/authentication.cpp b/src/core/functions/secrets/authentication.cpp index 148f95b..9dac397 100644 --- a/src/core/functions/secrets/authentication.cpp +++ b/src/core/functions/secrets/authentication.cpp @@ -68,7 +68,6 @@ static void AddSecretParameter(const std::string &key, const CreateSecretInput & static void RegisterCommonSecretParameters(CreateSecretFunction &function) { function.named_parameters["master_key"] = LogicalType::VARCHAR; - function.named_parameters["key_name"] = LogicalType::VARCHAR; function.named_parameters["length"] = LogicalType::INTEGER; } @@ -95,7 +94,6 @@ static unique_ptr CreateKeyEncryptionKey(ClientContext &context, Cre // get the other results from the user input auto master_key = input.options["master_key"].GetValue(); - auto key_name = input.options["key_name"].GetValue(); // Store the token in the secret result->secret_map["token"] = Value(master_key); diff --git a/test/sql/secrets/secrets_encryption.test b/test/sql/secrets/secrets_encryption.test index b806b53..15fcb11 100644 --- a/test/sql/secrets/secrets_encryption.test +++ b/test/sql/secrets/secrets_encryption.test @@ -11,25 +11,71 @@ require simple_encryption statement ok set allow_persistent_secrets=false; -# Create an internal secret (for internal encryption of columns) -statement ok -CREATE SECRET test_key ( - TYPE ENCRYPTION, - KEY_NAME 'key_1', - MASTER_KEY '0123456789112345', - LENGTH 16 -); +statement error +SELECT encrypt(11, 'key_1', 'random_message'); +---- +Invalid Input Error: No secret found with name 'key_1' -# Create an internal secret (for internal encryption of columns) +# Create an internal secret with wrong size statement error CREATE SECRET test_wrong_length ( TYPE ENCRYPTION, - KEY_NAME 'key_1', MASTER_KEY '0123456789112345', LENGTH 99 ); ---- Invalid Input Error: Invalid size for encryption key: '99', only a length of 16 bytes is supported +# Create an internal secret (for internal encryption of columns) +statement ok +CREATE SECRET key_1 ( + TYPE ENCRYPTION, + MASTER_KEY '0123456789112345', + LENGTH 16 +); + +query I +SELECT decrypt({'nonce_hi': 11752579000357969348, 'nonce_lo': 2472254480, 'value': 1288890}, 'key_1', 'random_message'); +---- +2082890652 + +# nonces are smaller here? +query I +SELECT decrypt({'nonce_hi': 9915119614377941136, 'nonce_lo': 5152853787508998146, 'value': -2098331716}, 'key_1', 'random_message'); +---- +11 + +statement ok +SELECT encrypt(11, 'key_1', 'random_message'); + +statement ok +CREATE TABLE test_1 AS SELECT 1 AS value FROM range(10); + +statement ok +SELECT encrypt(value, '0123456789112345') AS encrypted_value FROM test_1; + statement ok -SELECT encrypt(11, 'random_message'); \ No newline at end of file +ALTER TABLE test_1 ADD COLUMN encrypted_values STRUCT(nonce_hi UBIGINT, nonce_lo UBIGINT, value INTEGER); + +statement ok +ALTER TABLE test_1 ADD COLUMN decrypted_values INTEGER; + +statement ok +UPDATE test_1 SET encrypted_values = encrypt(value, 'key_1', 'random_message'); + +statement ok +UPDATE test_1 SET decrypted_values = decrypt(encrypted_values, 'key_1', 'random_message'); + +query I +SELECT decrypted_values FROM test_1; +---- +1 +1 +1 +1 +1 +1 +1 +1 +1 +1