Skip to content

Commit

Permalink
Merge branch 'main' into types
Browse files Browse the repository at this point in the history
  • Loading branch information
ccfelius committed Nov 8, 2024
2 parents 1acef62 + 56c96d5 commit a1870ab
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 161 deletions.
181 changes: 43 additions & 138 deletions src/core/functions/scalar/encrypt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,119 +75,62 @@ shared_ptr<EncryptionState> InitializeCryptoState(ExpressionState &state) {
return encryption_state->CreateEncryptionState();
}

shared_ptr<EncryptionState> InitializeDecryption(ExpressionState &state) {

// For now, hardcode everything
const string key = TEST_KEY;
unsigned char iv[16];
memcpy((void *)iv, "12345678901", 16);
//
// // TODO; construct nonce based on immutable ROW_ID + hash(col_name)
iv[12] = 0x00;
iv[13] = 0x00;
iv[14] = 0x00;
iv[15] = 0x00;

auto decryption_state = InitializeCryptoState(state);
decryption_state->InitializeDecryption(iv, 16, &key);

return decryption_state;
}

inline const uint8_t *DecryptValue(uint8_t *buffer, size_t size, ExpressionState &state) {

// Initialize Encryption
auto encryption_state = InitializeDecryption(state);
uint8_t decryption_buffer[MAX_BUFFER_SIZE];
uint8_t *temp_buf = decryption_buffer;

encryption_state->Process(buffer, size, temp_buf, size);

return temp_buf;
}

bool CheckEncryption(string_t printable_encrypted_data, uint8_t *buffer,
size_t size, const uint8_t *value, ExpressionState &state){

// cast encrypted data to blob back and forth
// to check whether data will be lost with casting
auto unblobbed_data = Blob::ToBlob(printable_encrypted_data);
auto encrypted_unblobbed_data =
reinterpret_cast<const uint8_t *>(unblobbed_data.data());

if (memcmp(encrypted_unblobbed_data, buffer, size) != 0) {
throw InvalidInputException(
"Original Encrypted Data differs from Unblobbed Encrypted Data");
}

auto decrypted_data = DecryptValue(buffer, size, state);
if (memcmp(decrypted_data, value, size) != 0) {
throw InvalidInputException(
"Original Data differs from Decrypted Data");
}
return true;
}

// Generated code
//---------------------------------------------------------------------------------------------

template <typename T>

// Fix this now with IsNUMERIC LogicalType::IsNumeric()
typename std::enable_if<std::is_integral<T>::value || std::is_floating_point<T>::value, T>::type
ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
EncryptValue(EncryptionState *encryption_state, Vector &result, T plaintext_data, uint8_t *buffer_p) {
// actually, you can just for process already give the pointer to the result, thus skip buffer
T encrypted_data;
memcpy(&encrypted_data, buffer_p, sizeof(T));
encryption_state->Process(reinterpret_cast<unsigned char*>(&plaintext_data), sizeof(T), reinterpret_cast<unsigned char*>(&encrypted_data), sizeof(T));
return encrypted_data;
}

// Handle string_t type and convert to Base64
//template <typename T>
//typename std::enable_if<std::is_same<T, string_t>::value, T>::type
//ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
// // Create a blob from the encrypted buffer data
// string_t blob(reinterpret_cast<const char *>(buffer_p), data_size);
//
// // Define a base64 output buffer large enough to store the encoded result
// size_t base64_size = Blob::ToBase64Size(blob);
// unique_ptr<char[]> base64_output(new char[base64_size]);
//
// // Convert blob to base64 and store it in the output buffer
// Blob::ToBase64(blob, base64_output.get());
//
// // Return the base64-encoded result as a new string_t
// return string_t(blob.GetString());
//}

// TODO: for decryption, convert a string to blob and then decrypt and then return string_t?
template <typename T>
typename std::enable_if<std::is_same<T, string_t>::value, T>::type
ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
return string_t(reinterpret_cast<const char *>(buffer_p), data_size);
typename std::enable_if<std::is_integral<T>::value || std::is_floating_point<T>::value, T>::type
DecryptValue(EncryptionState *encryption_state, Vector &result, T encrypted_data, uint8_t *buffer_p) {
// actually, you can just for process already give the pointer to the result, thus skip buffer
T decrypted_data;
encryption_state->Process(reinterpret_cast<unsigned char*>(&encrypted_data), sizeof(T), reinterpret_cast<unsigned char*>(&decrypted_data), sizeof(T));
return decrypted_data;
}

// Catch-all for unsupported types
// Handle string_t type and convert to Base64
template <typename T>
typename std::enable_if<!std::is_integral<T>::value && !std::is_floating_point<T>::value && !std::is_same<T, string_t>::value, T>::type
ConvertCipherText(uint8_t *buffer_p, size_t data_size, const uint8_t *input_data) {
throw std::invalid_argument("Unsupported type for Encryption");
}
typename std::enable_if<std::is_same<T, string_t>::value, T>::type
EncryptValue(EncryptionState *encryption_state, Vector &result, T value, uint8_t *buffer_p) {

template <typename T>
typename std::enable_if<!std::is_same<T, string_t>::value, size_t>::type
GetSizeOfInput(const T &input) {
// For numeric types, use sizeof(T) directly
return sizeof(T);
// first encrypt the bytes of the string into a temp buffer_p
auto input_data = data_ptr_t(value.GetData());
auto value_size = value.GetSize();
encryption_state->Process(input_data, value_size, buffer_p, value_size);

// Convert the encrypted data to Base64
auto encrypted_data = string_t(reinterpret_cast<const char*>(buffer_p), value_size);
size_t base64_size = Blob::ToBase64Size(encrypted_data);

// convert to Base64 into a newly allocated string in the result vector
string_t base64_data = StringVector::EmptyString(result, base64_size);
Blob::ToBase64(encrypted_data, base64_data.GetDataWriteable());

return base64_data;
}

// Specialized template for string_t type
template <typename T>
typename std::enable_if<std::is_same<T, string_t>::value, size_t>::type
GetSizeOfInput(const T &input) {
// For string_t, get actual string data size
return input.GetSize();
typename std::enable_if<std::is_same<T, string_t>::value, T>::type
DecryptValue(EncryptionState *encryption_state, Vector &result, T base64_data, uint8_t *buffer_p) {

// first encrypt the bytes of the string into a temp buffer_p
size_t encrypted_size = Blob::FromBase64Size(base64_data);
size_t decrypted_size = encrypted_size;
Blob::FromBase64(base64_data, reinterpret_cast<data_ptr_t>(buffer_p), encrypted_size);
D_ASSERT(encrypted_size <= base64_data.GetSize());

string_t decrypted_data = StringVector::EmptyString(result, decrypted_size);
encryption_state->Process(buffer_p, encrypted_size, reinterpret_cast<unsigned char*>(decrypted_data.GetDataWriteable()), decrypted_size);

return decrypted_data;
}
//---------------------------------------------------------------------------------------------

template <typename T>
void ExecuteEncryptStructExecutor(Vector &vector, Vector &result, idx_t size, ExpressionState &state, const string &key_t) {
Expand Down Expand Up @@ -241,24 +184,8 @@ void ExecuteEncryptExecutor(Vector &vector, Vector &result, idx_t size, Expressi
iv[12] = iv[13] = iv[14] = iv[15] = 0x00;

UnaryExecutor::Execute<T, T>(vector, result, size, [&](T input) -> T {
unsigned char byte_array[sizeof(T)];
auto data_size = GetSizeOfInput(input);
encryption_state->InitializeEncryption(iv, 16, &key_t);

// Convert input to byte array for processing
memcpy(byte_array, &input, sizeof(T));

// Encrypt data
encryption_state->Process(byte_array, data_size, buffer_p, data_size);
T encrypted_data = ConvertCipherText<T>(buffer_p, data_size, byte_array);

#if 0
D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast<const_data_ptr_t>(name.GetData()), state) == 1);
#endif

// this does not do anything for CTR and therefore can be skipped
encryption_state->Finalize(buffer_p, 0, nullptr, 0);
return encrypted_data;
return EncryptValue<T>(encryption_state.get(), result, input, buffer_p);;
});
}

Expand Down Expand Up @@ -318,25 +245,8 @@ void ExecuteDecryptExecutor(Vector &vector, Vector &result, idx_t size, Expressi
iv[12] = iv[13] = iv[14] = iv[15] = 0x00;

UnaryExecutor::Execute<T, T>(vector, result, size, [&](T input) -> T {
unsigned char byte_array[sizeof(T)];
auto data_size = GetSizeOfInput(input);
encryption_state->InitializeDecryption(iv, 16, &key_t);

// Convert input to byte array for processing
memcpy(byte_array, &input, sizeof(T));

// Encrypt data
encryption_state->Process(byte_array, data_size, buffer_p, data_size);
T decrypted_data = ConvertCipherText<T>(buffer_p, data_size, byte_array);

#if 0
D_ASSERT(CheckEncryption(printable_encrypted_data, buffer_p, size, reinterpret_cast<const_data_ptr_t>(name.GetData()), state) == 1);
#endif


// this does not do anything for CTR and therefore can be skipped
encryption_state->Finalize(buffer_p, 0, nullptr, 0);
return decrypted_data;
return DecryptValue<T>(encryption_state.get(), result, input, buffer_p);;
});
}

Expand Down Expand Up @@ -407,6 +317,7 @@ static void DecryptData(DataChunk &args, ExpressionState &state, Vector &result)
ExecuteDecrypt(value_vector, result, args.size(), state, key_t);
}


ScalarFunctionSet GetEncryptionFunction() {
ScalarFunctionSet set("encrypt");

Expand All @@ -419,12 +330,9 @@ ScalarFunctionSet GetEncryptionFunction() {
set.AddFunction(ScalarFunction({LogicalTypeId::BIGINT, LogicalType::VARCHAR}, LogicalTypeId::BIGINT, EncryptData,
EncryptFunctionData::EncryptBind));

set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BLOB, EncryptData,
set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, EncryptData,
EncryptFunctionData::EncryptBind));

// set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, EncryptData,
// EncryptFunctionData::EncryptBind));

return set;
}

Expand All @@ -441,9 +349,6 @@ ScalarFunctionSet GetDecryptionFunction() {
set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, DecryptData,
EncryptFunctionData::EncryptBind));

// set.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BLOB, DecryptData,
// EncryptFunctionData::EncryptBind));

return set;
}

Expand Down
63 changes: 43 additions & 20 deletions test/sql/encrypt_decrypt.test
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,46 @@
# group: [simple_encryption]

# Before we load the extension, this will fail
statement error
SELECT simple_encryption('Test');
----
Catalog Error: Scalar Function with name simple_encryption does not exist!

# Require statement will ensure this test is run with this extension loaded
require simple_encryption

# Confirm the extension works
query I
SELECT encrypt('testtest');
----
\x8A+\xBD\x00\xE4]\xA6L


# Confirm the extension works
query I
SELECT decrypt('\x8A+\xBD\x00\xE4]\xA6L');
----
test
statement ok
CREATE TABLE rd_data_2 AS
SELECT
1 AS testint
FROM
range(10);

statement ok
ALTER TABLE rd_data_2
ADD COLUMN encrypted_value INTEGER
ADD COLUMN decrypted_value INTEGER;


statement ok
UPDATE rd_data_2
SET encrypted_value = encrypt(testint, '0123456789112345');

statement ok
SET decrypted_value = decrypt(encrypted_value, '0123456789112345');


statement ok
CREATE TABLE rd_data AS
SELECT
SUBSTRING(MD5(RANDOM()::TEXT), 1, 5) AS rd_values
FROM
range(10);

statement ok
ALTER TABLE rd_data
ADD COLUMN encrypted_value VARCHAR;

statement ok
ALTER TABLE rd_data
ADD COLUMN decrypted_value VARCHAR;

statement ok
UPDATE rd_data
SET encrypted_value = encrypt(rd_values, '0123456789112345');

statement ok
UPDATE rd_data
SET decrypted_value = decrypt(encrypted_value, '0123456789112345');
12 changes: 9 additions & 3 deletions test/sql/simple_encryption.test
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,22 @@ SELECT decrypt(4095259532786215143, '0123456789112345');
query I
SELECT encrypt('testtest', '0123456789112345');
----
\xF6N\xCEt\xE4]\xA6L
iiu9AORdpkw=

#VARCHAR
query I
SELECT decrypt('\xF6N\xCEt\xE4]\xA6L', '0123456789112345');
SELECT decrypt('iiu9AORdpkw=', '0123456789112345');
----
testtest

#VARCHAR
query I
SELECT encrypt('test', '0123456789112345');
----
\xFAN\xCEt
iiu9AA==

#VARCHAR
query I
SELECT decrypt('iiu9AA==', '0123456789112345');
----
test

0 comments on commit a1870ab

Please sign in to comment.