diff --git a/account/api/views.py b/account/api/views.py index 37a580f..588a214 100644 --- a/account/api/views.py +++ b/account/api/views.py @@ -10,6 +10,7 @@ from rest_framework.throttling import AnonRateThrottle from account.models import MailingList, MailingListEmail, Profile +from profiles.api.views import update_postal_code_result from profiles.models import Result from .serializers import ProfileSerializer, SubscribeSerializer, UnSubscribeSerializer @@ -46,6 +47,16 @@ def update(self, request, *args, **kwargs): ) if serializer.is_valid(): serializer.save() + """ + To ensure the postal code results are updated before the user exists the poll. + After the questions are answered the front-end updates the profile information with a put request. + """ + + if ( + request.method == "PUT" + and serializer.data.get("postal_code", None) is not None + ): + update_postal_code_result(user) return Response(data=serializer.data, status=status.HTTP_200_OK) else: return Response(data=serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/account/tests/conftest.py b/account/tests/conftest.py index 958fc90..97b022d 100644 --- a/account/tests/conftest.py +++ b/account/tests/conftest.py @@ -26,8 +26,8 @@ def api_client_with_custom_ip_address(ip_address): @pytest.fixture -def users(): - User.objects.create(username="test1") +def users(results): + User.objects.create(username="test1", result=results.first()) return User.objects.all() diff --git a/account/tests/test_api.py b/account/tests/test_api.py index a1ff211..028b5cf 100644 --- a/account/tests/test_api.py +++ b/account/tests/test_api.py @@ -8,6 +8,7 @@ from rest_framework.reverse import reverse from account.models import MailingList, MailingListEmail, User +from profiles.models import PostalCode, PostalCodeResult from .utils import check_method_status_codes, patch @@ -75,6 +76,45 @@ def test_profile_created(api_client): assert user.profile.postal_code is None +@pytest.mark.django_db +def test_profile_put_creates_postal_code_result( + api_client_authenticated, users, profiles +): + user = users.get(username="test1") + url = reverse("account:profiles-detail", args=[user.id]) + data = { + "postal_code": "20210", + "optional_postal_code": "20220", + } + response = api_client_authenticated.put(url, data) + assert response.status_code == 200 + assert PostalCodeResult.objects.count() == 2 + assert ( + PostalCodeResult.objects.get( + postal_code=PostalCode.objects.get(postal_code="20210") + ).count + == 1 + ) + assert ( + PostalCodeResult.objects.get( + postal_code=PostalCode.objects.get(postal_code="20220") + ).count + == 1 + ) + + +@pytest.mark.django_db +def test_profile_put_not_creates_postal_code_result( + api_client_authenticated, users, profiles +): + user = users.get(username="test1") + url = reverse("account:profiles-detail", args=[user.id]) + data = {"year_of_birth": 42} + response = api_client_authenticated.put(url, data) + assert response.status_code == 200 + assert PostalCodeResult.objects.count() == 0 + + @pytest.mark.django_db def test_profile_put(api_client_authenticated, users, profiles): user = users.get(username="test1") @@ -123,6 +163,7 @@ def test_profile_patch_geneder(api_client_authenticated, users, profiles): patch(api_client_authenticated, url, {"gender": "X"}) user.refresh_from_db() assert user.profile.gender == "X" + assert PostalCodeResult.objects.count() == 0 @pytest.mark.django_db @@ -132,6 +173,7 @@ def test_profile_patch_postal_code(api_client_authenticated, users, profiles): patch(api_client_authenticated, url, {"postal_code": "20210"}) user.refresh_from_db() assert user.profile.postal_code == "20210" + assert PostalCodeResult.objects.count() == 0 @pytest.mark.django_db