From db980a1e5d90c11dfdf3cb2f32367d53a33f0a49 Mon Sep 17 00:00:00 2001 From: Giacomo Licari Date: Thu, 11 Apr 2024 19:20:21 +0200 Subject: [PATCH] API: add CSRF tests --- api/api/services/csrf.py | 17 ++++++++++------- api/api/services/validator.py | 9 ++------- api/tests/test_csrf.py | 19 +++++++++++++++++++ 3 files changed, 31 insertions(+), 14 deletions(-) create mode 100644 api/tests/test_csrf.py diff --git a/api/api/services/csrf.py b/api/api/services/csrf.py index 5931857..d6811f1 100644 --- a/api/api/services/csrf.py +++ b/api/api/services/csrf.py @@ -27,13 +27,16 @@ def generate_token(self): return CSRFTokenItem(request_id, token.hex()) def validate_token(self, request_id, token): - cipher_rsa = PKCS1_OAEP.new(self._privkey) - decrypted_text = cipher_rsa.decrypt(bytes.fromhex(token)).decode() - - expected_text = '%s%s' % (request_id, self._salt) - if decrypted_text == expected_text: - return True - return False + try: + cipher_rsa = PKCS1_OAEP.new(self._privkey) + decrypted_text = cipher_rsa.decrypt(bytes.fromhex(token)).decode() + + expected_text = '%s%s' % (request_id, self._salt) + if decrypted_text == expected_text: + return True + return False + except Exception: + return False class CSRF: diff --git a/api/api/services/validator.py b/api/api/services/validator.py index 12db851..9347a5b 100644 --- a/api/api/services/validator.py +++ b/api/api/services/validator.py @@ -86,13 +86,8 @@ def csrf_validation(self): self.errors.append('Bad request') self.http_return_code = 400 - try: - csrf_valid = self.csrf.validate_token(request_id, token) - if not csrf_valid: - self.errors.append('Bad request') - self.http_return_code = 400 - except Exception as e: - logging.error(e) + csrf_valid = self.csrf.validate_token(request_id, token) + if not csrf_valid: self.errors.append('Bad request') self.http_return_code = 400 diff --git a/api/tests/test_csrf.py b/api/tests/test_csrf.py new file mode 100644 index 0000000..f40fea2 --- /dev/null +++ b/api/tests/test_csrf.py @@ -0,0 +1,19 @@ +from .conftest import BaseTest + + +class TestCSRF(BaseTest): + + def test_values(self): + token_obj = self.csrf.generate_token() + self.assertTrue( + self.csrf.validate_token(token_obj.request_id, token_obj.token) + ) + self.assertFalse( + self.csrf.validate_token('myfakeid', token_obj.token) + ) + self.assertFalse( + self.csrf.validate_token('myfakeid', 'myfaketoken') + ) + self.assertFalse( + self.csrf.validate_token(token_obj.request_id, 'myfaketoken') + )