Skip to content

Commit

Permalink
feat(models): validate active TransitAgency
Browse files Browse the repository at this point in the history
when active=True, validate that:

- there are values for user-facing info fields like names, phone, etc.
- templates exist
  • Loading branch information
thekaveman committed Nov 6, 2024
1 parent f57345e commit 6012bff
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 29 deletions.
96 changes: 68 additions & 28 deletions benefits/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
from functools import cached_property
import importlib
import logging
from pathlib import Path
import uuid

from django import template
from django.conf import settings
from django.core.exceptions import ValidationError
from django.contrib.auth.models import Group, User
Expand All @@ -24,6 +26,22 @@
logger = logging.getLogger(__name__)


def template_path(template_name: str) -> Path:
"""Get a `pathlib.Path` for the named template, or None if it can't be found.
A `template_name` is the app-local name, e.g. `enrollment/success.html`.
Adapted from https://stackoverflow.com/a/75863472.
"""
for engine in template.engines.all():
for loader in engine.engine.template_loaders:
for origin in loader.get_template_sources(template_name):
path = Path(origin.name)
if path.exists() and path.is_file():
return path
return None


class SecretNameField(models.SlugField):
"""Field that stores the name of a secret held in a secret store.
Expand Down Expand Up @@ -264,6 +282,30 @@ def transit_processor_client_secret(self):
def enrollment_flows(self):
return self.enrollmentflow_set

def clean(self):
if self.active:
errors = {}
message = "This field is required for active transit agencies."
needed = dict(
short_name=self.short_name,
long_name=self.long_name,
phone=self.phone,
info_url=self.info_url,
)
for k, v in needed.items():
if not v:
errors[k] = ValidationError(message)

if not template_path(self.index_template):
errors["index_template"] = ValidationError(f"Template not found: {self.index_template}")
if not template_path(self.eligibility_index_template):
errors["eligibility_index_template"] = ValidationError(
f"Template not found: {self.eligibility_index_template}"
)

if errors:
raise ValidationError(errors)

@staticmethod
def by_id(id):
"""Get a TransitAgency instance by its ID."""
Expand Down Expand Up @@ -493,6 +535,17 @@ def uses_claims_verification(self):
"""True if this flow verifies via the claims provider and has a scope and claim. False otherwise."""
return self.claims_provider is not None and bool(self.claims_scope) and bool(self.claims_eligibility_claim)

@property
def claims_scheme(self):
return self.claims_scheme_override or self.claims_provider.scheme

@property
def claims_all_claims(self):
claims = [self.claims_eligibility_claim]
if self.claims_extra_claims is not None:
claims.extend(self.claims_extra_claims.split())
return claims

@property
def eligibility_verifier(self):
"""A str representing the entity that verifies eligibility for this flow.
Expand Down Expand Up @@ -520,23 +573,6 @@ def enrollment_success_template(self):
else:
return self.enrollment_success_template_override or f"{prefix}--{self.agency_card_name}.html"

def eligibility_form_instance(self, *args, **kwargs):
"""Return an instance of this flow's EligibilityForm, or None."""
if not bool(self.eligibility_form_class):
return None

# inspired by https://stackoverflow.com/a/30941292
module_name, class_name = self.eligibility_form_class.rsplit(".", 1)
FormClass = getattr(importlib.import_module(module_name), class_name)

return FormClass(*args, **kwargs)

@staticmethod
def by_id(id):
"""Get an EnrollmentFlow instance by its ID."""
logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}")
return EnrollmentFlow.objects.get(id=id)

def clean(self):
supports_expiration = self.supports_expiration
expiration_days = self.expiration_days
Expand All @@ -556,18 +592,22 @@ def clean(self):
if errors:
raise ValidationError(errors)

@property
def claims_scheme(self):
if not self.claims_scheme_override:
return self.claims_provider.scheme
return self.claims_scheme_override
def eligibility_form_instance(self, *args, **kwargs):
"""Return an instance of this flow's EligibilityForm, or None."""
if not bool(self.eligibility_form_class):
return None

@property
def claims_all_claims(self):
claims = [self.claims_eligibility_claim]
if self.claims_extra_claims is not None:
claims.extend(self.claims_extra_claims.split())
return claims
# inspired by https://stackoverflow.com/a/30941292
module_name, class_name = self.eligibility_form_class.rsplit(".", 1)
FormClass = getattr(importlib.import_module(module_name), class_name)

return FormClass(*args, **kwargs)

@staticmethod
def by_id(id):
"""Get an EnrollmentFlow instance by its ID."""
logger.debug(f"Get {EnrollmentFlow.__name__} by id: {id}")
return EnrollmentFlow.objects.get(id=id)


class EnrollmentEvent(models.Model):
Expand Down
54 changes: 53 additions & 1 deletion tests/pytest/core/test_models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from datetime import timedelta
from pathlib import Path

from django.conf import settings
from django.contrib.auth.models import Group, User
from django.core.exceptions import ValidationError
from django.utils import timezone

import pytest

from benefits.core.models import SecretNameField, EnrollmentFlow, TransitAgency, EnrollmentEvent, EnrollmentMethods
from benefits.core.models import (
template_path,
SecretNameField,
EnrollmentFlow,
TransitAgency,
EnrollmentEvent,
EnrollmentMethods,
)
import benefits.secrets


Expand All @@ -16,6 +25,25 @@ def mock_requests_get_pem_data(mocker):
return mocker.patch("benefits.core.models.requests.get", return_value=mocker.Mock(text="PEM text"))


@pytest.mark.django_db
@pytest.mark.parametrize(
"input_template,expected_path",
[
("error.html", f"{settings.BASE_DIR}/benefits/templates/error.html"),
("core/index.html", f"{settings.BASE_DIR}/benefits/core/templates/core/index.html"),
("eligibility/start.html", f"{settings.BASE_DIR}/benefits/eligibility/templates/eligibility/start.html"),
("", None),
("nope.html", None),
("core/not-there.html", None),
],
)
def test_template_path(input_template, expected_path):
if expected_path:
assert template_path(input_template) == Path(expected_path)
else:
assert template_path(input_template) is None


def test_SecretNameField_init():
field = SecretNameField()

Expand Down Expand Up @@ -518,6 +546,30 @@ def test_TransitAgency_for_user_in_group_not_linked_to_any_agency():
assert TransitAgency.for_user(user) is None


@pytest.mark.django_db
def test_TransitAgency_clean(model_TransitAgency_inactive):
model_TransitAgency_inactive.short_name = ""
model_TransitAgency_inactive.long_name = ""
model_TransitAgency_inactive.phone = ""
model_TransitAgency_inactive.info_url = ""
model_TransitAgency_inactive.index_template_override = "does/not/exist.html"
model_TransitAgency_inactive.eligibility_index_template_override = "does/not/exist.html"
# agency is inactive, OK to have incomplete fields
model_TransitAgency_inactive.clean()

# now mark it active and expect failure on clean()
model_TransitAgency_inactive.active = True
with pytest.raises(ValidationError) as e:
model_TransitAgency_inactive.clean()

assert "short_name" in e.value.error_dict
assert "long_name" in e.value.error_dict
assert "phone" in e.value.error_dict
assert "info_url" in e.value.error_dict
assert "index_template" in e.value.error_dict
assert "eligibility_index_template" in e.value.error_dict


@pytest.mark.django_db
def test_EnrollmentEvent_create(model_TransitAgency, model_EnrollmentFlow):
ts = timezone.now()
Expand Down

0 comments on commit 6012bff

Please sign in to comment.