Skip to content

Commit

Permalink
added tests for ciphertext format
Browse files Browse the repository at this point in the history
  • Loading branch information
PyryL committed Nov 15, 2023
1 parent b05d42f commit e6ab86d
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 6 deletions.
2 changes: 2 additions & 0 deletions kyber/decrypt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def __init__(self, private_key, ciphertext) -> None:
self._c = ciphertext
if len(self._sk) != 32*12*k:
raise ValueError()
if len(self._c) != du*k*n//8 + dv*n//8:
raise ValueError()

def decrypt(self) -> bytes:
"""
Expand Down
26 changes: 21 additions & 5 deletions tests/test_decryption.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
import unittest
from random import seed, randbytes
from kyber.decrypt import Decrypt
from kyber.constants import k
from kyber.constants import k, n, du, dv

class TestDecryption(unittest.TestCase):
def test_decryption_raises_with_invalid_input(self):
# this private key is one byte too long
def setUp(self):
seed(42)

def test_decryption_outputs_valid_shared_secret(self):
private_key = randbytes(32*12*k)
ciphertext = randbytes(du*k*n//8 + dv*n//8)
shared_secret = Decrypt(private_key, ciphertext).decrypt()
self.assertEqual(type(shared_secret), bytes)
self.assertEqual(len(shared_secret), 32)

def test_decryption_raises_with_invalid_private_key(self):
# this private key is one byte too long
invalid_private_key = randbytes(32*12*k + 1)
ciphertext_placeholder = ()
valid_ciphertext = randbytes(du*k*n//8 + dv*n//8)
with self.assertRaises(ValueError):
Decrypt(invalid_private_key, valid_ciphertext)

def test_decryption_raises_with_invalid_ciphertext(self):
# this ciphertext is one byte too short
valid_private_key = randbytes(32*12*k)
invalid_ciphertext = randbytes(du*k*n//8 + dv*n//8 - 1)
with self.assertRaises(ValueError):
Decrypt(invalid_private_key, ciphertext_placeholder)
Decrypt(valid_private_key, invalid_ciphertext)
8 changes: 7 additions & 1 deletion tests/test_encryption.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
from random import seed, randbytes
from kyber.encrypt import Encrypt
from kyber.constants import k, n
from kyber.constants import k, n, du, dv

class TestEncryption(unittest.TestCase):
def test_encryption_raises_with_invalid_input(self):
Expand All @@ -15,3 +15,9 @@ def test_encryption_generates_valid_shared_secret(self):
seed(42)
encrypter = Encrypt(randbytes(12 * k * n//8 + 32))
self.assertEqual(len(encrypter.secret), 32)

def test_encryption_outputs_valid_ciphertext(self):
seed(42)
ciphertext = Encrypt(randbytes(12 * k * n//8 + 32)).encrypt()
self.assertEqual(type(ciphertext), bytes)
self.assertEqual(len(ciphertext), du*k*n//8 + dv*n//8)

0 comments on commit e6ab86d

Please sign in to comment.