From 44ce8d99a2d114bfa9eb0df43abb3c7bb205a3fa Mon Sep 17 00:00:00 2001 From: Marcos Prieto Date: Wed, 23 Oct 2024 13:27:43 +0200 Subject: [PATCH] Consider LTI role to grant admin status in the dashboards Give access to the dashboards for full organizations to users that the LMS identifies as system wide admins (scope: system, type admin in our LTIRoles table) --- lms/models/lms_course.py | 8 +++ lms/models/lti_user.py | 2 +- lms/services/dashboard.py | 69 +++++++++++++++++++---- lms/views/dashboard/views.py | 6 -- tests/factories/__init__.py | 6 +- tests/factories/lms_course.py | 4 ++ tests/unit/lms/services/dashboard_test.py | 62 ++++++++++++++++---- 7 files changed, 126 insertions(+), 31 deletions(-) diff --git a/lms/models/lms_course.py b/lms/models/lms_course.py index 6d01cfbdde..558be24957 100644 --- a/lms/models/lms_course.py +++ b/lms/models/lms_course.py @@ -7,6 +7,8 @@ - LMSCourse membership stores role information, GroupingMembership doesn't. """ +from typing import TYPE_CHECKING + import sqlalchemy as sa from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -14,6 +16,9 @@ from lms.models import ApplicationInstance from lms.models._mixins import CreatedUpdatedMixin +if TYPE_CHECKING: + from lms.models import LMSUser, LTIRole + class LMSCourse(CreatedUpdatedMixin, Base): __tablename__ = "lms_course" @@ -75,12 +80,15 @@ class LMSCourseMembership(CreatedUpdatedMixin, Base): lms_course_id: Mapped[int] = mapped_column( sa.ForeignKey("lms_course.id", ondelete="cascade"), index=True ) + lms_course: Mapped[LMSCourse] = relationship() lms_user_id: Mapped[int] = mapped_column( sa.ForeignKey("lms_user.id", ondelete="cascade"), index=True ) + lms_user: Mapped["LMSUser"] = relationship() lti_role_id: Mapped[int] = mapped_column( sa.ForeignKey("lti_role.id", ondelete="cascade"), index=True, ) + lti_role: Mapped["LTIRole"] = relationship() diff --git a/lms/models/lti_user.py b/lms/models/lti_user.py index 0f62ca6cf6..32d05789ac 100644 --- a/lms/models/lti_user.py +++ b/lms/models/lti_user.py @@ -65,7 +65,7 @@ class LTIUser: """The user's email address.""" @property - def h_user(self): + def h_user(self) -> HUser: """Return a models.HUser generated from this LTIUser.""" return HUser.from_lti_user(self) diff --git a/lms/services/dashboard.py b/lms/services/dashboard.py index f12e8915f3..9e65dd99b7 100644 --- a/lms/services/dashboard.py +++ b/lms/services/dashboard.py @@ -1,25 +1,38 @@ from pyramid.httpexceptions import HTTPNotFound, HTTPUnauthorized -from sqlalchemy import select - -from lms.models import Assignment, Organization +from sqlalchemy import select, union + +from lms.models import ( + ApplicationInstance, + Assignment, + LMSCourse, + LMSCourseApplicationInstance, + LMSCourseMembership, + LMSUser, + LTIRole, + Organization, + RoleScope, + RoleType, +) from lms.models.dashboard_admin import DashboardAdmin from lms.security import Permissions from lms.services.organization import OrganizationService class DashboardService: - def __init__( + def __init__( # noqa: PLR0913 self, request, assignment_service, course_service, organization_service: OrganizationService, + h_authority: str, ): self._db = request.db self._assignment_service = assignment_service self._course_service = course_service self._organization_service = organization_service + self._h_authority = h_authority def get_request_assignment(self, request) -> Assignment: """Get and authorize an assignment for the given request.""" @@ -72,14 +85,44 @@ def get_request_course(self, request): return course - def get_organizations_by_admin_email(self, email: str) -> list[Organization]: - """Get a list of organizations where the user with email `email` is an admin in.""" + def get_organizations_where_admin( + self, h_userid: str, email: str + ) -> list[Organization]: + """Get a list of organizations where the user h_userid with email `email` is an admin in.""" organization_ids = [] + # A user can be an admin in an organization via having a matching email in DashboardAdmin + organization_id_by_email = select(DashboardAdmin.organization_id).where( + DashboardAdmin.email == email + ) + + # It also can be an admin via having the relevant LTI role in an organization. + organization_id_by_lti_admin = ( + select(Organization.id) + .join( + ApplicationInstance, + ApplicationInstance.organization_id == Organization.id, + ) + .join( + LMSCourseApplicationInstance, + LMSCourseApplicationInstance.application_instance_id + == ApplicationInstance.id, + ) + .join(LMSCourse, LMSCourse.id == LMSCourseApplicationInstance.lms_course_id) + .join( + LMSCourseMembership, LMSCourseMembership.lms_course_id == LMSCourse.id + ) + .join(LMSUser, LMSCourseMembership.lms_user_id == LMSUser.id) + .join(LTIRole) + .where( + LMSUser.h_userid == h_userid, + LTIRole.type == RoleType.ADMIN, + LTIRole.scope == RoleScope.SYSTEM, + ) + ) + for org_id in self._db.scalars( - select(DashboardAdmin.organization_id) - .where(DashboardAdmin.email == email) - .distinct() + union(organization_id_by_email, organization_id_by_lti_admin) ).all(): organization_ids.extend( self._organization_service.get_hierarchy_ids(org_id) @@ -125,8 +168,11 @@ def get_request_admin_organizations(self, request) -> list[Organization]: ) ).all() - return self.get_organizations_by_admin_email( - request.lti_user.email if request.lti_user else request.identity.userid + return self.get_organizations_where_admin( + h_userid=request.lti_user.h_user.userid(self._h_authority), + email=request.lti_user.email + if request.lti_user + else request.identity.userid, ) @@ -136,4 +182,5 @@ def factory(_context, request): assignment_service=request.find_service(name="assignment"), course_service=request.find_service(name="course"), organization_service=request.find_service(OrganizationService), + h_authority=request.registry.settings["h_authority"], ) diff --git a/lms/views/dashboard/views.py b/lms/views/dashboard/views.py index a66141d104..785c972e9c 100644 --- a/lms/views/dashboard/views.py +++ b/lms/views/dashboard/views.py @@ -44,12 +44,6 @@ def __init__(self, request) -> None: ) self.dashboard_service = request.find_service(name="dashboard") - self.admin_organizations = ( - self.dashboard_service.get_organizations_by_admin_email( - request.lti_user.email if request.lti_user else request.identity.userid - ) - ) - @view_config( route_name="dashboard.launch.assignment", permission=Permissions.DASHBOARD_VIEW, diff --git a/tests/factories/__init__.py b/tests/factories/__init__.py index 22a32e71aa..9a09eea392 100644 --- a/tests/factories/__init__.py +++ b/tests/factories/__init__.py @@ -32,7 +32,11 @@ from tests.factories.h_user import HUser from tests.factories.hubspot_company import HubSpotCompany from tests.factories.jwt_oauth2_token import JWTOAuth2Token -from tests.factories.lms_course import LMSCourse, LMSCourseApplicationInstance +from tests.factories.lms_course import ( + LMSCourse, + LMSCourseApplicationInstance, + LMSCourseMembership, +) from tests.factories.lms_user import LMSUser from tests.factories.lti_registration import LTIRegistration from tests.factories.lti_role import LTIRole, LTIRoleOverride diff --git a/tests/factories/lms_course.py b/tests/factories/lms_course.py index 6e03da44a0..1105556f52 100644 --- a/tests/factories/lms_course.py +++ b/tests/factories/lms_course.py @@ -16,3 +16,7 @@ LMSCourseApplicationInstance = make_factory( models.LMSCourseApplicationInstance, FACTORY_CLASS=SQLAlchemyModelFactory ) + +LMSCourseMembership = make_factory( + models.LMSCourseMembership, FACTORY_CLASS=SQLAlchemyModelFactory +) diff --git a/tests/unit/lms/services/dashboard_test.py b/tests/unit/lms/services/dashboard_test.py index b54eb5be5f..8c16077554 100644 --- a/tests/unit/lms/services/dashboard_test.py +++ b/tests/unit/lms/services/dashboard_test.py @@ -3,7 +3,7 @@ import pytest from pyramid.httpexceptions import HTTPNotFound, HTTPUnauthorized -from lms.models.dashboard_admin import DashboardAdmin +from lms.models import DashboardAdmin, RoleScope, RoleType from lms.services.dashboard import DashboardService, factory from tests import factories @@ -132,23 +132,56 @@ def test_delete_dashboard_admin(self, svc, db_session, organization): assert not db_session.query(DashboardAdmin).filter_by(id=admin.id).first() - def test_get_organizations_by_admin_email( + def test_get_organizations_where_admin( self, svc, db_session, organization, organization_service ): + # Admin user + lms_admin = factories.LMSUser(h_userid="admin") + + # Organization where just a teacher + organization_lti_teacher = factories.Organization() + teacher_course = factories.LMSCourse() + ai = factories.ApplicationInstance(organization=organization_lti_teacher) + factories.LMSCourseApplicationInstance( + lms_course=teacher_course, application_instance=ai + ) + factories.LMSCourseMembership( + lms_course=teacher_course, + lms_user=lms_admin, + lti_role=factories.LTIRole( + type=RoleType.INSTRUCTOR, scope=RoleScope.COURSE + ), + ) + + # Organization where admin via an LTIRole + organization_lti_admin = factories.Organization(parent=organization_lti_teacher) + course = factories.LMSCourse() + ai = factories.ApplicationInstance(organization=organization_lti_admin) + factories.LMSCourseApplicationInstance( + lms_course=course, application_instance=ai + ) + factories.LMSCourseMembership( + lms_course=course, + lms_user=lms_admin, + lti_role=factories.LTIRole(type=RoleType.ADMIN, scope=RoleScope.SYSTEM), + ) + # Organization where admin via email child_organization = factories.Organization(parent=organization) - admin = factories.DashboardAdmin( + email_admin = factories.DashboardAdmin( organization=organization, email="testing@example.com", created_by="creator" ) db_session.flush() - organization_service.get_hierarchy_ids.return_value = [ - organization.id, - child_organization.id, - ] + organization_service.get_hierarchy_ids.side_effect = ( + [ + organization.id, + child_organization.id, + ], + [organization_lti_admin.id], + ) - assert set(svc.get_organizations_by_admin_email(admin.email)) == { - organization, - child_organization, - } + assert set( + svc.get_organizations_where_admin(lms_admin.h_userid, email_admin.email) + ) == {organization, child_organization, organization_lti_admin} def test_get_request_admin_organizations_for_non_staff(self, pyramid_request, svc): pyramid_request.params = {"org_public_id": sentinel.public_id} @@ -200,7 +233,11 @@ def svc( self, assignment_service, course_service, organization_service, pyramid_request ): return DashboardService( - pyramid_request, assignment_service, course_service, organization_service + pyramid_request, + assignment_service, + course_service, + organization_service, + "authority", ) @pytest.fixture(autouse=True) @@ -244,6 +281,7 @@ def test_it( assignment_service=assignment_service, course_service=course_service, organization_service=organization_service, + h_authority="lms.hypothes.is", ) assert service == DashboardService.return_value