Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CSRF. #41

Merged
merged 4 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion api/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .manage import (block_user_cmd, create_access_keys_cmd,
create_enabled_token_cmd)
from .routes import apiv1
from .services import Web3Singleton
from .services import CSRF, Web3Singleton
from .services.database import db


Expand Down Expand Up @@ -46,6 +46,9 @@ def create_app():
db.init_app(app)
Migrate(app, db)

# Initialize CSRF Library
CSRF(app.config['CSRF_PRIVATE_KEY'], app.config['CSRF_SECRET_SALT'])

# Initialize Web3 class for latter usage
w3 = Web3Singleton(app.config['FAUCET_RPC_URL'], app.config['FAUCET_PRIVATE_KEY'])

Expand Down
22 changes: 16 additions & 6 deletions api/api/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from web3 import Web3

from .const import FaucetRequestType, TokenType
from .services import (AskEndpointValidator, Web3Singleton, claim_native,
from .services import (CSRF, AskEndpointValidator, Web3Singleton, claim_native,
claim_token)
from .services.database import AccessKey, Token, Transaction

Expand All @@ -14,7 +14,7 @@ def status():
return jsonify(status='ok'), 200


@apiv1.route("/info")
@apiv1.route("/info", methods=["GET"])
def info():
enabled_tokens = Token.enabled_tokens()
rate_limit_days = current_app.config['FAUCET_RATE_LIMIT_TIME_LIMIT_SECONDS'] / (24*60*60)
Expand All @@ -27,27 +27,37 @@ def info():
'rateLimitDays': round(rate_limit_days, 2)
} for t in enabled_tokens
]

# it's a singleton, gets instantiated at app creation time
csrf = CSRF.instance
csrf_item = csrf.generate_token()

return jsonify(
enabledTokens=enabled_tokens_json,
chainId=current_app.config['FAUCET_CHAIN_ID'],
chainName=current_app.config['FAUCET_CHAIN_NAME'],
faucetAddress=current_app.config['FAUCET_ADDRESS']
faucetAddress=current_app.config['FAUCET_ADDRESS'],
csrfToken=csrf_item.token,
csrfRequestId=csrf_item.request_id
), 200


def _ask(request_data, validate_captcha=True, access_key=None):
def _ask(request_data, request_headers, validate_captcha=True, validate_csrf=True, access_key=None):
"""Process /ask request

Args:
request_data (object): request object
validate_captcha (bool, optional): True if captcha must be validated, False otherwise. Defaults to True.
validate_csrf (bool, optional): True if CSRF token must be validated, False otherwise. Defaults to True.
access_key (object, optional): AccessKey instance. Defaults to None.

Returns:
tuple: json content, status code
"""
validator = AskEndpointValidator(request_data,
request_headers,
validate_captcha,
validate_csrf,
access_key=access_key)
ok = validator.validate()
if not ok:
Expand Down Expand Up @@ -94,7 +104,7 @@ def _ask(request_data, validate_captcha=True, access_key=None):

@apiv1.route("/ask", methods=["POST"])
def ask():
data, status_code = _ask(request.get_json(), validate_captcha=True, access_key=None)
data, status_code = _ask(request.get_json(), request.headers, validate_captcha=True, access_key=None)
return data, status_code


Expand All @@ -116,5 +126,5 @@ def cli_ask():
validation_errors.append('Access denied')
return jsonify(errors=validation_errors), 403

data, status_code = _ask(request.get_json(), validate_captcha=False, access_key=access_key)
data, status_code = _ask(request.get_json(), request.headers, validate_captcha=False, validate_csrf=False, access_key=access_key)
return data, status_code
1 change: 1 addition & 0 deletions api/api/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .csrf import CSRF
from .database import DatabaseSingleton
from .rate_limit import RateLimitStrategy, Strategy
from .token import Token
Expand Down
2 changes: 1 addition & 1 deletion api/api/services/captcha.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import requests
import logging

import requests

logging.basicConfig(level=logging.INFO)

Expand Down
48 changes: 48 additions & 0 deletions api/api/services/csrf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

import random

from Crypto.Cipher import PKCS1_OAEP
from Crypto.PublicKey import RSA


class CSRFTokenItem:
def __init__(self, request_id, token):
self.request_id = request_id
self.token = token


class CSRFToken:
def __init__(self, privkey, salt):
self._privkey = RSA.import_key(privkey)
self._pubkey = self._privkey.publickey()
self._salt = salt

def generate_token(self):
request_id = '%d' % random.randint(0, 1000)
data_to_encrypt = '%s%s' % (request_id, self._salt)

cipher_rsa = PKCS1_OAEP.new(self._pubkey)
token = cipher_rsa.encrypt(data_to_encrypt.encode())

return CSRFTokenItem(request_id, token.hex())

def validate_token(self, request_id, token):
try:
cipher_rsa = PKCS1_OAEP.new(self._privkey)
decrypted_text = cipher_rsa.decrypt(bytes.fromhex(token)).decode()

expected_text = '%s%s' % (request_id, self._salt)
if decrypted_text == expected_text:
return True
return False
except Exception:
return False


class CSRF:
_instance = None

def __new__(cls, privatekey, salt):
if not hasattr(cls, 'instance'):
cls.instance = CSRFToken(privatekey, salt)
return cls.instance
5 changes: 2 additions & 3 deletions api/api/services/database.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import sqlite3
from datetime import datetime

from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData, func

from api.const import (DEFAULT_ERC20_MAX_AMOUNT_PER_DAY,
DEFAULT_NATIVE_MAX_AMOUNT_PER_DAY, FaucetRequestType)
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData, func

flask_db_convention = {
"ix": 'ix_%(column_0_label)s',
Expand Down
33 changes: 30 additions & 3 deletions api/api/services/validator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import datetime
import logging

from api.const import TokenType
from flask import current_app, request
from web3 import Web3

from api.const import TokenType

from .captcha import captcha_verify
from .csrf import CSRF
from .database import AccessKeyConfig, BlockedUsers, Token, Transaction
from .rate_limit import Strategy

logging.basicConfig(level=logging.INFO)


class AskEndpointValidator:
errors = []
Expand All @@ -25,14 +28,22 @@ class AskEndpointValidator:
'RATE_LIMIT_EXCEEDED': 'recipient: you have exceeded the limit for today. Try again in %d hours'
}

def __init__(self, request_data, validate_captcha, access_key=None, *args, **kwargs):
def __init__(self, request_data, request_headers, validate_captcha, validate_csrf, access_key=None, *args, **kwargs):
self.request_data = request_data
self.request_headers = request_headers
self.validate_captcha = validate_captcha
self.validate_csrf = validate_csrf
self.access_key = access_key
self.ip_address = request.environ.get('HTTP_X_FORWARDED_FOR', request.remote_addr)
self.errors = []
self.csrf = CSRF.instance

def validate(self):
if self.validate_csrf:
self.csrf_validation()
if len(self.errors) > 0:
return False

self.blocked_user_validation()
if len(self.errors) > 0:
return False
Expand Down Expand Up @@ -64,6 +75,22 @@ def validate(self):
return False
return True

def csrf_validation(self):
token = self.request_headers.get('X-CSRFToken', None)
if not token:
self.errors.append('Bad request')
self.http_return_code = 400

request_id = self.request_data.get('requestId', None)
if not request_id:
self.errors.append('Bad request')
self.http_return_code = 400

csrf_valid = self.csrf.validate_token(request_id, token)
if not csrf_valid:
self.errors.append('Bad request')
self.http_return_code = 400

def blocked_user_validation(self):
recipient = self.request_data.get('recipient', None)
# Run validation on blocked users only if `recipient` is available.
Expand Down
3 changes: 3 additions & 0 deletions api/api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@
CAPTCHA_VERIFY_ENDPOINT = os.getenv('CAPTCHA_VERIFY_ENDPOINT')
CAPTCHA_SECRET_KEY = os.getenv('CAPTCHA_SECRET_KEY')
CAPTCHA_SITE_KEY = os.getenv('CAPTCHA_SITE_KEY')

CSRF_PRIVATE_KEY = os.getenv('CSRF_PRIVATE_KEY')
CSRF_SECRET_SALT = os.getenv('CSRF_SECRET_SALT')
1 change: 0 additions & 1 deletion api/migrations/versions/4cacf36b2356_.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""
import sqlalchemy as sa
from alembic import op

from api.services.database import flask_db_convention

# revision identifiers, used by Alembic.
Expand Down
3 changes: 2 additions & 1 deletion api/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ python-dotenv==1.0.0
web3==6.11.3
pytest==7.4.3
pytest-mock==3.12.0
gunicorn==21.2.0
gunicorn==21.2.0
pycryptodome==3.20.0
16 changes: 14 additions & 2 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from unittest import TestCase, TestResult, mock

from api.services import CSRF, Strategy
from api.services.database import Token, db
from flask.testing import FlaskClient

from api import create_app
from api.services import Strategy
from api.services.database import Token, db

from .temp_env_var import FAUCET_ENABLED_TOKENS, TEMP_ENV_VARS

Expand Down Expand Up @@ -71,6 +71,10 @@ def setUp(self):
with self.app_ctx:
self._reset_db()

self.csrf = CSRF.instance
# use same token for the whole test
self.csrf_token = self.csrf.generate_token()

def tearDown(self):
'''
Cleanup to do after running each test
Expand Down Expand Up @@ -99,6 +103,10 @@ def setUp(self):
with self.app_ctx:
self._reset_db()

self.csrf = CSRF.instance
# use same token for the whole test
self.csrf_token = self.csrf.generate_token()


class RateLimitIPorAddressBaseTest(BaseTest):
def setUp(self):
Expand All @@ -117,3 +125,7 @@ def setUp(self):

with self.app_ctx:
self._reset_db()

self.csrf = CSRF.instance
# use same token for the whole test
self.csrf_token = self.csrf.generate_token()
7 changes: 6 additions & 1 deletion api/tests/temp_env_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from api.const import (DEFAULT_ERC20_MAX_AMOUNT_PER_DAY,
DEFAULT_NATIVE_MAX_AMOUNT_PER_DAY, NATIVE_TOKEN_ADDRESS,
TokenType)
from Crypto.PublicKey import RSA

ERC20_TOKEN_ADDRESS = "0x" + '1' * 40

Expand All @@ -28,14 +29,18 @@
}
]

privatekey = RSA.generate(1024)

TEMP_ENV_VARS = {
'FAUCET_RPC_URL': 'http://localhost:8545',
'FAUCET_CHAIN_ID': str(FAUCET_CHAIN_ID),
'FAUCET_PRIVATE_KEY': token_bytes(32).hex(),
'FAUCET_RATE_LIMIT_TIME_LIMIT_SECONDS': '10',
'FAUCET_DATABASE_URI': 'sqlite://', # run in-memory
# 'FAUCET_DATABASE_URI': 'sqlite:///test.db',
'CAPTCHA_SECRET_KEY': CAPTCHA_TEST_SECRET_KEY
'CAPTCHA_SECRET_KEY': CAPTCHA_TEST_SECRET_KEY,
'CSRF_PRIVATE_KEY': privatekey.export_key().decode(),
'CSRF_SECRET_SALT': 'testsalt'
}

# Mocked values
Expand Down
Loading
Loading