Skip to content

Commit

Permalink
Add support for CSRF
Browse files Browse the repository at this point in the history
  • Loading branch information
giacomognosis committed Apr 11, 2024
1 parent dd0a132 commit 0c73372
Show file tree
Hide file tree
Showing 17 changed files with 227 additions and 43 deletions.
5 changes: 4 additions & 1 deletion api/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .manage import (block_user_cmd, create_access_keys_cmd,
create_enabled_token_cmd)
from .routes import apiv1
from .services import Web3Singleton
from .services import CSRF, Web3Singleton
from .services.database import db


Expand Down Expand Up @@ -46,6 +46,9 @@ def create_app():
db.init_app(app)
Migrate(app, db)

# Initialize CSRF Library
CSRF(app.config['CSRF_PRIVATE_KEY'], app.config['CSRF_SECRET_SALT'])

# Initialize Web3 class for latter usage
w3 = Web3Singleton(app.config['FAUCET_RPC_URL'], app.config['FAUCET_PRIVATE_KEY'])

Expand Down
22 changes: 16 additions & 6 deletions api/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from web3 import Web3

from .const import FaucetRequestType, TokenType
from .services import (AskEndpointValidator, Web3Singleton, claim_native,
from .services import (CSRF, AskEndpointValidator, Web3Singleton, claim_native,
claim_token)
from .services.database import AccessKey, Token, Transaction

Expand All @@ -14,7 +14,7 @@ def status():
return jsonify(status='ok'), 200


@apiv1.route("/info")
@apiv1.route("/info", methods=["GET"])
def info():
enabled_tokens = Token.enabled_tokens()
rate_limit_days = current_app.config['FAUCET_RATE_LIMIT_TIME_LIMIT_SECONDS'] / (24*60*60)
Expand All @@ -27,27 +27,37 @@ def info():
'rateLimitDays': round(rate_limit_days, 2)
} for t in enabled_tokens
]

# it's a singleton, gets instantiated at app creation time
csrf = CSRF.instance
csrf_item = csrf.generate_token()

return jsonify(
enabledTokens=enabled_tokens_json,
chainId=current_app.config['FAUCET_CHAIN_ID'],
chainName=current_app.config['FAUCET_CHAIN_NAME'],
faucetAddress=current_app.config['FAUCET_ADDRESS']
faucetAddress=current_app.config['FAUCET_ADDRESS'],
csrfToken=csrf_item.token,
csrfRequestId=csrf_item.request_id
), 200


def _ask(request_data, validate_captcha=True, access_key=None):
def _ask(request_data, request_headers, validate_captcha=True, validate_csrf=True, access_key=None):
"""Process /ask request
Args:
request_data (object): request object
validate_captcha (bool, optional): True if captcha must be validated, False otherwise. Defaults to True.
validate_csrf (bool, optional): True if CSRF token must be validated, False otherwise. Defaults to True.
access_key (object, optional): AccessKey instance. Defaults to None.
Returns:
tuple: json content, status code
"""
validator = AskEndpointValidator(request_data,
request_headers,
validate_captcha,
validate_csrf,
access_key=access_key)
ok = validator.validate()
if not ok:
Expand Down Expand Up @@ -94,7 +104,7 @@ def _ask(request_data, validate_captcha=True, access_key=None):

@apiv1.route("/ask", methods=["POST"])
def ask():
data, status_code = _ask(request.get_json(), validate_captcha=True, access_key=None)
data, status_code = _ask(request.get_json(), request.headers, validate_captcha=True, access_key=None)
return data, status_code


Expand All @@ -116,5 +126,5 @@ def cli_ask():
validation_errors.append('Access denied')
return jsonify(errors=validation_errors), 403

data, status_code = _ask(request.get_json(), validate_captcha=False, access_key=access_key)
data, status_code = _ask(request.get_json(), request.headers, validate_captcha=False, validate_csrf=False, access_key=access_key)
return data, status_code
1 change: 1 addition & 0 deletions api/api/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .csrf import CSRF
from .database import DatabaseSingleton
from .rate_limit import RateLimitStrategy, Strategy
from .token import Token
Expand Down
2 changes: 1 addition & 1 deletion api/api/services/captcha.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import requests
import logging

import requests

logging.basicConfig(level=logging.INFO)

Expand Down
45 changes: 45 additions & 0 deletions api/api/services/csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

import random

from Crypto.Cipher import PKCS1_OAEP
from Crypto.PublicKey import RSA


class CSRFTokenItem:
def __init__(self, request_id, token):
self.request_id = request_id
self.token = token


class CSRFToken:
def __init__(self, privkey, salt):
self._privkey = RSA.import_key(privkey)
self._pubkey = self._privkey.publickey()
self._salt = salt

def generate_token(self):
request_id = '%d' % random.randint(0, 1000)
data_to_encrypt = '%s%s' % (request_id, self._salt)

cipher_rsa = PKCS1_OAEP.new(self._pubkey)
token = cipher_rsa.encrypt(data_to_encrypt.encode())

return CSRFTokenItem(request_id, token.hex())

def validate_token(self, request_id, token):
cipher_rsa = PKCS1_OAEP.new(self._privkey)
decrypted_text = cipher_rsa.decrypt(bytes.fromhex(token)).decode()

expected_text = '%s%s' % (request_id, self._salt)
if decrypted_text == expected_text:
return True
return False


class CSRF:
_instance = None

def __new__(cls, privatekey, salt):
if not hasattr(cls, 'instance'):
cls.instance = CSRFToken(privatekey, salt)
return cls.instance
5 changes: 2 additions & 3 deletions api/api/services/database.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import sqlite3
from datetime import datetime

from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData, func

from api.const import (DEFAULT_ERC20_MAX_AMOUNT_PER_DAY,
DEFAULT_NATIVE_MAX_AMOUNT_PER_DAY, FaucetRequestType)
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData, func

flask_db_convention = {
"ix": 'ix_%(column_0_label)s',
Expand Down
39 changes: 36 additions & 3 deletions api/api/services/validator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import datetime
import logging

from api.const import TokenType
from flask import current_app, request
from web3 import Web3

from api.const import TokenType

from .captcha import captcha_verify
from .csrf import CSRF
from .database import AccessKeyConfig, BlockedUsers, Token, Transaction
from .rate_limit import Strategy

logging.basicConfig(level=logging.INFO)


class AskEndpointValidator:
errors = []
Expand All @@ -25,14 +28,23 @@ class AskEndpointValidator:
'RATE_LIMIT_EXCEEDED': 'recipient: you have exceeded the limit for today. Try again in %d hours'
}

def __init__(self, request_data, validate_captcha, access_key=None, *args, **kwargs):
def __init__(self, request_data, request_headers, validate_captcha, validate_csrf, access_key=None, *args, **kwargs):
self.request_data = request_data
self.request_headers = request_headers
self.validate_captcha = validate_captcha
self.validate_csrf = validate_csrf
self.access_key = access_key
self.ip_address = request.environ.get('HTTP_X_FORWARDED_FOR', request.remote_addr)
self.errors = []
self.csrf = CSRF.instance

def validate(self):
import pdb; pdb.set_trace()
if self.validate_csrf:
self.csrf_validation()
if len(self.errors) > 0:
return False

self.blocked_user_validation()
if len(self.errors) > 0:
return False
Expand Down Expand Up @@ -64,6 +76,27 @@ def validate(self):
return False
return True

def csrf_validation(self):
token = self.request_headers.get('X-CSRFToken', None)
if not token:
self.errors.append('Bad request')
self.http_return_code = 400

request_id = self.request_data.get('requestId', None)
if not request_id:
self.errors.append('Bad request')
self.http_return_code = 400

try:
csrf_valid = self.csrf.validate_token(request_id, token)
if not csrf_valid:
self.errors.append('Bad request')
self.http_return_code = 400
except Exception as e:
logging.error(e)
self.errors.append('Bad request')
self.http_return_code = 400

def blocked_user_validation(self):
recipient = self.request_data.get('recipient', None)
# Run validation on blocked users only if `recipient` is available.
Expand Down
3 changes: 3 additions & 0 deletions api/api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@
CAPTCHA_VERIFY_ENDPOINT = os.getenv('CAPTCHA_VERIFY_ENDPOINT')
CAPTCHA_SECRET_KEY = os.getenv('CAPTCHA_SECRET_KEY')
CAPTCHA_SITE_KEY = os.getenv('CAPTCHA_SITE_KEY')

CSRF_PRIVATE_KEY = os.getenv('CSRF_PRIVATE_KEY')
CSRF_SECRET_SALT = os.getenv('CSRF_SECRET_SALT')
1 change: 0 additions & 1 deletion api/migrations/versions/4cacf36b2356_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""
import sqlalchemy as sa
from alembic import op

from api.services.database import flask_db_convention

# revision identifiers, used by Alembic.
Expand Down
3 changes: 2 additions & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ python-dotenv==1.0.0
web3==6.11.3
pytest==7.4.3
pytest-mock==3.12.0
gunicorn==21.2.0
gunicorn==21.2.0
pycryptodome==3.20.0
16 changes: 14 additions & 2 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from unittest import TestCase, TestResult, mock

from api.services import CSRF, Strategy
from api.services.database import Token, db
from flask.testing import FlaskClient

from api import create_app
from api.services import Strategy
from api.services.database import Token, db

from .temp_env_var import FAUCET_ENABLED_TOKENS, TEMP_ENV_VARS

Expand Down Expand Up @@ -71,6 +71,10 @@ def setUp(self):
with self.app_ctx:
self._reset_db()

self.csrf = CSRF.instance
# use same token for the whole test
self.csrf_token = self.csrf.generate_token()

def tearDown(self):
'''
Cleanup to do after running each test
Expand Down Expand Up @@ -99,6 +103,10 @@ def setUp(self):
with self.app_ctx:
self._reset_db()

self.csrf = CSRF.instance
# use same token for the whole test
self.csrf_token = self.csrf.generate_token()


class RateLimitIPorAddressBaseTest(BaseTest):
def setUp(self):
Expand All @@ -117,3 +125,7 @@ def setUp(self):

with self.app_ctx:
self._reset_db()

self.csrf = CSRF.instance
# use same token for the whole test
self.csrf_token = self.csrf.generate_token()
7 changes: 6 additions & 1 deletion api/tests/temp_env_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from api.const import (DEFAULT_ERC20_MAX_AMOUNT_PER_DAY,
DEFAULT_NATIVE_MAX_AMOUNT_PER_DAY, NATIVE_TOKEN_ADDRESS,
TokenType)
from Crypto.PublicKey import RSA

ERC20_TOKEN_ADDRESS = "0x" + '1' * 40

Expand All @@ -28,14 +29,18 @@
}
]

privatekey = RSA.generate(1024)

TEMP_ENV_VARS = {
'FAUCET_RPC_URL': 'http://localhost:8545',
'FAUCET_CHAIN_ID': str(FAUCET_CHAIN_ID),
'FAUCET_PRIVATE_KEY': token_bytes(32).hex(),
'FAUCET_RATE_LIMIT_TIME_LIMIT_SECONDS': '10',
'FAUCET_DATABASE_URI': 'sqlite://', # run in-memory
# 'FAUCET_DATABASE_URI': 'sqlite:///test.db',
'CAPTCHA_SECRET_KEY': CAPTCHA_TEST_SECRET_KEY
'CAPTCHA_SECRET_KEY': CAPTCHA_TEST_SECRET_KEY,
'CSRF_PRIVATE_KEY': privatekey.export_key().decode(),
'CSRF_SECRET_SALT': 'testsalt'
}

# Mocked values
Expand Down
Loading

0 comments on commit 0c73372

Please sign in to comment.