From 6012bff42ce0b2850adc740286297c8de56d9f98 Mon Sep 17 00:00:00 2001 From: Kegan Maher Date: Thu, 31 Oct 2024 19:11:27 +0000 Subject: [PATCH] feat(models): validate active TransitAgency when active=True, validate that: - there are values for user-facing info fields like names, phone, etc. - templates exist --- benefits/core/models.py | 96 ++++++++++++++++++++++---------- tests/pytest/core/test_models.py | 54 +++++++++++++++++- 2 files changed, 121 insertions(+), 29 deletions(-) diff --git a/benefits/core/models.py b/benefits/core/models.py index 0d20bf30ac..225862fad0 100644 --- a/benefits/core/models.py +++ b/benefits/core/models.py @@ -5,8 +5,10 @@ from functools import cached_property import importlib import logging +from pathlib import Path import uuid +from django import template from django.conf import settings from django.core.exceptions import ValidationError from django.contrib.auth.models import Group, User @@ -24,6 +26,22 @@ logger = logging.getLogger(__name__) +def template_path(template_name: str) -> Path: + """Get a `pathlib.Path` for the named template, or None if it can't be found. + + A `template_name` is the app-local name, e.g. `enrollment/success.html`. + + Adapted from https://stackoverflow.com/a/75863472. + """ + for engine in template.engines.all(): + for loader in engine.engine.template_loaders: + for origin in loader.get_template_sources(template_name): + path = Path(origin.name) + if path.exists() and path.is_file(): + return path + return None + + class SecretNameField(models.SlugField): """Field that stores the name of a secret held in a secret store. @@ -264,6 +282,30 @@ def transit_processor_client_secret(self): def enrollment_flows(self): return self.enrollmentflow_set + def clean(self): + if self.active: + errors = {} + message = "This field is required for active transit agencies." + needed = dict( + short_name=self.short_name, + long_name=self.long_name, + phone=self.phone, + info_url=self.info_url, + ) + for k, v in needed.items(): + if not v: + errors[k] = ValidationError(message) + + if not template_path(self.index_template): + errors["index_template"] = ValidationError(f"Template not found: {self.index_template}") + if not template_path(self.eligibility_index_template): + errors["eligibility_index_template"] = ValidationError( + f"Template not found: {self.eligibility_index_template}" + ) + + if errors: + raise ValidationError(errors) + @staticmethod def by_id(id): """Get a TransitAgency instance by its ID.""" @@ -493,6 +535,17 @@ def uses_claims_verification(self): """True if this flow verifies via the claims provider and has a scope and claim. False otherwise.""" return self.claims_provider is not None and bool(self.claims_scope) and bool(self.claims_eligibility_claim) + @property + def claims_scheme(self): + return self.claims_scheme_override or self.claims_provider.scheme + + @property + def claims_all_claims(self): + claims = [self.claims_eligibility_claim] + if self.claims_extra_claims is not None: + claims.extend(self.claims_extra_claims.split()) + return claims + @property def eligibility_verifier(self): """A str representing the entity that verifies eligibility for this flow. @@ -520,23 +573,6 @@ def enrollment_success_template(self): else: return self.enrollment_success_template_override or f"{prefix}--{self.agency_card_name}.html" - def eligibility_form_instance(self, *args, **kwargs): - """Return an instance of this flow's EligibilityForm, or None.""" - if not bool(self.eligibility_form_class): - return None - - # inspired by https://stackoverflow.com/a/30941292 - module_name, class_name = self.eligibility_form_class.rsplit(".", 1) - FormClass = getattr(importlib.import_module(module_name), class_name) - - return FormClass(*args, **kwargs) - - @staticmethod - def by_id(id): - """Get an EnrollmentFlow instance by its ID.""" - logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}") - return EnrollmentFlow.objects.get(id=id) - def clean(self): supports_expiration = self.supports_expiration expiration_days = self.expiration_days @@ -556,18 +592,22 @@ def clean(self): if errors: raise ValidationError(errors) - @property - def claims_scheme(self): - if not self.claims_scheme_override: - return self.claims_provider.scheme - return self.claims_scheme_override + def eligibility_form_instance(self, *args, **kwargs): + """Return an instance of this flow's EligibilityForm, or None.""" + if not bool(self.eligibility_form_class): + return None - @property - def claims_all_claims(self): - claims = [self.claims_eligibility_claim] - if self.claims_extra_claims is not None: - claims.extend(self.claims_extra_claims.split()) - return claims + # inspired by https://stackoverflow.com/a/30941292 + module_name, class_name = self.eligibility_form_class.rsplit(".", 1) + FormClass = getattr(importlib.import_module(module_name), class_name) + + return FormClass(*args, **kwargs) + + @staticmethod + def by_id(id): + """Get an EnrollmentFlow instance by its ID.""" + logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}") + return EnrollmentFlow.objects.get(id=id) class EnrollmentEvent(models.Model): diff --git a/tests/pytest/core/test_models.py b/tests/pytest/core/test_models.py index e2be8c339d..6290350b80 100644 --- a/tests/pytest/core/test_models.py +++ b/tests/pytest/core/test_models.py @@ -1,4 +1,6 @@ from datetime import timedelta +from pathlib import Path + from django.conf import settings from django.contrib.auth.models import Group, User from django.core.exceptions import ValidationError @@ -6,7 +8,14 @@ import pytest -from benefits.core.models import SecretNameField, EnrollmentFlow, TransitAgency, EnrollmentEvent, EnrollmentMethods +from benefits.core.models import ( + template_path, + SecretNameField, + EnrollmentFlow, + TransitAgency, + EnrollmentEvent, + EnrollmentMethods, +) import benefits.secrets @@ -16,6 +25,25 @@ def mock_requests_get_pem_data(mocker): return mocker.patch("benefits.core.models.requests.get", return_value=mocker.Mock(text="PEM text")) +@pytest.mark.django_db +@pytest.mark.parametrize( + "input_template,expected_path", + [ + ("error.html", f"{settings.BASE_DIR}/benefits/templates/error.html"), + ("core/index.html", f"{settings.BASE_DIR}/benefits/core/templates/core/index.html"), + ("eligibility/start.html", f"{settings.BASE_DIR}/benefits/eligibility/templates/eligibility/start.html"), + ("", None), + ("nope.html", None), + ("core/not-there.html", None), + ], +) +def test_template_path(input_template, expected_path): + if expected_path: + assert template_path(input_template) == Path(expected_path) + else: + assert template_path(input_template) is None + + def test_SecretNameField_init(): field = SecretNameField() @@ -518,6 +546,30 @@ def test_TransitAgency_for_user_in_group_not_linked_to_any_agency(): assert TransitAgency.for_user(user) is None +@pytest.mark.django_db +def test_TransitAgency_clean(model_TransitAgency_inactive): + model_TransitAgency_inactive.short_name = "" + model_TransitAgency_inactive.long_name = "" + model_TransitAgency_inactive.phone = "" + model_TransitAgency_inactive.info_url = "" + model_TransitAgency_inactive.index_template_override = "does/not/exist.html" + model_TransitAgency_inactive.eligibility_index_template_override = "does/not/exist.html" + # agency is inactive, OK to have incomplete fields + model_TransitAgency_inactive.clean() + + # now mark it active and expect failure on clean() + model_TransitAgency_inactive.active = True + with pytest.raises(ValidationError) as e: + model_TransitAgency_inactive.clean() + + assert "short_name" in e.value.error_dict + assert "long_name" in e.value.error_dict + assert "phone" in e.value.error_dict + assert "info_url" in e.value.error_dict + assert "index_template" in e.value.error_dict + assert "eligibility_index_template" in e.value.error_dict + + @pytest.mark.django_db def test_EnrollmentEvent_create(model_TransitAgency, model_EnrollmentFlow): ts = timezone.now()