Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Find key by key_name in DuckDB secrets #13

Merged
merged 2 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 100 additions & 36 deletions src/core/functions/scalar/encrypt_to_etype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,40 @@ uint32_t GetLengthFromSecret(ExpressionState &state){
return length_value.GetValue<uint32_t>();
}

std::string GenerateColumnKey(ExpressionState &state, const std::string &message){
string GetSecretFromKeyName(ExpressionState &state, string key_name){

// todo; we can also put the secret in the state
auto &info = GetEncryptionBindInfo(state);
auto &secret_manager = SecretManager::Get(info.context);
auto transaction = CatalogTransaction::GetSystemCatalogTransaction(info.context);
auto secret_entry = secret_manager.GetSecretByName(transaction, key_name);

if (!secret_entry) {
throw InvalidInputException("No secret found with name '%s'.", key_name);
}

// Safely access the secret
if (!secret_entry->secret) {
throw InvalidInputException("Secret found, but '%s' contains no actual secret.", key_name);
}

// Retrieve the secret
auto &secret = *secret_entry->secret;
// Retrieve the key, value secret
const auto *kv_secret = dynamic_cast<const KeyValueSecret *>(&secret);

Value token_value;
if (!kv_secret->TryGetValue("token", token_value)) {
throw InvalidInputException("'token' not found in 'encryption' secret.");
}

return token_value.ToString();
}

std::string GenerateColumnKey(ExpressionState &state, const std::string &key_name, const std::string &message){
// Get the encryption key from DuckDB Secrets Manager
auto secret = GetKeyFromSecret(state);
// auto secret = GetKeyFromSecret(state);
auto secret = GetSecretFromKeyName(state, key_name);
size_t length = GetLengthFromSecret(state);

return CalculateHMAC(secret, message, length);
Expand All @@ -223,11 +254,23 @@ bool HasSpace(shared_ptr<SimpleEncryptionState> simple_encryption_state,


void SetIV(shared_ptr<SimpleEncryptionState> simple_encryption_state) {
simple_encryption_state->iv[0] = simple_encryption_state->iv[1] = 0;
simple_encryption_state->iv[1] = 0;
simple_encryption_state->encryption_state->GenerateRandomData(
reinterpret_cast<data_ptr_t>(simple_encryption_state->iv), 12);
}

bool CheckGeneratedKeySize(const uint32_t size){

switch(size){
case 16:
case 24:
case 32:
return true;
default:
return false;
}
}

shared_ptr<EncryptionState> GetEncryptionState(ExpressionState &state) {
return GetSimpleEncryptionState(state)->encryption_state;
}
Expand All @@ -247,7 +290,7 @@ LogicalType CreateEVARtypeStruct() {

template <typename T>
void EncryptToEtype(LogicalType result_struct, Vector &input_vector,
const string message_t, uint64_t size, ExpressionState &state,
const string key_name, const string message_t, uint64_t size, ExpressionState &state,
Vector &result) {

// this now happens for every chunk, maybe we should already put it in the bind
Expand All @@ -256,10 +299,12 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector,

// calculate column key if no key set yet
if (!simple_encryption_state->key_flag){
simple_encryption_state->key = GenerateColumnKey(state, message_t);
simple_encryption_state->key = GenerateColumnKey(state, key_name, message_t);
simple_encryption_state->key_flag = true;
}

D_ASSERT(CheckGeneratedKeySize(simple_encryption_state->key.size()));

// Reset the reference of the result vector
Vector struct_vector(result_struct, size);
result.ReferenceAndSetType(struct_vector);
Expand All @@ -274,16 +319,18 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector,
auto &nonce_hi = children[0];
nonce_hi->SetVectorType(VectorType::CONSTANT_VECTOR);

auto nonce_lo = simple_encryption_state->iv[1];

using ENCRYPTED_TYPE = StructTypeTernary<uint64_t, uint64_t, T>;
using PLAINTEXT_TYPE = PrimitiveType<T>;

encryption_state->InitializeEncryption(
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16,
reinterpret_cast<const string *>(&simple_encryption_state->key));

GenericExecutor::ExecuteUnary<PLAINTEXT_TYPE, ENCRYPTED_TYPE>(
input_vector, result, size, [&](PLAINTEXT_TYPE input) {

// increment the low part of the nonce
simple_encryption_state->iv[1]++;
simple_encryption_state->counter++;

encryption_state->InitializeEncryption(
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16,
reinterpret_cast<const string *>(&simple_encryption_state->key));
Expand All @@ -292,25 +339,31 @@ void EncryptToEtype(LogicalType result_struct, Vector &input_vector,
ProcessAndCastEncrypt(encryption_state, result, input.val,
simple_encryption_state->buffer_p);

nonce_lo = simple_encryption_state->iv[1];
simple_encryption_state->counter++;
simple_encryption_state->iv[1]++;

return ENCRYPTED_TYPE{simple_encryption_state->iv[0],
simple_encryption_state->iv[1], encrypted_data};
nonce_lo, encrypted_data};
});
}


template <typename T>
void DecryptFromEtype(Vector &input_vector, const string message_t, uint64_t size,
void DecryptFromEtype(Vector &input_vector, const string key_name, const string message_t, uint64_t size,
ExpressionState &state, Vector &result) {

auto simple_encryption_state = GetSimpleEncryptionState(state);
auto encryption_state = GetEncryptionState(state);

// calculate column key if no key set yet
if (!simple_encryption_state->key_flag){
simple_encryption_state->key = GenerateColumnKey(state, message_t);
simple_encryption_state->key = GenerateColumnKey(state, key_name, message_t);
simple_encryption_state->key_flag = true;
}

D_ASSERT(CheckGeneratedKeySize(simple_encryption_state->key.size()));

using ENCRYPTED_TYPE = StructTypeTernary<uint64_t, uint64_t, T>;
using PLAINTEXT_TYPE = PrimitiveType<T>;

Expand All @@ -320,7 +373,7 @@ void DecryptFromEtype(Vector &input_vector, const string message_t, uint64_t siz
simple_encryption_state->iv[1] = input.b_val;

encryption_state->InitializeDecryption(
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 16,
reinterpret_cast<const_data_ptr_t>(simple_encryption_state->iv), 12,
reinterpret_cast<const string *>(&simple_encryption_state->key));

T decrypted_data =
Expand All @@ -338,45 +391,50 @@ static void EncryptDataToEtype(DataChunk &args, ExpressionState &state,
auto vector_type = input_vector.GetType();
auto size = args.size();

// Get the encryption key from client input
auto &message_vector = args.data[1];
// Get the key_name and message from client input
auto &key_name_vector = args.data[1];
auto &message_vector = args.data[2];

D_ASSERT(message_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);
D_ASSERT(key_name_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);
const string message_t =
ConstantVector::GetData<string_t>(message_vector)[0].GetString();
const string key_name_t =
ConstantVector::GetData<string_t>(key_name_vector)[0].GetString();

if (vector_type.IsNumeric()) {
switch (vector_type.id()) {
case LogicalTypeId::TINYINT:
case LogicalTypeId::UTINYINT:
return EncryptToEtype<int8_t>(CreateEINTtypeStruct(), input_vector, message_t,
return EncryptToEtype<int8_t>(CreateEINTtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
case LogicalTypeId::SMALLINT:
case LogicalTypeId::USMALLINT:
return EncryptToEtype<int16_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::INTEGER:
return EncryptToEtype<int32_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::UINTEGER:
return EncryptToEtype<uint32_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::BIGINT:
return EncryptToEtype<int64_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::UBIGINT:
return EncryptToEtype<uint64_t>(CreateEINTtypeStruct(), input_vector,
message_t, size, state, result);
key_name_t, message_t, size, state, result);
case LogicalTypeId::FLOAT:
return EncryptToEtype<float>(CreateEINTtypeStruct(), input_vector, message_t,
return EncryptToEtype<float>(CreateEINTtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
case LogicalTypeId::DOUBLE:
return EncryptToEtype<double>(CreateEINTtypeStruct(), input_vector, message_t,
return EncryptToEtype<double>(CreateEINTtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
default:
throw NotImplementedException("Unsupported numeric type for encryption");
}
} else if (vector_type.id() == LogicalTypeId::VARCHAR) {
return EncryptToEtype<string_t>(CreateEVARtypeStruct(), input_vector, message_t,
return EncryptToEtype<string_t>(CreateEVARtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
} else if (vector_type.IsNested()) {
throw NotImplementedException(
Expand All @@ -393,10 +451,16 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state,

auto size = args.size();
auto &input_vector = args.data[0];
auto &message_vector = args.data[1];

// Get the key_name and message from client input
auto &key_name_vector = args.data[1];
auto &message_vector = args.data[2];
D_ASSERT(message_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);
D_ASSERT(key_name_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);

// Fetch the message as a constant string
const string key_name_t =
ConstantVector::GetData<string_t>(key_name_vector)[0].GetString();
const string message_t =
ConstantVector::GetData<string_t>(message_vector)[0].GetString();

Expand All @@ -408,32 +472,32 @@ static void DecryptDataFromEtype(DataChunk &args, ExpressionState &state,
switch (vector_type.id()) {
case LogicalTypeId::TINYINT:
case LogicalTypeId::UTINYINT:
return DecryptFromEtype<int8_t>(input_vector, message_t, size, state, result);
return DecryptFromEtype<int8_t>(input_vector, key_name_t, message_t, size, state, result);
case LogicalTypeId::SMALLINT:
case LogicalTypeId::USMALLINT:
return DecryptFromEtype<int16_t>(input_vector, message_t, size, state,
return DecryptFromEtype<int16_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::INTEGER:
return DecryptFromEtype<int32_t>(input_vector, message_t, size, state,
return DecryptFromEtype<int32_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::UINTEGER:
return DecryptFromEtype<uint32_t>(input_vector, message_t, size, state,
return DecryptFromEtype<uint32_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::BIGINT:
return DecryptFromEtype<int64_t>(input_vector, message_t, size, state,
return DecryptFromEtype<int64_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::UBIGINT:
return DecryptFromEtype<uint64_t>(input_vector, message_t, size, state,
return DecryptFromEtype<uint64_t>(input_vector, key_name_t, message_t, size, state,
result);
case LogicalTypeId::FLOAT:
return DecryptFromEtype<float>(input_vector, message_t, size, state, result);
return DecryptFromEtype<float>(input_vector, key_name_t, message_t, size, state, result);
case LogicalTypeId::DOUBLE:
return DecryptFromEtype<double>(input_vector, message_t, size, state, result);
return DecryptFromEtype<double>(input_vector, key_name_t, message_t, size, state, result);
default:
throw NotImplementedException("Unsupported numeric type for decryption");
}
} else if (vector_type.id() == LogicalTypeId::VARCHAR) {
return EncryptToEtype<string_t>(CreateEVARtypeStruct(), input_vector, message_t,
return EncryptToEtype<string_t>(CreateEVARtypeStruct(), input_vector, key_name_t, message_t,
size, state, result);
} else if (vector_type.IsNested()) {
throw NotImplementedException(
Expand All @@ -449,7 +513,7 @@ ScalarFunctionSet GetEncryptionStructFunction() {

for (auto &type : LogicalType::AllTypes()) {
set.AddFunction(
ScalarFunction({type, LogicalType::VARCHAR},
ScalarFunction({type, LogicalType::VARCHAR, LogicalType::VARCHAR},
LogicalType::STRUCT({{"nonce_hi", LogicalType::UBIGINT},
{"nonce_lo", LogicalType::UBIGINT},
{"value", type}}),
Expand All @@ -469,7 +533,7 @@ ScalarFunctionSet GetDecryptionStructFunction() {
{LogicalType::STRUCT({{"nonce_hi", nonce_type_a},
{"nonce_lo", nonce_type_b},
{"value", type}}),
LogicalType::VARCHAR},
LogicalType::VARCHAR, LogicalType::VARCHAR},
type, DecryptDataFromEtype, EncryptFunctionData::EncryptBind));
}
}
Expand Down
2 changes: 0 additions & 2 deletions src/core/functions/secrets/authentication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ static void AddSecretParameter(const std::string &key, const CreateSecretInput &

static void RegisterCommonSecretParameters(CreateSecretFunction &function) {
function.named_parameters["master_key"] = LogicalType::VARCHAR;
function.named_parameters["key_name"] = LogicalType::VARCHAR;
function.named_parameters["length"] = LogicalType::INTEGER;
}

Expand All @@ -95,7 +94,6 @@ static unique_ptr<BaseSecret> CreateKeyEncryptionKey(ClientContext &context, Cre

// get the other results from the user input
auto master_key = input.options["master_key"].GetValue<std::string>();
auto key_name = input.options["key_name"].GetValue<std::string>();

// Store the token in the secret
result->secret_map["token"] = Value(master_key);
Expand Down
68 changes: 57 additions & 11 deletions test/sql/secrets/secrets_encryption.test
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,71 @@ require simple_encryption
statement ok
set allow_persistent_secrets=false;

# Create an internal secret (for internal encryption of columns)
statement ok
CREATE SECRET test_key (
TYPE ENCRYPTION,
KEY_NAME 'key_1',
MASTER_KEY '0123456789112345',
LENGTH 16
);
statement error
SELECT encrypt(11, 'key_1', 'random_message');
----
Invalid Input Error: No secret found with name 'key_1'

# Create an internal secret (for internal encryption of columns)
# Create an internal secret with wrong size
statement error
CREATE SECRET test_wrong_length (
TYPE ENCRYPTION,
KEY_NAME 'key_1',
MASTER_KEY '0123456789112345',
LENGTH 99
);
----
Invalid Input Error: Invalid size for encryption key: '99', only a length of 16 bytes is supported

# Create an internal secret (for internal encryption of columns)
statement ok
CREATE SECRET key_1 (
TYPE ENCRYPTION,
MASTER_KEY '0123456789112345',
LENGTH 16
);

query I
SELECT decrypt({'nonce_hi': 11752579000357969348, 'nonce_lo': 2472254480, 'value': 1288890}, 'key_1', 'random_message');
----
2082890652

# nonces are smaller here?
query I
SELECT decrypt({'nonce_hi': 9915119614377941136, 'nonce_lo': 5152853787508998146, 'value': -2098331716}, 'key_1', 'random_message');
----
11

statement ok
SELECT encrypt(11, 'key_1', 'random_message');

statement ok
CREATE TABLE test_1 AS SELECT 1 AS value FROM range(10);

statement ok
SELECT encrypt(value, '0123456789112345') AS encrypted_value FROM test_1;

statement ok
SELECT encrypt(11, 'random_message');
ALTER TABLE test_1 ADD COLUMN encrypted_values STRUCT(nonce_hi UBIGINT, nonce_lo UBIGINT, value INTEGER);

statement ok
ALTER TABLE test_1 ADD COLUMN decrypted_values INTEGER;

statement ok
UPDATE test_1 SET encrypted_values = encrypt(value, 'key_1', 'random_message');

statement ok
UPDATE test_1 SET decrypted_values = decrypt(encrypted_values, 'key_1', 'random_message');

query I
SELECT decrypted_values FROM test_1;
----
1
1
1
1
1
1
1
1
1
1
Loading