From 8c4631e36ae1e6724a7f97860c84436e1f134a43 Mon Sep 17 00:00:00 2001 From: Axel Benjaminsson Date: Sat, 9 Dec 2023 17:38:10 +0100 Subject: [PATCH] Fix eliptic curve valid point detection --- include/curve25519.hpp | 39 ++++++++++- src/curve25519.cpp | 148 ++++++++++++++++++++++++++++++++--------- src/pubkey.cpp | 19 ++++-- 3 files changed, 169 insertions(+), 37 deletions(-) diff --git a/include/curve25519.hpp b/include/curve25519.hpp index a57ad7f9..85c511a9 100644 --- a/include/curve25519.hpp +++ b/include/curve25519.hpp @@ -3,20 +3,57 @@ #include +/* +class uint128_t{ +private: + uint64_t data[2]; +public: + uint128_t(){ + data[0] = 0; + data[1] = 0; + } + uint128_t(uint64_t a, uint64_t b){ + data[0] = a; + data[1] = b; + } + + uint128_t operator +(const uint128_t& other){ + uint128_t res; + res.data[0] = other.data[0] + this->data[0]; + res.data[1] = other.data[1] + this->data[1]; + + // check overflow. + if((res.data[0] < other.data[0]) || (res.data[0] < this->data[0])){ + res.data[1] += 1; + } + } + + static uint128_t m(uint64_t a, uint64_t b){ + return uint128_t(); + } + + uint128_t operator =(const uint128_t& other){ + data[0] = other.data[0]; + data[1] = other.data[1]; + } +};*/ + class FieldElement{ private: uint64_t nums[5]; - FieldElement pow2k(uint32_t m); + FieldElement& reduce(); static uint64_t load8(const uint8_t *data); public: +FieldElement pow2k(uint32_t m); static const FieldElement ONE; static const FieldElement EDWARDS_D; FieldElement(); FieldElement(const uint64_t from[5]); FieldElement(const uint8_t *bytes); + FieldElement(const FieldElement& other); void conditional_assign(const FieldElement& other, bool condition); void conditional_negate(bool condition); void pow22501(FieldElement &t3, FieldElement &t19) const; diff --git a/src/curve25519.cpp b/src/curve25519.cpp index a0459fbd..9b65aadb 100644 --- a/src/curve25519.cpp +++ b/src/curve25519.cpp @@ -41,8 +41,8 @@ FieldElement FieldElement::pow2k(uint32_t k){ } while(k > 0){ - uint64_t a3_19 = 19 * a[3]; - uint64_t a4_19 = 19 * a[4]; + uint64_t a3_19 = a[3] * 19; + uint64_t a4_19 = a[4] * 19; __uint128_t c0 = m(a[0], a[0]) + 2*( m(a[1], a4_19) + m(a[2], a3_19) ); __uint128_t c1 = m(a[3], a3_19) + 2*( m(a[0], a[1]) + m(a[2], a4_19) ); @@ -132,20 +132,98 @@ FieldElement::FieldElement(const uint64_t from[5]){ } } +FieldElement::FieldElement(const FieldElement& other){ + for(int i = 0; i < 5; i++){ + nums[i] = other.nums[i]; + } +} + FieldElement::FieldElement(const uint8_t *bytes){ - const uint64_t low_51_bit_mask = (uint64_t(1) << 51) - 1; - - const uint64_t new_nums[5] = { - load8(bytes) & low_51_bit_mask, - // load bits [ 48,112), shift to [ 51,112) - (load8(bytes + 6) >> 3) & low_51_bit_mask, - // load bits [ 96,160), shift to [102,160) - (load8(bytes + 12) >> 6) & low_51_bit_mask, - // load bits [152,216), shift to [153,216) - (load8(bytes + 19) >> 1) & low_51_bit_mask, - // load bits [192,256), shift to [204,112) - (load8(bytes + 24) >> 12) & low_51_bit_mask - }; + uint8_t temp[32]; + for(unsigned int i = 0; i < 32; i++){ + temp[i] = bytes[i]; + } + temp[31] &= 127; + + uint64_t x1 = (((uint64_t)temp[31]) << 44); + uint64_t x2 = (((uint64_t)temp[30]) << 36); + uint64_t x3 = (((uint64_t)temp[29]) << 28); + uint64_t x4 = (((uint64_t)temp[28]) << 20); + uint64_t x5 = (((uint64_t)temp[27]) << 12); + uint64_t x6 = (((uint64_t)temp[26]) << 4); + uint64_t x7 = (((uint64_t)temp[25]) << 47); + uint64_t x8 = (((uint64_t)temp[24]) << 39); + uint64_t x9 = (((uint64_t)temp[23]) << 31); + uint64_t x10 = (((uint64_t)temp[22]) << 23); + uint64_t x11 = (((uint64_t)temp[21]) << 15); + uint64_t x12 = (((uint64_t)temp[20]) << 7); + uint64_t x13 = (((uint64_t)temp[19]) << 50); + uint64_t x14 = (((uint64_t)temp[18]) << 42); + uint64_t x15 = (((uint64_t)temp[17]) << 34); + uint64_t x16 = (((uint64_t)temp[16]) << 26); + uint64_t x17 = (((uint64_t)temp[15]) << 18); + uint64_t x18 = (((uint64_t)temp[14]) << 10); + uint64_t x19 = (((uint64_t)temp[13]) << 2); + uint64_t x20 = (((uint64_t)temp[12]) << 45); + uint64_t x21 = (((uint64_t)temp[11]) << 37); + uint64_t x22 = (((uint64_t)temp[10]) << 29); + uint64_t x23 = (((uint64_t)temp[9]) << 21); + uint64_t x24 = (((uint64_t)temp[8]) << 13); + uint64_t x25 = (((uint64_t)temp[7]) << 5); + uint64_t x26 = (((uint64_t)temp[6]) << 48); + uint64_t x27 = (((uint64_t)temp[5]) << 40); + uint64_t x28 = (((uint64_t)temp[4]) << 32); + uint64_t x29 = (((uint64_t)temp[3]) << 24); + uint64_t x30 = (((uint64_t)temp[2]) << 16); + uint64_t x31 = (((uint64_t)temp[1]) << 8); + uint64_t x32 = (temp[0]); + uint64_t x33 = (x31 + ((uint64_t)x32)); + uint64_t x34 = (x30 + x33); + uint64_t x35 = (x29 + x34); + uint64_t x36 = (x28 + x35); + uint64_t x37 = (x27 + x36); + uint64_t x38 = (x26 + x37); + uint64_t x39 = (x38 & 0x7ffffffffffff); + uint64_t x40 = (uint8_t)(x38 >> 51); + uint64_t x41 = (x25 + ((uint64_t)x40)); + uint64_t x42 = (x24 + x41); + uint64_t x43 = (x23 + x42); + uint64_t x44 = (x22 + x43); + uint64_t x45 = (x21 + x44); + uint64_t x46 = (x20 + x45); + uint64_t x47 = (x46 & 0x7ffffffffffff); + uint64_t x48 = (uint8_t)(x46 >> 51); + uint64_t x49 = (x19 + ((uint64_t)x48)); + uint64_t x50 = (x18 + x49); + uint64_t x51 = (x17 + x50); + uint64_t x52 = (x16 + x51); + uint64_t x53 = (x15 + x52); + uint64_t x54 = (x14 + x53); + uint64_t x55 = (x13 + x54); + uint64_t x56 = (x55 & 0x7ffffffffffff); + uint64_t x57 = (uint8_t)(x55 >> 51); + uint64_t x58 = (x12 + ((uint64_t)x57)); + uint64_t x59 = (x11 + x58); + uint64_t x60 = (x10 + x59); + uint64_t x61 = (x9 + x60); + uint64_t x62 = (x8 + x61); + uint64_t x63 = (x7 + x62); + uint64_t x64 = (x63 & 0x7ffffffffffff); + uint64_t x65 = (uint8_t)(x63 >> 51); + uint64_t x66 = (x6 + ((uint64_t)x65)); + uint64_t x67 = (x5 + x66); + uint64_t x68 = (x4 + x67); + uint64_t x69 = (x3 + x68); + uint64_t x70 = (x2 + x69); + uint64_t x71 = (x1 + x70); + + uint64_t new_nums[5]; + + new_nums[0] = x39; + new_nums[1] = x47; + new_nums[2] = x56; + new_nums[3] = x64; + new_nums[4] = x71; for(int i = 0; i < 5; i++){ this->nums[i] = new_nums[i]; @@ -170,15 +248,18 @@ bool FieldElement::operator==(const FieldElement &other) const{ FieldElement FieldElement::square() const{ FieldElement ret; - ret = *this; - ret.pow2k(1); - return ret; + ret = (*this); + return ret.pow2k(1); } FieldElement FieldElement::operator*(const FieldElement &other) const{ - const uint64_t *b = other.nums; - const uint64_t *a = this->nums; + uint64_t b[5]; + uint64_t a[5]; + for(unsigned int i = 0; i < 5; i++){ + b[i] = other.nums[i]; + a[i] = this->nums[i]; + } const uint64_t b1_19 = b[1] * 19; const uint64_t b2_19 = b[2] * 19; @@ -186,14 +267,17 @@ FieldElement FieldElement::operator*(const FieldElement &other) const{ const uint64_t b4_19 = b[4] * 19; // Multiply to get 128-bit coefficients of output - const __uint128_t c0 = m(a[0], b[0]) + m(a[4], b1_19) + m(a[3], b2_19) + m(a[2], b3_19) + m(a[1], b4_19); + __uint128_t c0 = m(a[0], b[0]) + m(a[4], b1_19) + m(a[3], b2_19) + m(a[2], b3_19) + m(a[1], b4_19); __uint128_t c1 = m(a[1], b[0]) + m(a[0], b[1]) + m(a[4], b2_19) + m(a[3], b3_19) + m(a[2], b4_19); __uint128_t c2 = m(a[2], b[0]) + m(a[1], b[1]) + m(a[0], b[2]) + m(a[4], b3_19) + m(a[3], b4_19); __uint128_t c3 = m(a[3], b[0]) + m(a[2], b[1]) + m(a[1], b[2]) + m(a[0], b[3]) + m(a[4], b4_19); __uint128_t c4 = m(a[4], b[0]) + m(a[3], b[1]) + m(a[2], b[2]) + m(a[1], b[3]) + m(a[0] , b[4]); const uint64_t LOW_51_BIT_MASK = (((uint64_t) 1) << 51) - 1; - uint64_t out[5] = {0}; + uint64_t out[5]; + for(unsigned int i = 0; i < 5; i++){ + out[0] = 0; + } c1 += (__uint128_t)((uint64_t)(c0 >> 51)); out[0] = ((uint64_t)c0) & LOW_51_BIT_MASK; @@ -208,6 +292,7 @@ FieldElement FieldElement::operator*(const FieldElement &other) const{ out[3] = ((uint64_t)c3) & LOW_51_BIT_MASK; uint64_t carry = (uint64_t)(c4 >> 51); + out[4] = ((uint64_t)c4) & LOW_51_BIT_MASK; out[0] += carry * 19; @@ -293,15 +378,16 @@ void FieldElement::pow22501(FieldElement &t3, FieldElement &t19) const{ FieldElement FieldElement::pow_p58() const{ FieldElement t19; FieldElement dummy; - pow22501(t19, dummy); + pow22501(dummy, t19); + FieldElement t20 = t19.pow2k(2); - FieldElement t21 = *this * t20; + FieldElement t21 = (*this) * t20; return t21; } FieldElement sqrt_ratio_i(const FieldElement &u, const FieldElement &v, bool &was_nonzero_square){ - FieldElement v3 = u.square() * v; + FieldElement v3 = v.square() * v; FieldElement v7 = v3.square() * v; FieldElement r = (u * v3) * (u * v7).pow_p58(); @@ -309,16 +395,16 @@ FieldElement sqrt_ratio_i(const FieldElement &u, const FieldElement &v, bool &wa FieldElement check = v * r.square(); FieldElement i = SQRT_M1; - bool correct_sign_sqrt = check == u; - bool flipped_sign_sqrt = check == (-u); - bool flipped_sign_sqrt_i = check == ((-u) * i); + bool correct_sign_sqrt = (check == u); + bool flipped_sign_sqrt = (check == (-u)); + bool flipped_sign_sqrt_i = (check == ((-u) * i)); FieldElement r_prime = SQRT_M1 * r; - r.conditional_assign(r_prime, flipped_sign_sqrt | flipped_sign_sqrt_i); + r.conditional_assign(r_prime, flipped_sign_sqrt || flipped_sign_sqrt_i); bool r_is_negative = r.is_negative(); r.conditional_negate(r_is_negative); - was_nonzero_square = correct_sign_sqrt | flipped_sign_sqrt; + was_nonzero_square = (correct_sign_sqrt || flipped_sign_sqrt); return r; } @@ -339,7 +425,7 @@ bool decompress_step_1(const uint8_t *repr, FieldElement &x, FieldElement &y, Fi bool is_y_point_valid(const uint8_t *repr){ FieldElement y(repr); - + FieldElement YY = y.square(); FieldElement u = YY - FieldElement::ONE; diff --git a/src/pubkey.cpp b/src/pubkey.cpp index 8a9d10bc..6fd826ca 100644 --- a/src/pubkey.cpp +++ b/src/pubkey.cpp @@ -254,17 +254,27 @@ Variant Pubkey::new_associated_token_address(const Variant &wallet_address, cons TypedArray arr; arr.append(Pubkey(wallet_address).get_bytes()); - arr.append(Pubkey(token_mint_address).get_bytes()); arr.append(Object::cast_to(TokenProgram::get_pid())->get_bytes()); + arr.append(Pubkey(token_mint_address).get_bytes()); + + arr.append(PackedByteArray()); String pid = String(SolanaSDK::SPL_ASSOCIATED_TOKEN_ADDRESS.c_str()); Variant pid_key = Pubkey::new_from_string(pid); Pubkey *res = memnew(Pubkey); - res->create_program_address_bytes(arr, pid_key); - - return res; + for(uint8_t i = 255; i > 0; i--){ + PackedByteArray bump_seed; + bump_seed.push_back(i); + arr[3] = bump_seed; + if(res->create_program_address_bytes(arr, pid_key)){ + return res; + } + } + + internal::gdextension_interface_print_warning("y points were not valid", "new_associated_token_address", __FILE__, __LINE__, false); + return nullptr; } @@ -302,7 +312,6 @@ bool Pubkey::create_program_address_bytes(const Array seeds, const Variant &prog delete[] hash_ptr; if(is_y_point_valid(hash)){ - internal::gdextension_interface_print_warning("y point is not valid", "create_program_address", __FILE__, __LINE__, false); return false; }