Skip to content

Commit

Permalink
Add async group lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
jrdnbradford committed Sep 26, 2024
1 parent 96723da commit e7c3bde
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 46 deletions.
77 changes: 32 additions & 45 deletions oauthenticator/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os

import aiohttp
from jupyterhub.auth import LocalAuthenticator
from tornado.auth import GoogleOAuth2Mixin
from tornado.web import HTTPError
Expand Down Expand Up @@ -243,7 +244,7 @@ async def update_auth_model(self, auth_model):

user_groups = set()
if self.allowed_google_groups or self.admin_google_groups:
user_groups = self._fetch_member_groups(user_email, user_domain)
user_groups = await self._fetch_member_groups(user_email, user_domain)
# sets are not JSONable, cast to list for auth_state
user_info["google_groups"] = list(user_groups)

Expand Down Expand Up @@ -314,7 +315,7 @@ async def check_allowed(self, username, auth_model):
# users should be explicitly allowed via config, otherwise they aren't
return False

def _service_client_credentials(self, scopes, user_email_domain):
async def _service_client_credentials(self, scopes, user_email_domain):
"""
Return a configured service client credentials for the API.
"""
Expand All @@ -338,47 +339,28 @@ def _service_client_credentials(self, scopes, user_email_domain):

return credentials

def _service_client(self, service_name, service_version, credentials, http=None):
async def _setup_credentials(self, user_email_domain):
"""
Return a configured service client for the API.
Set up the service client for Google API.
"""
credentials = await self._service_client_credentials(
scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"],
user_email_domain=user_email_domain,
)

try:
from googleapiclient.discovery import build
from google.auth.transport import requests
except:
raise ImportError(
"Could not import googleapiclient.discovery's build,"
"Could not import google.auth.transport's requests,"
"you may need to run 'pip install oauthenticator[googlegroups]' or not declare google groups"
)

self.log.debug(
f"service_name is {service_name}, service_version is {service_version}"
)

return build(
serviceName=service_name,
version=service_version,
credentials=credentials,
cache_discovery=False,
http=http,
)

def _setup_service(self, user_email_domain, http=None):
"""
Set up the service client for Google API.
"""
credentials = self._service_client_credentials(
scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"],
user_email_domain=user_email_domain,
)
service = self._service_client(
service_name='admin',
service_version='directory_v1',
credentials=credentials,
http=http,
)
return service
request = requests.Request()
credentials.refresh(request)
return credentials

def _fetch_member_groups(
async def _fetch_member_groups(
self,
member_email,
user_email_domain,
Expand All @@ -389,17 +371,22 @@ def _fetch_member_groups(
"""
Return a set with the google groups a given user/group is a member of, including nested groups if allowed.
"""
# FIXME: When this function is used and waiting for web request
# responses, JupyterHub gets blocked from doing other things.
# Ideally the web requests should be made using an async client
# that can be awaited while JupyterHub handles other things.
#
if not hasattr(self, 'service'):
self.service = self._setup_service(user_email_domain, http)

resp = self.service.groups().list(userKey=member_email).execute()
if not hasattr(self, 'credentials'):
self.credentials = await self._setup_credentials(user_email_domain)

headers = {'Authorization': f'Bearer {self.credentials.token}'}

async with aiohttp.ClientSession() as session:
async with session.get(
f'https://www.googleapis.com/admin/directory/v1/groups?userKey={member_email}',
headers=headers,
) as resp:
group_data = await resp.json()

member_groups = {
g['email'].split('@')[0] for g in resp.get('groups', []) if g.get('email')
g['email'].split('@')[0]
for g in group_data.get('groups', [])
if g.get('email')
}
self.log.debug(f"Fetched groups for {member_email}: {member_groups}")

Expand All @@ -410,7 +397,7 @@ def _fetch_member_groups(
for group in member_groups:
if group not in processed_groups:
processed_groups.add(group)
nested_groups = self._fetch_member_groups(
nested_groups = await self._fetch_member_groups(
f"{group}@{user_email_domain}",
user_email_domain,
http,
Expand Down
3 changes: 2 additions & 1 deletion oauthenticator/tests/test_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import re
from unittest import mock
from unittest.mock import AsyncMock

from pytest import fixture, mark, raises
from traitlets.config import Config
Expand Down Expand Up @@ -211,7 +212,7 @@ async def test_google(
handled_user_model = user_model("[email protected]", "user1")
handler = google_client.handler_for_user(handled_user_model)
with mock.patch.object(
authenticator, "_fetch_member_groups", lambda *args: {"group1"}
authenticator, "_fetch_member_groups", AsyncMock(return_value={"group1"})
):
auth_model = await authenticator.get_authenticated_user(handler, None)

Expand Down

0 comments on commit e7c3bde

Please sign in to comment.