From 92d70ae5d9feb5319b6e2cf5e896cd8f3ecac0c4 Mon Sep 17 00:00:00 2001 From: Stefan Borer Date: Fri, 22 May 2020 17:58:21 +0200 Subject: [PATCH 1/2] feat: remove ldap auth, simplify test client setup Remove support for LDAP auth, in order for preparing to OIDC auth instead. Until now, we relied on a custom `JSONAPIClient` which is unnecessary, hence we drop it. Use the standard rest_framework `APIClient` instead with `force_authenticate`, cutting out the authentication middleware during testing. --- requirements.txt | 2 - timed/conftest.py | 109 +++++++----------- .../employment/tests/test_absence_balance.py | 4 +- timed/employment/tests/test_user.py | 17 +-- .../employment/tests/test_worktime_balance.py | 8 +- timed/projects/tests/test_project.py | 4 +- .../reports/tests/test_customer_statistic.py | 4 +- timed/reports/tests/test_project_statistic.py | 2 +- timed/reports/tests/test_task_statistic.py | 2 +- timed/reports/tests/test_work_report.py | 4 +- timed/settings.py | 32 ++--- timed/tests/client.py | 90 --------------- timed/tests/test_client.py | 20 ---- timed/tracking/tests/test_report.py | 2 +- timed/urls.py | 3 - 15 files changed, 69 insertions(+), 234 deletions(-) delete mode 100644 timed/tests/client.py delete mode 100644 timed/tests/test_client.py diff --git a/requirements.txt b/requirements.txt index 78a0c649..009f06aa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,9 @@ python-dateutil==2.8.1 django==2.2.13 -django-auth-ldap==2.1.1 # might remove this once we find out how the jsonapi extras_require work django-filter==2.3.0 django-multiselectfield==0.1.12 djangorestframework==3.11.0 -djangorestframework-simplejwt==4.4.0 djangorestframework-jsonapi[django-filter]==3.1.0 psycopg2==2.8.5 pytz==2020.1 diff --git a/timed/conftest.py b/timed/conftest.py index 8657d323..ad813889 100644 --- a/timed/conftest.py +++ b/timed/conftest.py @@ -1,15 +1,14 @@ import inspect -import mockldap import pytest from django.contrib.auth import get_user_model from factory.base import FactoryMetaClass from pytest_factoryboy import register +from rest_framework.test import APIClient from timed.employment import factories as employment_factories from timed.projects import factories as projects_factories from timed.subscription import factories as subscription_factories -from timed.tests.client import JSONAPIClient from timed.tracking import factories as tracking_factories @@ -25,48 +24,9 @@ def register_module(module): register_module(tracking_factories) -@pytest.fixture(autouse=True, scope="session") -def ldap_directory(): - top = ("o=test", {"o": "test"}) - people = ("ou=people,o=test", {"ou": "people"}) - groups = ("ou=groups,o=test", {"ou": "groups"}) - ldapuser = ( - "uid=ldapuser,ou=people,o=test", - { - "uid": ["ldapuser"], - "objectClass": [ - "person", - "organizationalPerson", - "inetOrgPerson", - "posixAccount", - ], - "userPassword": ["Test1234!"], - "uidNumber": ["1000"], - "gidNumber": ["1000"], - "givenName": ["givenName"], - "mail": ["ldapuser@example.net"], - "sn": ["LdapUser"], - }, - ) - - directory = dict([top, people, groups, ldapuser]) - mock = mockldap.MockLdap(directory) - mock.start() - - yield - - mock.stop() - - -@pytest.fixture -def client(db): - return JSONAPIClient() - - @pytest.fixture -def auth_client(db): - """Return instance of a JSONAPIClient that is logged in as test user.""" - user = get_user_model().objects.create_user( +def auth_user(db): + return get_user_model().objects.create_user( username="user", password="123qweasd", first_name="Test", @@ -75,43 +35,58 @@ def auth_client(db): is_staff=False, ) - client = JSONAPIClient() - client.user = user - client.login("user", "123qweasd") - return client - @pytest.fixture -def admin_client(db): - """Return instance of a JSONAPIClient that is logged in as a staff user.""" - user = get_user_model().objects.create_user( - username="user", +def admin_user(db): + return get_user_model().objects.create_user( + username="admin", password="123qweasd", - first_name="Test", + first_name="Admin", last_name="User", is_superuser=False, is_staff=True, ) - client = JSONAPIClient() - client.user = user - client.login("user", "123qweasd") - return client - @pytest.fixture -def superadmin_client(db): - """Return instance of a JSONAPIClient that is logged in as superuser.""" - user = get_user_model().objects.create_user( - username="user", +def superadmin_user(db): + return get_user_model().objects.create_user( + username="superadmin", password="123qweasd", - first_name="Test", + first_name="Superadmin", last_name="User", - is_staff=True, is_superuser=True, + is_staff=True, ) - client = JSONAPIClient() - client.user = user - client.login("user", "123qweasd") + +@pytest.fixture +def client(): + return APIClient() + + +@pytest.fixture +def auth_client(auth_user): + """Return instance of a APIClient that is logged in as test user.""" + client = APIClient() + client.force_authenticate(user=auth_user) + client.user = auth_user + return client + + +@pytest.fixture +def admin_client(admin_user): + """Return instance of a APIClient that is logged in as a staff user.""" + client = APIClient() + client.force_authenticate(user=admin_user) + client.user = admin_user + return client + + +@pytest.fixture +def superadmin_client(superadmin_user): + """Return instance of a APIClient that is logged in as superuser.""" + client = APIClient() + client.force_authenticate(user=superadmin_user) + client.user = superadmin_user return client diff --git a/timed/employment/tests/test_absence_balance.py b/timed/employment/tests/test_absence_balance.py index 90174ac8..8111777a 100644 --- a/timed/employment/tests/test_absence_balance.py +++ b/timed/employment/tests/test_absence_balance.py @@ -30,7 +30,7 @@ def test_absence_balance_full_day(auth_client, django_assert_num_queries): url = reverse("absence-balance-list") - with django_assert_num_queries(7): + with django_assert_num_queries(6): result = auth_client.get( url, data={ @@ -73,7 +73,7 @@ def test_absence_balance_fill_worktime(auth_client, django_assert_num_queries): AbsenceFactory.create(date=day, user=user, type=absence_type) url = reverse("absence-balance-list") - with django_assert_num_queries(12): + with django_assert_num_queries(11): result = auth_client.get( url, data={ diff --git a/timed/employment/tests/test_user.py b/timed/employment/tests/test_user.py index 637d97b1..6efaf22c 100644 --- a/timed/employment/tests/test_user.py +++ b/timed/employment/tests/test_user.py @@ -1,7 +1,6 @@ from datetime import date, timedelta import pytest -from django.contrib.auth import get_user_model from django.urls import reverse from rest_framework import status @@ -20,27 +19,19 @@ def test_user_list_unauthenticated(client): assert response.status_code == status.HTTP_401_UNAUTHORIZED -def test_user_update_unauthenticated(client): +def test_user_update_unauthenticated(client, db): user = UserFactory.create() url = reverse("user-detail", args=[user.id]) response = client.patch(url) assert response.status_code == status.HTTP_401_UNAUTHORIZED -def test_user_login_ldap(client): - client.login("ldapuser", "Test1234!") - user = get_user_model().objects.get(username="ldapuser") - assert user.first_name == "givenName" - assert user.last_name == "LdapUser" - assert user.email == "ldapuser@example.net" - - -def test_user_list(auth_client, django_assert_num_queries): +def test_user_list(db, auth_client, django_assert_num_queries): UserFactory.create_batch(2) url = reverse("user-list") - with django_assert_num_queries(8): + with django_assert_num_queries(7): response = auth_client.get(url) assert response.status_code == status.HTTP_200_OK @@ -135,7 +126,7 @@ def test_user_delete_superuser(superadmin_client): assert response.status_code == status.HTTP_204_NO_CONTENT -def test_user_delete_with_reports_superuser(superadmin_client): +def test_user_delete_with_reports_superuser(superadmin_client, db): """Test that user with reports may not be deleted.""" user = UserFactory.create() ReportFactory.create(user=user) diff --git a/timed/employment/tests/test_worktime_balance.py b/timed/employment/tests/test_worktime_balance.py index 58e1aee2..72d5076b 100644 --- a/timed/employment/tests/test_worktime_balance.py +++ b/timed/employment/tests/test_worktime_balance.py @@ -24,7 +24,7 @@ def test_worktime_balance_create(auth_client): def test_worktime_balance_no_employment(auth_client, django_assert_num_queries): url = reverse("worktime-balance-list") - with django_assert_num_queries(4): + with django_assert_num_queries(3): result = auth_client.get( url, data={"user": auth_client.user.id, "date": "2017-01-01"} ) @@ -89,7 +89,7 @@ def test_worktime_balance_with_employments(auth_client, django_assert_num_querie args=["{0}_{1}".format(auth_client.user.id, end_date.strftime("%Y-%m-%d"))], ) - with django_assert_num_queries(12): + with django_assert_num_queries(11): result = auth_client.get(url) assert result.status_code == status.HTTP_200_OK @@ -180,7 +180,7 @@ def test_worktime_balance_list_last_reported_date_no_reports( url = reverse("worktime-balance-list") - with django_assert_num_queries(2): + with django_assert_num_queries(1): result = auth_client.get(url, data={"last_reported_date": 1}) assert result.status_code == status.HTTP_200_OK @@ -215,7 +215,7 @@ def test_worktime_balance_list_last_reported_date( url = reverse("worktime-balance-list") - with django_assert_num_queries(10): + with django_assert_num_queries(9): result = auth_client.get(url, data={"last_reported_date": 1}) assert result.status_code == status.HTTP_200_OK diff --git a/timed/projects/tests/test_project.py b/timed/projects/tests/test_project.py index bf50c98a..8a425a22 100644 --- a/timed/projects/tests/test_project.py +++ b/timed/projects/tests/test_project.py @@ -29,7 +29,7 @@ def test_project_list_include(auth_client, django_assert_num_queries, project): url = reverse("project-list") - with django_assert_num_queries(8): + with django_assert_num_queries(7): response = auth_client.get( url, data={"include": ",".join(ProjectSerializer.included_serializers.keys())}, @@ -41,7 +41,7 @@ def test_project_list_include(auth_client, django_assert_num_queries, project): assert json["data"][0]["id"] == str(project.id) -def test_project_detail_no_auth(client, project): +def test_project_detail_no_auth(db, client, project): url = reverse("project-detail", args=[project.id]) res = client.get(url) diff --git a/timed/reports/tests/test_customer_statistic.py b/timed/reports/tests/test_customer_statistic.py index db295de3..2bbdf63b 100644 --- a/timed/reports/tests/test_customer_statistic.py +++ b/timed/reports/tests/test_customer_statistic.py @@ -11,7 +11,7 @@ def test_customer_statistic_list(auth_client, django_assert_num_queries): report2 = ReportFactory.create(duration=timedelta(hours=4)) url = reverse("customer-statistic-list") - with django_assert_num_queries(4): + with django_assert_num_queries(3): result = auth_client.get( url, data={"ordering": "duration", "include": "customer"} ) @@ -56,7 +56,7 @@ def test_customer_statistic_detail(auth_client, django_assert_num_queries): report = ReportFactory.create(duration=timedelta(hours=1)) url = reverse("customer-statistic-detail", args=[report.task.project.customer.id]) - with django_assert_num_queries(3): + with django_assert_num_queries(2): result = auth_client.get( url, data={"ordering": "duration", "include": "customer"} ) diff --git a/timed/reports/tests/test_project_statistic.py b/timed/reports/tests/test_project_statistic.py index 249fd037..aa4df7e7 100644 --- a/timed/reports/tests/test_project_statistic.py +++ b/timed/reports/tests/test_project_statistic.py @@ -11,7 +11,7 @@ def test_project_statistic_list(auth_client, django_assert_num_queries): report2 = ReportFactory.create(duration=timedelta(hours=4)) url = reverse("project-statistic-list") - with django_assert_num_queries(5): + with django_assert_num_queries(4): result = auth_client.get( url, data={"ordering": "duration", "include": "project,project.customer"} ) diff --git a/timed/reports/tests/test_task_statistic.py b/timed/reports/tests/test_task_statistic.py index 4c409857..ee5bb286 100644 --- a/timed/reports/tests/test_task_statistic.py +++ b/timed/reports/tests/test_task_statistic.py @@ -14,7 +14,7 @@ def test_task_statistic_list(auth_client, django_assert_num_queries): ReportFactory.create(duration=timedelta(hours=2), task=task_z) url = reverse("task-statistic-list") - with django_assert_num_queries(5): + with django_assert_num_queries(4): result = auth_client.get( url, data={ diff --git a/timed/reports/tests/test_work_report.py b/timed/reports/tests/test_work_report.py index c16ae6e6..a24e9c59 100644 --- a/timed/reports/tests/test_work_report.py +++ b/timed/reports/tests/test_work_report.py @@ -25,7 +25,7 @@ def test_work_report_single_project(auth_client, django_assert_num_queries): ) url = reverse("work-report-list") - with django_assert_num_queries(4): + with django_assert_num_queries(3): res = auth_client.get( url, data={ @@ -60,7 +60,7 @@ def test_work_report_multiple_projects(auth_client, django_assert_num_queries): ReportFactory.create_batch(10, user=user, task=task, date=report_date) url = reverse("work-report-list") - with django_assert_num_queries(4): + with django_assert_num_queries(3): res = auth_client.get(url, data={"user": auth_client.user.id, "verified": 0}) assert res.status_code == status.HTTP_200_OK assert "20170901-WorkReports.zip" in (res["Content-Disposition"]) diff --git a/timed/settings.py b/timed/settings.py index a6d56783..7bd0ee11 100644 --- a/timed/settings.py +++ b/timed/settings.py @@ -1,4 +1,3 @@ -import datetime import os import re @@ -159,6 +158,12 @@ def default(default_dev=env.NOTSET, default_prod=env.NOTSET): "EXCEPTION_HANDLER": "rest_framework_json_api.exceptions.exception_handler", "DEFAULT_PAGINATION_CLASS": "rest_framework_json_api.pagination.JsonApiPageNumberPagination", "DEFAULT_RENDERER_CLASSES": ("rest_framework_json_api.renderers.JSONRenderer",), + "TEST_REQUEST_RENDERER_CLASSES": ( + "rest_framework_json_api.renderers.JSONRenderer", + "rest_framework.renderers.JSONRenderer", + "rest_framework.renderers.MultiPartRenderer", + ), + "TEST_REQUEST_DEFAULT_FORMAT": "vnd.api+json", } JSON_API_FORMAT_FIELD_NAMES = "dasherize" @@ -167,32 +172,11 @@ def default(default_dev=env.NOTSET, default_prod=env.NOTSET): APPEND_SLASH = False -# Authentication definition - -AUTHENTICATION_BACKENDS = ["django.contrib.auth.backends.ModelBackend"] - -AUTH_LDAP_ENABLED = env.bool("DJANGO_AUTH_LDAP_ENABLED", default=False) -if AUTH_LDAP_ENABLED: - AUTH_LDAP_USER_ATTR_MAP = env.dict( - "DJANGO_AUTH_LDAP_USER_ATTR_MAP", - default={"first_name": "givenName", "last_name": "sn", "email": "mail"}, - ) - - AUTH_LDAP_SERVER_URI = env.str("DJANGO_AUTH_LDAP_SERVER_URI") - AUTH_LDAP_BIND_DN = env.str("DJANGO_AUTH_LDAP_BIND_DN", default="") - AUTH_LDAP_BIND_PASSWORD = env.str("DJANGO_AUTH_LDAP_BIND_PASSWORD", default="") - AUTH_LDAP_USER_DN_TEMPLATE = env.str("DJANGO_AUTH_LDAP_USER_DN_TEMPLATE") - AUTHENTICATION_BACKENDS.insert(0, "django_auth_ldap.backend.LDAPBackend") +# Authentication +# AUTHENTICATION_BACKENDS = ["django.contrib.auth.backends.ModelBackend"] AUTH_USER_MODEL = "employment.User" -SIMPLE_AUTH = { - "ACCESS_TOKEN_LIFETIME": datetime.timedelta(days=2), - "REFRESH_TOKEN_LIFETIME": datetime.timedelta(days=7), - # TODO check if this is ROTATE_REFRESH_TOKENS - # "JWT_ALLOW_REFRESH": True, -} - AUTH_PASSWORD_VALIDATORS = [ { "NAME": "django.contrib.auth.password_validation.UserAttributeSimilarityValidator" # noqa diff --git a/timed/tests/client.py b/timed/tests/client.py deleted file mode 100644 index eedcf483..00000000 --- a/timed/tests/client.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Helpers for testing with JSONAPI.""" - -import json - -from django.urls import reverse -from rest_framework import exceptions, status -from rest_framework.test import APIClient - - -class JSONAPIClient(APIClient): - """Base API client for testing CRUD methods with JSONAPI format.""" - - def __init__(self, *args, **kwargs): - """Initialize the API client.""" - super().__init__(*args, **kwargs) - - self._content_type = "application/vnd.api+json" - - def _parse_data(self, data): - return json.dumps(data) if data else data - - def get(self, path, data=None, **kwargs): - """Patched GET method to enforce JSONAPI format. - - :param str path: The URL to call - :param dict data: The data of the request - """ - return super().get( - path=path, data=data, content_type=self._content_type, **kwargs - ) - - def post(self, path, data=None, **kwargs): - """Patched POST method to enforce JSONAPI format. - - :param str path: The URL to call - :param dict data: The data of the request - """ - return super().post( - path=path, - data=self._parse_data(data), - content_type=self._content_type, - **kwargs, - ) - - def delete(self, path, data=None, **kwargs): - """Patched DELETE method to enforce JSONAPI format. - - :param str path: The URL to call - :param dict data: The data of the request - """ - return super().delete( - path=path, - data=self._parse_data(data), - content_type=self._content_type, - **kwargs, - ) - - def patch(self, path, data=None, **kwargs): - """Patched PATCH method to enforce JSONAPI format. - - :param str path: The URL to call - :param dict data: The data of the request - """ - return super().patch( - path=path, - data=self._parse_data(data), - content_type=self._content_type, - **kwargs, - ) - - def login(self, username, password): - """Authenticate a user. - - :param str username: Username of the user - :param str password: Password of the user - :raises: exceptions.AuthenticationFailed - """ - data = { - "data": { - "attributes": {"username": username, "password": password}, - "type": "token-obtain-pair-views", - } - } - - response = self.post(reverse("login"), data) - - if response.status_code != status.HTTP_200_OK: - raise exceptions.AuthenticationFailed() - - self.credentials(HTTP_AUTHORIZATION=f"Bearer {response.data['access']}") diff --git a/timed/tests/test_client.py b/timed/tests/test_client.py deleted file mode 100644 index 21ab16f5..00000000 --- a/timed/tests/test_client.py +++ /dev/null @@ -1,20 +0,0 @@ -import pytest -from django.contrib.auth import get_user_model -from rest_framework import exceptions - -from timed.tests.client import JSONAPIClient - - -def test_client_login(db): - get_user_model().objects.create_user( - username="user", password="123qweasd", first_name="Test", last_name="User" - ) - - client = JSONAPIClient() - client.login("user", "123qweasd") - - -def test_client_login_fails(db): - client = JSONAPIClient() - with pytest.raises(exceptions.AuthenticationFailed): - client.login("someuser", "invalidpw") diff --git a/timed/tracking/tests/test_report.py b/timed/tracking/tests/test_report.py index a362cf23..2d96aeda 100644 --- a/timed/tracking/tests/test_report.py +++ b/timed/tracking/tests/test_report.py @@ -795,7 +795,7 @@ def test_report_export( url = reverse("report-export") - with django_assert_num_queries(2): + with django_assert_num_queries(1): response = auth_client.get(url, data={"file_type": file_type}) assert response.status_code == status.HTTP_200_OK diff --git a/timed/urls.py b/timed/urls.py index bc0f76d9..1dd4aabf 100644 --- a/timed/urls.py +++ b/timed/urls.py @@ -2,12 +2,9 @@ from django.conf.urls import include, url from django.contrib import admin -from rest_framework_simplejwt.views import TokenObtainPairView, TokenRefreshView urlpatterns = [ url(r"^admin/", admin.site.urls), - url(r"^api/v1/auth/login", TokenObtainPairView.as_view(), name="login"), - url(r"^api/v1/auth/refresh", TokenRefreshView.as_view(), name="refresh"), url(r"^api/v1/", include("timed.employment.urls")), url(r"^api/v1/", include("timed.projects.urls")), url(r"^api/v1/", include("timed.tracking.urls")), From e59447596ea525bb3a753cbb63f5e27b93e586db Mon Sep 17 00:00:00 2001 From: Stefan Borer Date: Tue, 26 May 2020 09:40:36 +0200 Subject: [PATCH 2/2] feat(auth): implement oidc authentication Implement OIDC-based authentication using `mozilla-django-oidc`. Additionally to the default ModelBackend (used for django-admin), we subclass the Mozilla OIDC backend and customize it for our needs. That is, we accept two kinds of OIDC flows: * authorization code (using a public client) * client credentials grant (using a confidential client). The first is the accepted for the timed-frontend, the latter is used for server-to-server communication by services consuming the timed api. --- dev-config/nginx.conf | 2 +- docker-compose.override.yml | 6 +- docker-compose.yml | 17 +++- requirements-dev.txt | 1 + requirements.txt | 1 + timed/authentication.py | 103 ++++++++++++++++++++++++ timed/conftest.py | 6 ++ timed/settings.py | 52 ++++++++++++- timed/tests/test_authentication.py | 121 +++++++++++++++++++++++++++++ 9 files changed, 304 insertions(+), 5 deletions(-) create mode 100644 timed/authentication.py create mode 100644 timed/tests/test_authentication.py diff --git a/dev-config/nginx.conf b/dev-config/nginx.conf index 21cacb43..7fe439ee 100644 --- a/dev-config/nginx.conf +++ b/dev-config/nginx.conf @@ -14,7 +14,7 @@ server { client_max_body_size 50m; # db-flush may not be exposed in PRODUCTION! - location ~ ^/(api|admin|db-flush)/ { + location ~ ^/(api|admin|static|db-flush)/ { set $backend http://backend; proxy_pass $backend; } diff --git a/docker-compose.override.yml b/docker-compose.override.yml index aa35375e..b9b83b64 100644 --- a/docker-compose.override.yml +++ b/docker-compose.override.yml @@ -1,4 +1,4 @@ -version: "3" +version: "3.7" services: backend: @@ -14,6 +14,8 @@ services: volumes: - ./:/app command: /bin/sh -c "wait-for-it.sh -t 60 db:5432 -- ./manage.py migrate && ./manage.py runserver 0.0.0.0:80" + networks: + - timed.local mailhog: image: mailhog/mailhog @@ -21,3 +23,5 @@ services: - 8025:8025 environment: - MH_UI_WEB_PATH=mailhog + networks: + - timed.local diff --git a/docker-compose.yml b/docker-compose.yml index fa8fdab1..770fe94e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,4 +1,4 @@ -version: "3" +version: "3.7" services: db: @@ -10,6 +10,8 @@ services: environment: - POSTGRES_USER=timed - POSTGRES_PASSWORD=timed + networks: + - timed.local frontend: image: adfinissygroup/timed-frontend:latest @@ -17,6 +19,8 @@ services: - backend ports: - 4200:80 + networks: + - timed.local backend: build: . @@ -29,6 +33,8 @@ services: - DJANGO_DATABASE_PORT=5432 - ENV=docker - STATIC_ROOT=/var/www/static + networks: + - timed.local keycloak: image: jboss/keycloak:10.0.1 @@ -44,6 +50,8 @@ services: - DB_PASSWORD=timed - PROXY_ADDRESS_FORWARDING=true command: ["-Dkeycloak.migration.action=import", "-Dkeycloak.migration.provider=singleFile", "-Dkeycloak.migration.file=/etc/keycloak/keycloak-config.json", "-b", "0.0.0.0"] + networks: + - timed.local proxy: image: nginx:1.17.10-alpine @@ -51,6 +59,13 @@ services: - 80:80 volumes: - ./dev-config/nginx.conf:/etc/nginx/conf.d/default.conf:ro + networks: + timed.local: + aliases: + - timed.local volumes: dbdata: + +networks: + timed.local: diff --git a/requirements-dev.txt b/requirements-dev.txt index 86728d43..c80422e8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -21,4 +21,5 @@ pytest-factoryboy==2.0.3 pytest-freezegun==0.4.1 pytest-mock==3.1.1 pytest-randomly==3.4.0 +requests-mock==1.8.0 snapshottest==0.5.1 diff --git a/requirements.txt b/requirements.txt index 009f06aa..516f44ca 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,6 +5,7 @@ django-filter==2.3.0 django-multiselectfield==0.1.12 djangorestframework==3.11.0 djangorestframework-jsonapi[django-filter]==3.1.0 +mozilla-django-oidc==1.2.3 psycopg2==2.8.5 pytz==2020.1 pyexcel-webio==0.1.4 diff --git a/timed/authentication.py b/timed/authentication.py new file mode 100644 index 00000000..e0de2b60 --- /dev/null +++ b/timed/authentication.py @@ -0,0 +1,103 @@ +import base64 +import functools +import hashlib + +import requests +from django.conf import settings +from django.core.cache import cache +from django.core.exceptions import SuspiciousOperation +from django.utils.encoding import force_bytes +from mozilla_django_oidc.auth import LOGGER, OIDCAuthenticationBackend + + +class TimedOIDCAuthenticationBackend(OIDCAuthenticationBackend): + def get_introspection(self, access_token, id_token, payload): + """Return user details dictionary.""" + + basic = base64.b64encode( + f"{settings.OIDC_OP_INTROSPECT_CLIENT_ID}:{settings.OIDC_OP_INTROSPECT_CLIENT_SECRET}".encode( + "utf-8" + ) + ).decode() + headers = { + "Authorization": f"Basic {basic}", + "Content-Type": "application/x-www-form-urlencoded", + } + response = requests.post( + settings.OIDC_OP_INTROSPECT_ENDPOINT, + verify=settings.OIDC_VERIFY_SSL, + headers=headers, + data={"token": access_token}, + ) + response.raise_for_status() + return response.json() + + def get_userinfo_or_introspection(self, access_token): + try: + claims = self.cached_request( + self.get_userinfo, access_token, "auth.userinfo" + ) + except requests.HTTPError as e: + if not ( + e.response.status_code in [401, 403] and settings.OIDC_CHECK_INTROSPECT + ): + raise e + + # check introspection if userinfo fails (confidental client) + claims = self.cached_request( + self.get_introspection, access_token, "auth.introspection" + ) + if "client_id" not in claims: + raise SuspiciousOperation("client_id not present in introspection") + + return claims + + def get_or_create_user(self, access_token, id_token, payload): + """Verify claims and return user, otherwise raise an Exception.""" + + claims = self.get_userinfo_or_introspection(access_token) + + users = self.filter_users_by_claims(claims) + + if len(users) == 1: + return users[0] + elif settings.OIDC_CREATE_USER: + return self.create_user(claims) + else: + LOGGER.debug( + "Login failed: No user with username %s found, and " + "OIDC_CREATE_USER is False", + self.get_username(claims), + ) + return None + + def filter_users_by_claims(self, claims): + username = self.get_username(claims) + return self.UserModel.objects.filter(username=username) + + def cached_request(self, method, token, cache_prefix): + token_hash = hashlib.sha256(force_bytes(token)).hexdigest() + + func = functools.partial(method, token, None, None) + + return cache.get_or_set( + f"{cache_prefix}.{token_hash}", + func, + timeout=settings.OIDC_BEARER_TOKEN_REVALIDATION_TIME, + ) + + def create_user(self, claims): + """Return object for a newly created user account.""" + username = self.get_username(claims) + email = claims.get(settings.OIDC_EMAIL_CLAIM, "") + first_name = claims.get(settings.OIDC_FIRSTNAME_CLAIM, "") + last_name = claims.get(settings.OIDC_LASTNAME_CLAIM, "") + return self.UserModel.objects.create( + username=username, email=email, first_name=first_name, last_name=last_name + ) + + def get_username(self, claims): + try: + return claims[settings.OIDC_USERNAME_CLAIM] + except KeyError: + raise SuspiciousOperation("Couldn't find username claim") diff --git a/timed/conftest.py b/timed/conftest.py index ad813889..2e9f5c9a 100644 --- a/timed/conftest.py +++ b/timed/conftest.py @@ -2,6 +2,7 @@ import pytest from django.contrib.auth import get_user_model +from django.core.cache import cache from factory.base import FactoryMetaClass from pytest_factoryboy import register from rest_framework.test import APIClient @@ -90,3 +91,8 @@ def superadmin_client(superadmin_user): client.force_authenticate(user=superadmin_user) client.user = superadmin_user return client + + +@pytest.fixture(scope="function", autouse=True) +def _autoclear_cache(): + cache.clear() diff --git a/timed/settings.py b/timed/settings.py index 7bd0ee11..14aff22e 100644 --- a/timed/settings.py +++ b/timed/settings.py @@ -60,6 +60,7 @@ def default(default_dev=env.NOTSET, default_prod=env.NOTSET): "rest_framework", "django_filters", "djmoney", + "mozilla_django_oidc", "timed.employment", "timed.projects", "timed.tracking", @@ -141,6 +142,17 @@ def default(default_dev=env.NOTSET, default_prod=env.NOTSET): STATIC_URL = env.str("STATIC_URL", "/static/") STATIC_ROOT = env.str("STATIC_ROOT", None) +# Cache + +CACHES = { + "default": { + "BACKEND": env.str( + "CACHE_BACKEND", default="django.core.cache.backends.locmem.LocMemCache" + ), + "LOCATION": env.str("CACHE_LOCATION", ""), + } +} + # Rest framework definition REST_FRAMEWORK = { @@ -152,7 +164,7 @@ def default(default_dev=env.NOTSET, default_prod=env.NOTSET): "DEFAULT_PARSER_CLASSES": ("rest_framework_json_api.parsers.JSONParser",), "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",), "DEFAULT_AUTHENTICATION_CLASSES": ( - "rest_framework_simplejwt.authentication.JWTAuthentication", + "mozilla_django_oidc.contrib.drf.OIDCAuthentication", ), "DEFAULT_METADATA_CLASS": "rest_framework_json_api.metadata.JSONAPIMetadata", "EXCEPTION_HANDLER": "rest_framework_json_api.exceptions.exception_handler", @@ -174,8 +186,11 @@ def default(default_dev=env.NOTSET, default_prod=env.NOTSET): # Authentication -# AUTHENTICATION_BACKENDS = ["django.contrib.auth.backends.ModelBackend"] AUTH_USER_MODEL = "employment.User" +AUTHENTICATION_BACKENDS = [ + "django.contrib.auth.backends.ModelBackend", + "timed.authentication.TimedOIDCAuthenticationBackend", +] AUTH_PASSWORD_VALIDATORS = [ { @@ -188,6 +203,39 @@ def default(default_dev=env.NOTSET, default_prod=env.NOTSET): }, ] +# OIDC + +OIDC_DEFAULT_BASE_URL = "http://timed.local/auth/realms/timed/protocol/openid-connect" + +OIDC_OP_USER_ENDPOINT = env.str( + "OIDC_USERINFO_ENDPOINT", default=default(f"{OIDC_DEFAULT_BASE_URL}/userinfo") +) +OIDC_OP_TOKEN_ENDPOINT = env.str( + "OIDC_TOKEN_ENDPOINT", default=default(f"{OIDC_DEFAULT_BASE_URL}/token") +) +OIDC_RP_CLIENT_ID = env.str("OIDC_CLIENT_ID", default=None) +OIDC_RP_CLIENT_SECRET = env.str("OIDC_CLIENT_SECRET", default=None) +OIDC_VERIFY_SSL = env.bool("OIDC_VERIFY_SSL", default=default(False, True)) +OIDC_CREATE_USER = env.bool("OIDC_CREATE_USER", default=False) + +OIDC_USERNAME_CLAIM = env.str("OIDC_USERNAME_CLAIM", default="preferred_username") +OIDC_EMAIL_CLAIM = env.str("OIDC_EMAIL_CLAIM", default="email") +OIDC_FIRSTNAME_CLAIM = env.str("OIDC_FIRSTNAME_CLAIM", default="given_name") +OIDC_LASTNAME_CLAIM = env.str("OIDC_LASTNAME_CLAIM", default="family_name") +# time in seconds +OIDC_BEARER_TOKEN_REVALIDATION_TIME = env.int( + "OIDC_BEARER_TOKEN_REVALIDATION_TIME", default=60 +) +OIDC_CHECK_INTROSPECT = env.bool("OIDC_CHECK_INTROSPECT", default=True) +OIDC_OP_INTROSPECT_ENDPOINT = env.str( + "OIDC_INTROSPECT_ENDPOINT", + default=default(f"{OIDC_DEFAULT_BASE_URL}/token/introspect"), +) +OIDC_OP_INTROSPECT_CLIENT_ID = env.str("OIDC_INTROSPECT_CLIENT_ID", default=None) +OIDC_OP_INTROSPECT_CLIENT_SECRET = env.str( + "OIDC_INTROSPECT_CLIENT_SECRET", default=None +) + # Email definition EMAIL_CONFIG = env.email_url("EMAIL_URL", default="smtp://localhost:25") diff --git a/timed/tests/test_authentication.py b/timed/tests/test_authentication.py new file mode 100644 index 00000000..a50fa19a --- /dev/null +++ b/timed/tests/test_authentication.py @@ -0,0 +1,121 @@ +import hashlib +import json + +import pytest +from django.contrib.auth import get_user_model +from django.core.cache import cache +from mozilla_django_oidc.contrib.drf import OIDCAuthentication +from requests.exceptions import HTTPError +from rest_framework import exceptions, status +from rest_framework.exceptions import AuthenticationFailed + + +@pytest.mark.parametrize("is_id_token", [True, False]) +@pytest.mark.parametrize( + "authentication_header,authenticated,error", + [ + ("", False, False), + ("Bearer", False, True), + ("Bearer Too many params", False, True), + ("Basic Auth", False, True), + ("Bearer Token", True, False), + ], +) +@pytest.mark.parametrize("user__username", ["1"]) +def test_authentication( + db, + user, + rf, + authentication_header, + authenticated, + error, + is_id_token, + requests_mock, + settings, +): + userinfo = {"preferred_username": "1"} + requests_mock.get(settings.OIDC_OP_USER_ENDPOINT, text=json.dumps(userinfo)) + + if not is_id_token: + userinfo = {"client_id": "test_client", "preferred_username": "1"} + requests_mock.get( + settings.OIDC_OP_USER_ENDPOINT, status_code=status.HTTP_401_UNAUTHORIZED + ) + requests_mock.post( + settings.OIDC_OP_INTROSPECT_ENDPOINT, text=json.dumps(userinfo) + ) + + request = rf.get("/openid", HTTP_AUTHORIZATION=authentication_header) + try: + result = OIDCAuthentication().authenticate(request) + except exceptions.AuthenticationFailed: + assert error + else: + if result: + key = "userinfo" if is_id_token else "introspection" + user, auth = result + assert user.is_authenticated + assert ( + cache.get(f"auth.{key}.{hashlib.sha256(b'Token').hexdigest()}") + == userinfo + ) + + +@pytest.mark.parametrize( + "create_user,username,expected_count", + [(False, "", 0), (True, "", 1), (True, "foo@example.com", 1)], +) +def test_authentication_new_user( + db, rf, requests_mock, settings, create_user, username, expected_count +): + settings.OIDC_CREATE_USER = create_user + user_model = get_user_model() + assert user_model.objects.filter(username=username).count() == 0 + + userinfo = {"preferred_username": username} + requests_mock.get(settings.OIDC_OP_USER_ENDPOINT, text=json.dumps(userinfo)) + + request = rf.get("/openid", HTTP_AUTHORIZATION="Bearer Token") + + try: + user, _ = OIDCAuthentication().authenticate(request) + except AuthenticationFailed: + assert not create_user + else: + assert user.username == username + + assert user_model.objects.count() == expected_count + + +def test_authentication_idp_502(db, rf, requests_mock, settings): + requests_mock.get( + settings.OIDC_OP_USER_ENDPOINT, status_code=status.HTTP_502_BAD_GATEWAY + ) + + request = rf.get("/openid", HTTP_AUTHORIZATION="Bearer Token") + with pytest.raises(HTTPError): + OIDCAuthentication().authenticate(request) + + +def test_authentication_idp_missing_claim(db, rf, requests_mock, settings): + settings.OIDC_USERNAME_CLAIM = "missing" + userinfo = {"preferred_username": "1"} + requests_mock.get(settings.OIDC_OP_USER_ENDPOINT, text=json.dumps(userinfo)) + + request = rf.get("/openid", HTTP_AUTHORIZATION="Bearer Token") + with pytest.raises(AuthenticationFailed): + OIDCAuthentication().authenticate(request) + + +def test_authentication_no_client(db, rf, requests_mock, settings): + requests_mock.get( + settings.OIDC_OP_USER_ENDPOINT, status_code=status.HTTP_401_UNAUTHORIZED + ) + requests_mock.post( + settings.OIDC_OP_INTROSPECT_ENDPOINT, + text=json.dumps({"preferred_username": "1"}), + ) + + request = rf.get("/openid", HTTP_AUTHORIZATION="Bearer Token") + with pytest.raises(AuthenticationFailed): + OIDCAuthentication().authenticate(request)