Skip to content

Commit

Permalink
Merge pull request #4 from ccfelius/types
Browse files Browse the repository at this point in the history
Supporting varchar struct format
  • Loading branch information
ccfelius authored Nov 11, 2024
2 parents 5ef8bf3 + 75c5461 commit 9a5b021
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 24 deletions.
2 changes: 2 additions & 0 deletions src/core/functions/function_data/encrypt_function_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ bool EncryptFunctionData::Equals(const FunctionData &other_p) const {

unique_ptr<FunctionData> EncryptFunctionData::EncryptBind(ClientContext &context, ScalarFunction &bound_function,
vector<unique_ptr<Expression>> &arguments) {
// here, implement bound statements?

// do something
return make_uniq<EncryptFunctionData>(context);
}
Expand Down
206 changes: 182 additions & 24 deletions src/core/functions/scalar/encrypt_to_etype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,38 @@ int32_t EncryptValueInt32(EncryptionState *encryption_state, Vector &result, int
return encrypted_data;
}

string_t EncryptValueToString(shared_ptr<EncryptionState> encryption_state, Vector &result, string_t value, uint8_t *buffer_p) {

// first encrypt the bytes of the string into a temp buffer_p
auto input_data = data_ptr_t(value.GetData());
auto value_size = value.GetSize();
encryption_state->Process(input_data, value_size, buffer_p, value_size);

// Convert the encrypted data to Base64
auto encrypted_data = string_t(reinterpret_cast<const char*>(buffer_p), value_size);
size_t base64_size = Blob::ToBase64Size(encrypted_data);

// convert to Base64 into a newly allocated string in the result vector
string_t base64_data = StringVector::EmptyString(result, base64_size);
Blob::ToBase64(encrypted_data, base64_data.GetDataWriteable());

return base64_data;
}

string_t DecryptValueToString(shared_ptr<EncryptionState> encryption_state, Vector &result, string_t base64_data, uint8_t *buffer_p) {

// first encrypt the bytes of the string into a temp buffer_p
size_t encrypted_size = Blob::FromBase64Size(base64_data);
size_t decrypted_size = encrypted_size;
Blob::FromBase64(base64_data, reinterpret_cast<data_ptr_t>(buffer_p), encrypted_size);
D_ASSERT(encrypted_size <= base64_data.GetSize());

string_t decrypted_data = StringVector::EmptyString(result, decrypted_size);
encryption_state->Process(buffer_p, encrypted_size, reinterpret_cast<unsigned char*>(decrypted_data.GetDataWriteable()), decrypted_size);

return decrypted_data;
}

static void EncryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vector &result) {

auto &func_expr = (BoundFunctionExpression &)state.expr;
Expand Down Expand Up @@ -62,11 +94,6 @@ static void EncryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect
// reset the reference of the result vector
result.ReferenceAndSetType(struct_vector);

// // Get the child vectors from the struct vector
// auto &children = StructVector::GetEntries(result);
// children[0] = make_uniq<Vector>(LogicalType::INTEGER);
// children[1] = make_uniq<Vector>(LogicalType::INTEGER);

// TODO: put this in the state of the extension
uint8_t encryption_buffer[MAX_BUFFER_SIZE];
uint8_t *buffer_p = encryption_buffer;
Expand All @@ -75,6 +102,7 @@ static void EncryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect
auto encryption_state = simple_encryption_state->encryption_state;

// this can be an int64_t, we have 12 bytes available...
// this needs to be in the state btw, because it needs to keep increasing PER vector
int32_t nonce_count = 0;

// TODO: construct nonce based on immutable ROW_ID + hash(col_name)
Expand Down Expand Up @@ -102,7 +130,7 @@ static void EncryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect
});
}

static void DecryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vector &result) {
static void EncryptDataChunkStructString(DataChunk &args, ExpressionState &state, Vector &result) {

auto &func_expr = (BoundFunctionExpression &)state.expr;
auto &info = (EncryptFunctionData &)*func_expr.bind_info;
Expand All @@ -113,13 +141,65 @@ static void DecryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect
"simple_encryption");

auto &input_vector = args.data[0];
// D_assert to check whether it is a struct?
// D_ASSERT(input_vector.GetType() == LogicalType::STRUCT);
// get children of input data
auto &children = StructVector::GetEntries(input_vector);
auto &nonce_vector = children[0];
auto &value_vector = children[1];
auto &key_vector = args.data[1];

D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);

// Fetch the encryption key as a constant string
const string key_t =
ConstantVector::GetData<string_t>(key_vector)[0].GetString();

// create struct_type
LogicalType result_struct = LogicalType::STRUCT(
{{"nonce", LogicalType::INTEGER}, {"value", LogicalType::VARCHAR}});

Vector struct_vector(result_struct, args.size());
// reset the reference of the result vector
result.ReferenceAndSetType(struct_vector);

// TODO: put this in the state of the extension
uint8_t encryption_buffer[MAX_BUFFER_SIZE];
uint8_t *buffer_p = encryption_buffer;

unsigned char iv[16];
auto encryption_state = simple_encryption_state->encryption_state;

// this can be an int64_t, we have 12 bytes available...
// this needs to be in the state btw, because it needs to keep increasing PER vector
int32_t nonce_count = 0;

// TODO: construct nonce based on immutable ROW_ID + hash(col_name)
memcpy(iv, "12345678901", 12);
iv[12] = iv[13] = iv[14] = iv[15] = 0x00;

encryption_state->InitializeEncryption(iv, 16, &key_t);
using ENCRYPTED_TYPE = StructTypeBinary<int32_t, string_t>;
using PLAINTEXT_TYPE = PrimitiveType<string_t>;

GenericExecutor::ExecuteUnary<PLAINTEXT_TYPE, ENCRYPTED_TYPE>(input_vector, result, args.size(), [&](PLAINTEXT_TYPE input) {

// set the nonce
nonce_count++;
memcpy(iv, &nonce_count, sizeof(int32_t));

// Encrypt data
string_t encrypted_data = EncryptValueToString(encryption_state, result, input.val, buffer_p);

return ENCRYPTED_TYPE {nonce_count, encrypted_data};
});
}

static void DecryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vector &result) {

auto &func_expr = (BoundFunctionExpression &)state.expr;
auto &info = (EncryptFunctionData &)*func_expr.bind_info;

// refactor this into GetSimpleEncryptionState(info.context);
auto simple_encryption_state =
info.context.registered_state->Get<SimpleEncryptionState>(
"simple_encryption");

auto &input_vector = args.data[0];
auto &key_vector = args.data[1];
D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);

Expand All @@ -144,26 +224,94 @@ static void DecryptDataChunkStruct(DataChunk &args, ExpressionState &state, Vect
// Decrypt data
int32_t decrypted_data;

// also, template this function
BinaryExecutor::Execute<int32_t, int32_t, int32_t>(
reinterpret_cast<Vector &>(nonce_vector),
reinterpret_cast<Vector &>(value_vector), result, args.size(),
[&](int32_t nonce, int32_t input) {
// Set the nonce
memcpy(iv, &nonce, sizeof(int32_t));
using ENCRYPTED_TYPE = StructTypeBinary<int32_t, int32_t>;
using PLAINTEXT_TYPE = PrimitiveType<int32_t>;

GenericExecutor::ExecuteUnary<ENCRYPTED_TYPE, PLAINTEXT_TYPE>(
input_vector, result, args.size(), [&](ENCRYPTED_TYPE input) {
auto nonce = input.a_val;
auto value = input.b_val;

// Set the nonce
memcpy(iv, &nonce, sizeof(int32_t));

encryption_state->Process(
reinterpret_cast<unsigned char *>(&value), sizeof(int32_t),
reinterpret_cast<unsigned char *>(&decrypted_data),
sizeof(int32_t));

return decrypted_data;
});
}

static void DecryptDataChunkStructString(DataChunk &args, ExpressionState &state, Vector &result) {

auto &func_expr = (BoundFunctionExpression &)state.expr;
auto &info = (EncryptFunctionData &)*func_expr.bind_info;

// refactor this into GetSimpleEncryptionState(info.context);
auto simple_encryption_state =
info.context.registered_state->Get<SimpleEncryptionState>(
"simple_encryption");

auto &input_vector = args.data[0];
auto &key_vector = args.data[1];
D_ASSERT(key_vector.GetVectorType() == VectorType::CONSTANT_VECTOR);

// Fetch the encryption key as a constant string
const string key_t =
ConstantVector::GetData<string_t>(key_vector)[0].GetString();

// maybe convert vector to unified format? Like they do in other scalar fucntions

// TODO: put this in the state of the extension
uint8_t encryption_buffer[MAX_BUFFER_SIZE];
uint8_t *buffer_p = encryption_buffer;

unsigned char iv[16];
auto encryption_state = simple_encryption_state->encryption_state;

// TODO: construct nonce based on immutable ROW_ID + hash(col_name)
memcpy(iv, "12345678901", 12);
iv[12] = iv[13] = iv[14] = iv[15] = 0x00;

encryption_state->InitializeDecryption(iv, 16, &key_t);
// Decrypt data
encryption_state->Process(reinterpret_cast<unsigned char*>(&input), sizeof(int32_t), reinterpret_cast<unsigned char*>(&decrypted_data), sizeof(int32_t));
int32_t decrypted_data;

return decrypted_data;
});
using ENCRYPTED_TYPE = StructTypeBinary<int32_t, string_t>;
using PLAINTEXT_TYPE = PrimitiveType<string_t>;

GenericExecutor::ExecuteUnary<ENCRYPTED_TYPE, PLAINTEXT_TYPE>(
input_vector, result, args.size(), [&](ENCRYPTED_TYPE input) {
auto nonce = input.a_val;
auto value = input.b_val;

// Set the nonce
memcpy(iv, &nonce, sizeof(int32_t));

// Decrypt data
string_t decrypted_data =
DecryptValueToString(encryption_state, result, value, buffer_p);
return decrypted_data;
});
}


ScalarFunctionSet GetEncryptionStructFunction() {
ScalarFunctionSet set("encrypt_etypes");

set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, EncryptionTypes::E_INT(), EncryptDataChunkStruct,
// set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, EncryptionTypes::E_INT(), EncryptDataChunkStruct,
// EncryptFunctionData::EncryptBind));

// Function to Encrypt INTEGERS
set.AddFunction(ScalarFunction({LogicalTypeId::INTEGER, LogicalType::VARCHAR}, LogicalType::STRUCT(
{{"nonce", LogicalType::INTEGER}, {"value", LogicalType::INTEGER}}), EncryptDataChunkStruct,
EncryptFunctionData::EncryptBind));

// Function to encrypt VARCHAR
set.AddFunction(ScalarFunction({LogicalTypeId::VARCHAR, LogicalType::VARCHAR}, LogicalType::STRUCT(
{{"nonce", LogicalType::INTEGER}, {"value", LogicalType::VARCHAR}}), EncryptDataChunkStructString,
EncryptFunctionData::EncryptBind));

return set;
Expand All @@ -172,9 +320,19 @@ ScalarFunctionSet GetEncryptionStructFunction() {
ScalarFunctionSet GetDecryptionStructFunction() {
ScalarFunctionSet set("decrypt_etypes");

set.AddFunction(ScalarFunction({EncryptionTypes::E_INT(), LogicalType::VARCHAR}, LogicalTypeId::INTEGER, DecryptDataChunkStruct,
// Why is E_INT not working?
// set.AddFunction(ScalarFunction({EncryptionTypes::E_INT(), LogicalType::VARCHAR}, LogicalTypeId::INTEGER, DecryptDataChunkStruct,
// EncryptFunctionData::EncryptBind));

// try with input struct?
set.AddFunction(ScalarFunction({LogicalType::STRUCT(
{{"nonce", LogicalType::INTEGER}, {"value", LogicalType::INTEGER}}), LogicalType::VARCHAR}, LogicalTypeId::INTEGER, DecryptDataChunkStruct,
EncryptFunctionData::EncryptBind));

set.AddFunction(ScalarFunction({LogicalType::STRUCT(
{{"nonce", LogicalType::INTEGER}, {"value", LogicalType::VARCHAR}}), LogicalType::VARCHAR}, LogicalTypeId::VARCHAR, DecryptDataChunkStructString,
EncryptFunctionData::EncryptBind));

return set;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ struct EncryptFunctionData: FunctionData {

// Save the ClientContext
ClientContext &context;
// BoundStatement relation;

EncryptFunctionData(ClientContext &context)
: context(context) {}
Expand Down

0 comments on commit 9a5b021

Please sign in to comment.