diff --git a/benefits/core/models.py b/benefits/core/models.py index 68e289b8b4..2ab2927fb6 100644 --- a/benefits/core/models.py +++ b/benefits/core/models.py @@ -8,7 +8,7 @@ from django.conf import settings from django.core.exceptions import ValidationError -from django.contrib.auth.models import Group +from django.contrib.auth.models import Group, User from django.db import models from django.urls import reverse @@ -435,3 +435,13 @@ def all_active(): """Get all TransitAgency instances marked active.""" logger.debug(f"Get all active {TransitAgency.__name__}") return TransitAgency.objects.filter(active=True) + + @staticmethod + def for_user(user: User): + group = user.groups.first() + + if group is not None: + # TransitAgency to Group is one-to-one, so there will be either 0 or 1 returned + return TransitAgency.objects.filter(group=group).first() + else: + return None diff --git a/tests/pytest/core/test_models.py b/tests/pytest/core/test_models.py index 54e46e8d9c..a40562a7d6 100644 --- a/tests/pytest/core/test_models.py +++ b/tests/pytest/core/test_models.py @@ -1,4 +1,5 @@ from django.conf import settings +from django.contrib.auth.models import Group, User from django.core.exceptions import ValidationError import pytest @@ -464,3 +465,18 @@ def test_TransitAgency_all_active(model_TransitAgency): assert len(result) > 0 assert model_TransitAgency in result assert inactive_agency not in result + + +@pytest.mark.django_db +def test_TransitAgency_for_user(model_TransitAgency): + group = Group.objects.create(name="test_group") + + agency_for_user = TransitAgency.by_id(model_TransitAgency.id) + agency_for_user.pk = None + agency_for_user.group = group + agency_for_user.save() + + user = User.objects.create_user(username="test_user", email="test_user@example.com", password="test", is_staff=True) + user.groups.add(group) + + assert TransitAgency.for_user(user) == agency_for_user