From 82ce15d98d07c51c744f0f3733cc19e56f7480ab Mon Sep 17 00:00:00 2001 From: juuso-j Date: Mon, 8 Apr 2024 09:26:45 +0300 Subject: [PATCH 1/4] Add class StartPollRateThrottle --- profiles/api/utils.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/profiles/api/utils.py b/profiles/api/utils.py index 56819b2..375e1e4 100644 --- a/profiles/api/utils.py +++ b/profiles/api/utils.py @@ -1,5 +1,6 @@ import django_filters from rest_framework.exceptions import ValidationError +from rest_framework.throttling import AnonRateThrottle from profiles.models import PostalCodeResult @@ -15,6 +16,15 @@ def blur_count(count, threshold=5): return count +class StartPollRateThrottle(AnonRateThrottle): + """ + The AnonRateThrottle will only ever throttle unauthenticated users. + The IP address of the incoming request is used to generate a unique key to throttle against. + """ + + rate = "10/day" + + class CustomValidationError(ValidationError): # The detail field is shown also when DEBUG=False # Ensures the error message is displayed to the user From e8553325d61449c61c09096fbd093d742181c25f Mon Sep 17 00:00:00 2001 From: juuso-j Date: Mon, 8 Apr 2024 09:27:21 +0300 Subject: [PATCH 2/4] Add StartPollRateThrottle to start_poll throttle_classes --- profiles/api/views.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/profiles/api/views.py b/profiles/api/views.py index bd35944..848ba6d 100644 --- a/profiles/api/views.py +++ b/profiles/api/views.py @@ -56,7 +56,7 @@ ) from profiles.utils import encrypt_text, generate_password, get_user_result -from .utils import PostalCodeResultFilter +from .utils import PostalCodeResultFilter, StartPollRateThrottle logger = logging.getLogger(__name__) @@ -185,6 +185,7 @@ def list(self, request, *args, **kwargs): detail=False, methods=["POST"], permission_classes=[AllowAny], + throttle_classes=[StartPollRateThrottle], ) def start_poll(self, request): uuid4 = uuid.uuid4() From 1f45a73e8591effd798a3a574bfa0c8cde1d001b Mon Sep 17 00:00:00 2001 From: juuso-j Date: Mon, 8 Apr 2024 09:30:10 +0300 Subject: [PATCH 3/4] Add throttling test for start_poll --- profiles/tests/api/test_answer.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/profiles/tests/api/test_answer.py b/profiles/tests/api/test_answer.py index eb3af19..89f2965 100644 --- a/profiles/tests/api/test_answer.py +++ b/profiles/tests/api/test_answer.py @@ -1,8 +1,11 @@ +import time + import pytest from rest_framework.authtoken.models import Token from rest_framework.reverse import reverse from account.models import User +from profiles.api.views import QuestionViewSet from profiles.models import Answer, Option, Question, SubQuestion @@ -21,6 +24,29 @@ def test_start_poll(api_client): assert User.objects.all().count() == 1 +@pytest.mark.django_db +@pytest.mark.parametrize( + "ip_address", + [ + ("240.231.131.14"), + ], +) +def test_start_poll_throttling(api_client_with_custom_ip_address): + # Set number of requests to be made from the rate. The rate is stored as a string, e.g., rate = "10/day" + num_requests = int( + QuestionViewSet.start_poll.kwargs["throttle_classes"][0].rate.split("/")[0] + ) + url = reverse("profiles:question-start-poll") + count = 0 + while count < num_requests: + response = api_client_with_custom_ip_address.post(url) + assert response.status_code == 200 + count += 1 + time.sleep(2) + response = api_client_with_custom_ip_address.post(url) + assert response.status_code == 429 + + @pytest.mark.django_db def test_post_answer(api_client_authenticated, users, questions, options): user = users.get(username="test1") From 9aa45c0a026c373245ae74e578aa1d1401a08c76 Mon Sep 17 00:00:00 2001 From: juuso-j Date: Mon, 8 Apr 2024 09:42:20 +0300 Subject: [PATCH 4/4] Read num_request from the defined rate --- account/tests/test_api.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/account/tests/test_api.py b/account/tests/test_api.py index b063537..d686f4a 100644 --- a/account/tests/test_api.py +++ b/account/tests/test_api.py @@ -7,6 +7,7 @@ from freezegun import freeze_time from rest_framework.reverse import reverse +from account.api.views import ProfileViewSet from account.models import MailingList, MailingListEmail, User from profiles.models import PostalCode, PostalCodeResult @@ -48,7 +49,10 @@ def test_unauthenticated_cannot_do_anything(api_client, users): def test_mailing_list_unsubscribe_throttling( api_client_with_custom_ip_address, mailing_list_emails ): - num_requests = 10 + # Set number of requests to be made from the rate. The rate is stored as a string, e.g., rate = "10/day" + num_requests = int( + ProfileViewSet.unsubscribe.kwargs["throttle_classes"][0].rate.split("/")[0] + ) url = reverse("account:profiles-unsubscribe") count = 0 while count < num_requests: