diff --git a/oauthenticator/google.py b/oauthenticator/google.py index f9d386b1..72ef7bd9 100644 --- a/oauthenticator/google.py +++ b/oauthenticator/google.py @@ -4,6 +4,7 @@ import os +import aiohttp from jupyterhub.auth import LocalAuthenticator from tornado.auth import GoogleOAuth2Mixin from tornado.web import HTTPError @@ -243,7 +244,7 @@ async def update_auth_model(self, auth_model): user_groups = set() if self.allowed_google_groups or self.admin_google_groups: - user_groups = self._fetch_member_groups(user_email, user_domain) + user_groups = await self._fetch_member_groups(user_email, user_domain) # sets are not JSONable, cast to list for auth_state user_info["google_groups"] = list(user_groups) @@ -314,7 +315,7 @@ async def check_allowed(self, username, auth_model): # users should be explicitly allowed via config, otherwise they aren't return False - def _service_client_credentials(self, scopes, user_email_domain): + async def _service_client_credentials(self, scopes, user_email_domain): """ Return a configured service client credentials for the API. """ @@ -338,47 +339,28 @@ def _service_client_credentials(self, scopes, user_email_domain): return credentials - def _service_client(self, service_name, service_version, credentials, http=None): + async def _setup_credentials(self, user_email_domain): """ - Return a configured service client for the API. + Set up the service client for Google API. """ + credentials = await self._service_client_credentials( + scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"], + user_email_domain=user_email_domain, + ) + try: - from googleapiclient.discovery import build + from google.auth.transport import requests except: raise ImportError( - "Could not import googleapiclient.discovery's build," + "Could not import google.auth.transport's requests," "you may need to run 'pip install oauthenticator[googlegroups]' or not declare google groups" ) - self.log.debug( - f"service_name is {service_name}, service_version is {service_version}" - ) - - return build( - serviceName=service_name, - version=service_version, - credentials=credentials, - cache_discovery=False, - http=http, - ) - - def _setup_service(self, user_email_domain, http=None): - """ - Set up the service client for Google API. - """ - credentials = self._service_client_credentials( - scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"], - user_email_domain=user_email_domain, - ) - service = self._service_client( - service_name='admin', - service_version='directory_v1', - credentials=credentials, - http=http, - ) - return service + request = requests.Request() + credentials.refresh(request) + return credentials - def _fetch_member_groups( + async def _fetch_member_groups( self, member_email, user_email_domain, @@ -389,17 +371,22 @@ def _fetch_member_groups( """ Return a set with the google groups a given user/group is a member of, including nested groups if allowed. """ - # FIXME: When this function is used and waiting for web request - # responses, JupyterHub gets blocked from doing other things. - # Ideally the web requests should be made using an async client - # that can be awaited while JupyterHub handles other things. - # - if not hasattr(self, 'service'): - self.service = self._setup_service(user_email_domain, http) - - resp = self.service.groups().list(userKey=member_email).execute() + if not hasattr(self, 'credentials'): + self.credentials = await self._setup_credentials(user_email_domain) + + headers = {'Authorization': f'Bearer {self.credentials.token}'} + + async with aiohttp.ClientSession() as session: + async with session.get( + f'https://www.googleapis.com/admin/directory/v1/groups?userKey={member_email}', + headers=headers, + ) as resp: + group_data = await resp.json() + member_groups = { - g['email'].split('@')[0] for g in resp.get('groups', []) if g.get('email') + g['email'].split('@')[0] + for g in group_data.get('groups', []) + if g.get('email') } self.log.debug(f"Fetched groups for {member_email}: {member_groups}") @@ -410,7 +397,7 @@ def _fetch_member_groups( for group in member_groups: if group not in processed_groups: processed_groups.add(group) - nested_groups = self._fetch_member_groups( + nested_groups = await self._fetch_member_groups( f"{group}@{user_email_domain}", user_email_domain, http, diff --git a/oauthenticator/tests/test_google.py b/oauthenticator/tests/test_google.py index 4c3a9e0c..070de014 100644 --- a/oauthenticator/tests/test_google.py +++ b/oauthenticator/tests/test_google.py @@ -3,6 +3,7 @@ import logging import re from unittest import mock +from unittest.mock import AsyncMock from pytest import fixture, mark, raises from traitlets.config import Config @@ -211,7 +212,7 @@ async def test_google( handled_user_model = user_model("user1@example.com", "user1") handler = google_client.handler_for_user(handled_user_model) with mock.patch.object( - authenticator, "_fetch_member_groups", lambda *args: {"group1"} + authenticator, "_fetch_member_groups", AsyncMock(return_value={"group1"}) ): auth_model = await authenticator.get_authenticated_user(handler, None)