Skip to content

Commit

Permalink
API: add CSRF tests
Browse files Browse the repository at this point in the history
  • Loading branch information
giacomognosis committed Apr 11, 2024
1 parent 1425b63 commit db980a1
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 14 deletions.
17 changes: 10 additions & 7 deletions api/api/services/csrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,16 @@ def generate_token(self):
return CSRFTokenItem(request_id, token.hex())

def validate_token(self, request_id, token):
cipher_rsa = PKCS1_OAEP.new(self._privkey)
decrypted_text = cipher_rsa.decrypt(bytes.fromhex(token)).decode()

expected_text = '%s%s' % (request_id, self._salt)
if decrypted_text == expected_text:
return True
return False
try:
cipher_rsa = PKCS1_OAEP.new(self._privkey)
decrypted_text = cipher_rsa.decrypt(bytes.fromhex(token)).decode()

expected_text = '%s%s' % (request_id, self._salt)
if decrypted_text == expected_text:
return True
return False
except Exception:
return False


class CSRF:
Expand Down
9 changes: 2 additions & 7 deletions api/api/services/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,8 @@ def csrf_validation(self):
self.errors.append('Bad request')
self.http_return_code = 400

try:
csrf_valid = self.csrf.validate_token(request_id, token)
if not csrf_valid:
self.errors.append('Bad request')
self.http_return_code = 400
except Exception as e:
logging.error(e)
csrf_valid = self.csrf.validate_token(request_id, token)
if not csrf_valid:
self.errors.append('Bad request')
self.http_return_code = 400

Expand Down
19 changes: 19 additions & 0 deletions api/tests/test_csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .conftest import BaseTest


class TestCSRF(BaseTest):

def test_values(self):
token_obj = self.csrf.generate_token()
self.assertTrue(
self.csrf.validate_token(token_obj.request_id, token_obj.token)
)
self.assertFalse(
self.csrf.validate_token('myfakeid', token_obj.token)
)
self.assertFalse(
self.csrf.validate_token('myfakeid', 'myfaketoken')
)
self.assertFalse(
self.csrf.validate_token(token_obj.request_id, 'myfaketoken')
)

0 comments on commit db980a1

Please sign in to comment.