Skip to content

Commit

Permalink
Fix eliptic curve valid point detection
Browse files Browse the repository at this point in the history
  • Loading branch information
Virus-Axel committed Dec 9, 2023
1 parent 504c51c commit 8c4631e
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 37 deletions.
39 changes: 38 additions & 1 deletion include/curve25519.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,57 @@

#include <godot_cpp/classes/node.hpp>

/*
class uint128_t{
private:
uint64_t data[2];
public:
uint128_t(){
data[0] = 0;
data[1] = 0;
}
uint128_t(uint64_t a, uint64_t b){
data[0] = a;
data[1] = b;
}
uint128_t operator +(const uint128_t& other){
uint128_t res;
res.data[0] = other.data[0] + this->data[0];
res.data[1] = other.data[1] + this->data[1];
// check overflow.
if((res.data[0] < other.data[0]) || (res.data[0] < this->data[0])){
res.data[1] += 1;
}
}
static uint128_t m(uint64_t a, uint64_t b){
return uint128_t();
}
uint128_t operator =(const uint128_t& other){
data[0] = other.data[0];
data[1] = other.data[1];
}
};*/

class FieldElement{
private:
uint64_t nums[5];

FieldElement pow2k(uint32_t m);

FieldElement& reduce();
static uint64_t load8(const uint8_t *data);
public:
FieldElement pow2k(uint32_t m);
static const FieldElement ONE;
static const FieldElement EDWARDS_D;

FieldElement();
FieldElement(const uint64_t from[5]);
FieldElement(const uint8_t *bytes);
FieldElement(const FieldElement& other);
void conditional_assign(const FieldElement& other, bool condition);
void conditional_negate(bool condition);
void pow22501(FieldElement &t3, FieldElement &t19) const;
Expand Down
148 changes: 117 additions & 31 deletions src/curve25519.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ FieldElement FieldElement::pow2k(uint32_t k){
}

while(k > 0){
uint64_t a3_19 = 19 * a[3];
uint64_t a4_19 = 19 * a[4];
uint64_t a3_19 = a[3] * 19;
uint64_t a4_19 = a[4] * 19;

__uint128_t c0 = m(a[0], a[0]) + 2*( m(a[1], a4_19) + m(a[2], a3_19) );
__uint128_t c1 = m(a[3], a3_19) + 2*( m(a[0], a[1]) + m(a[2], a4_19) );
Expand Down Expand Up @@ -132,20 +132,98 @@ FieldElement::FieldElement(const uint64_t from[5]){
}
}

FieldElement::FieldElement(const FieldElement& other){
for(int i = 0; i < 5; i++){
nums[i] = other.nums[i];
}
}

FieldElement::FieldElement(const uint8_t *bytes){
const uint64_t low_51_bit_mask = (uint64_t(1) << 51) - 1;

const uint64_t new_nums[5] = {
load8(bytes) & low_51_bit_mask,
// load bits [ 48,112), shift to [ 51,112)
(load8(bytes + 6) >> 3) & low_51_bit_mask,
// load bits [ 96,160), shift to [102,160)
(load8(bytes + 12) >> 6) & low_51_bit_mask,
// load bits [152,216), shift to [153,216)
(load8(bytes + 19) >> 1) & low_51_bit_mask,
// load bits [192,256), shift to [204,112)
(load8(bytes + 24) >> 12) & low_51_bit_mask
};
uint8_t temp[32];
for(unsigned int i = 0; i < 32; i++){
temp[i] = bytes[i];
}
temp[31] &= 127;

uint64_t x1 = (((uint64_t)temp[31]) << 44);
uint64_t x2 = (((uint64_t)temp[30]) << 36);
uint64_t x3 = (((uint64_t)temp[29]) << 28);
uint64_t x4 = (((uint64_t)temp[28]) << 20);
uint64_t x5 = (((uint64_t)temp[27]) << 12);
uint64_t x6 = (((uint64_t)temp[26]) << 4);
uint64_t x7 = (((uint64_t)temp[25]) << 47);
uint64_t x8 = (((uint64_t)temp[24]) << 39);
uint64_t x9 = (((uint64_t)temp[23]) << 31);
uint64_t x10 = (((uint64_t)temp[22]) << 23);
uint64_t x11 = (((uint64_t)temp[21]) << 15);
uint64_t x12 = (((uint64_t)temp[20]) << 7);
uint64_t x13 = (((uint64_t)temp[19]) << 50);
uint64_t x14 = (((uint64_t)temp[18]) << 42);
uint64_t x15 = (((uint64_t)temp[17]) << 34);
uint64_t x16 = (((uint64_t)temp[16]) << 26);
uint64_t x17 = (((uint64_t)temp[15]) << 18);
uint64_t x18 = (((uint64_t)temp[14]) << 10);
uint64_t x19 = (((uint64_t)temp[13]) << 2);
uint64_t x20 = (((uint64_t)temp[12]) << 45);
uint64_t x21 = (((uint64_t)temp[11]) << 37);
uint64_t x22 = (((uint64_t)temp[10]) << 29);
uint64_t x23 = (((uint64_t)temp[9]) << 21);
uint64_t x24 = (((uint64_t)temp[8]) << 13);
uint64_t x25 = (((uint64_t)temp[7]) << 5);
uint64_t x26 = (((uint64_t)temp[6]) << 48);
uint64_t x27 = (((uint64_t)temp[5]) << 40);
uint64_t x28 = (((uint64_t)temp[4]) << 32);
uint64_t x29 = (((uint64_t)temp[3]) << 24);
uint64_t x30 = (((uint64_t)temp[2]) << 16);
uint64_t x31 = (((uint64_t)temp[1]) << 8);
uint64_t x32 = (temp[0]);
uint64_t x33 = (x31 + ((uint64_t)x32));
uint64_t x34 = (x30 + x33);
uint64_t x35 = (x29 + x34);
uint64_t x36 = (x28 + x35);
uint64_t x37 = (x27 + x36);
uint64_t x38 = (x26 + x37);
uint64_t x39 = (x38 & 0x7ffffffffffff);
uint64_t x40 = (uint8_t)(x38 >> 51);
uint64_t x41 = (x25 + ((uint64_t)x40));
uint64_t x42 = (x24 + x41);
uint64_t x43 = (x23 + x42);
uint64_t x44 = (x22 + x43);
uint64_t x45 = (x21 + x44);
uint64_t x46 = (x20 + x45);
uint64_t x47 = (x46 & 0x7ffffffffffff);
uint64_t x48 = (uint8_t)(x46 >> 51);
uint64_t x49 = (x19 + ((uint64_t)x48));
uint64_t x50 = (x18 + x49);
uint64_t x51 = (x17 + x50);
uint64_t x52 = (x16 + x51);
uint64_t x53 = (x15 + x52);
uint64_t x54 = (x14 + x53);
uint64_t x55 = (x13 + x54);
uint64_t x56 = (x55 & 0x7ffffffffffff);
uint64_t x57 = (uint8_t)(x55 >> 51);
uint64_t x58 = (x12 + ((uint64_t)x57));
uint64_t x59 = (x11 + x58);
uint64_t x60 = (x10 + x59);
uint64_t x61 = (x9 + x60);
uint64_t x62 = (x8 + x61);
uint64_t x63 = (x7 + x62);
uint64_t x64 = (x63 & 0x7ffffffffffff);
uint64_t x65 = (uint8_t)(x63 >> 51);
uint64_t x66 = (x6 + ((uint64_t)x65));
uint64_t x67 = (x5 + x66);
uint64_t x68 = (x4 + x67);
uint64_t x69 = (x3 + x68);
uint64_t x70 = (x2 + x69);
uint64_t x71 = (x1 + x70);

uint64_t new_nums[5];

new_nums[0] = x39;
new_nums[1] = x47;
new_nums[2] = x56;
new_nums[3] = x64;
new_nums[4] = x71;

for(int i = 0; i < 5; i++){
this->nums[i] = new_nums[i];
Expand All @@ -170,30 +248,36 @@ bool FieldElement::operator==(const FieldElement &other) const{

FieldElement FieldElement::square() const{
FieldElement ret;
ret = *this;
ret.pow2k(1);
return ret;
ret = (*this);
return ret.pow2k(1);
}

FieldElement FieldElement::operator*(const FieldElement &other) const{

const uint64_t *b = other.nums;
const uint64_t *a = this->nums;
uint64_t b[5];
uint64_t a[5];
for(unsigned int i = 0; i < 5; i++){
b[i] = other.nums[i];
a[i] = this->nums[i];
}

const uint64_t b1_19 = b[1] * 19;
const uint64_t b2_19 = b[2] * 19;
const uint64_t b3_19 = b[3] * 19;
const uint64_t b4_19 = b[4] * 19;

// Multiply to get 128-bit coefficients of output
const __uint128_t c0 = m(a[0], b[0]) + m(a[4], b1_19) + m(a[3], b2_19) + m(a[2], b3_19) + m(a[1], b4_19);
__uint128_t c0 = m(a[0], b[0]) + m(a[4], b1_19) + m(a[3], b2_19) + m(a[2], b3_19) + m(a[1], b4_19);
__uint128_t c1 = m(a[1], b[0]) + m(a[0], b[1]) + m(a[4], b2_19) + m(a[3], b3_19) + m(a[2], b4_19);
__uint128_t c2 = m(a[2], b[0]) + m(a[1], b[1]) + m(a[0], b[2]) + m(a[4], b3_19) + m(a[3], b4_19);
__uint128_t c3 = m(a[3], b[0]) + m(a[2], b[1]) + m(a[1], b[2]) + m(a[0], b[3]) + m(a[4], b4_19);
__uint128_t c4 = m(a[4], b[0]) + m(a[3], b[1]) + m(a[2], b[2]) + m(a[1], b[3]) + m(a[0] , b[4]);

const uint64_t LOW_51_BIT_MASK = (((uint64_t) 1) << 51) - 1;
uint64_t out[5] = {0};
uint64_t out[5];
for(unsigned int i = 0; i < 5; i++){
out[0] = 0;
}

c1 += (__uint128_t)((uint64_t)(c0 >> 51));
out[0] = ((uint64_t)c0) & LOW_51_BIT_MASK;
Expand All @@ -208,6 +292,7 @@ FieldElement FieldElement::operator*(const FieldElement &other) const{
out[3] = ((uint64_t)c3) & LOW_51_BIT_MASK;

uint64_t carry = (uint64_t)(c4 >> 51);

out[4] = ((uint64_t)c4) & LOW_51_BIT_MASK;

out[0] += carry * 19;
Expand Down Expand Up @@ -293,32 +378,33 @@ void FieldElement::pow22501(FieldElement &t3, FieldElement &t19) const{
FieldElement FieldElement::pow_p58() const{
FieldElement t19;
FieldElement dummy;
pow22501(t19, dummy);
pow22501(dummy, t19);

FieldElement t20 = t19.pow2k(2);
FieldElement t21 = *this * t20;
FieldElement t21 = (*this) * t20;

return t21;
}

FieldElement sqrt_ratio_i(const FieldElement &u, const FieldElement &v, bool &was_nonzero_square){
FieldElement v3 = u.square() * v;
FieldElement v3 = v.square() * v;
FieldElement v7 = v3.square() * v;

FieldElement r = (u * v3) * (u * v7).pow_p58();

FieldElement check = v * r.square();
FieldElement i = SQRT_M1;

bool correct_sign_sqrt = check == u;
bool flipped_sign_sqrt = check == (-u);
bool flipped_sign_sqrt_i = check == ((-u) * i);
bool correct_sign_sqrt = (check == u);
bool flipped_sign_sqrt = (check == (-u));
bool flipped_sign_sqrt_i = (check == ((-u) * i));

FieldElement r_prime = SQRT_M1 * r;
r.conditional_assign(r_prime, flipped_sign_sqrt | flipped_sign_sqrt_i);
r.conditional_assign(r_prime, flipped_sign_sqrt || flipped_sign_sqrt_i);
bool r_is_negative = r.is_negative();
r.conditional_negate(r_is_negative);

was_nonzero_square = correct_sign_sqrt | flipped_sign_sqrt;
was_nonzero_square = (correct_sign_sqrt || flipped_sign_sqrt);

return r;
}
Expand All @@ -339,7 +425,7 @@ bool decompress_step_1(const uint8_t *repr, FieldElement &x, FieldElement &y, Fi

bool is_y_point_valid(const uint8_t *repr){
FieldElement y(repr);

FieldElement YY = y.square();
FieldElement u = YY - FieldElement::ONE;

Expand Down
19 changes: 14 additions & 5 deletions src/pubkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,17 +254,27 @@ Variant Pubkey::new_associated_token_address(const Variant &wallet_address, cons
TypedArray<PackedByteArray> arr;

arr.append(Pubkey(wallet_address).get_bytes());
arr.append(Pubkey(token_mint_address).get_bytes());
arr.append(Object::cast_to<Pubkey>(TokenProgram::get_pid())->get_bytes());
arr.append(Pubkey(token_mint_address).get_bytes());

arr.append(PackedByteArray());

String pid = String(SolanaSDK::SPL_ASSOCIATED_TOKEN_ADDRESS.c_str());

Variant pid_key = Pubkey::new_from_string(pid);

Pubkey *res = memnew(Pubkey);
res->create_program_address_bytes(arr, pid_key);

return res;
for(uint8_t i = 255; i > 0; i--){
PackedByteArray bump_seed;
bump_seed.push_back(i);
arr[3] = bump_seed;
if(res->create_program_address_bytes(arr, pid_key)){
return res;
}
}

internal::gdextension_interface_print_warning("y points were not valid", "new_associated_token_address", __FILE__, __LINE__, false);
return nullptr;
}


Expand Down Expand Up @@ -302,7 +312,6 @@ bool Pubkey::create_program_address_bytes(const Array seeds, const Variant &prog
delete[] hash_ptr;

if(is_y_point_valid(hash)){
internal::gdextension_interface_print_warning("y point is not valid", "create_program_address", __FILE__, __LINE__, false);
return false;
}

Expand Down

0 comments on commit 8c4631e

Please sign in to comment.