From 6ec85ea8628f6ad1c3731bad4fb20411814ee8a5 Mon Sep 17 00:00:00 2001 From: Derek Parker Date: Thu, 1 Feb 2024 15:11:38 -0800 Subject: [PATCH] Backport Marvin fix --- patches/001-initial-openssl-for-fips.patch | 62 +- patches/005-marvin.patch | 1750 ++++++++++++++++++++ 2 files changed, 1782 insertions(+), 30 deletions(-) create mode 100644 patches/005-marvin.patch diff --git a/patches/001-initial-openssl-for-fips.patch b/patches/001-initial-openssl-for-fips.patch index 91b1fa4716..0368359a6a 100644 --- a/patches/001-initial-openssl-for-fips.patch +++ b/patches/001-initial-openssl-for-fips.patch @@ -5545,7 +5545,8 @@ index 64c83c21c5..f48c57adff 100644 key := C._goboringcrypto_RSA_new() if key == nil { - return nil, fail("RSA_new") -- } ++ return nil, NewOpenSSLError("RSA_new failed") + } - if !bigToBn(&key.n, N) || - !bigToBn(&key.e, E) || - !bigToBn(&key.d, D) || @@ -5555,8 +5556,6 @@ index 64c83c21c5..f48c57adff 100644 - !bigToBn(&key.dmq1, Dq) || - !bigToBn(&key.iqmp, Qinv) { - return nil, fail("BN_bin2bn") -+ return nil, NewOpenSSLError("RSA_new failed") -+ } + var n, e, d, p, q, dp, dq, qinv *C.GO_BIGNUM + n = bigToBN(N) + e = bigToBN(E) @@ -5728,9 +5727,6 @@ index 64c83c21c5..f48c57adff 100644 return out[:outLen], nil } -- md := cryptoHashToMD(h) -- if md == nil { -- return nil, errors.New("crypto/rsa: unsupported hash function: " + strconv.Itoa(int(h))) + var out []byte + var outLen C.size_t + @@ -5738,45 +5734,44 @@ index 64c83c21c5..f48c57adff 100644 + return C._goboringcrypto_EVP_RSA_sign(md, base(msg), C.uint(len(msg)), base(out), &outLen, key) + }) == 0 { + return nil, NewOpenSSLError("RSA_sign") - } -- nid := C._goboringcrypto_EVP_MD_type(md) ++ } + return out[:outLen], nil +} + +func signRSAPKCS1v15Raw(priv *PrivateKeyRSA, msg []byte, md *C.GO_EVP_MD) ([]byte, error) { - var out []byte -- var outLen C.uint ++ var out []byte + var outLen C.size_t + PanicIfStrictFIPS("You must provide a raw unhashed message for PKCS1v15 signing and use HashSignPKCS1v15 instead of SignPKCS1v15") + - if priv.withKey(func(key *C.GO_RSA) C.int { - out = make([]byte, C._goboringcrypto_RSA_size(key)) -- return C._goboringcrypto_RSA_sign(nid, base(hashed), C.uint(len(hashed)), -- base(out), &outLen, key) ++ if priv.withKey(func(key *C.GO_RSA) C.int { ++ out = make([]byte, C._goboringcrypto_RSA_size(key)) + outLen = C.size_t(len(out)) + return C._goboringcrypto_EVP_sign_raw(md, nil, base(msg), + C.size_t(len(msg)), base(out), &outLen, key) - }) == 0 { -- return nil, fail("RSA_sign") ++ }) == 0 { + return nil, NewOpenSSLError("RSA_sign") - } ++ } + runtime.KeepAlive(priv) - return out[:outLen], nil - } - --func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error { -- if h == 0 { -- var out []byte -- var outLen C.size_t ++ return out[:outLen], nil ++} ++ +func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, msg, sig []byte, msgIsHashed bool) error { + if h == 0 && ExecutingTest() { + return verifyRSAPKCS1v15Raw(pub, msg, sig) + } + -+ md := cryptoHashToMD(h) -+ if md == nil { + md := cryptoHashToMD(h) + if md == nil { +- return nil, errors.New("crypto/rsa: unsupported hash function: " + strconv.Itoa(int(h))) + return errors.New("crypto/rsa: unsupported hash function") -+ } + } +- nid := C._goboringcrypto_EVP_MD_type(md) +- var out []byte +- var outLen C.uint +- if priv.withKey(func(key *C.GO_RSA) C.int { +- out = make([]byte, C._goboringcrypto_RSA_size(key)) +- return C._goboringcrypto_RSA_sign(nid, base(hashed), C.uint(len(hashed)), +- base(out), &outLen, key) + + if pub.withKey(func(key *C.GO_RSA) C.int { + size := int(C._goboringcrypto_RSA_size(key)) @@ -5784,10 +5779,17 @@ index 64c83c21c5..f48c57adff 100644 + return 0 + } + return 1 -+ }) == 0 { + }) == 0 { +- return nil, fail("RSA_sign") + return errors.New("crypto/rsa: verification error") -+ } -+ + } +- return out[:outLen], nil +-} + +-func VerifyRSAPKCS1v15(pub *PublicKeyRSA, h crypto.Hash, hashed, sig []byte) error { +- if h == 0 { +- var out []byte +- var outLen C.size_t + if msgIsHashed { + PanicIfStrictFIPS("You must provide a raw unhashed message for PKCS1v15 verification and use HashVerifyPKCS1v15 instead of VerifyPKCS1v15") + nid := C._goboringcrypto_EVP_MD_type(md) diff --git a/patches/005-marvin.patch b/patches/005-marvin.patch new file mode 100644 index 0000000000..c1b9a0f7b4 --- /dev/null +++ b/patches/005-marvin.patch @@ -0,0 +1,1750 @@ +From 2be9d3fd57246dcd64ca8a214508d0731d87f46c Mon Sep 17 00:00:00 2001 +From: Lúcás Meier +Date: Tue, 08 Jun 2021 21:36:06 +0200 +Subject: [PATCH] crypto/rsa: replace big.Int for encryption and decryption + +Infamously, big.Int does not provide constant-time arithmetic, making +its use in cryptographic code quite tricky. RSA uses big.Int +pervasively, in its public API, for key generation, precomputation, and +for encryption and decryption. This is a known problem. One mitigation, +blinding, is already in place during decryption. This helps mitigate the +very leaky exponentiation operation. Because big.Int is fundamentally +not constant-time, it's unfortunately difficult to guarantee that +mitigations like these are completely effective. + +This patch removes the use of big.Int for encryption and decryption, +replacing it with an internal nat type instead. Signing and verification +are also affected, because they depend on encryption and decryption. + +Overall, this patch degrades performance by 55% for private key +operations, and 4-5x for (much faster) public key operations. +(Signatures do both, so the slowdown is worse than decryption.) + +name old time/op new time/op delta +DecryptPKCS1v15/2048-8 1.50ms ± 0% 2.34ms ± 0% +56.44% (p=0.000 n=8+10) +DecryptPKCS1v15/3072-8 4.40ms ± 0% 6.79ms ± 0% +54.33% (p=0.000 n=10+9) +DecryptPKCS1v15/4096-8 9.31ms ± 0% 15.14ms ± 0% +62.60% (p=0.000 n=10+10) +EncryptPKCS1v15/2048-8 8.16µs ± 0% 355.58µs ± 0% +4258.90% (p=0.000 n=10+9) +DecryptOAEP/2048-8 1.50ms ± 0% 2.34ms ± 0% +55.68% (p=0.000 n=10+9) +EncryptOAEP/2048-8 8.51µs ± 0% 355.95µs ± 0% +4082.75% (p=0.000 n=10+9) +SignPKCS1v15/2048-8 1.51ms ± 0% 2.69ms ± 0% +77.94% (p=0.000 n=10+10) +VerifyPKCS1v15/2048-8 7.25µs ± 0% 354.34µs ± 0% +4789.52% (p=0.000 n=9+9) +SignPSS/2048-8 1.51ms ± 0% 2.70ms ± 0% +78.80% (p=0.000 n=9+10) +VerifyPSS/2048-8 8.27µs ± 1% 355.65µs ± 0% +4199.39% (p=0.000 n=10+10) + +Keep in mind that this is without any assembly at all, and that further +improvements are likely possible. I think having a review of the logic +and the cryptography would be a good idea at this stage, before we +complicate the code too much through optimization. + +The bulk of the work is in nat.go. This introduces two new types: nat, +representing natural numbers, and modulus, representing moduli used in +modular arithmetic. + +A nat has an "announced size", which may be larger than its "true size", +the number of bits needed to represent this number. Operations on a nat +will only ever leak its announced size, never its true size, or other +information about its value. The size of a nat is always clear based on +how its value is set. For example, x.mod(y, m) will make the announced +size of x match that of m, since x is reduced modulo m. + +Operations assume that the announced size of the operands match what's +expected (with a few exceptions). For example, x.modAdd(y, m) assumes +that x and y have the same announced size as m, and that they're reduced +modulo m. + +Nats are represented over unsatured bits.UintSize - 1 bit limbs. This +means that we can't reuse the assembly routines for big.Int, which use +saturated bits.UintSize limbs. The advantage of unsaturated limbs is +that it makes Montgomery multiplication faster, by needing fewer +registers in a hot loop. This makes exponentiation faster, which +consists of many Montgomery multiplications. + +Moduli use nat internally. Unlike nat, the true size of a modulus always +matches its announced size. When creating a modulus, any zero padding is +removed. Moduli will also precompute constants when created, which is +another reason why having a separate type is desirable. + +Updates #20654 + +Co-authored-by: Filippo Valsorda +Change-Id: I73b61f87d58ab912e80a9644e255d552cbadcced +--- + +diff --git a/src/crypto/rsa/example_test.go b/src/crypto/rsa/example_test.go +index 8c3a997..d07ee7d 100644 +--- a/src/crypto/rsa/example_test.go ++++ b/src/crypto/rsa/example_test.go +@@ -13,7 +13,6 @@ + "crypto/sha256" + "encoding/hex" + "fmt" +- "io" + "os" + ) + +@@ -37,21 +36,17 @@ + // a buffer that contains a random key. Thus, if the RSA result isn't + // well-formed, the implementation uses a random key in constant time. + func ExampleDecryptPKCS1v15SessionKey() { +- // crypto/rand.Reader is a good source of entropy for blinding the RSA +- // operation. +- rng := rand.Reader +- + // The hybrid scheme should use at least a 16-byte symmetric key. Here + // we read the random key that will be used if the RSA decryption isn't + // well-formed. + key := make([]byte, 32) +- if _, err := io.ReadFull(rng, key); err != nil { ++ if _, err := rand.Read(key); err != nil { + panic("RNG failure") + } + + rsaCiphertext, _ := hex.DecodeString("aabbccddeeff") + +- if err := DecryptPKCS1v15SessionKey(rng, rsaPrivateKey, rsaCiphertext, key); err != nil { ++ if err := DecryptPKCS1v15SessionKey(nil, rsaPrivateKey, rsaCiphertext, key); err != nil { + // Any errors that result will be “public” – meaning that they + // can be determined without any secret information. (For + // instance, if the length of key is impossible given the RSA +@@ -87,10 +82,6 @@ + } + + func ExampleSignPKCS1v15() { +- // crypto/rand.Reader is a good source of entropy for blinding the RSA +- // operation. +- rng := rand.Reader +- + message := []byte("message to be signed") + + // Only small messages can be signed directly; thus the hash of a +@@ -100,7 +91,7 @@ + // of writing (2016). + hashed := sha256.Sum256(message) + +- signature, err := SignPKCS1v15(rng, rsaPrivateKey, crypto.SHA256, hashed[:]) ++ signature, err := SignPKCS1v15(nil, rsaPrivateKey, crypto.SHA256, hashed[:]) + if err != nil { + fmt.Fprintf(os.Stderr, "Error from signing: %s\n", err) + return +@@ -152,11 +143,7 @@ + ciphertext, _ := hex.DecodeString("4d1ee10e8f286390258c51a5e80802844c3e6358ad6690b7285218a7c7ed7fc3a4c7b950fbd04d4b0239cc060dcc7065ca6f84c1756deb71ca5685cadbb82be025e16449b905c568a19c088a1abfad54bf7ecc67a7df39943ec511091a34c0f2348d04e058fcff4d55644de3cd1d580791d4524b92f3e91695582e6e340a1c50b6c6d78e80b4e42c5b4d45e479b492de42bbd39cc642ebb80226bb5200020d501b24a37bcc2ec7f34e596b4fd6b063de4858dbf5a4e3dd18e262eda0ec2d19dbd8e890d672b63d368768360b20c0b6b8592a438fa275e5fa7f60bef0dd39673fd3989cc54d2cb80c08fcd19dacbc265ee1c6014616b0e04ea0328c2a04e73460") + label := []byte("orders") + +- // crypto/rand.Reader is a good source of entropy for blinding the RSA +- // operation. +- rng := rand.Reader +- +- plaintext, err := DecryptOAEP(sha256.New(), rng, test2048Key, ciphertext, label) ++ plaintext, err := DecryptOAEP(sha256.New(), nil, test2048Key, ciphertext, label) + if err != nil { + fmt.Fprintf(os.Stderr, "Error from decryption: %s\n", err) + return +diff --git a/src/crypto/rsa/nat.go b/src/crypto/rsa/nat.go +new file mode 100644 +index 0000000..da521c2 +--- /dev/null ++++ b/src/crypto/rsa/nat.go +@@ -0,0 +1,626 @@ ++// Copyright 2021 The Go Authors. All rights reserved. ++// Use of this source code is governed by a BSD-style ++// license that can be found in the LICENSE file. ++ ++package rsa ++ ++import ( ++ "math/big" ++ "math/bits" ++) ++ ++const ( ++ // _W is the number of bits we use for our limbs. ++ _W = bits.UintSize - 1 ++ // _MASK selects _W bits from a full machine word. ++ _MASK = (1 << _W) - 1 ++) ++ ++// choice represents a constant-time boolean. The value of choice is always ++// either 1 or 0. We use an int instead of bool in order to make decisions in ++// constant time by turning it into a mask. ++type choice uint ++ ++func not(c choice) choice { return 1 ^ c } ++ ++const yes = choice(1) ++const no = choice(0) ++ ++// ctSelect returns x if on == 1, and y if on == 0. The execution time of this ++// function does not depend on its inputs. If on is any value besides 1 or 0, ++// the result is undefined. ++func ctSelect(on choice, x, y uint) uint { ++ // When on == 1, mask is 0b111..., otherwise mask is 0b000... ++ mask := -uint(on) ++ // When mask is all zeros, we just have y, otherwise, y cancels with itself. ++ return y ^ (mask & (y ^ x)) ++} ++ ++// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this ++// function does not depend on its inputs. ++func ctEq(x, y uint) choice { ++ // If x != y, then either x - y or y - x will generate a carry. ++ _, c1 := bits.Sub(x, y, 0) ++ _, c2 := bits.Sub(y, x, 0) ++ return not(choice(c1 | c2)) ++} ++ ++// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this ++// function does not depend on its inputs. ++func ctGeq(x, y uint) choice { ++ // If x < y, then x - y generates a carry. ++ _, carry := bits.Sub(x, y, 0) ++ return not(choice(carry)) ++} ++ ++// nat represents an arbitrary natural number ++// ++// Each nat has an announced length, which is the number of limbs it has stored. ++// Operations on this number are allowed to leak this length, but will not leak ++// any information about the values contained in those limbs. ++type nat struct { ++ // limbs is a little-endian representation in base 2^W with ++ // W = bits.UintSize - 1. The top bit is always unset between operations. ++ // ++ // The top bit is left unset to optimize Montgomery multiplication, in the ++ // inner loop of exponentiation. Using fully saturated limbs would leave us ++ // working with 129-bit numbers on 64-bit platforms, wasting a lot of space, ++ // and thus time. ++ limbs []uint ++} ++ ++// expand expands x to n limbs, leaving its value unchanged. ++func (x *nat) expand(n int) *nat { ++ for len(x.limbs) > n { ++ if x.limbs[len(x.limbs)-1] != 0 { ++ panic("rsa: internal error: shrinking nat") ++ } ++ x.limbs = x.limbs[:len(x.limbs)-1] ++ } ++ if cap(x.limbs) < n { ++ newLimbs := make([]uint, n) ++ copy(newLimbs, x.limbs) ++ x.limbs = newLimbs ++ return x ++ } ++ extraLimbs := x.limbs[len(x.limbs):n] ++ for i := range extraLimbs { ++ extraLimbs[i] = 0 ++ } ++ x.limbs = x.limbs[:n] ++ return x ++} ++ ++// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs). ++func (x *nat) reset(n int) *nat { ++ if cap(x.limbs) < n { ++ x.limbs = make([]uint, n) ++ return x ++ } ++ for i := range x.limbs { ++ x.limbs[i] = 0 ++ } ++ x.limbs = x.limbs[:n] ++ return x ++} ++ ++// clone returns a new nat, with the same value and announced length as x. ++func (x *nat) clone() *nat { ++ out := &nat{make([]uint, len(x.limbs))} ++ copy(out.limbs, x.limbs) ++ return out ++} ++ ++// natFromBig creates a new natural number from a big.Int. ++// ++// The announced length of the resulting nat is based on the actual bit size of ++// the input, ignoring leading zeroes. ++func natFromBig(x *big.Int) *nat { ++ xLimbs := x.Bits() ++ bitSize := bigBitLen(x) ++ requiredLimbs := (bitSize + _W - 1) / _W ++ ++ out := &nat{make([]uint, requiredLimbs)} ++ outI := 0 ++ shift := 0 ++ for i := range xLimbs { ++ xi := uint(xLimbs[i]) ++ out.limbs[outI] |= (xi << shift) & _MASK ++ outI++ ++ if outI == requiredLimbs { ++ return out ++ } ++ out.limbs[outI] = xi >> (_W - shift) ++ shift++ // this assumes bits.UintSize - _W = 1 ++ if shift == _W { ++ shift = 0 ++ outI++ ++ } ++ } ++ return out ++} ++ ++// fillBytes sets bytes to x as a zero-extended big-endian byte slice. ++// ++// If bytes is not long enough to contain the number or at least len(x.limbs)-1 ++// limbs, or has zero length, fillBytes will panic. ++func (x *nat) fillBytes(bytes []byte) []byte { ++ if len(bytes) == 0 { ++ panic("nat: fillBytes invoked with too small buffer") ++ } ++ for i := range bytes { ++ bytes[i] = 0 ++ } ++ shift := 0 ++ outI := len(bytes) - 1 ++ for i, limb := range x.limbs { ++ remainingBits := _W ++ for remainingBits >= 8 { ++ bytes[outI] |= byte(limb) << shift ++ consumed := 8 - shift ++ limb >>= consumed ++ remainingBits -= consumed ++ shift = 0 ++ outI-- ++ if outI < 0 { ++ if limb != 0 || i < len(x.limbs)-1 { ++ panic("nat: fillBytes invoked with too small buffer") ++ } ++ return bytes ++ } ++ } ++ bytes[outI] = byte(limb) ++ shift = remainingBits ++ } ++ return bytes ++} ++ ++// natFromBytes converts a slice of big-endian bytes into a nat. ++// ++// The announced length of the output depends on the length of bytes. Unlike ++// big.Int, creating a nat will not remove leading zeros. ++func natFromBytes(bytes []byte) *nat { ++ bitSize := len(bytes) * 8 ++ requiredLimbs := (bitSize + _W - 1) / _W ++ ++ out := &nat{make([]uint, requiredLimbs)} ++ outI := 0 ++ shift := 0 ++ for i := len(bytes) - 1; i >= 0; i-- { ++ bi := bytes[i] ++ out.limbs[outI] |= uint(bi) << shift ++ shift += 8 ++ if shift >= _W { ++ shift -= _W ++ out.limbs[outI] &= _MASK ++ outI++ ++ if shift > 0 { ++ out.limbs[outI] = uint(bi) >> (8 - shift) ++ } ++ } ++ } ++ return out ++} ++ ++// cmpEq returns 1 if x == y, and 0 otherwise. ++// ++// Both operands must have the same announced length. ++func (x *nat) cmpEq(y *nat) choice { ++ // Eliminate bounds checks in the loop. ++ size := len(x.limbs) ++ xLimbs := x.limbs[:size] ++ yLimbs := y.limbs[:size] ++ ++ equal := yes ++ for i := 0; i < size; i++ { ++ equal &= ctEq(xLimbs[i], yLimbs[i]) ++ } ++ return equal ++} ++ ++// cmpGeq returns 1 if x >= y, and 0 otherwise. ++// ++// Both operands must have the same announced length. ++func (x *nat) cmpGeq(y *nat) choice { ++ // Eliminate bounds checks in the loop. ++ size := len(x.limbs) ++ xLimbs := x.limbs[:size] ++ yLimbs := y.limbs[:size] ++ ++ var c uint ++ for i := 0; i < size; i++ { ++ c = (xLimbs[i] - yLimbs[i] - c) >> _W ++ } ++ // If there was a carry, then subtracting y underflowed, so ++ // x is not greater than or equal to y. ++ return not(choice(c)) ++} ++ ++// assign sets x <- y if on == 1, and does nothing otherwise. ++// ++// Both operands must have the same announced length. ++func (x *nat) assign(on choice, y *nat) *nat { ++ // Eliminate bounds checks in the loop. ++ size := len(x.limbs) ++ xLimbs := x.limbs[:size] ++ yLimbs := y.limbs[:size] ++ ++ for i := 0; i < size; i++ { ++ xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i]) ++ } ++ return x ++} ++ ++// add computes x += y if on == 1, and does nothing otherwise. It returns the ++// carry of the addition regardless of on. ++// ++// Both operands must have the same announced length. ++func (x *nat) add(on choice, y *nat) (c uint) { ++ // Eliminate bounds checks in the loop. ++ size := len(x.limbs) ++ xLimbs := x.limbs[:size] ++ yLimbs := y.limbs[:size] ++ ++ for i := 0; i < size; i++ { ++ res := xLimbs[i] + yLimbs[i] + c ++ xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i]) ++ c = res >> _W ++ } ++ return ++} ++ ++// sub computes x -= y if on == 1, and does nothing otherwise. It returns the ++// borrow of the subtraction regardless of on. ++// ++// Both operands must have the same announced length. ++func (x *nat) sub(on choice, y *nat) (c uint) { ++ // Eliminate bounds checks in the loop. ++ size := len(x.limbs) ++ xLimbs := x.limbs[:size] ++ yLimbs := y.limbs[:size] ++ ++ for i := 0; i < size; i++ { ++ res := xLimbs[i] - yLimbs[i] - c ++ xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i]) ++ c = res >> _W ++ } ++ return ++} ++ ++// modulus is used for modular arithmetic, precomputing relevant constants. ++// ++// Moduli are assumed to be odd numbers. Moduli can also leak the exact ++// number of bits needed to store their value, and are stored without padding. ++// ++// Their actual value is still kept secret. ++type modulus struct { ++ // The underlying natural number for this modulus. ++ // ++ // This will be stored without any padding, and shouldn't alias with any ++ // other natural number being used. ++ nat *nat ++ leading int // number of leading zeros in the modulus ++ m0inv uint // -nat.limbs[0]⁻¹ mod _W ++} ++ ++// minusInverseModW computes -x⁻¹ mod _W with x odd. ++// ++// This operation is used to precompute a constant involved in Montgomery ++// multiplication. ++func minusInverseModW(x uint) uint { ++ // Every iteration of this loop doubles the least-significant bits of ++ // correct inverse in y. The first three bits are already correct (1⁻¹ = 1, ++ // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough ++ // for 61 bits (and wastes only one iteration for 31 bits). ++ // ++ // See https://crypto.stackexchange.com/a/47496. ++ y := x ++ for i := 0; i < 5; i++ { ++ y = y * (2 - x*y) ++ } ++ return (1 << _W) - (y & _MASK) ++} ++ ++// modulusFromNat creates a new modulus from a nat. ++// ++// The nat should be odd, nonzero, and the number of significant bits in the ++// number should be leakable. The nat shouldn't be reused. ++func modulusFromNat(nat *nat) *modulus { ++ m := &modulus{} ++ m.nat = nat ++ size := len(m.nat.limbs) ++ for m.nat.limbs[size-1] == 0 { ++ size-- ++ } ++ m.nat.limbs = m.nat.limbs[:size] ++ m.leading = _W - bitLen(m.nat.limbs[size-1]) ++ m.m0inv = minusInverseModW(m.nat.limbs[0]) ++ return m ++} ++ ++// bitLen is a version of bits.Len that only leaks the bit length of n, but not ++// its value. bits.Len and bits.LeadingZeros use a lookup table for the ++// low-order bits on some architectures. ++func bitLen(n uint) int { ++ var len int ++ // We assume, here and elsewhere, that comparison to zero is constant time ++ // with respect to different non-zero values. ++ for n != 0 { ++ len++ ++ n >>= 1 ++ } ++ return len ++} ++ ++// bigBitLen is a version of big.Int.BitLen that only leaks the bit length of x, ++// but not its value. big.Int.BitLen uses bits.Len. ++func bigBitLen(x *big.Int) int { ++ xLimbs := x.Bits() ++ fullLimbs := len(xLimbs) - 1 ++ topLimb := uint(xLimbs[len(xLimbs)-1]) ++ return fullLimbs*bits.UintSize + bitLen(topLimb) ++} ++ ++// modulusSize returns the size of m in bytes. ++func modulusSize(m *modulus) int { ++ bits := len(m.nat.limbs)*_W - int(m.leading) ++ return (bits + 7) / 8 ++} ++ ++// shiftIn calculates x = x << _W + y mod m. ++// ++// This assumes that x is already reduced mod m, and that y < 2^_W. ++func (x *nat) shiftIn(y uint, m *modulus) *nat { ++ d := new(nat).resetFor(m) ++ ++ // Eliminate bounds checks in the loop. ++ size := len(m.nat.limbs) ++ xLimbs := x.limbs[:size] ++ dLimbs := d.limbs[:size] ++ mLimbs := m.nat.limbs[:size] ++ ++ // Each iteration of this loop computes x = 2x + b mod m, where b is a bit ++ // from y. Effectively, it left-shifts x and adds y one bit at a time, ++ // reducing it every time. ++ // ++ // To do the reduction, each iteration computes both 2x + b and 2x + b - m. ++ // The next iteration (and finally the return line) will use either result ++ // based on whether the subtraction underflowed. ++ needSubtraction := no ++ for i := _W - 1; i >= 0; i-- { ++ carry := (y >> i) & 1 ++ var borrow uint ++ for i := 0; i < size; i++ { ++ l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i]) ++ ++ res := l<<1 + carry ++ xLimbs[i] = res & _MASK ++ carry = res >> _W ++ ++ res = xLimbs[i] - mLimbs[i] - borrow ++ dLimbs[i] = res & _MASK ++ borrow = res >> _W ++ } ++ // See modAdd for how carry (aka overflow), borrow (aka underflow), and ++ // needSubtraction relate. ++ needSubtraction = ctEq(carry, borrow) ++ } ++ return x.assign(needSubtraction, d) ++} ++ ++// mod calculates out = x mod m. ++// ++// This works regardless how large the value of x is. ++// ++// The output will be resized to the size of m and overwritten. ++func (out *nat) mod(x *nat, m *modulus) *nat { ++ out.resetFor(m) ++ // Working our way from the most significant to the least significant limb, ++ // we can insert each limb at the least significant position, shifting all ++ // previous limbs left by _W. This way each limb will get shifted by the ++ // correct number of bits. We can insert at least N - 1 limbs without ++ // overflowing m. After that, we need to reduce every time we shift. ++ i := len(x.limbs) - 1 ++ // For the first N - 1 limbs we can skip the actual shifting and position ++ // them at the shifted position, which starts at min(N - 2, i). ++ start := len(m.nat.limbs) - 2 ++ if i < start { ++ start = i ++ } ++ for j := start; j >= 0; j-- { ++ out.limbs[j] = x.limbs[i] ++ i-- ++ } ++ // We shift in the remaining limbs, reducing modulo m each time. ++ for i >= 0 { ++ out.shiftIn(x.limbs[i], m) ++ i-- ++ } ++ return out ++} ++ ++// expandFor ensures out has the right size to work with operations modulo m. ++// ++// This assumes that out has as many or fewer limbs than m, or that the extra ++// limbs are all zero (which may happen when decoding a value that has leading ++// zeroes in its bytes representation that spill over the limb threshold). ++func (out *nat) expandFor(m *modulus) *nat { ++ return out.expand(len(m.nat.limbs)) ++} ++ ++// resetFor ensures out has the right size to work with operations modulo m. ++// ++// out is zeroed and may start at any size. ++func (out *nat) resetFor(m *modulus) *nat { ++ return out.reset(len(m.nat.limbs)) ++} ++ ++// modSub computes x = x - y mod m. ++// ++// The length of both operands must be the same as the modulus. Both operands ++// must already be reduced modulo m. ++func (x *nat) modSub(y *nat, m *modulus) *nat { ++ underflow := x.sub(yes, y) ++ // If the subtraction underflowed, add m. ++ x.add(choice(underflow), m.nat) ++ return x ++} ++ ++// modAdd computes x = x + y mod m. ++// ++// The length of both operands must be the same as the modulus. Both operands ++// must already be reduced modulo m. ++func (x *nat) modAdd(y *nat, m *modulus) *nat { ++ overflow := x.add(yes, y) ++ underflow := not(x.cmpGeq(m.nat)) // x < m ++ ++ // Three cases are possible: ++ // ++ // - overflow = 0, underflow = 0 ++ // ++ // In this case, addition fits in our limbs, but we can still subtract away ++ // m without an underflow, so we need to perform the subtraction to reduce ++ // our result. ++ // ++ // - overflow = 0, underflow = 1 ++ // ++ // The addition fits in our limbs, but we can't subtract m without ++ // underflowing. The result is already reduced. ++ // ++ // - overflow = 1, underflow = 1 ++ // ++ // The addition does not fit in our limbs, and the subtraction's borrow ++ // would cancel out with the addition's carry. We need to subtract m to ++ // reduce our result. ++ // ++ // The overflow = 1, underflow = 0 case is not possible, because y is at ++ // most m - 1, and if adding m - 1 overflows, then subtracting m must ++ // necessarily underflow. ++ needSubtraction := ctEq(overflow, uint(underflow)) ++ ++ x.sub(needSubtraction, m.nat) ++ return x ++} ++ ++// montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and ++// n = len(m.nat.limbs). ++// ++// Faster Montgomery multiplication replaces standard modular multiplication for ++// numbers in this representation. ++// ++// This assumes that x is already reduced mod m. ++func (x *nat) montgomeryRepresentation(m *modulus) *nat { ++ for i := 0; i < len(m.nat.limbs); i++ { ++ x.shiftIn(0, m) // x = x * 2^_W mod m ++ } ++ return x ++} ++ ++// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and ++// n = len(m.nat.limbs), using the Montgomery Multiplication technique. ++// ++// All inputs should be the same length, not aliasing d, and already ++// reduced modulo m. d will be resized to the size of m and overwritten. ++func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat { ++ // See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication ++ // for a description of the algorithm. ++ ++ // Eliminate bounds checks in the loop. ++ size := len(m.nat.limbs) ++ aLimbs := a.limbs[:size] ++ bLimbs := b.limbs[:size] ++ dLimbs := d.resetFor(m).limbs[:size] ++ mLimbs := m.nat.limbs[:size] ++ ++ var overflow uint ++ for i := 0; i < size; i++ { ++ f := ((dLimbs[0] + aLimbs[i]*bLimbs[0]) * m.m0inv) & _MASK ++ carry := uint(0) ++ for j := 0; j < size; j++ { ++ // z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W ++ hi, lo := bits.Mul(aLimbs[i], bLimbs[j]) ++ z_lo, c := bits.Add(dLimbs[j], lo, 0) ++ z_hi, _ := bits.Add(0, hi, c) ++ hi, lo = bits.Mul(f, mLimbs[j]) ++ z_lo, c = bits.Add(z_lo, lo, 0) ++ z_hi, _ = bits.Add(z_hi, hi, c) ++ z_lo, c = bits.Add(z_lo, carry, 0) ++ z_hi, _ = bits.Add(z_hi, 0, c) ++ if j > 0 { ++ dLimbs[j-1] = z_lo & _MASK ++ } ++ carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2 ++ } ++ z := overflow + carry // z <= 2^(W+1) - 1 ++ dLimbs[size-1] = z & _MASK ++ overflow = z >> _W // overflow <= 1 ++ } ++ // See modAdd for how overflow, underflow, and needSubtraction relate. ++ underflow := not(d.cmpGeq(m.nat)) // d < m ++ needSubtraction := ctEq(overflow, uint(underflow)) ++ d.sub(needSubtraction, m.nat) ++ ++ return d ++} ++ ++// modMul calculates x *= y mod m. ++// ++// x and y must already be reduced modulo m, they must share its announced ++// length, and they may not alias. ++func (x *nat) modMul(y *nat, m *modulus) *nat { ++ // A Montgomery multiplication by a value out of the Montgomery domain ++ // takes the result out of Montgomery representation. ++ xR := x.clone().montgomeryRepresentation(m) // xR = x * R mod m ++ return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m ++} ++ ++// exp calculates out = x^e mod m. ++// ++// The exponent e is represented in big-endian order. The output will be resized ++// to the size of m and overwritten. x must already be reduced modulo m. ++func (out *nat) exp(x *nat, e []byte, m *modulus) *nat { ++ // We use a 4 bit window. For our RSA workload, 4 bit windows are faster ++ // than 2 bit windows, but use an extra 12 nats worth of scratch space. ++ // Using bit sizes that don't divide 8 are more complex to implement. ++ table := make([]*nat, (1<<4)-1) // table[i] = x ^ (i+1) ++ table[0] = x.clone().montgomeryRepresentation(m) ++ for i := 1; i < len(table); i++ { ++ table[i] = new(nat).expandFor(m) ++ table[i].montgomeryMul(table[i-1], table[0], m) ++ } ++ ++ out.resetFor(m) ++ out.limbs[0] = 1 ++ out.montgomeryRepresentation(m) ++ t0 := new(nat).expandFor(m) ++ t1 := new(nat).expandFor(m) ++ for _, b := range e { ++ for _, j := range []int{4, 0} { ++ // Square four times. ++ t1.montgomeryMul(out, out, m) ++ out.montgomeryMul(t1, t1, m) ++ t1.montgomeryMul(out, out, m) ++ out.montgomeryMul(t1, t1, m) ++ ++ // Select x^k in constant time from the table. ++ k := uint((b >> j) & 0b1111) ++ for i := range table { ++ t0.assign(ctEq(k, uint(i+1)), table[i]) ++ } ++ ++ // Multiply by x^k, discarding the result if k = 0. ++ t1.montgomeryMul(out, t0, m) ++ out.assign(not(ctEq(k, 0)), t1) ++ } ++ } ++ ++ // By Montgomery multiplying with 1 not in Montgomery representation, we ++ // convert out back from Montgomery representation, because it works out to ++ // dividing by R. ++ t0.assign(yes, out) ++ t1.resetFor(m) ++ t1.limbs[0] = 1 ++ out.montgomeryMul(t0, t1, m) ++ ++ return out ++} +diff --git a/src/crypto/rsa/nat_test.go b/src/crypto/rsa/nat_test.go +new file mode 100644 +index 0000000..3e6eb10 +--- /dev/null ++++ b/src/crypto/rsa/nat_test.go +@@ -0,0 +1,384 @@ ++// Copyright 2021 The Go Authors. All rights reserved. ++// Use of this source code is governed by a BSD-style ++// license that can be found in the LICENSE file. ++ ++package rsa ++ ++import ( ++ "bytes" ++ "math/big" ++ "math/bits" ++ "math/rand" ++ "reflect" ++ "testing" ++ "testing/quick" ++) ++ ++// Generate generates an even nat. It's used by testing/quick to produce random ++// *nat values for quick.Check invocations. ++func (*nat) Generate(r *rand.Rand, size int) reflect.Value { ++ limbs := make([]uint, size) ++ for i := 0; i < size; i++ { ++ limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2) ++ } ++ return reflect.ValueOf(&nat{limbs}) ++} ++ ++func testModAddCommutative(a *nat, b *nat) bool { ++ mLimbs := make([]uint, len(a.limbs)) ++ for i := 0; i < len(mLimbs); i++ { ++ mLimbs[i] = _MASK ++ } ++ m := modulusFromNat(&nat{mLimbs}) ++ aPlusB := a.clone() ++ aPlusB.modAdd(b, m) ++ bPlusA := b.clone() ++ bPlusA.modAdd(a, m) ++ return aPlusB.cmpEq(bPlusA) == 1 ++} ++ ++func TestModAddCommutative(t *testing.T) { ++ err := quick.Check(testModAddCommutative, &quick.Config{}) ++ if err != nil { ++ t.Error(err) ++ } ++} ++ ++func testModSubThenAddIdentity(a *nat, b *nat) bool { ++ mLimbs := make([]uint, len(a.limbs)) ++ for i := 0; i < len(mLimbs); i++ { ++ mLimbs[i] = _MASK ++ } ++ m := modulusFromNat(&nat{mLimbs}) ++ original := a.clone() ++ a.modSub(b, m) ++ a.modAdd(b, m) ++ return a.cmpEq(original) == 1 ++} ++ ++func TestModSubThenAddIdentity(t *testing.T) { ++ err := quick.Check(testModSubThenAddIdentity, &quick.Config{}) ++ if err != nil { ++ t.Error(err) ++ } ++} ++ ++func testMontgomeryRoundtrip(a *nat) bool { ++ one := &nat{make([]uint, len(a.limbs))} ++ one.limbs[0] = 1 ++ aPlusOne := a.clone() ++ aPlusOne.add(1, one) ++ m := modulusFromNat(aPlusOne) ++ monty := a.clone() ++ monty.montgomeryRepresentation(m) ++ aAgain := monty.clone() ++ aAgain.montgomeryMul(monty, one, m) ++ return a.cmpEq(aAgain) == 1 ++} ++ ++func TestMontgomeryRoundtrip(t *testing.T) { ++ err := quick.Check(testMontgomeryRoundtrip, &quick.Config{}) ++ if err != nil { ++ t.Error(err) ++ } ++} ++ ++func TestFromBig(t *testing.T) { ++ expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} ++ theBig := new(big.Int).SetBytes(expected) ++ actual := natFromBig(theBig).fillBytes(make([]byte, len(expected))) ++ if !bytes.Equal(actual, expected) { ++ t.Errorf("%+x != %+x", actual, expected) ++ } ++} ++ ++func TestFillBytes(t *testing.T) { ++ xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} ++ x := natFromBytes(xBytes) ++ for l := 20; l >= len(xBytes); l-- { ++ buf := make([]byte, l) ++ rand.Read(buf) ++ actual := x.fillBytes(buf) ++ expected := make([]byte, l) ++ copy(expected[l-len(xBytes):], xBytes) ++ if !bytes.Equal(actual, expected) { ++ t.Errorf("%d: %+v != %+v", l, actual, expected) ++ } ++ } ++ for l := len(xBytes) - 1; l >= 0; l-- { ++ (func() { ++ defer func() { ++ if recover() == nil { ++ t.Errorf("%d: expected panic", l) ++ } ++ }() ++ x.fillBytes(make([]byte, l)) ++ })() ++ } ++} ++ ++func TestFromBytes(t *testing.T) { ++ f := func(xBytes []byte) bool { ++ if len(xBytes) == 0 { ++ return true ++ } ++ actual := natFromBytes(xBytes).fillBytes(make([]byte, len(xBytes))) ++ if !bytes.Equal(actual, xBytes) { ++ t.Errorf("%+x != %+x", actual, xBytes) ++ return false ++ } ++ return true ++ } ++ ++ err := quick.Check(f, &quick.Config{}) ++ if err != nil { ++ t.Error(err) ++ } ++ ++ f([]byte{0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) ++ f(bytes.Repeat([]byte{0xFF}, _W)) ++} ++ ++func TestShiftIn(t *testing.T) { ++ if bits.UintSize != 64 { ++ t.Skip("examples are only valid in 64 bit") ++ } ++ examples := []struct { ++ m, x, expected []byte ++ y uint64 ++ }{{ ++ m: []byte{13}, ++ x: []byte{0}, ++ y: 0x7FFF_FFFF_FFFF_FFFF, ++ expected: []byte{7}, ++ }, { ++ m: []byte{13}, ++ x: []byte{7}, ++ y: 0x7FFF_FFFF_FFFF_FFFF, ++ expected: []byte{11}, ++ }, { ++ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, ++ x: make([]byte, 9), ++ y: 0x7FFF_FFFF_FFFF_FFFF, ++ expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ++ }, { ++ m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, ++ x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, ++ y: 0, ++ expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08}, ++ }} ++ ++ for i, tt := range examples { ++ m := modulusFromNat(natFromBytes(tt.m)) ++ got := natFromBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m) ++ if got.cmpEq(natFromBytes(tt.expected).expandFor(m)) != 1 { ++ t.Errorf("%d: got %x, expected %x", i, got, tt.expected) ++ } ++ } ++} ++ ++func TestModulusAndNatSizes(t *testing.T) { ++ // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as ++ // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two ++ // limbs, if they are not, they fit in three. This can be a problem because ++ // modulus strips leading zeroes and nat does not. ++ m := modulusFromNat(natFromBytes([]byte{ ++ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, ++ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})) ++ x := natFromBytes([]byte{ ++ 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, ++ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}) ++ x.expandFor(m) // must not panic for shrinking ++} ++ ++func TestExpand(t *testing.T) { ++ sliced := []uint{1, 2, 3, 4} ++ examples := []struct { ++ in []uint ++ n int ++ out []uint ++ }{{ ++ []uint{1, 2}, ++ 4, ++ []uint{1, 2, 0, 0}, ++ }, { ++ sliced[:2], ++ 4, ++ []uint{1, 2, 0, 0}, ++ }, { ++ []uint{1, 2}, ++ 2, ++ []uint{1, 2}, ++ }, { ++ []uint{1, 2, 0}, ++ 2, ++ []uint{1, 2}, ++ }} ++ ++ for i, tt := range examples { ++ got := (&nat{tt.in}).expand(tt.n) ++ if len(got.limbs) != len(tt.out) || got.cmpEq(&nat{tt.out}) != 1 { ++ t.Errorf("%d: got %x, expected %x", i, got, tt.out) ++ } ++ } ++} ++ ++func TestMod(t *testing.T) { ++ m := modulusFromNat(natFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})) ++ x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) ++ out := new(nat) ++ out.mod(x, m) ++ expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09}) ++ if out.cmpEq(expected) != 1 { ++ t.Errorf("%+v != %+v", out, expected) ++ } ++} ++ ++func TestModSub(t *testing.T) { ++ m := modulusFromNat(&nat{[]uint{13}}) ++ x := &nat{[]uint{6}} ++ y := &nat{[]uint{7}} ++ x.modSub(y, m) ++ expected := &nat{[]uint{12}} ++ if x.cmpEq(expected) != 1 { ++ t.Errorf("%+v != %+v", x, expected) ++ } ++ x.modSub(y, m) ++ expected = &nat{[]uint{5}} ++ if x.cmpEq(expected) != 1 { ++ t.Errorf("%+v != %+v", x, expected) ++ } ++} ++ ++func TestModAdd(t *testing.T) { ++ m := modulusFromNat(&nat{[]uint{13}}) ++ x := &nat{[]uint{6}} ++ y := &nat{[]uint{7}} ++ x.modAdd(y, m) ++ expected := &nat{[]uint{0}} ++ if x.cmpEq(expected) != 1 { ++ t.Errorf("%+v != %+v", x, expected) ++ } ++ x.modAdd(y, m) ++ expected = &nat{[]uint{7}} ++ if x.cmpEq(expected) != 1 { ++ t.Errorf("%+v != %+v", x, expected) ++ } ++} ++ ++func TestExp(t *testing.T) { ++ m := modulusFromNat(&nat{[]uint{13}}) ++ x := &nat{[]uint{3}} ++ out := &nat{[]uint{0}} ++ out.exp(x, []byte{12}, m) ++ expected := &nat{[]uint{1}} ++ if out.cmpEq(expected) != 1 { ++ t.Errorf("%+v != %+v", out, expected) ++ } ++} ++ ++func makeBenchmarkModulus() *modulus { ++ m := make([]uint, 32) ++ for i := 0; i < 32; i++ { ++ m[i] = _MASK ++ } ++ return modulusFromNat(&nat{limbs: m}) ++} ++ ++func makeBenchmarkValue() *nat { ++ x := make([]uint, 32) ++ for i := 0; i < 32; i++ { ++ x[i] = _MASK - 1 ++ } ++ return &nat{limbs: x} ++} ++ ++func makeBenchmarkExponent() []byte { ++ e := make([]byte, 256) ++ for i := 0; i < 32; i++ { ++ e[i] = 0xFF ++ } ++ return e ++} ++ ++func BenchmarkModAdd(b *testing.B) { ++ x := makeBenchmarkValue() ++ y := makeBenchmarkValue() ++ m := makeBenchmarkModulus() ++ ++ b.ResetTimer() ++ for i := 0; i < b.N; i++ { ++ x.modAdd(y, m) ++ } ++} ++ ++func BenchmarkModSub(b *testing.B) { ++ x := makeBenchmarkValue() ++ y := makeBenchmarkValue() ++ m := makeBenchmarkModulus() ++ ++ b.ResetTimer() ++ for i := 0; i < b.N; i++ { ++ x.modSub(y, m) ++ } ++} ++ ++func BenchmarkMontgomeryRepr(b *testing.B) { ++ x := makeBenchmarkValue() ++ m := makeBenchmarkModulus() ++ ++ b.ResetTimer() ++ for i := 0; i < b.N; i++ { ++ x.montgomeryRepresentation(m) ++ } ++} ++ ++func BenchmarkMontgomeryMul(b *testing.B) { ++ x := makeBenchmarkValue() ++ y := makeBenchmarkValue() ++ out := makeBenchmarkValue() ++ m := makeBenchmarkModulus() ++ ++ b.ResetTimer() ++ for i := 0; i < b.N; i++ { ++ out.montgomeryMul(x, y, m) ++ } ++} ++ ++func BenchmarkModMul(b *testing.B) { ++ x := makeBenchmarkValue() ++ y := makeBenchmarkValue() ++ m := makeBenchmarkModulus() ++ ++ b.ResetTimer() ++ for i := 0; i < b.N; i++ { ++ x.modMul(y, m) ++ } ++} ++ ++func BenchmarkExpBig(b *testing.B) { ++ out := new(big.Int) ++ exponentBytes := makeBenchmarkExponent() ++ x := new(big.Int).SetBytes(exponentBytes) ++ e := new(big.Int).SetBytes(exponentBytes) ++ n := new(big.Int).SetBytes(exponentBytes) ++ one := new(big.Int).SetUint64(1) ++ n.Add(n, one) ++ ++ b.ResetTimer() ++ for i := 0; i < b.N; i++ { ++ out.Exp(x, e, n) ++ } ++} ++ ++func BenchmarkExp(b *testing.B) { ++ x := makeBenchmarkValue() ++ e := makeBenchmarkExponent() ++ out := makeBenchmarkValue() ++ m := makeBenchmarkModulus() ++ ++ b.ResetTimer() ++ for i := 0; i < b.N; i++ { ++ out.exp(x, e, m) ++ } ++} +diff --git a/src/crypto/rsa/pkcs1v15.go b/src/crypto/rsa/pkcs1v15.go +index ea5a878..59cf0e6 100644 +--- a/src/crypto/rsa/pkcs1v15.go ++++ b/src/crypto/rsa/pkcs1v15.go +@@ -11,7 +11,6 @@ + "crypto/subtle" + "errors" + "io" +- "math/big" + ) + + // This file implements encryption and decryption using PKCS #1 v1.5 padding. +@@ -76,13 +75,11 @@ + return boring.EncryptRSANoPadding(bkey, em) + } + +- m := new(big.Int).SetBytes(em) +- c := encrypt(new(big.Int), pub, m) +- return c.FillBytes(em), nil ++ return encrypt(pub, em), nil + } + + // DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS #1 v1.5. +-// If random != nil, it uses RSA blinding to avoid timing side-channel attacks. ++// The random parameter is legacy and ignored, and it can be as nil. + // + // Note that whether this function returns an error or not discloses secret + // information. If an attacker can cause this function to run repeatedly and +@@ -106,7 +103,7 @@ + return out, nil + } + +- valid, out, index, err := decryptPKCS1v15(random, priv, ciphertext) ++ valid, out, index, err := decryptPKCS1v15(priv, ciphertext) + if err != nil { + return nil, err + } +@@ -117,7 +114,7 @@ + } + + // DecryptPKCS1v15SessionKey decrypts a session key using RSA and the padding scheme from PKCS #1 v1.5. +-// If random != nil, it uses RSA blinding to avoid timing side-channel attacks. ++// The random parameter is legacy and ignored, and it can be as nil. + // It returns an error if the ciphertext is the wrong length or if the + // ciphertext is greater than the public modulus. Otherwise, no error is + // returned. If the padding is valid, the resulting plaintext message is copied +@@ -144,7 +141,7 @@ + return ErrDecryption + } + +- valid, em, index, err := decryptPKCS1v15(random, priv, ciphertext) ++ valid, em, index, err := decryptPKCS1v15(priv, ciphertext) + if err != nil { + return err + } +@@ -160,13 +157,13 @@ + return nil + } + +-// decryptPKCS1v15 decrypts ciphertext using priv and blinds the operation if +-// random is not nil. It returns one or zero in valid that indicates whether the +-// plaintext was correctly structured. In either case, the plaintext is +-// returned in em so that it may be read independently of whether it was valid +-// in order to maintain constant memory access patterns. If the plaintext was +-// valid then index contains the index of the original message in em. +-func decryptPKCS1v15(random io.Reader, priv *PrivateKey, ciphertext []byte) (valid int, em []byte, index int, err error) { ++// decryptPKCS1v15 decrypts ciphertext using priv. It returns one or zero in ++// valid that indicates whether the plaintext was correctly structured. ++// In either case, the plaintext is returned in em so that it may be read ++// independently of whether it was valid in order to maintain constant memory ++// access patterns. If the plaintext was valid then index contains the index of ++// the original message in em, to allow constant time padding removal. ++func decryptPKCS1v15(priv *PrivateKey, ciphertext []byte) (valid int, em []byte, index int, err error) { + k := priv.Size() + if k < 11 { + err = ErrDecryption +@@ -184,13 +181,10 @@ + return + } + } else { +- c := new(big.Int).SetBytes(ciphertext) +- var m *big.Int +- m, err = decrypt(random, priv, c) ++ em, err = decrypt(priv, ciphertext) + if err != nil { + return + } +- em = m.FillBytes(make([]byte, k)) + } + + firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0) +@@ -266,8 +260,7 @@ + // function. If hash is zero, hashed is signed directly. This isn't + // advisable except for interoperability. + // +-// If random is not nil then RSA blinding will be used to avoid timing +-// side-channel attacks. ++// The random parameter is legacy and ignored, and it can be as nil. + // + // This function is deterministic. Thus, if the set of possible + // messages is small, an attacker may be able to build a map from +@@ -302,13 +295,7 @@ + copy(em[k-tLen:k-hashLen], prefix) + copy(em[k-hashLen:k], hashed) + +- m := new(big.Int).SetBytes(em) +- c, err := decryptAndCheck(random, priv, m) +- if err != nil { +- return nil, err +- } +- +- return c.FillBytes(em), nil ++ return decryptAndCheck(priv, em) + } + + // VerifyPKCS1v15 verifies an RSA PKCS #1 v1.5 signature. +@@ -346,9 +333,7 @@ + return ErrVerification + } + +- c := new(big.Int).SetBytes(sig) +- m := encrypt(new(big.Int), pub, c) +- em := m.FillBytes(make([]byte, k)) ++ em := encrypt(pub, sig) + // EM = 0x00 || 0x01 || PS || 0x00 || T + + ok := subtle.ConstantTimeByteEq(em[0], 0) +diff --git a/src/crypto/rsa/pss.go b/src/crypto/rsa/pss.go +index fd9fc2e..e4e217b 100644 +--- a/src/crypto/rsa/pss.go ++++ b/src/crypto/rsa/pss.go +@@ -13,7 +13,6 @@ + "errors" + "hash" + "io" +- "math/big" + ) + + // Per RFC 8017, Section 9.1 +@@ -208,8 +207,8 @@ + // Note that hashed must be the result of hashing the input message using the + // given hash function. salt is a random sequence of bytes whose length will be + // later used to verify the signature. +-func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { +- emBits := priv.N.BitLen() - 1 ++func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { ++ emBits := bigBitLen(priv.N) - 1 + em, err := emsaPSSEncode(hashed, emBits, salt, hash.New()) + if err != nil { + return nil, err +@@ -229,13 +228,20 @@ + return s, nil + } + +- m := new(big.Int).SetBytes(em) +- c, err := decryptAndCheck(rand, priv, m) +- if err != nil { +- return nil, err ++ // RFC 8017: "Note that the octet length of EM will be one less than k if ++ // modBits - 1 is divisible by 8 and equal to k otherwise, where k is the ++ // length in octets of the RSA modulus n." 🙄 ++ // ++ // This is extremely annoying, as all other encrypt and decrypt inputs are ++ // always the exact same size as the modulus. Since it only happens for ++ // weird modulus sizes, fix it by padding inefficiently. ++ if emLen, k := len(em), priv.Size(); emLen < k { ++ emNew := make([]byte, k) ++ copy(emNew[k-emLen:], em) ++ em = emNew + } +- s := make([]byte, priv.Size()) +- return c.FillBytes(s), nil ++ ++ return decryptAndCheck(priv, em) + } + + const ( +@@ -296,7 +302,7 @@ + saltLength := opts.saltLength() + switch saltLength { + case PSSSaltLengthAuto: +- saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size() ++ saltLength = (bigBitLen(priv.N)-1+7)/8 - 2 - hash.Size() + case PSSSaltLengthEqualsHash: + saltLength = hash.Size() + } +@@ -310,7 +316,7 @@ + if _, err := io.ReadFull(rand, salt); err != nil { + return nil, err + } +- return signPSSWithSalt(rand, priv, hash, digest, salt) ++ return signPSSWithSalt(priv, hash, digest, salt) + } + + // VerifyPSS verifies a PSS signature. +@@ -339,13 +345,22 @@ + if len(sig) != pub.Size() { + return ErrVerification + } +- s := new(big.Int).SetBytes(sig) +- m := encrypt(new(big.Int), pub, s) +- emBits := pub.N.BitLen() - 1 ++ ++ emBits := bigBitLen(pub.N) - 1 + emLen := (emBits + 7) / 8 +- if m.BitLen() > emLen*8 { +- return ErrVerification ++ em := encrypt(pub, sig) ++ ++ // Like in signPSSWithSalt, deal with mismatches between emLen and the size ++ // of the modulus. The spec would have us wire emLen into the encoding ++ // function, but we'd rather always encode to the size of the modulus and ++ // then strip leading zeroes if necessary. This only happens for weird ++ // modulus sizes anyway. ++ for len(em) > emLen && len(em) > 0 { ++ if em[0] != 0 { ++ return ErrVerification ++ } ++ em = em[1:] + } +- em := m.FillBytes(make([]byte, emLen)) ++ + return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New()) + } +diff --git a/src/crypto/rsa/pss_test.go b/src/crypto/rsa/pss_test.go +index f1f1704..cf03e3c 100644 +--- a/src/crypto/rsa/pss_test.go ++++ b/src/crypto/rsa/pss_test.go +@@ -236,7 +236,10 @@ + } + } + +-func TestSignWithPSSSaltLengthAuto(t *testing.T) { ++func TestPSS513(t *testing.T) { ++ // See Issue 42741, and separately, RFC 8017: "Note that the octet length of ++ // EM will be one less than k if modBits - 1 is divisible by 8 and equal to ++ // k otherwise, where k is the length in octets of the RSA modulus n." + key, err := GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) +@@ -249,8 +252,9 @@ + if err != nil { + t.Fatal(err) + } +- if len(signature) == 0 { +- t.Fatal("empty signature returned") ++ err = VerifyPSS(&key.PublicKey, crypto.SHA256, digest[:], signature, nil) ++ if err != nil { ++ t.Error(err) + } + } + +diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go +index 34218e5..237d745 100644 +--- a/src/crypto/rsa/rsa.go ++++ b/src/crypto/rsa/rsa.go +@@ -19,7 +19,10 @@ + // over the public key primitive, the PrivateKey type implements the + // Decrypter and Signer interfaces from the crypto package. + // +-// The RSA operations in this package are not implemented using constant-time algorithms. ++// Operations in this package are implemented using constant-time algorithms, ++// except for [GenerateKey], [PrivateKey.Precompute], and [PrivateKey.Validate]. ++// Every other operation only leaks the bit size of the involved values, which ++// all depend on the selected key size. + package rsa + + import ( +@@ -29,6 +32,7 @@ + "crypto/internal/randutil" + "crypto/rand" + "crypto/subtle" ++ "encoding/binary" + "errors" + "hash" + "io" +@@ -36,7 +40,6 @@ + "math/big" + ) + +-var bigZero = big.NewInt(0) + var bigOne = big.NewInt(1) + + // A PublicKey represents the public part of an RSA key. +@@ -51,7 +54,7 @@ + // Size returns the modulus size in bytes. Raw signatures and ciphertexts + // for or by this public key will have the same size. + func (pub *PublicKey) Size() int { +- return (pub.N.BitLen() + 7) / 8 ++ return (bigBitLen(pub.N) + 7) / 8 + } + + // Equal reports whether pub and x have the same value. +@@ -429,11 +432,20 @@ + // too large for the size of the public key. + var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA public key size") + +-func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int { ++func encrypt(pub *PublicKey, plaintext []byte) []byte { + boring.Unreachable() +- e := big.NewInt(int64(pub.E)) +- c.Exp(m, e, pub.N) +- return c ++ ++ N := modulusFromNat(natFromBig(pub.N)) ++ m := natFromBytes(plaintext).expandFor(N) ++ ++ e := make([]byte, 8) ++ binary.BigEndian.PutUint64(e, uint64(pub.E)) ++ for len(e) > 1 && e[0] == 0 { ++ e = e[1:] ++ } ++ ++ out := make([]byte, modulusSize(N)) ++ return new(nat).exp(m, e, N).fillBytes(out) + } + + // EncryptOAEP encrypts the given message with RSA-OAEP. +@@ -501,12 +513,7 @@ + return boring.EncryptRSANoPadding(bkey, em) + } + +- m := new(big.Int) +- m.SetBytes(em) +- c := encrypt(new(big.Int), pub, m) +- +- out := make([]byte, k) +- return c.FillBytes(out), nil ++ return encrypt(pub, em), nil + } + + // ErrDecryption represents a failure to decrypt a message. +@@ -548,101 +555,74 @@ + } + } + +-// decrypt performs an RSA decryption, resulting in a plaintext integer. If a +-// random source is given, RSA blinding is used. +-func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) { ++// decrypt performs an RSA decryption of ciphertext into out. ++func decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { + if len(priv.Primes) <= 2 { + boring.Unreachable() + } +- // TODO(agl): can we get away with reusing blinds? +- if c.Cmp(priv.N) > 0 { +- err = ErrDecryption +- return ++ ++ N := modulusFromNat(natFromBig(priv.N)) ++ c := natFromBytes(ciphertext).expandFor(N) ++ if c.cmpGeq(N.nat) == 1 { ++ return nil, ErrDecryption + } + if priv.N.Sign() == 0 { + return nil, ErrDecryption + } + +- var ir *big.Int +- if random != nil { +- randutil.MaybeReadByte(random) +- +- // Blinding enabled. Blinding involves multiplying c by r^e. +- // Then the decryption operation performs (m^e * r^e)^d mod n +- // which equals mr mod n. The factor of r can then be removed +- // by multiplying by the multiplicative inverse of r. +- +- var r *big.Int +- ir = new(big.Int) +- for { +- r, err = rand.Int(random, priv.N) +- if err != nil { +- return +- } +- if r.Cmp(bigZero) == 0 { +- r = bigOne +- } +- ok := ir.ModInverse(r, priv.N) +- if ok != nil { +- break +- } +- } +- bigE := big.NewInt(int64(priv.E)) +- rpowe := new(big.Int).Exp(r, bigE, priv.N) // N != 0 +- cCopy := new(big.Int).Set(c) +- cCopy.Mul(cCopy, rpowe) +- cCopy.Mod(cCopy, priv.N) +- c = cCopy +- } +- ++ // Note that because our private decryption exponents are stored as big.Int, ++ // we potentially leak the exact number of bits of these exponents. This ++ // isn't great, but should be fine. + if priv.Precomputed.Dp == nil { +- m = new(big.Int).Exp(c, priv.D, priv.N) +- } else { +- // We have the precalculated values needed for the CRT. +- m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0]) +- m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1]) +- m.Sub(m, m2) +- if m.Sign() < 0 { +- m.Add(m, priv.Primes[0]) +- } +- m.Mul(m, priv.Precomputed.Qinv) +- m.Mod(m, priv.Primes[0]) +- m.Mul(m, priv.Primes[1]) +- m.Add(m, m2) +- +- for i, values := range priv.Precomputed.CRTValues { +- prime := priv.Primes[2+i] +- m2.Exp(c, values.Exp, prime) +- m2.Sub(m2, m) +- m2.Mul(m2, values.Coeff) +- m2.Mod(m2, prime) +- if m2.Sign() < 0 { +- m2.Add(m2, prime) +- } +- m2.Mul(m2, values.R) +- m.Add(m, m2) +- } ++ out := make([]byte, modulusSize(N)) ++ return new(nat).exp(c, priv.D.Bytes(), N).fillBytes(out), nil + } + +- if ir != nil { +- // Unblind. +- m.Mul(m, ir) +- m.Mod(m, priv.N) ++ t0 := new(nat) ++ P := modulusFromNat(natFromBig(priv.Primes[0])) ++ Q := modulusFromNat(natFromBig(priv.Primes[1])) ++ // m = c ^ Dp mod p ++ m := new(nat).exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P) ++ // m2 = c ^ Dq mod q ++ m2 := new(nat).exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q) ++ // m = m - m2 mod p ++ m.modSub(t0.mod(m2, P), P) ++ // m = m * Qinv mod p ++ m.modMul(natFromBig(priv.Precomputed.Qinv).expandFor(P), P) ++ // m = m * q mod N ++ m.expandFor(N).modMul(t0.mod(Q.nat, N), N) ++ // m = m + m2 mod N ++ m.modAdd(m2.expandFor(N), N) ++ ++ for i, values := range priv.Precomputed.CRTValues { ++ p := modulusFromNat(natFromBig(priv.Primes[2+i])) ++ // m2 = c ^ Exp mod p ++ m2.exp(t0.mod(c, p), values.Exp.Bytes(), p) ++ // m2 = m2 - m mod p ++ m2.modSub(t0.mod(m, p), p) ++ // m2 = m2 * Coeff mod p ++ m2.modMul(natFromBig(values.Coeff).expandFor(p), p) ++ // m2 = m2 * R mod N ++ R := natFromBig(values.R).expandFor(N) ++ m2.expandFor(N).modMul(R, N) ++ // m = m + m2 mod N ++ m.modAdd(m2, N) + } + +- return ++ out := make([]byte, modulusSize(N)) ++ return m.fillBytes(out), nil + } + +-func decryptAndCheck(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) { +- m, err = decrypt(random, priv, c) ++func decryptAndCheck(priv *PrivateKey, ciphertext []byte) (m []byte, err error) { ++ m, err = decrypt(priv, ciphertext) + if err != nil { + return nil, err + } + + // In order to defend against errors in the CRT computation, m^e is + // calculated, which should match the original ciphertext. +- check := encrypt(new(big.Int), &priv.PublicKey, m) +- if c.Cmp(check) != 0 { ++ check := encrypt(&priv.PublicKey, m) ++ if subtle.ConstantTimeCompare(ciphertext, check) != 1 { + return nil, errors.New("rsa: internal error") + } + return m, nil +@@ -654,9 +634,7 @@ + // Encryption and decryption of a given message must use the same hash function + // and sha256.New() is a reasonable choice. + // +-// The random parameter, if not nil, is used to blind the private-key operation +-// and avoid timing side-channel attacks. Blinding is purely internal to this +-// function – the random data need not match that used when encrypting. ++// The random parameter is legacy and ignored, and it can be as nil. + // + // The label parameter must match the value given when encrypting. See + // EncryptOAEP for details. +@@ -685,9 +663,8 @@ + } + return out, nil + } +- c := new(big.Int).SetBytes(ciphertext) + +- m, err := decrypt(random, priv, c) ++ em, err := decrypt(priv, ciphertext) + if err != nil { + return nil, err + } +@@ -696,10 +673,6 @@ + lHash := hash.Sum(nil) + hash.Reset() + +- // We probably leak the number of leading zeros. +- // It's not clear that we can do anything about this. +- em := m.FillBytes(make([]byte, k)) +- + firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0) + + seed := em[1 : hash.Size()+1] +diff --git a/src/crypto/rsa/rsa_test.go b/src/crypto/rsa/rsa_test.go +index d94a234d26..1ad5186d57 100644 +--- a/src/crypto/rsa/rsa_test.go ++++ b/src/crypto/rsa/rsa_test.go +@@ -39,7 +39,7 @@ func TestKeyGeneration(t *testing.T) { + } + if boring.Enabled() && size < 1024 { + t.Logf("skipping short key with BoringCrypto: %d", size) +- continue; ++ continue + } + testKeyBasics(t, priv) + if testing.Short() { +@@ -148,7 +148,7 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { + // longer than 2048 bits). + if bits := priv.N.BitLen(); bits < 2048 { + t.Logf("skipping short key with BoringCrypto: %d", bits) +- return; ++ return + } + sha256 := sha256.New() + msg := []byte("hi!") +@@ -169,25 +169,20 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { + } + + pub := &priv.PublicKey ++ n := modulusSize(modulusFromNat(natFromBig(pub.N))) ++ mb := make([]byte, n) + m := big.NewInt(42) +- c := encrypt(new(big.Int), pub, m) ++ m.FillBytes(mb) ++ c := encrypt(pub, m.Bytes()) + +- m2, err := decrypt(nil, priv, c) ++ m2, err := decrypt(priv, c) + if err != nil { + t.Errorf("error while decrypting: %s", err) + return + } +- if m.Cmp(m2) != 0 { ++ if bytes.Compare(mb, m2) != 0 { + t.Errorf("got:%v, want:%v (%+v)", m2, m, priv) + } +- +- m3, err := decrypt(rand.Reader, priv, c) +- if err != nil { +- t.Errorf("error while decrypting (blind): %s", err) +- } +- if m.Cmp(m3) != 0 { +- t.Errorf("(blind) got:%v, want:%v (%#v)", m3, m, priv) +- } + } + + func fromBase10(base10 string) *big.Int { +@@ -227,7 +222,7 @@ func BenchmarkRSA2048Decrypt(b *testing.B) { + b.StartTimer() + + for i := 0; i < b.N; i++ { +- decrypt(nil, test2048Key, c) ++ decrypt(test2048Key, c.Bytes()) + } + } + +@@ -266,7 +261,7 @@ func Benchmark3PrimeRSA2048Decrypt(b *testing.B) { + b.StartTimer() + + for i := 0; i < b.N; i++ { +- decrypt(nil, priv, c) ++ decrypt(priv, c.Bytes()) + } + } +