diff --git a/src/crypto/rsa/example_test.go b/src/crypto/rsa/example_test.go index ce5c2d91cd..aa3a71b248 100644 --- a/src/crypto/rsa/example_test.go +++ b/src/crypto/rsa/example_test.go @@ -12,7 +12,6 @@ import ( "crypto/sha256" "encoding/hex" "fmt" - "io" "os" ) @@ -36,21 +35,17 @@ import ( // a buffer that contains a random key. Thus, if the RSA result isn't // well-formed, the implementation uses a random key in constant time. func ExampleDecryptPKCS1v15SessionKey() { - // crypto/rand.Reader is a good source of entropy for blinding the RSA - // operation. - rng := rand.Reader - // The hybrid scheme should use at least a 16-byte symmetric key. Here // we read the random key that will be used if the RSA decryption isn't // well-formed. key := make([]byte, 32) - if _, err := io.ReadFull(rng, key); err != nil { + if _, err := rand.Read(key); err != nil { panic("RNG failure") } rsaCiphertext, _ := hex.DecodeString("aabbccddeeff") - if err := DecryptPKCS1v15SessionKey(rng, rsaPrivateKey, rsaCiphertext, key); err != nil { + if err := DecryptPKCS1v15SessionKey(nil, rsaPrivateKey, rsaCiphertext, key); err != nil { // Any errors that result will be “public” – meaning that they // can be determined without any secret information. (For // instance, if the length of key is impossible given the RSA @@ -86,10 +81,6 @@ func ExampleDecryptPKCS1v15SessionKey() { } func ExampleSignPKCS1v15() { - // crypto/rand.Reader is a good source of entropy for blinding the RSA - // operation. - rng := rand.Reader - message := []byte("message to be signed") // Only small messages can be signed directly; thus the hash of a @@ -99,7 +90,7 @@ func ExampleSignPKCS1v15() { // of writing (2016). hashed := sha256.Sum256(message) - signature, err := SignPKCS1v15(rng, rsaPrivateKey, crypto.SHA256, hashed[:]) + signature, err := SignPKCS1v15(nil, rsaPrivateKey, crypto.SHA256, hashed[:]) if err != nil { fmt.Fprintf(os.Stderr, "Error from signing: %s\n", err) return @@ -151,11 +142,7 @@ func ExampleDecryptOAEP() { ciphertext, _ := hex.DecodeString("4d1ee10e8f286390258c51a5e80802844c3e6358ad6690b7285218a7c7ed7fc3a4c7b950fbd04d4b0239cc060dcc7065ca6f84c1756deb71ca5685cadbb82be025e16449b905c568a19c088a1abfad54bf7ecc67a7df39943ec511091a34c0f2348d04e058fcff4d55644de3cd1d580791d4524b92f3e91695582e6e340a1c50b6c6d78e80b4e42c5b4d45e479b492de42bbd39cc642ebb80226bb5200020d501b24a37bcc2ec7f34e596b4fd6b063de4858dbf5a4e3dd18e262eda0ec2d19dbd8e890d672b63d368768360b20c0b6b8592a438fa275e5fa7f60bef0dd39673fd3989cc54d2cb80c08fcd19dacbc265ee1c6014616b0e04ea0328c2a04e73460") label := []byte("orders") - // crypto/rand.Reader is a good source of entropy for blinding the RSA - // operation. - rng := rand.Reader - - plaintext, err := DecryptOAEP(sha256.New(), rng, test2048Key, ciphertext, label) + plaintext, err := DecryptOAEP(sha256.New(), nil, test2048Key, ciphertext, label) if err != nil { fmt.Fprintf(os.Stderr, "Error from decryption: %s\n", err) return diff --git a/src/crypto/rsa/nat.go b/src/crypto/rsa/nat.go new file mode 100644 index 0000000000..da521c22f3 --- /dev/null +++ b/src/crypto/rsa/nat.go @@ -0,0 +1,626 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rsa + +import ( + "math/big" + "math/bits" +) + +const ( + // _W is the number of bits we use for our limbs. + _W = bits.UintSize - 1 + // _MASK selects _W bits from a full machine word. + _MASK = (1 << _W) - 1 +) + +// choice represents a constant-time boolean. The value of choice is always +// either 1 or 0. We use an int instead of bool in order to make decisions in +// constant time by turning it into a mask. +type choice uint + +func not(c choice) choice { return 1 ^ c } + +const yes = choice(1) +const no = choice(0) + +// ctSelect returns x if on == 1, and y if on == 0. The execution time of this +// function does not depend on its inputs. If on is any value besides 1 or 0, +// the result is undefined. +func ctSelect(on choice, x, y uint) uint { + // When on == 1, mask is 0b111..., otherwise mask is 0b000... + mask := -uint(on) + // When mask is all zeros, we just have y, otherwise, y cancels with itself. + return y ^ (mask & (y ^ x)) +} + +// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this +// function does not depend on its inputs. +func ctEq(x, y uint) choice { + // If x != y, then either x - y or y - x will generate a carry. + _, c1 := bits.Sub(x, y, 0) + _, c2 := bits.Sub(y, x, 0) + return not(choice(c1 | c2)) +} + +// ctGeq returns 1 if x >= y, and 0 otherwise. The execution time of this +// function does not depend on its inputs. +func ctGeq(x, y uint) choice { + // If x < y, then x - y generates a carry. + _, carry := bits.Sub(x, y, 0) + return not(choice(carry)) +} + +// nat represents an arbitrary natural number +// +// Each nat has an announced length, which is the number of limbs it has stored. +// Operations on this number are allowed to leak this length, but will not leak +// any information about the values contained in those limbs. +type nat struct { + // limbs is a little-endian representation in base 2^W with + // W = bits.UintSize - 1. The top bit is always unset between operations. + // + // The top bit is left unset to optimize Montgomery multiplication, in the + // inner loop of exponentiation. Using fully saturated limbs would leave us + // working with 129-bit numbers on 64-bit platforms, wasting a lot of space, + // and thus time. + limbs []uint +} + +// expand expands x to n limbs, leaving its value unchanged. +func (x *nat) expand(n int) *nat { + for len(x.limbs) > n { + if x.limbs[len(x.limbs)-1] != 0 { + panic("rsa: internal error: shrinking nat") + } + x.limbs = x.limbs[:len(x.limbs)-1] + } + if cap(x.limbs) < n { + newLimbs := make([]uint, n) + copy(newLimbs, x.limbs) + x.limbs = newLimbs + return x + } + extraLimbs := x.limbs[len(x.limbs):n] + for i := range extraLimbs { + extraLimbs[i] = 0 + } + x.limbs = x.limbs[:n] + return x +} + +// reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs). +func (x *nat) reset(n int) *nat { + if cap(x.limbs) < n { + x.limbs = make([]uint, n) + return x + } + for i := range x.limbs { + x.limbs[i] = 0 + } + x.limbs = x.limbs[:n] + return x +} + +// clone returns a new nat, with the same value and announced length as x. +func (x *nat) clone() *nat { + out := &nat{make([]uint, len(x.limbs))} + copy(out.limbs, x.limbs) + return out +} + +// natFromBig creates a new natural number from a big.Int. +// +// The announced length of the resulting nat is based on the actual bit size of +// the input, ignoring leading zeroes. +func natFromBig(x *big.Int) *nat { + xLimbs := x.Bits() + bitSize := bigBitLen(x) + requiredLimbs := (bitSize + _W - 1) / _W + + out := &nat{make([]uint, requiredLimbs)} + outI := 0 + shift := 0 + for i := range xLimbs { + xi := uint(xLimbs[i]) + out.limbs[outI] |= (xi << shift) & _MASK + outI++ + if outI == requiredLimbs { + return out + } + out.limbs[outI] = xi >> (_W - shift) + shift++ // this assumes bits.UintSize - _W = 1 + if shift == _W { + shift = 0 + outI++ + } + } + return out +} + +// fillBytes sets bytes to x as a zero-extended big-endian byte slice. +// +// If bytes is not long enough to contain the number or at least len(x.limbs)-1 +// limbs, or has zero length, fillBytes will panic. +func (x *nat) fillBytes(bytes []byte) []byte { + if len(bytes) == 0 { + panic("nat: fillBytes invoked with too small buffer") + } + for i := range bytes { + bytes[i] = 0 + } + shift := 0 + outI := len(bytes) - 1 + for i, limb := range x.limbs { + remainingBits := _W + for remainingBits >= 8 { + bytes[outI] |= byte(limb) << shift + consumed := 8 - shift + limb >>= consumed + remainingBits -= consumed + shift = 0 + outI-- + if outI < 0 { + if limb != 0 || i < len(x.limbs)-1 { + panic("nat: fillBytes invoked with too small buffer") + } + return bytes + } + } + bytes[outI] = byte(limb) + shift = remainingBits + } + return bytes +} + +// natFromBytes converts a slice of big-endian bytes into a nat. +// +// The announced length of the output depends on the length of bytes. Unlike +// big.Int, creating a nat will not remove leading zeros. +func natFromBytes(bytes []byte) *nat { + bitSize := len(bytes) * 8 + requiredLimbs := (bitSize + _W - 1) / _W + + out := &nat{make([]uint, requiredLimbs)} + outI := 0 + shift := 0 + for i := len(bytes) - 1; i >= 0; i-- { + bi := bytes[i] + out.limbs[outI] |= uint(bi) << shift + shift += 8 + if shift >= _W { + shift -= _W + out.limbs[outI] &= _MASK + outI++ + if shift > 0 { + out.limbs[outI] = uint(bi) >> (8 - shift) + } + } + } + return out +} + +// cmpEq returns 1 if x == y, and 0 otherwise. +// +// Both operands must have the same announced length. +func (x *nat) cmpEq(y *nat) choice { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + equal := yes + for i := 0; i < size; i++ { + equal &= ctEq(xLimbs[i], yLimbs[i]) + } + return equal +} + +// cmpGeq returns 1 if x >= y, and 0 otherwise. +// +// Both operands must have the same announced length. +func (x *nat) cmpGeq(y *nat) choice { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + var c uint + for i := 0; i < size; i++ { + c = (xLimbs[i] - yLimbs[i] - c) >> _W + } + // If there was a carry, then subtracting y underflowed, so + // x is not greater than or equal to y. + return not(choice(c)) +} + +// assign sets x <- y if on == 1, and does nothing otherwise. +// +// Both operands must have the same announced length. +func (x *nat) assign(on choice, y *nat) *nat { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + for i := 0; i < size; i++ { + xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i]) + } + return x +} + +// add computes x += y if on == 1, and does nothing otherwise. It returns the +// carry of the addition regardless of on. +// +// Both operands must have the same announced length. +func (x *nat) add(on choice, y *nat) (c uint) { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + for i := 0; i < size; i++ { + res := xLimbs[i] + yLimbs[i] + c + xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i]) + c = res >> _W + } + return +} + +// sub computes x -= y if on == 1, and does nothing otherwise. It returns the +// borrow of the subtraction regardless of on. +// +// Both operands must have the same announced length. +func (x *nat) sub(on choice, y *nat) (c uint) { + // Eliminate bounds checks in the loop. + size := len(x.limbs) + xLimbs := x.limbs[:size] + yLimbs := y.limbs[:size] + + for i := 0; i < size; i++ { + res := xLimbs[i] - yLimbs[i] - c + xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i]) + c = res >> _W + } + return +} + +// modulus is used for modular arithmetic, precomputing relevant constants. +// +// Moduli are assumed to be odd numbers. Moduli can also leak the exact +// number of bits needed to store their value, and are stored without padding. +// +// Their actual value is still kept secret. +type modulus struct { + // The underlying natural number for this modulus. + // + // This will be stored without any padding, and shouldn't alias with any + // other natural number being used. + nat *nat + leading int // number of leading zeros in the modulus + m0inv uint // -nat.limbs[0]⁻¹ mod _W +} + +// minusInverseModW computes -x⁻¹ mod _W with x odd. +// +// This operation is used to precompute a constant involved in Montgomery +// multiplication. +func minusInverseModW(x uint) uint { + // Every iteration of this loop doubles the least-significant bits of + // correct inverse in y. The first three bits are already correct (1⁻¹ = 1, + // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough + // for 61 bits (and wastes only one iteration for 31 bits). + // + // See https://crypto.stackexchange.com/a/47496. + y := x + for i := 0; i < 5; i++ { + y = y * (2 - x*y) + } + return (1 << _W) - (y & _MASK) +} + +// modulusFromNat creates a new modulus from a nat. +// +// The nat should be odd, nonzero, and the number of significant bits in the +// number should be leakable. The nat shouldn't be reused. +func modulusFromNat(nat *nat) *modulus { + m := &modulus{} + m.nat = nat + size := len(m.nat.limbs) + for m.nat.limbs[size-1] == 0 { + size-- + } + m.nat.limbs = m.nat.limbs[:size] + m.leading = _W - bitLen(m.nat.limbs[size-1]) + m.m0inv = minusInverseModW(m.nat.limbs[0]) + return m +} + +// bitLen is a version of bits.Len that only leaks the bit length of n, but not +// its value. bits.Len and bits.LeadingZeros use a lookup table for the +// low-order bits on some architectures. +func bitLen(n uint) int { + var len int + // We assume, here and elsewhere, that comparison to zero is constant time + // with respect to different non-zero values. + for n != 0 { + len++ + n >>= 1 + } + return len +} + +// bigBitLen is a version of big.Int.BitLen that only leaks the bit length of x, +// but not its value. big.Int.BitLen uses bits.Len. +func bigBitLen(x *big.Int) int { + xLimbs := x.Bits() + fullLimbs := len(xLimbs) - 1 + topLimb := uint(xLimbs[len(xLimbs)-1]) + return fullLimbs*bits.UintSize + bitLen(topLimb) +} + +// modulusSize returns the size of m in bytes. +func modulusSize(m *modulus) int { + bits := len(m.nat.limbs)*_W - int(m.leading) + return (bits + 7) / 8 +} + +// shiftIn calculates x = x << _W + y mod m. +// +// This assumes that x is already reduced mod m, and that y < 2^_W. +func (x *nat) shiftIn(y uint, m *modulus) *nat { + d := new(nat).resetFor(m) + + // Eliminate bounds checks in the loop. + size := len(m.nat.limbs) + xLimbs := x.limbs[:size] + dLimbs := d.limbs[:size] + mLimbs := m.nat.limbs[:size] + + // Each iteration of this loop computes x = 2x + b mod m, where b is a bit + // from y. Effectively, it left-shifts x and adds y one bit at a time, + // reducing it every time. + // + // To do the reduction, each iteration computes both 2x + b and 2x + b - m. + // The next iteration (and finally the return line) will use either result + // based on whether the subtraction underflowed. + needSubtraction := no + for i := _W - 1; i >= 0; i-- { + carry := (y >> i) & 1 + var borrow uint + for i := 0; i < size; i++ { + l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i]) + + res := l<<1 + carry + xLimbs[i] = res & _MASK + carry = res >> _W + + res = xLimbs[i] - mLimbs[i] - borrow + dLimbs[i] = res & _MASK + borrow = res >> _W + } + // See modAdd for how carry (aka overflow), borrow (aka underflow), and + // needSubtraction relate. + needSubtraction = ctEq(carry, borrow) + } + return x.assign(needSubtraction, d) +} + +// mod calculates out = x mod m. +// +// This works regardless how large the value of x is. +// +// The output will be resized to the size of m and overwritten. +func (out *nat) mod(x *nat, m *modulus) *nat { + out.resetFor(m) + // Working our way from the most significant to the least significant limb, + // we can insert each limb at the least significant position, shifting all + // previous limbs left by _W. This way each limb will get shifted by the + // correct number of bits. We can insert at least N - 1 limbs without + // overflowing m. After that, we need to reduce every time we shift. + i := len(x.limbs) - 1 + // For the first N - 1 limbs we can skip the actual shifting and position + // them at the shifted position, which starts at min(N - 2, i). + start := len(m.nat.limbs) - 2 + if i < start { + start = i + } + for j := start; j >= 0; j-- { + out.limbs[j] = x.limbs[i] + i-- + } + // We shift in the remaining limbs, reducing modulo m each time. + for i >= 0 { + out.shiftIn(x.limbs[i], m) + i-- + } + return out +} + +// expandFor ensures out has the right size to work with operations modulo m. +// +// This assumes that out has as many or fewer limbs than m, or that the extra +// limbs are all zero (which may happen when decoding a value that has leading +// zeroes in its bytes representation that spill over the limb threshold). +func (out *nat) expandFor(m *modulus) *nat { + return out.expand(len(m.nat.limbs)) +} + +// resetFor ensures out has the right size to work with operations modulo m. +// +// out is zeroed and may start at any size. +func (out *nat) resetFor(m *modulus) *nat { + return out.reset(len(m.nat.limbs)) +} + +// modSub computes x = x - y mod m. +// +// The length of both operands must be the same as the modulus. Both operands +// must already be reduced modulo m. +func (x *nat) modSub(y *nat, m *modulus) *nat { + underflow := x.sub(yes, y) + // If the subtraction underflowed, add m. + x.add(choice(underflow), m.nat) + return x +} + +// modAdd computes x = x + y mod m. +// +// The length of both operands must be the same as the modulus. Both operands +// must already be reduced modulo m. +func (x *nat) modAdd(y *nat, m *modulus) *nat { + overflow := x.add(yes, y) + underflow := not(x.cmpGeq(m.nat)) // x < m + + // Three cases are possible: + // + // - overflow = 0, underflow = 0 + // + // In this case, addition fits in our limbs, but we can still subtract away + // m without an underflow, so we need to perform the subtraction to reduce + // our result. + // + // - overflow = 0, underflow = 1 + // + // The addition fits in our limbs, but we can't subtract m without + // underflowing. The result is already reduced. + // + // - overflow = 1, underflow = 1 + // + // The addition does not fit in our limbs, and the subtraction's borrow + // would cancel out with the addition's carry. We need to subtract m to + // reduce our result. + // + // The overflow = 1, underflow = 0 case is not possible, because y is at + // most m - 1, and if adding m - 1 overflows, then subtracting m must + // necessarily underflow. + needSubtraction := ctEq(overflow, uint(underflow)) + + x.sub(needSubtraction, m.nat) + return x +} + +// montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and +// n = len(m.nat.limbs). +// +// Faster Montgomery multiplication replaces standard modular multiplication for +// numbers in this representation. +// +// This assumes that x is already reduced mod m. +func (x *nat) montgomeryRepresentation(m *modulus) *nat { + for i := 0; i < len(m.nat.limbs); i++ { + x.shiftIn(0, m) // x = x * 2^_W mod m + } + return x +} + +// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and +// n = len(m.nat.limbs), using the Montgomery Multiplication technique. +// +// All inputs should be the same length, not aliasing d, and already +// reduced modulo m. d will be resized to the size of m and overwritten. +func (d *nat) montgomeryMul(a *nat, b *nat, m *modulus) *nat { + // See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication + // for a description of the algorithm. + + // Eliminate bounds checks in the loop. + size := len(m.nat.limbs) + aLimbs := a.limbs[:size] + bLimbs := b.limbs[:size] + dLimbs := d.resetFor(m).limbs[:size] + mLimbs := m.nat.limbs[:size] + + var overflow uint + for i := 0; i < size; i++ { + f := ((dLimbs[0] + aLimbs[i]*bLimbs[0]) * m.m0inv) & _MASK + carry := uint(0) + for j := 0; j < size; j++ { + // z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W + hi, lo := bits.Mul(aLimbs[i], bLimbs[j]) + z_lo, c := bits.Add(dLimbs[j], lo, 0) + z_hi, _ := bits.Add(0, hi, c) + hi, lo = bits.Mul(f, mLimbs[j]) + z_lo, c = bits.Add(z_lo, lo, 0) + z_hi, _ = bits.Add(z_hi, hi, c) + z_lo, c = bits.Add(z_lo, carry, 0) + z_hi, _ = bits.Add(z_hi, 0, c) + if j > 0 { + dLimbs[j-1] = z_lo & _MASK + } + carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2 + } + z := overflow + carry // z <= 2^(W+1) - 1 + dLimbs[size-1] = z & _MASK + overflow = z >> _W // overflow <= 1 + } + // See modAdd for how overflow, underflow, and needSubtraction relate. + underflow := not(d.cmpGeq(m.nat)) // d < m + needSubtraction := ctEq(overflow, uint(underflow)) + d.sub(needSubtraction, m.nat) + + return d +} + +// modMul calculates x *= y mod m. +// +// x and y must already be reduced modulo m, they must share its announced +// length, and they may not alias. +func (x *nat) modMul(y *nat, m *modulus) *nat { + // A Montgomery multiplication by a value out of the Montgomery domain + // takes the result out of Montgomery representation. + xR := x.clone().montgomeryRepresentation(m) // xR = x * R mod m + return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m +} + +// exp calculates out = x^e mod m. +// +// The exponent e is represented in big-endian order. The output will be resized +// to the size of m and overwritten. x must already be reduced modulo m. +func (out *nat) exp(x *nat, e []byte, m *modulus) *nat { + // We use a 4 bit window. For our RSA workload, 4 bit windows are faster + // than 2 bit windows, but use an extra 12 nats worth of scratch space. + // Using bit sizes that don't divide 8 are more complex to implement. + table := make([]*nat, (1<<4)-1) // table[i] = x ^ (i+1) + table[0] = x.clone().montgomeryRepresentation(m) + for i := 1; i < len(table); i++ { + table[i] = new(nat).expandFor(m) + table[i].montgomeryMul(table[i-1], table[0], m) + } + + out.resetFor(m) + out.limbs[0] = 1 + out.montgomeryRepresentation(m) + t0 := new(nat).expandFor(m) + t1 := new(nat).expandFor(m) + for _, b := range e { + for _, j := range []int{4, 0} { + // Square four times. + t1.montgomeryMul(out, out, m) + out.montgomeryMul(t1, t1, m) + t1.montgomeryMul(out, out, m) + out.montgomeryMul(t1, t1, m) + + // Select x^k in constant time from the table. + k := uint((b >> j) & 0b1111) + for i := range table { + t0.assign(ctEq(k, uint(i+1)), table[i]) + } + + // Multiply by x^k, discarding the result if k = 0. + t1.montgomeryMul(out, t0, m) + out.assign(not(ctEq(k, 0)), t1) + } + } + + // By Montgomery multiplying with 1 not in Montgomery representation, we + // convert out back from Montgomery representation, because it works out to + // dividing by R. + t0.assign(yes, out) + t1.resetFor(m) + t1.limbs[0] = 1 + out.montgomeryMul(t0, t1, m) + + return out +} diff --git a/src/crypto/rsa/nat_test.go b/src/crypto/rsa/nat_test.go new file mode 100644 index 0000000000..3e6eb10f61 --- /dev/null +++ b/src/crypto/rsa/nat_test.go @@ -0,0 +1,384 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package rsa + +import ( + "bytes" + "math/big" + "math/bits" + "math/rand" + "reflect" + "testing" + "testing/quick" +) + +// Generate generates an even nat. It's used by testing/quick to produce random +// *nat values for quick.Check invocations. +func (*nat) Generate(r *rand.Rand, size int) reflect.Value { + limbs := make([]uint, size) + for i := 0; i < size; i++ { + limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2) + } + return reflect.ValueOf(&nat{limbs}) +} + +func testModAddCommutative(a *nat, b *nat) bool { + mLimbs := make([]uint, len(a.limbs)) + for i := 0; i < len(mLimbs); i++ { + mLimbs[i] = _MASK + } + m := modulusFromNat(&nat{mLimbs}) + aPlusB := a.clone() + aPlusB.modAdd(b, m) + bPlusA := b.clone() + bPlusA.modAdd(a, m) + return aPlusB.cmpEq(bPlusA) == 1 +} + +func TestModAddCommutative(t *testing.T) { + err := quick.Check(testModAddCommutative, &quick.Config{}) + if err != nil { + t.Error(err) + } +} + +func testModSubThenAddIdentity(a *nat, b *nat) bool { + mLimbs := make([]uint, len(a.limbs)) + for i := 0; i < len(mLimbs); i++ { + mLimbs[i] = _MASK + } + m := modulusFromNat(&nat{mLimbs}) + original := a.clone() + a.modSub(b, m) + a.modAdd(b, m) + return a.cmpEq(original) == 1 +} + +func TestModSubThenAddIdentity(t *testing.T) { + err := quick.Check(testModSubThenAddIdentity, &quick.Config{}) + if err != nil { + t.Error(err) + } +} + +func testMontgomeryRoundtrip(a *nat) bool { + one := &nat{make([]uint, len(a.limbs))} + one.limbs[0] = 1 + aPlusOne := a.clone() + aPlusOne.add(1, one) + m := modulusFromNat(aPlusOne) + monty := a.clone() + monty.montgomeryRepresentation(m) + aAgain := monty.clone() + aAgain.montgomeryMul(monty, one, m) + return a.cmpEq(aAgain) == 1 +} + +func TestMontgomeryRoundtrip(t *testing.T) { + err := quick.Check(testMontgomeryRoundtrip, &quick.Config{}) + if err != nil { + t.Error(err) + } +} + +func TestFromBig(t *testing.T) { + expected := []byte{0x01, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} + theBig := new(big.Int).SetBytes(expected) + actual := natFromBig(theBig).fillBytes(make([]byte, len(expected))) + if !bytes.Equal(actual, expected) { + t.Errorf("%+x != %+x", actual, expected) + } +} + +func TestFillBytes(t *testing.T) { + xBytes := []byte{0xAA, 0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88} + x := natFromBytes(xBytes) + for l := 20; l >= len(xBytes); l-- { + buf := make([]byte, l) + rand.Read(buf) + actual := x.fillBytes(buf) + expected := make([]byte, l) + copy(expected[l-len(xBytes):], xBytes) + if !bytes.Equal(actual, expected) { + t.Errorf("%d: %+v != %+v", l, actual, expected) + } + } + for l := len(xBytes) - 1; l >= 0; l-- { + (func() { + defer func() { + if recover() == nil { + t.Errorf("%d: expected panic", l) + } + }() + x.fillBytes(make([]byte, l)) + })() + } +} + +func TestFromBytes(t *testing.T) { + f := func(xBytes []byte) bool { + if len(xBytes) == 0 { + return true + } + actual := natFromBytes(xBytes).fillBytes(make([]byte, len(xBytes))) + if !bytes.Equal(actual, xBytes) { + t.Errorf("%+x != %+x", actual, xBytes) + return false + } + return true + } + + err := quick.Check(f, &quick.Config{}) + if err != nil { + t.Error(err) + } + + f([]byte{0xFF, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}) + f(bytes.Repeat([]byte{0xFF}, _W)) +} + +func TestShiftIn(t *testing.T) { + if bits.UintSize != 64 { + t.Skip("examples are only valid in 64 bit") + } + examples := []struct { + m, x, expected []byte + y uint64 + }{{ + m: []byte{13}, + x: []byte{0}, + y: 0x7FFF_FFFF_FFFF_FFFF, + expected: []byte{7}, + }, { + m: []byte{13}, + x: []byte{7}, + y: 0x7FFF_FFFF_FFFF_FFFF, + expected: []byte{11}, + }, { + m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, + x: make([]byte, 9), + y: 0x7FFF_FFFF_FFFF_FFFF, + expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + }, { + m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d}, + x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, + y: 0, + expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08}, + }} + + for i, tt := range examples { + m := modulusFromNat(natFromBytes(tt.m)) + got := natFromBytes(tt.x).expandFor(m).shiftIn(uint(tt.y), m) + if got.cmpEq(natFromBytes(tt.expected).expandFor(m)) != 1 { + t.Errorf("%d: got %x, expected %x", i, got, tt.expected) + } + } +} + +func TestModulusAndNatSizes(t *testing.T) { + // These are 126 bit (2 * _W on 64-bit architectures) values, serialized as + // 128 bits worth of bytes. If leading zeroes are stripped, they fit in two + // limbs, if they are not, they fit in three. This can be a problem because + // modulus strips leading zeroes and nat does not. + m := modulusFromNat(natFromBytes([]byte{ + 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})) + x := natFromBytes([]byte{ + 0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}) + x.expandFor(m) // must not panic for shrinking +} + +func TestExpand(t *testing.T) { + sliced := []uint{1, 2, 3, 4} + examples := []struct { + in []uint + n int + out []uint + }{{ + []uint{1, 2}, + 4, + []uint{1, 2, 0, 0}, + }, { + sliced[:2], + 4, + []uint{1, 2, 0, 0}, + }, { + []uint{1, 2}, + 2, + []uint{1, 2}, + }, { + []uint{1, 2, 0}, + 2, + []uint{1, 2}, + }} + + for i, tt := range examples { + got := (&nat{tt.in}).expand(tt.n) + if len(got.limbs) != len(tt.out) || got.cmpEq(&nat{tt.out}) != 1 { + t.Errorf("%d: got %x, expected %x", i, got, tt.out) + } + } +} + +func TestMod(t *testing.T) { + m := modulusFromNat(natFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})) + x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}) + out := new(nat) + out.mod(x, m) + expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09}) + if out.cmpEq(expected) != 1 { + t.Errorf("%+v != %+v", out, expected) + } +} + +func TestModSub(t *testing.T) { + m := modulusFromNat(&nat{[]uint{13}}) + x := &nat{[]uint{6}} + y := &nat{[]uint{7}} + x.modSub(y, m) + expected := &nat{[]uint{12}} + if x.cmpEq(expected) != 1 { + t.Errorf("%+v != %+v", x, expected) + } + x.modSub(y, m) + expected = &nat{[]uint{5}} + if x.cmpEq(expected) != 1 { + t.Errorf("%+v != %+v", x, expected) + } +} + +func TestModAdd(t *testing.T) { + m := modulusFromNat(&nat{[]uint{13}}) + x := &nat{[]uint{6}} + y := &nat{[]uint{7}} + x.modAdd(y, m) + expected := &nat{[]uint{0}} + if x.cmpEq(expected) != 1 { + t.Errorf("%+v != %+v", x, expected) + } + x.modAdd(y, m) + expected = &nat{[]uint{7}} + if x.cmpEq(expected) != 1 { + t.Errorf("%+v != %+v", x, expected) + } +} + +func TestExp(t *testing.T) { + m := modulusFromNat(&nat{[]uint{13}}) + x := &nat{[]uint{3}} + out := &nat{[]uint{0}} + out.exp(x, []byte{12}, m) + expected := &nat{[]uint{1}} + if out.cmpEq(expected) != 1 { + t.Errorf("%+v != %+v", out, expected) + } +} + +func makeBenchmarkModulus() *modulus { + m := make([]uint, 32) + for i := 0; i < 32; i++ { + m[i] = _MASK + } + return modulusFromNat(&nat{limbs: m}) +} + +func makeBenchmarkValue() *nat { + x := make([]uint, 32) + for i := 0; i < 32; i++ { + x[i] = _MASK - 1 + } + return &nat{limbs: x} +} + +func makeBenchmarkExponent() []byte { + e := make([]byte, 256) + for i := 0; i < 32; i++ { + e[i] = 0xFF + } + return e +} + +func BenchmarkModAdd(b *testing.B) { + x := makeBenchmarkValue() + y := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.modAdd(y, m) + } +} + +func BenchmarkModSub(b *testing.B) { + x := makeBenchmarkValue() + y := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.modSub(y, m) + } +} + +func BenchmarkMontgomeryRepr(b *testing.B) { + x := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.montgomeryRepresentation(m) + } +} + +func BenchmarkMontgomeryMul(b *testing.B) { + x := makeBenchmarkValue() + y := makeBenchmarkValue() + out := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out.montgomeryMul(x, y, m) + } +} + +func BenchmarkModMul(b *testing.B) { + x := makeBenchmarkValue() + y := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.modMul(y, m) + } +} + +func BenchmarkExpBig(b *testing.B) { + out := new(big.Int) + exponentBytes := makeBenchmarkExponent() + x := new(big.Int).SetBytes(exponentBytes) + e := new(big.Int).SetBytes(exponentBytes) + n := new(big.Int).SetBytes(exponentBytes) + one := new(big.Int).SetUint64(1) + n.Add(n, one) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out.Exp(x, e, n) + } +} + +func BenchmarkExp(b *testing.B) { + x := makeBenchmarkValue() + e := makeBenchmarkExponent() + out := makeBenchmarkValue() + m := makeBenchmarkModulus() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + out.exp(x, e, m) + } +} diff --git a/src/crypto/rsa/pkcs1v15.go b/src/crypto/rsa/pkcs1v15.go index ceb32d0b0d..f8f22013b2 100644 --- a/src/crypto/rsa/pkcs1v15.go +++ b/src/crypto/rsa/pkcs1v15.go @@ -9,7 +9,6 @@ import ( "crypto/subtle" "errors" "io" - "math/big" "crypto/internal/boring" "crypto/internal/randutil" @@ -77,13 +76,11 @@ func EncryptPKCS1v15(random io.Reader, pub *PublicKey, msg []byte) ([]byte, erro return boring.EncryptRSANoPadding(bkey, em) } - m := new(big.Int).SetBytes(em) - c := encrypt(new(big.Int), pub, m) - return c.FillBytes(em), nil + return encrypt(pub, em), nil } // DecryptPKCS1v15 decrypts a plaintext using RSA and the padding scheme from PKCS #1 v1.5. -// If rand != nil, it uses RSA blinding to avoid timing side-channel attacks. +// The random parameter is legacy and ignored, and it can be as nil. // // Note that whether this function returns an error or not discloses secret // information. If an attacker can cause this function to run repeatedly and @@ -107,7 +104,7 @@ func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) ([]byt return out, nil } - valid, out, index, err := decryptPKCS1v15(rand, priv, ciphertext) + valid, out, index, err := decryptPKCS1v15(priv, ciphertext) if err != nil { return nil, err } @@ -118,7 +115,7 @@ func DecryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) ([]byt } // DecryptPKCS1v15SessionKey decrypts a session key using RSA and the padding scheme from PKCS #1 v1.5. -// If rand != nil, it uses RSA blinding to avoid timing side-channel attacks. +// The random parameter is legacy and ignored, and it can be as nil. // It returns an error if the ciphertext is the wrong length or if the // ciphertext is greater than the public modulus. Otherwise, no error is // returned. If the padding is valid, the resulting plaintext message is copied @@ -145,7 +142,7 @@ func DecryptPKCS1v15SessionKey(rand io.Reader, priv *PrivateKey, ciphertext []by return ErrDecryption } - valid, em, index, err := decryptPKCS1v15(rand, priv, ciphertext) + valid, em, index, err := decryptPKCS1v15(priv, ciphertext) if err != nil { return err } @@ -161,13 +158,13 @@ func DecryptPKCS1v15SessionKey(rand io.Reader, priv *PrivateKey, ciphertext []by return nil } -// decryptPKCS1v15 decrypts ciphertext using priv and blinds the operation if -// rand is not nil. It returns one or zero in valid that indicates whether the -// plaintext was correctly structured. In either case, the plaintext is -// returned in em so that it may be read independently of whether it was valid -// in order to maintain constant memory access patterns. If the plaintext was -// valid then index contains the index of the original message in em. -func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid int, em []byte, index int, err error) { +// decryptPKCS1v15 decrypts ciphertext using priv. It returns one or zero in +// valid that indicates whether the plaintext was correctly structured. +// In either case, the plaintext is returned in em so that it may be read +// independently of whether it was valid in order to maintain constant memory +// access patterns. If the plaintext was valid then index contains the index of +// the original message in em, to allow constant time padding removal. +func decryptPKCS1v15(priv *PrivateKey, ciphertext []byte) (valid int, em []byte, index int, err error) { k := priv.Size() if k < 11 { err = ErrDecryption @@ -185,13 +182,10 @@ func decryptPKCS1v15(rand io.Reader, priv *PrivateKey, ciphertext []byte) (valid return } } else { - c := new(big.Int).SetBytes(ciphertext) - var m *big.Int - m, err = decrypt(rand, priv, c) + em, err = decrypt(priv, ciphertext) if err != nil { return } - em = m.FillBytes(make([]byte, k)) } firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0) @@ -265,8 +259,7 @@ var hashPrefixes = map[crypto.Hash][]byte{ // function. If hash is zero, hashed is signed directly. This isn't // advisable except for interoperability. // -// If rand is not nil then RSA blinding will be used to avoid timing -// side-channel attacks. +// The random parameter is legacy and ignored, and it can be as nil. // // This function is deterministic. Thus, if the set of possible // messages is small, an attacker may be able to build a map from @@ -301,13 +294,7 @@ func SignPKCS1v15(random io.Reader, priv *PrivateKey, hash crypto.Hash, hashed [ copy(em[k-tLen:k-hashLen], prefix) copy(em[k-hashLen:k], hashed) - m := new(big.Int).SetBytes(em) - c, err := decryptAndCheck(random, priv, m) - if err != nil { - return nil, err - } - - return c.FillBytes(em), nil + return decryptAndCheck(priv, em) } // VerifyPKCS1v15 verifies an RSA PKCS #1 v1.5 signature. @@ -345,9 +332,7 @@ func VerifyPKCS1v15(pub *PublicKey, hash crypto.Hash, hashed []byte, sig []byte) return ErrVerification } - c := new(big.Int).SetBytes(sig) - m := encrypt(new(big.Int), pub, c) - em := m.FillBytes(make([]byte, k)) + em := encrypt(pub, sig) // EM = 0x00 || 0x01 || PS || 0x00 || T ok := subtle.ConstantTimeByteEq(em[0], 0) diff --git a/src/crypto/rsa/pss.go b/src/crypto/rsa/pss.go index c59915de2d..aa2831dc7f 100644 --- a/src/crypto/rsa/pss.go +++ b/src/crypto/rsa/pss.go @@ -13,7 +13,6 @@ import ( "errors" "hash" "io" - "math/big" ) // Per RFC 8017, Section 9.1 @@ -208,8 +207,8 @@ func emsaPSSVerify(mHash, em []byte, emBits, sLen int, hash hash.Hash) error { // Note that hashed must be the result of hashing the input message using the // given hash function. salt is a random sequence of bytes whose length will be // later used to verify the signature. -func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { - emBits := priv.N.BitLen() - 1 +func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([]byte, error) { + emBits := bigBitLen(priv.N) - 1 em, err := emsaPSSEncode(hashed, emBits, salt, hash.New()) if err != nil { return nil, err @@ -229,13 +228,20 @@ func signPSSWithSalt(rand io.Reader, priv *PrivateKey, hash crypto.Hash, hashed, return s, nil } - m := new(big.Int).SetBytes(em) - c, err := decryptAndCheck(rand, priv, m) - if err != nil { - return nil, err + // RFC 8017: "Note that the octet length of EM will be one less than k if + // modBits - 1 is divisible by 8 and equal to k otherwise, where k is the + // length in octets of the RSA modulus n." 🙄 + // + // This is extremely annoying, as all other encrypt and decrypt inputs are + // always the exact same size as the modulus. Since it only happens for + // weird modulus sizes, fix it by padding inefficiently. + if emLen, k := len(em), priv.Size(); emLen < k { + emNew := make([]byte, k) + copy(emNew[k-emLen:], em) + em = emNew } - s := make([]byte, priv.Size()) - return c.FillBytes(s), nil + + return decryptAndCheck(priv, em) } const ( @@ -285,7 +291,7 @@ func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, saltLength := opts.saltLength() switch saltLength { case PSSSaltLengthAuto: - saltLength = (priv.N.BitLen()-1+7)/8 - 2 - hash.Size() + saltLength = (bigBitLen(priv.N)-1+7)/8 - 2 - hash.Size() case PSSSaltLengthEqualsHash: saltLength = hash.Size() } @@ -306,7 +312,7 @@ func SignPSS(rand io.Reader, priv *PrivateKey, hash crypto.Hash, digest []byte, if _, err := io.ReadFull(rand, salt); err != nil { return nil, err } - return signPSSWithSalt(rand, priv, hash, digest, salt) + return signPSSWithSalt(priv, hash, digest, salt) } // VerifyPSS verifies a PSS signature. @@ -330,13 +336,22 @@ func VerifyPSS(pub *PublicKey, hash crypto.Hash, digest []byte, sig []byte, opts if len(sig) != pub.Size() { return ErrVerification } - s := new(big.Int).SetBytes(sig) - m := encrypt(new(big.Int), pub, s) - emBits := pub.N.BitLen() - 1 + + emBits := bigBitLen(pub.N) - 1 emLen := (emBits + 7) / 8 - if m.BitLen() > emLen*8 { - return ErrVerification + em := encrypt(pub, sig) + + // Like in signPSSWithSalt, deal with mismatches between emLen and the size + // of the modulus. The spec would have us wire emLen into the encoding + // function, but we'd rather always encode to the size of the modulus and + // then strip leading zeroes if necessary. This only happens for weird + // modulus sizes anyway. + for len(em) > emLen && len(em) > 0 { + if em[0] != 0 { + return ErrVerification + } + em = em[1:] } - em := m.FillBytes(make([]byte, emLen)) + return emsaPSSVerify(digest, em, emBits, opts.saltLength(), hash.New()) } diff --git a/src/crypto/rsa/pss_test.go b/src/crypto/rsa/pss_test.go index b547a87c71..6909f32847 100644 --- a/src/crypto/rsa/pss_test.go +++ b/src/crypto/rsa/pss_test.go @@ -231,7 +231,10 @@ func TestPSSSigning(t *testing.T) { } } -func TestSignWithPSSSaltLengthAuto(t *testing.T) { +func TestPSS513(t *testing.T) { + // See Issue 42741, and separately, RFC 8017: "Note that the octet length of + // EM will be one less than k if modBits - 1 is divisible by 8 and equal to + // k otherwise, where k is the length in octets of the RSA modulus n." if boring.Enabled() { t.Skip("skipping in boring mode: invalid key length") } @@ -247,8 +250,9 @@ func TestSignWithPSSSaltLengthAuto(t *testing.T) { if err != nil { t.Fatal(err) } - if len(signature) == 0 { - t.Fatal("empty signature returned") + err = VerifyPSS(&key.PublicKey, crypto.SHA256, digest[:], signature, nil) + if err != nil { + t.Error(err) } } diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go index 4f1b792839..4b5d74b9e6 100644 --- a/src/crypto/rsa/rsa.go +++ b/src/crypto/rsa/rsa.go @@ -19,13 +19,17 @@ // over the public key primitive, the PrivateKey type implements the // Decrypter and Signer interfaces from the crypto package. // -// The RSA operations in this package are not implemented using constant-time algorithms. +// Operations in this package are implemented using constant-time algorithms, +// except for [GenerateKey], [PrivateKey.Precompute], and [PrivateKey.Validate]. +// Every other operation only leaks the bit size of the involved values, which +// all depend on the selected key size. package rsa import ( "crypto" "crypto/rand" "crypto/subtle" + "encoding/binary" "errors" "hash" "io" @@ -38,7 +42,6 @@ import ( "fmt" ) -var bigZero = big.NewInt(0) var bigOne = big.NewInt(1) // A PublicKey represents the public part of an RSA key. @@ -55,7 +58,7 @@ type PublicKey struct { // Size returns the modulus size in bytes. Raw signatures and ciphertexts // for or by this public key will have the same size. func (pub *PublicKey) Size() int { - return (pub.N.BitLen() + 7) / 8 + return (bigBitLen(pub.N) + 7) / 8 } // Equal reports whether pub and x have the same value. @@ -435,11 +438,20 @@ func mgf1XOR(out []byte, hash hash.Hash, seed []byte) { // too large for the size of the public key. var ErrMessageTooLong = errors.New("crypto/rsa: message too long for RSA public key size") -func encrypt(c *big.Int, pub *PublicKey, m *big.Int) *big.Int { +func encrypt(pub *PublicKey, plaintext []byte) []byte { boring.Unreachable() - e := big.NewInt(int64(pub.E)) - c.Exp(m, e, pub.N) - return c + + N := modulusFromNat(natFromBig(pub.N)) + m := natFromBytes(plaintext).expandFor(N) + + e := make([]byte, 8) + binary.BigEndian.PutUint64(e, uint64(pub.E)) + for len(e) > 1 && e[0] == 0 { + e = e[1:] + } + + out := make([]byte, modulusSize(N)) + return new(nat).exp(m, e, N).fillBytes(out) } // EncryptOAEP encrypts the given message with RSA-OAEP. @@ -507,12 +519,7 @@ func EncryptOAEP(hash hash.Hash, random io.Reader, pub *PublicKey, msg []byte, l return boring.EncryptRSANoPadding(bkey, em) } - m := new(big.Int) - m.SetBytes(em) - c := encrypt(new(big.Int), pub, m) - - out := make([]byte, k) - return c.FillBytes(out), nil + return encrypt(pub, em), nil } // ErrDecryption represents a failure to decrypt a message. @@ -554,101 +561,74 @@ func (priv *PrivateKey) Precompute() { } } -// decrypt performs an RSA decryption, resulting in a plaintext integer. If a -// random source is given, RSA blinding is used. -func decrypt(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) { +// decrypt performs an RSA decryption of ciphertext into out. +func decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { if len(priv.Primes) <= 2 { boring.Unreachable() } - // TODO(agl): can we get away with reusing blinds? - if c.Cmp(priv.N) > 0 { - err = ErrDecryption - return + + N := modulusFromNat(natFromBig(priv.N)) + c := natFromBytes(ciphertext).expandFor(N) + if c.cmpGeq(N.nat) == 1 { + return nil, ErrDecryption } if priv.N.Sign() == 0 { return nil, ErrDecryption } - var ir *big.Int - if random != nil { - randutil.MaybeReadByte(random) - - // Blinding enabled. Blinding involves multiplying c by r^e. - // Then the decryption operation performs (m^e * r^e)^d mod n - // which equals mr mod n. The factor of r can then be removed - // by multiplying by the multiplicative inverse of r. - - var r *big.Int - ir = new(big.Int) - for { - r, err = rand.Int(random, priv.N) - if err != nil { - return - } - if r.Cmp(bigZero) == 0 { - r = bigOne - } - ok := ir.ModInverse(r, priv.N) - if ok != nil { - break - } - } - bigE := big.NewInt(int64(priv.E)) - rpowe := new(big.Int).Exp(r, bigE, priv.N) // N != 0 - cCopy := new(big.Int).Set(c) - cCopy.Mul(cCopy, rpowe) - cCopy.Mod(cCopy, priv.N) - c = cCopy - } - + // Note that because our private decryption exponents are stored as big.Int, + // we potentially leak the exact number of bits of these exponents. This + // isn't great, but should be fine. if priv.Precomputed.Dp == nil { - m = new(big.Int).Exp(c, priv.D, priv.N) - } else { - // We have the precalculated values needed for the CRT. - m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0]) - m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1]) - m.Sub(m, m2) - if m.Sign() < 0 { - m.Add(m, priv.Primes[0]) - } - m.Mul(m, priv.Precomputed.Qinv) - m.Mod(m, priv.Primes[0]) - m.Mul(m, priv.Primes[1]) - m.Add(m, m2) - - for i, values := range priv.Precomputed.CRTValues { - prime := priv.Primes[2+i] - m2.Exp(c, values.Exp, prime) - m2.Sub(m2, m) - m2.Mul(m2, values.Coeff) - m2.Mod(m2, prime) - if m2.Sign() < 0 { - m2.Add(m2, prime) - } - m2.Mul(m2, values.R) - m.Add(m, m2) - } - } - - if ir != nil { - // Unblind. - m.Mul(m, ir) - m.Mod(m, priv.N) - } - - return + out := make([]byte, modulusSize(N)) + return new(nat).exp(c, priv.D.Bytes(), N).fillBytes(out), nil + } + + t0 := new(nat) + P := modulusFromNat(natFromBig(priv.Primes[0])) + Q := modulusFromNat(natFromBig(priv.Primes[1])) + // m = c ^ Dp mod p + m := new(nat).exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P) + // m2 = c ^ Dq mod q + m2 := new(nat).exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q) + // m = m - m2 mod p + m.modSub(t0.mod(m2, P), P) + // m = m * Qinv mod p + m.modMul(natFromBig(priv.Precomputed.Qinv).expandFor(P), P) + // m = m * q mod N + m.expandFor(N).modMul(t0.mod(Q.nat, N), N) + // m = m + m2 mod N + m.modAdd(m2.expandFor(N), N) + + for i, values := range priv.Precomputed.CRTValues { + p := modulusFromNat(natFromBig(priv.Primes[2+i])) + // m2 = c ^ Exp mod p + m2.exp(t0.mod(c, p), values.Exp.Bytes(), p) + // m2 = m2 - m mod p + m2.modSub(t0.mod(m, p), p) + // m2 = m2 * Coeff mod p + m2.modMul(natFromBig(values.Coeff).expandFor(p), p) + // m2 = m2 * R mod N + R := natFromBig(values.R).expandFor(N) + m2.expandFor(N).modMul(R, N) + // m = m + m2 mod N + m.modAdd(m2, N) + } + + out := make([]byte, modulusSize(N)) + return m.fillBytes(out), nil } -func decryptAndCheck(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int, err error) { - m, err = decrypt(random, priv, c) +func decryptAndCheck(priv *PrivateKey, ciphertext []byte) (m []byte, err error) { + m, err = decrypt(priv, ciphertext) if err != nil { return nil, err } // In order to defend against errors in the CRT computation, m^e is // calculated, which should match the original ciphertext. - check := encrypt(new(big.Int), &priv.PublicKey, m) - if c.Cmp(check) != 0 { + check := encrypt(&priv.PublicKey, m) + if subtle.ConstantTimeCompare(ciphertext, check) != 1 { return nil, errors.New("rsa: internal error") } return m, nil @@ -660,9 +640,7 @@ func decryptAndCheck(random io.Reader, priv *PrivateKey, c *big.Int) (m *big.Int // Encryption and decryption of a given message must use the same hash function // and sha256.New() is a reasonable choice. // -// The random parameter, if not nil, is used to blind the private-key operation -// and avoid timing side-channel attacks. Blinding is purely internal to this -// function – the random data need not match that used when encrypting. +// The random parameter is legacy and ignored, and it can be as nil. // // The label parameter must match the value given when encrypting. See // EncryptOAEP for details. @@ -687,9 +665,8 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext } return out, nil } - c := new(big.Int).SetBytes(ciphertext) - m, err := decrypt(random, priv, c) + em, err := decrypt(priv, ciphertext) if err != nil { return nil, err } @@ -698,10 +675,6 @@ func DecryptOAEP(hash hash.Hash, random io.Reader, priv *PrivateKey, ciphertext lHash := hash.Sum(nil) hash.Reset() - // We probably leak the number of leading zeros. - // It's not clear that we can do anything about this. - em := m.FillBytes(make([]byte, k)) - firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0) seed := em[1 : hash.Size()+1] diff --git a/src/crypto/rsa/rsa_test.go b/src/crypto/rsa/rsa_test.go index 08769e420c..023eb9fa72 100644 --- a/src/crypto/rsa/rsa_test.go +++ b/src/crypto/rsa/rsa_test.go @@ -13,7 +13,6 @@ import ( "fmt" "math/big" "testing" - ) import "crypto/internal/boring" @@ -173,25 +172,20 @@ func testKeyBasics(t *testing.T, priv *PrivateKey) { } pub := &priv.PublicKey + n := modulusSize(modulusFromNat(natFromBig(pub.N))) + mb := make([]byte, n) m := big.NewInt(42) - c := encrypt(new(big.Int), pub, m) + m.FillBytes(mb) + c := encrypt(pub, m.Bytes()) - m2, err := decrypt(nil, priv, c) + m2, err := decrypt(priv, c) if err != nil { t.Errorf("error while decrypting: %s", err) return } - if m.Cmp(m2) != 0 { + if bytes.Compare(mb, m2) != 0 { t.Errorf("got:%v, want:%v (%+v)", m2, m, priv) } - - m3, err := decrypt(rand.Reader, priv, c) - if err != nil { - t.Errorf("error while decrypting (blind): %s", err) - } - if m.Cmp(m3) != 0 { - t.Errorf("(blind) got:%v, want:%v (%#v)", m3, m, priv) - } } func fromBase10(base10 string) *big.Int { @@ -220,16 +214,16 @@ func init() { test2048Key.Precompute() // This is the same testRSA2048PrivateKey from src/crypto/tls/boring_test.go, // just formatted without using the x509 Parser - testRSA2048PrivateKey = &PrivateKey { + testRSA2048PrivateKey = &PrivateKey{ PublicKey: PublicKey{ - N: fromBase10("20191212046465051006148469115982609963794084216822290493008497548603282433337961188011759317867632936762484431807200684727542982286641865915343951546098189846608892055894575224375729344858650310374442622904229900868894242623139807621975608166515302294530216022389036816474348374698399654955992710180316983674809047409565569027596663420090767109285120403886497729233127551307356270679924351259776100107640885071765865832767303853854517356000385050677175012549806941229051812974721510192346810990827150439838227830352248569839727388943852973737249863837089274675024496841834194785931485429238306703429257731792443735979"), - E: 65537, - }, - D: fromBase10("17880854551669112566868255345124108779447961606053558991611260520405836487267781427740459393783689829925402008838157275130340717548134956040019107677074732476577915942750039777107871579671122369249613210066309031335411813988461299033587444447689322284662780986426216011635232478916424602504476935371549462113036228740820951710434375466081011497256196435741125837599218374223248197677547321257961509961401385322723627033844333644253777689603896264679633990939957571483400832267925506777396569554295752505112186882586887396943424085633026984063372469902814987050483471096892524886948283571883744403645335501920852525393"), - Primes: []*big.Int{ - fromBase10("135564917074042739008372452399559667250812269638554028593490636590148234941034106656615266472037321030780472224077878987192393666277731486488609490961161995141171813440923127505183021899359310251888145112092740773465142711876177808655062479870526201006500762429604105802612357839979630776094264195301632424911"), - fromBase10("148941278335581696308445609123523329975323575697232717856977715718810138995490768513650108277383732380774181214791356462453504708304090734692215322335879527529217737837271384209093576836051031684425884921572908683147368296418243939771852059523598364231128661438022752350148969064661946939745752818523498309989"), - }, + N: fromBase10("20191212046465051006148469115982609963794084216822290493008497548603282433337961188011759317867632936762484431807200684727542982286641865915343951546098189846608892055894575224375729344858650310374442622904229900868894242623139807621975608166515302294530216022389036816474348374698399654955992710180316983674809047409565569027596663420090767109285120403886497729233127551307356270679924351259776100107640885071765865832767303853854517356000385050677175012549806941229051812974721510192346810990827150439838227830352248569839727388943852973737249863837089274675024496841834194785931485429238306703429257731792443735979"), + E: 65537, + }, + D: fromBase10("17880854551669112566868255345124108779447961606053558991611260520405836487267781427740459393783689829925402008838157275130340717548134956040019107677074732476577915942750039777107871579671122369249613210066309031335411813988461299033587444447689322284662780986426216011635232478916424602504476935371549462113036228740820951710434375466081011497256196435741125837599218374223248197677547321257961509961401385322723627033844333644253777689603896264679633990939957571483400832267925506777396569554295752505112186882586887396943424085633026984063372469902814987050483471096892524886948283571883744403645335501920852525393"), + Primes: []*big.Int{ + fromBase10("135564917074042739008372452399559667250812269638554028593490636590148234941034106656615266472037321030780472224077878987192393666277731486488609490961161995141171813440923127505183021899359310251888145112092740773465142711876177808655062479870526201006500762429604105802612357839979630776094264195301632424911"), + fromBase10("148941278335581696308445609123523329975323575697232717856977715718810138995490768513650108277383732380774181214791356462453504708304090734692215322335879527529217737837271384209093576836051031684425884921572908683147368296418243939771852059523598364231128661438022752350148969064661946939745752818523498309989"), + }, } testRSA2048PrivateKey.Precompute() @@ -247,7 +241,7 @@ func BenchmarkRSA2048Decrypt(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - decrypt(nil, test2048Key, c) + decrypt(test2048Key, c.Bytes()) } } @@ -286,7 +280,7 @@ func Benchmark3PrimeRSA2048Decrypt(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - decrypt(nil, priv, c) + decrypt(priv, c.Bytes()) } } @@ -374,7 +368,7 @@ func TestEncryptDecryptOAEP(t *testing.T) { if boring.Enabled() && priv.PublicKey.Size() < 256 { t.Logf("skipping check for unsupported key less than 2048 bits") - continue; + continue } t.Logf("running check for supported key size") for j, message := range test.msgs {