Skip to content

Commit

Permalink
Merge pull request #13 from ccfelius/secrets
Browse files Browse the repository at this point in the history
Find key by key_name in DuckDB secrets
  • Loading branch information
ccfelius authored Nov 20, 2024
2 parents e869f0b + 12c3511 commit b82f54a
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 49 deletions.
136 changes: 100 additions & 36 deletions src/core/functions/scalar/encrypt_to_etype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,40 @@ uint32_t GetLengthFromSecret(ExpressionState &state){
return length_value.GetValue<uint32_t>();
}

std::string GenerateColumnKey(ExpressionState &state, const std::string &message){
string GetSecretFromKeyName(ExpressionState &state, string key_name){

// todo; we can also put the secret in the state
auto &info = GetEncryptionBindInfo(state);
auto &secret_manager = SecretManager::Get(info.context);
auto transaction = CatalogTransaction::GetSystemCatalogTransaction(info.context);
auto secret_entry = secret_manager.GetSecretByName(transaction, key_name);

if (!secret_entry) {
throw InvalidInputException("No secret found with name '%s'.", key_name);
}

// Safely access the secret
if (!secret_entry->secret) {
throw InvalidInputException("Secret found, but '%s' contains no actual secret.", key_name);
}

// Retrieve the secret
auto &secret = *secret_entry->secret;
// Retrieve the key, value secret
const auto *kv_secret = dynamic_cast<const KeyValueSecret *>(&secret);

Value token_value;
if (!kv_secret->TryGetValue("token", token_value)) {
throw InvalidInputException("'token' not found in 'encryption' secret.");
}

return token_value.ToString();
}

std::string GenerateColumnKey(ExpressionState &state, const std::string &key_name, const std::string &message){
// Get the encryption key from DuckDB Secrets Manager
auto secret = GetKeyFromSecret(state);
// auto secret = GetKeyFromSecret(state);
auto secret = GetSecretFromKeyName(state, key_name);
size_t length = GetLengthFromSecret(state);

return CalculateHMAC(secret, message, length);
Expand All @@ -223,11 +254,23 @@ bool HasSpace(shared_ptr<SimpleEncryptionState> simple_encryption_state,


void SetIV(shared_ptr<SimpleEncryptionState> simple_encryption_state) {
simple_encryption_state->iv[0] = simple_encryption_state->iv[1] = 0;
simple_encryption_state->iv[1] = 0;
simple_encryption_state->encryption_state->GenerateRandomData(
reinterpret_cast<data_ptr_t>(simple_encryption_state->iv), 12);
}

bool CheckGeneratedKeySize(const uint32_t size){

switch(size){
case 16:
case 24:
case 32:
return true;
default:
return false;
}
}

shared_ptr<EncryptionState> GetEncryptionState(ExpressionState &state) {
return GetSimpleEncryptionState(state)->encryption_state;
}
Expand All @@ -247,7 +290,7 @@ LogicalType CreateEVARtypeStruct() {

template <typename T>
void EncryptToEtype(LogicalType result_struct, Vector &input_vector,
const string message_t, uint64_t size, ExpressionState &state,
const string key_name, const string message_t, uint64_t size, ExpressionState &state,
Vector &result) {

// this now happens for every chunk, maybe we should already put it in the bind
Expand All @@ -256,10 +299,12 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector,

// calculate column key if no key set yet
if (!simple_encryption_state->key_flag){
simple_encryption_state->key = GenerateColumnKey(state, message_t);
simple_encryption_state->key = GenerateColumnKey(state, key_name, message_t);
simple_encryption_state->key_flag = true;
}

D_ASSERT(CheckGeneratedKeySize(simple_encryption_state->key.size()));

// Reset the reference of the result vector
Vector struct_vector(result_struct, size);
result.ReferenceAndSetType(struct_vector);
Expand All @@ -274,16 +319,18 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector,
auto &nonce_hi = children[0];
nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR);

auto nonce_lo = simple_encryption_state->iv[1];

using ENCRYPTED_TYPE = StructTypeTernary<uint64_t, uint64_t, T>;
using PLAINTEXT_TYPE = PrimitiveType<T>;

encryption_state->InitializeEncryption(
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16,
reinterpret_cast<const string *>(&simple_encryption_state->key));

GenericExecutor::ExecuteUnary<PLAINTEXT_TYPE, ENCRYPTED_TYPE>(
input_vector, result, size, [&](PLAINTEXT_TYPE input) {

// increment the low part of the nonce
simple_encryption_state->iv[1]++;
simple_encryption_state->counter++;

encryption_state->InitializeEncryption(
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16,
reinterpret_cast<const string *>(&simple_encryption_state->key));
Expand All @@ -292,25 +339,31 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector,
ProcessAndCastEncrypt(encryption_state, result, input.val,
simple_encryption_state->buffer_p);

nonce_lo = simple_encryption_state->iv[1];
simple_encryption_state->counter++;
simple_encryption_state->iv[1]++;

return ENCRYPTED_TYPE{simple_encryption_state->iv[0],
simple_encryption_state->iv[1], encrypted_data};
nonce_lo, encrypted_data};
});
}


template <typename T>
void DecryptFromEtype(Vector &input_vector, const string message_t, uint64_t size,
void DecryptFromEtype(Vector &input_vector, const string key_name, const string message_t, uint64_t size,
ExpressionState &state, Vector &result) {

auto simple_encryption_state = GetSimpleEncryptionState(state);
auto encryption_state = GetEncryptionState(state);

// calculate column key if no key set yet
if (!simple_encryption_state->key_flag){
simple_encryption_state->key = GenerateColumnKey(state, message_t);
simple_encryption_state->key = GenerateColumnKey(state, key_name, message_t);
simple_encryption_state->key_flag = true;
}

D_ASSERT(CheckGeneratedKeySize(simple_encryption_state->key.size()));

using ENCRYPTED_TYPE = StructTypeTernary<uint64_t, uint64_t, T>;
using PLAINTEXT_TYPE = PrimitiveType<T>;

Expand All @@ -320,7 +373,7 @@ void DecryptFromEtype(Vector &input_vector, const string message_t, uint64_t siz
simple_encryption_state->iv[1] = input.b_val;

encryption_state->InitializeDecryption(
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16,
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 12,
reinterpret_cast<const string *>(&simple_encryption_state->key));

T decrypted_data =
Expand All @@ -338,45 +391,50 @@ static void EncryptDataToEtype(DataChunk &args, ExpressionState &state,
auto vector_type = input_vector.GetType();
auto size = args.size();

// Get the encryption key from client input
auto &message_vector = args.data[1];
// Get the key_name and message from client input
auto &key_name_vector = args.data[1];
auto &message_vector = args.data[2];

D_ASSERT(message_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);
D_ASSERT(key_name_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);
const string message_t =
ConstantVector::GetData<string_t>(message_vector)[0].GetString();
const string key_name_t =
ConstantVector::GetData<string_t>(key_name_vector)[0].GetString();

if (vector_type.IsNumeric()) {
switch (vector_type.id()) {
case LogicalTypeId::TINYINT:
case LogicalTypeId::UTINYINT:
return EncryptToEtype<int8_t>(CreateEINTtypeStruct(), input_vector, message_t,
return EncryptToEtype<int8_t>(CreateEINTtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
case LogicalTypeId::SMALLINT:
case LogicalTypeId::USMALLINT:
return EncryptToEtype<int16_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::INTEGER:
return EncryptToEtype<int32_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::UINTEGER:
return EncryptToEtype<uint32_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::BIGINT:
return EncryptToEtype<int64_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::UBIGINT:
return EncryptToEtype<uint64_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::FLOAT:
return EncryptToEtype<float>(CreateEINTtypeStruct(), input_vector, message_t,
return EncryptToEtype<float>(CreateEINTtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
case LogicalTypeId::DOUBLE:
return EncryptToEtype<double>(CreateEINTtypeStruct(), input_vector, message_t,
return EncryptToEtype<double>(CreateEINTtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
default:
throw NotImplementedException("Unsupported numeric type for encryption");
}
} else if (vector_type.id() == LogicalTypeId::VARCHAR) {
return EncryptToEtype<string_t>(CreateEVARtypeStruct(), input_vector, message_t,
return EncryptToEtype<string_t>(CreateEVARtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
} else if (vector_type.IsNested()) {
throw NotImplementedException(
Expand All @@ -393,10 +451,16 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state,

auto size = args.size();
auto &input_vector = args.data[0];
auto &message_vector = args.data[1];

// Get the key_name and message from client input
auto &key_name_vector = args.data[1];
auto &message_vector = args.data[2];
D_ASSERT(message_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);
D_ASSERT(key_name_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);

// Fetch the message as a constant string
const string key_name_t =
ConstantVector::GetData<string_t>(key_name_vector)[0].GetString();
const string message_t =
ConstantVector::GetData<string_t>(message_vector)[0].GetString();

Expand All @@ -408,32 +472,32 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state,
switch (vector_type.id()) {
case LogicalTypeId::TINYINT:
case LogicalTypeId::UTINYINT:
return DecryptFromEtype<int8_t>(input_vector, message_t, size, state, result);
return DecryptFromEtype<int8_t>(input_vector, key_name_t, message_t, size, state, result);
case LogicalTypeId::SMALLINT:
case LogicalTypeId::USMALLINT:
return DecryptFromEtype<int16_t>(input_vector, message_t, size, state,
return DecryptFromEtype<int16_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::INTEGER:
return DecryptFromEtype<int32_t>(input_vector, message_t, size, state,
return DecryptFromEtype<int32_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::UINTEGER:
return DecryptFromEtype<uint32_t>(input_vector, message_t, size, state,
return DecryptFromEtype<uint32_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::BIGINT:
return DecryptFromEtype<int64_t>(input_vector, message_t, size, state,
return DecryptFromEtype<int64_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::UBIGINT:
return DecryptFromEtype<uint64_t>(input_vector, message_t, size, state,
return DecryptFromEtype<uint64_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::FLOAT:
return DecryptFromEtype<float>(input_vector, message_t, size, state, result);
return DecryptFromEtype<float>(input_vector, key_name_t, message_t, size, state, result);
case LogicalTypeId::DOUBLE:
return DecryptFromEtype<double>(input_vector, message_t, size, state, result);
return DecryptFromEtype<double>(input_vector, key_name_t, message_t, size, state, result);
default:
throw NotImplementedException("Unsupported numeric type for decryption");
}
} else if (vector_type.id() == LogicalTypeId::VARCHAR) {
return EncryptToEtype<string_t>(CreateEVARtypeStruct(), input_vector, message_t,
return EncryptToEtype<string_t>(CreateEVARtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
} else if (vector_type.IsNested()) {
throw NotImplementedException(
Expand All @@ -449,7 +513,7 @@ ScalarFunctionSet GetEncryptionStructFunction() {

for (auto &type : LogicalType::AllTypes()) {
set.AddFunction(
ScalarFunction({type, LogicalType::VARCHAR},
ScalarFunction({type, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT},
{"nonce_lo", LogicalType::UBIGINT},
{"value", type}}),
Expand All @@ -469,7 +533,7 @@ ScalarFunctionSet GetDecryptionStructFunction() {
{LogicalType::STRUCT({{"nonce_hi", nonce_type_a},
{"nonce_lo", nonce_type_b},
{"value", type}}),
LogicalType::VARCHAR},
LogicalType::VARCHAR, LogicalType::VARCHAR},
type, DecryptDataFromEtype, EncryptFunctionData::EncryptBind));
}
}
Expand Down
2 changes: 0 additions & 2 deletions src/core/functions/secrets/authentication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ static void AddSecretParameter(const std::string &key, const CreateSecretInput &

static void RegisterCommonSecretParameters(CreateSecretFunction &function) {
function.named_parameters["master_key"] = LogicalType::VARCHAR;
function.named_parameters["key_name"] = LogicalType::VARCHAR;
function.named_parameters["length"] = LogicalType::INTEGER;
}

Expand All @@ -95,7 +94,6 @@ static unique_ptr<BaseSecret> CreateKeyEncryptionKey(ClientContext &context, Cre

// get the other results from the user input
auto master_key = input.options["master_key"].GetValue<std::string>();
auto key_name = input.options["key_name"].GetValue<std::string>();

// Store the token in the secret
result->secret_map["token"] = Value(master_key);
Expand Down
68 changes: 57 additions & 11 deletions test/sql/secrets/secrets_encryption.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,71 @@ require simple_encryption
statement ok
set allow_persistent_secrets=false;

# Create an internal secret (for internal encryption of columns)
statement ok
CREATE SECRET test_key (
TYPE ENCRYPTION,
KEY_NAME 'key_1',
MASTER_KEY '0123456789112345',
LENGTH 16
);
statement error
SELECT encrypt(11, 'key_1', 'random_message');
----
Invalid Input Error: No secret found with name 'key_1'

# Create an internal secret (for internal encryption of columns)
# Create an internal secret with wrong size
statement error
CREATE SECRET test_wrong_length (
TYPE ENCRYPTION,
KEY_NAME 'key_1',
MASTER_KEY '0123456789112345',
LENGTH 99
);
----
Invalid Input Error: Invalid size for encryption key: '99', only a length of 16 bytes is supported

# Create an internal secret (for internal encryption of columns)
statement ok
CREATE SECRET key_1 (
TYPE ENCRYPTION,
MASTER_KEY '0123456789112345',
LENGTH 16
);

query I
SELECT decrypt({'nonce_hi': 11752579000357969348, 'nonce_lo': 2472254480, 'value': 1288890}, 'key_1', 'random_message');
----
2082890652

# nonces are smaller here?
query I
SELECT decrypt({'nonce_hi': 9915119614377941136, 'nonce_lo': 5152853787508998146, 'value': -2098331716}, 'key_1', 'random_message');
----
11

statement ok
SELECT encrypt(11, 'key_1', 'random_message');

statement ok
CREATE TABLE test_1 AS SELECT 1 AS value FROM range(10);

statement ok
SELECT encrypt(value, '0123456789112345') AS encrypted_value FROM test_1;

statement ok
SELECT encrypt(11, 'random_message');
ALTER TABLE test_1 ADD COLUMN encrypted_values STRUCT(nonce_hi UBIGINT, nonce_lo UBIGINT, value INTEGER);

statement ok
ALTER TABLE test_1 ADD COLUMN decrypted_values INTEGER;

statement ok
UPDATE test_1 SET encrypted_values = encrypt(value, 'key_1', 'random_message');

statement ok
UPDATE test_1 SET decrypted_values = decrypt(encrypted_values, 'key_1', 'random_message');

query I
SELECT decrypted_values FROM test_1;
----
1
1
1
1
1
1
1
1
1
1

0 comments on commit b82f54a

Please sign in to comment.