Skip to content

Commit

Permalink
Minor refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
tadhg-ohiggins authored and timoballard committed May 17, 2023
1 parent 1d64ac9 commit 36a5468
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 71 deletions.
37 changes: 21 additions & 16 deletions backend/audit/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,23 @@ def __init__(self, message, eligible_users):


def has_access(sac, user):
accesses = Access.objects.filter(sac=sac, user=user)
if not accesses:
return False

return True
"""Does a user have permission to access a submission?"""
return bool(Access.objects.filter(sac=sac, user=user))


def has_role(sac, user, role):
accesses = Access.objects.filter(sac=sac, user=user, role=role)
if not accesses:
return False

return True
"""Does a user have a specific role on a submission?"""
return bool(Access.objects.filter(sac=sac, user=user, role=role))


class SingleAuditChecklistAccessRequiredMixin(LoginRequiredMixin):
"""
View mixin to require that a user is logged in and has access to the submission.
"""

def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
try:
report_id = kwargs["report_id"]
sac = SingleAuditChecklist.objects.get(report_id=report_id)
sac = SingleAuditChecklist.objects.get(report_id=kwargs["report_id"])

if not has_access(sac, request.user):
raise PermissionDenied("You do not have access to this audit.")
Expand All @@ -48,11 +45,15 @@ def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpRespo


class CertifyingAuditeeRequiredMixin(LoginRequiredMixin):
"""
View mixin to require that a user is logged in, has access to the submission, and has
the ``certifying_auditee_contact`` role.
"""

def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
role = "certifying_auditee_contact"
try:
report_id = kwargs["report_id"]
sac = SingleAuditChecklist.objects.get(report_id=report_id)
sac = SingleAuditChecklist.objects.get(report_id=kwargs["report_id"])

if not has_access(sac, request.user):
raise PermissionDenied("You do not have access to this audit")
Expand All @@ -73,11 +74,15 @@ def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpRespo


class CertifyingAuditorRequiredMixin(LoginRequiredMixin):
"""
View mixin to require that a user is logged in, has access to the submission, and has
the ``certifying_auditor_contact`` role.
"""

def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
role = "certifying_auditor_contact"
try:
report_id = kwargs["report_id"]
sac = SingleAuditChecklist.objects.get(report_id=report_id)
sac = SingleAuditChecklist.objects.get(report_id=kwargs["report_id"])

if not has_access(sac, request.user):
raise PermissionDenied("You do not have access to this audit")
Expand Down
58 changes: 20 additions & 38 deletions backend/audit/test_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,15 @@ def get(self, request, *args, **kwargs):
pass

def test_missing_report_id_raises(self):
factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
view = self.ViewStub()

self.assertRaises(KeyError, view.dispatch, request)

def test_nonexistent_sac_raises(self):
factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
view = self.ViewStub()

self.assertRaises(
PermissionDenied, view.dispatch, request, report_id="not-a-real-report-id"
)
Expand All @@ -41,9 +39,7 @@ def test_no_access_raises(self):
user = baker.make(User)
sac = baker.make(SingleAuditChecklist)

factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
request.user = user

view = self.ViewStub()
Expand All @@ -56,9 +52,7 @@ def test_has_access(self):
sac = baker.make(SingleAuditChecklist)
baker.make(Access, sac=sac, user=user)

factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
request.user = user

view = self.ViewStub()
Expand All @@ -71,17 +65,15 @@ def get(self, request, *args, **kwargs):
pass

def test_missing_report_id_raises(self):
factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
view = self.ViewStub()

self.assertRaises(KeyError, view.dispatch, request)

def test_nonexistent_sac_raises(self):
factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
view = self.ViewStub()

self.assertRaises(
PermissionDenied, view.dispatch, request, report_id="not-a-real-report-id"
)
Expand All @@ -90,12 +82,10 @@ def test_no_access_raises(self):
user = baker.make(User)
sac = baker.make(SingleAuditChecklist)

factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
request.user = user

view = self.ViewStub()

self.assertRaises(
PermissionDenied, view.dispatch, request, report_id=sac.report_id
)
Expand Down Expand Up @@ -134,9 +124,7 @@ def test_has_role(self):

baker.make(Access, sac=sac, user=user, role="certifying_auditee_contact")

factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
request.user = user

view = self.ViewStub()
Expand All @@ -149,17 +137,15 @@ def get(self, request, *args, **kwargs):
pass

def test_missing_report_id_raises(self):
factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
view = self.ViewStub()

self.assertRaises(KeyError, view.dispatch, request)

def test_nonexistent_sac_raises(self):
factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
view = self.ViewStub()

self.assertRaises(
PermissionDenied, view.dispatch, request, report_id="not-a-real-report-id"
)
Expand All @@ -168,12 +154,10 @@ def test_no_access_raises(self):
user = baker.make(User)
sac = baker.make(SingleAuditChecklist)

factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
request.user = user

view = self.ViewStub()

self.assertRaises(
PermissionDenied, view.dispatch, request, report_id=sac.report_id
)
Expand Down Expand Up @@ -212,9 +196,7 @@ def test_has_role(self):

baker.make(Access, sac=sac, user=user, role="certifying_auditor_contact")

factory = RequestFactory()
request = factory.get("/")

request = RequestFactory().get("/")
request.user = user

view = self.ViewStub()
Expand Down
30 changes: 13 additions & 17 deletions backend/audit/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@
# Mocking the user login and file scan functions
def _mock_login_and_scan(client, mock_scan_file):
"""Helper function to mock the login and file scan functions"""
user = baker.make(User)

sac = baker.make(SingleAuditChecklist)
user, sac = _make_user_and_sac()

baker.make(Access, user=user, sac=sac)

Expand All @@ -93,6 +91,12 @@ def _authed_post(client, user, view_str, kwargs=None, data=None):
return _client_post(client, view_str, kwargs, data)


def _make_user_and_sac(**kwargs):
user = baker.make(User)
sac = baker.make(SingleAuditChecklist, **kwargs)
return user, sac


class MySubmissionsViewTests(TestCase):
def setUp(self):
self.user = baker.make(User)
Expand Down Expand Up @@ -187,10 +191,7 @@ def test_auditor_certification(self):
"""
Test that certifying auditor contacts can provide auditor certification
"""
user = baker.make(User)
sac = baker.make(
SingleAuditChecklist, submission_status="ready_for_certification"
)
user, sac = _make_user_and_sac(submission_status="ready_for_certification")
baker.make(Access, sac=sac, user=user, role="certifying_auditor_contact")

kwargs = {"report_id": sac.report_id}
Expand All @@ -204,8 +205,7 @@ def test_auditee_certification(self):
"""
Test that certifying auditee contacts can provide auditee certification
"""
user = baker.make(User)
sac = baker.make(SingleAuditChecklist, submission_status="auditor_certified")
user, sac = _make_user_and_sac(submission_status="auditor_certified")
baker.make(Access, sac=sac, user=user, role="certifying_auditee_contact")

kwargs = {"report_id": sac.report_id}
Expand All @@ -219,8 +219,7 @@ def test_submission(self):
"""
Test that certifying auditee contacts can perform submission
"""
user = baker.make(User)
sac = baker.make(SingleAuditChecklist, submission_status="auditee_certified")
user, sac = _make_user_and_sac(submission_status="auditee_certified")
baker.make(Access, sac=sac, user=user, role="certifying_auditee_contact")

kwargs = {"report_id": sac.report_id}
Expand Down Expand Up @@ -291,8 +290,7 @@ def test_bad_report_id_returns_403(self):

def test_inaccessible_audit_returns_403(self):
"""When a request is made for an audit that is inaccessible for this user, a 403 error should be returned"""
user = baker.make(User)
sac = baker.make(SingleAuditChecklist)
user, sac = _make_user_and_sac()

self.client.force_login(user)
for form_section in EXCEL_FILES:
Expand All @@ -307,8 +305,7 @@ def test_inaccessible_audit_returns_403(self):

def test_no_file_attached_returns_400(self):
"""When a request is made with no file attached, a 400 error should be returned"""
user = baker.make(User)
sac = baker.make(SingleAuditChecklist)
user, sac = _make_user_and_sac()
baker.make(Access, user=user, sac=sac)

self.client.force_login(user)
Expand All @@ -325,8 +322,7 @@ def test_no_file_attached_returns_400(self):

def test_invalid_file_upload_returns_400(self):
"""When an invalid Excel file is uploaded, a 400 error should be returned"""
user = baker.make(User)
sac = baker.make(SingleAuditChecklist)
user, sac = _make_user_and_sac()
baker.make(Access, user=user, sac=sac)

self.client.force_login(user)
Expand Down

0 comments on commit 36a5468

Please sign in to comment.