From 2a49ab97a9cd5506ee2e41d9df5cebb707de5de0 Mon Sep 17 00:00:00 2001 From: Pyry Lahtinen Date: Wed, 15 Nov 2023 09:27:30 +0200 Subject: [PATCH] streamlined decode util to handle lists instead of single polynomials --- kyber/decrypt.py | 11 +++-------- kyber/encrypt.py | 6 ++---- kyber/utils/encoding.py | 22 +++++++++++++--------- tests/test_encoding.py | 8 ++++---- 4 files changed, 22 insertions(+), 25 deletions(-) diff --git a/kyber/decrypt.py b/kyber/decrypt.py index e831c1c..cdca206 100644 --- a/kyber/decrypt.py +++ b/kyber/decrypt.py @@ -20,17 +20,12 @@ def decrypt(self) -> bytes: :returns Decrypted 32-bit shared secret """ - # split self._sk into chunks of length 32*12 and decode each one of them into a polynomial - s = np.array([ - decode(self._sk[32*12*i : 32*12*(i+1)], 12) for i in range(len(self._sk)//(32*12)) - ]) + s = np.array(decode(self._sk, 12)) u, v = self._c[:du*k*n//8], self._c[du*k*n//8:] - u = np.array([ - decode(u[32*du*i : 32*du*(i+1)], du) for i in range(len(u)//(32*du)) - ]) - v = decode(v, dv) + u = decode(u, du) + v = decode(v, dv)[0] u = np.array([decompress(pol, du) for pol in u]) v = decompress(v, dv) diff --git a/kyber/encrypt.py b/kyber/encrypt.py index 07f53fd..8ec2d6b 100644 --- a/kyber/encrypt.py +++ b/kyber/encrypt.py @@ -34,9 +34,7 @@ def encrypt(self): rb = self._r t, rho = self._pk[:-32], self._pk[-32:] - t = np.array([ - decode(t[32*12*i : 32*12*(i+1)], 12) for i in range(len(t)//(32*12)) - ]) + t = np.array(decode(t, 12)) A = np.empty((k, k), Polynomial) for i in range(k): @@ -60,7 +58,7 @@ def encrypt(self): e2 = polmod(e2) u = np.matmul(A.T, r) + e1 - v = np.matmul(t.T, r) + e2 + decompress(decode(m, 1), 1) + v = np.matmul(t.T, r) + e2 + decompress(decode(m, 1)[0], 1) u = matmod(u) v = polmod(v) diff --git a/kyber/utils/encoding.py b/kyber/utils/encoding.py index 67e4004..2aaa533 100644 --- a/kyber/utils/encoding.py +++ b/kyber/utils/encoding.py @@ -32,17 +32,21 @@ def encode(pols: list[Polynomial], l: int) -> bytes: assert len(result) == 32*l*len(pols) return bytes(result) -def decode(b: bytes, l: int) -> Polynomial: +def decode(b: bytes, l: int) -> list[Polynomial]: """ - Converts the given byte array (length `32*l`) into a polynomial (degree 255) + Converts the given byte array (length `32*l*x` for some integer x) into + a list of polynomials (length x, each degree 255) in which each coefficient is in range `0...2**l-1` (inclusive). """ - if len(b) != 32*l: + if len(b) % 32*l != 0: raise ValueError() - bits = bytes_to_bits(b) - f = np.empty((256, )) - for i in range(256): - f[i] = sum(bits[i*l+j]*2**j for j in range(l)) # accesses each bit exactly once - assert 0 <= f[i] and f[i] <= 2**l-1 - return Polynomial(f) + result = [] + for t in range(len(b) // (32*l)): + bits = bytes_to_bits(b[32*l*t : 32*l*(t+1)]) + f = np.empty((256, )) + for i in range(256): + f[i] = sum(bits[i*l+j]*2**j for j in range(l)) # accesses each bit exactly once + assert 0 <= f[i] and f[i] <= 2**l-1 + result.append(Polynomial(f)) + return result diff --git a/tests/test_encoding.py b/tests/test_encoding.py index fb0b35c..0199d5a 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -12,17 +12,17 @@ def setUp(self): self.polynomial2 = Polynomial([randint(0, 1) for _ in range(256)]) def test_encoding_symmetry(self): - polynomial = decode(self.data, self.l) - restored_data = encode([polynomial], self.l) + polynomials = decode(self.data, self.l) + restored_data = encode(polynomials, self.l) self.assertEqual(self.data, restored_data) def test_decode_coefficients(self): - polynomial = decode(self.data, self.l) + polynomial = decode(self.data, self.l)[0] for c in polynomial.coef: self.assertTrue(0 <= int(c) or int(c) <= 2**self.l-1) def test_decode_degree(self): - polynomial = decode(self.data, self.l) + polynomial = decode(self.data, self.l)[0] self.assertEqual(len(polynomial.coef), 256) def test_decode_raises_with_invalid_argument_length(self):