Skip to content

Commit

Permalink
Fix Invenio credentials handling
Browse files Browse the repository at this point in the history
Only ask for token when is really required
  • Loading branch information
davelopez committed May 29, 2024
1 parent c2438b2 commit 7cfba26
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
7 changes: 1 addition & 6 deletions lib/galaxy/files/sources/_rdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from typing_extensions import Unpack

from galaxy.exceptions import AuthenticationRequired
from galaxy.files import ProvidesUserFileSourcesUserContext
from galaxy.files.sources import (
BaseFilesSource,
Expand Down Expand Up @@ -193,15 +192,11 @@ def _serialization_props(self, user_context: OptionalUserContext = None):
effective_props[key] = self._evaluate_prop(val, user_context=user_context)
return effective_props

def get_authorization_token(self, user_context: OptionalUserContext) -> str:
def get_authorization_token(self, user_context: OptionalUserContext) -> Optional[str]:
token = None
if user_context:
effective_props = self._serialization_props(user_context)
token = effective_props.get("token")
if not token:
raise AuthenticationRequired(
f"Please provide a personal access token in your user's preferences for '{self.label}'"
)
return token

def get_public_name(self, user_context: OptionalUserContext) -> Optional[str]:
Expand Down
31 changes: 18 additions & 13 deletions lib/galaxy/files/sources/invenio.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,12 +217,7 @@ def create_draft_record(
},
}

headers = self._get_request_headers(user_context)
if "Authorization" not in headers:
raise Exception(
"Cannot create record without authentication token. Please set your personal access token in your Galaxy preferences."
)

headers = self._get_request_headers(user_context, auth_required=True)
response = requests.post(self.records_url, json=create_record_request, headers=headers)
self._ensure_response_has_expected_status_code(response, 201)
record = response.json()
Expand All @@ -238,7 +233,7 @@ def upload_file_to_draft_record(
):
record = self._get_draft_record(record_id, user_context=user_context)
upload_file_url = record["links"]["files"]
headers = self._get_request_headers(user_context)
headers = self._get_request_headers(user_context, auth_required=True)

# Add file metadata entry
response = requests.post(upload_file_url, json=[{"key": filename}], headers=headers)
Expand Down Expand Up @@ -394,28 +389,38 @@ def _get_creator_from_public_name(self, public_name: Optional[str] = None) -> Cr
}

def _get_response(
self, user_context: OptionalUserContext, request_url: str, params: Optional[Dict[str, Any]] = None
self,
user_context: OptionalUserContext,
request_url: str,
params: Optional[Dict[str, Any]] = None,
auth_required: bool = False,
) -> dict:
headers = self._get_request_headers(user_context)
headers = self._get_request_headers(user_context, auth_required)
response = requests.get(request_url, params=params, headers=headers)
self._ensure_response_has_expected_status_code(response, 200)
return response.json()

def _get_request_headers(self, user_context: OptionalUserContext):
def _get_request_headers(self, user_context: OptionalUserContext, auth_required: bool = False):
token = self.plugin.get_authorization_token(user_context)
headers = {"Authorization": f"Bearer {token}"} if token else {}
if auth_required and token is None:
self._raise_auth_required()
return headers

def _ensure_response_has_expected_status_code(self, response, expected_status_code: int):
if response.status_code == 403:
record_url = response.url.replace("/api", "").replace("/files", "")
raise AuthenticationRequired(f"Please make sure you have the necessary permissions to access: {record_url}")
if response.status_code != expected_status_code:
if response.status_code == 403:
self._raise_auth_required()
error_message = self._get_response_error_message(response)
raise Exception(
f"Request to {response.url} failed with status code {response.status_code}: {error_message}"
)

def _raise_auth_required(self):
raise AuthenticationRequired(
f"Please provide a personal access token in your user's preferences for '{self.plugin.label}'"
)

def _get_response_error_message(self, response):
response_json = response.json()
error_message = response_json.get("message") if response.status_code == 400 else response.text
Expand Down

0 comments on commit 7cfba26

Please sign in to comment.