Skip to content

Commit

Permalink
add timeouts to requests
Browse files Browse the repository at this point in the history
  • Loading branch information
Comeani committed Jun 13, 2024
1 parent e35f44f commit 12cc6a6
Showing 1 changed file with 50 additions and 16 deletions.
66 changes: 50 additions & 16 deletions apps/utils/keystone.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, base_url: str = KEYSTONE_URL) -> None:

self.base_url = base_url
self._token: Optional[str] = None
self._timeout: int = 10

def login(self, username: str, password: str, endpoint: str = KEYSTONE_AUTH_ENDPOINT) -> None:
"""Logs in to the Keystone API and caches the JWT token.
Expand All @@ -40,7 +41,11 @@ def login(self, username: str, password: str, endpoint: str = KEYSTONE_AUTH_ENDP
requests.HTTPError: If the login request fails
"""

response = requests.post(f"{self.base_url}/{endpoint}", json={"username": username, "password": password})
response = requests.post(
f"{self.base_url}/{endpoint}",
json={"username": username, "password": password},
timeout=self._timeout
)
response.raise_for_status()
self._token = response.json().get("token")

Expand Down Expand Up @@ -90,7 +95,7 @@ def _process_response(response: requests.Response, response_type: ResponseConten
raise ValueError(f"Invalid response type: {response_type}")

def get(
self, endpoint: str, params: Optional[Dict[str, Any]] = None, response_type: ResponseContentType = 'json'
self, endpoint: str, params: Optional[Dict[str, Any]] = None, response_type: ResponseContentType = 'json'
) -> ParsedResponseContent:
"""Makes a GET request to the specified endpoint.
Expand All @@ -106,12 +111,16 @@ def get(
requests.HTTPError: If the GET request fails
"""

response = requests.get(f"{self.base_url}/{endpoint}", headers=self._get_headers(), params=params)
response = requests.get(f"{self.base_url}/{endpoint}",
headers=self._get_headers(),
params=params,
timeout=self._timeout
)
response.raise_for_status()
return self._process_response(response, response_type)

def post(
self, endpoint: str, data: Optional[Dict[str, Any]] = None, response_type: ResponseContentType = 'json'
self, endpoint: str, data: Optional[Dict[str, Any]] = None, response_type: ResponseContentType = 'json'
) -> ParsedResponseContent:
"""Makes a POST request to the specified endpoint.
Expand All @@ -127,12 +136,16 @@ def post(
requests.HTTPError: If the POST request fails
"""

response = requests.post(f"{self.base_url}/{endpoint}", headers=self._get_headers(), json=data)
response = requests.post(f"{self.base_url}/{endpoint}",
headers=self._get_headers(),
json=data,
timeout=self._timeout
)
response.raise_for_status()
return self._process_response(response, response_type)

def patch(
self, endpoint: str, data: Optional[Dict[str, Any]] = None, response_type: ResponseContentType = 'json'
self, endpoint: str, data: Optional[Dict[str, Any]] = None, response_type: ResponseContentType = 'json'
) -> ParsedResponseContent:
"""Makes a PATCH request to the specified endpoint.
Expand All @@ -148,12 +161,16 @@ def patch(
requests.HTTPError: If the PATCH request fails
"""

response = requests.patch(f"{self.base_url}/{endpoint}", headers=self._get_headers(), json=data)
response = requests.patch(f"{self.base_url}/{endpoint}",
headers=self._get_headers(),
json=data,
timeout=self._timeout
)
response.raise_for_status()
return self._process_response(response, response_type)

def put(
self, endpoint: str, data: Optional[Dict[str, Any]] = None, response_type: ResponseContentType = 'json'
self, endpoint: str, data: Optional[Dict[str, Any]] = None, response_type: ResponseContentType = 'json'
) -> ParsedResponseContent:
"""Makes a PUT request to the specified endpoint.
Expand All @@ -169,7 +186,11 @@ def put(
requests.HTTPError: If the PUT request fails
"""

response = requests.put(f"{self.base_url}/{endpoint}", headers=self._get_headers(), json=data)
response = requests.put(f"{self.base_url}/{endpoint}",
headers=self._get_headers(),
json=data,
timeout=self._timeout
)
response.raise_for_status()
return self._process_response(response, response_type)

Expand All @@ -187,15 +208,18 @@ def delete(self, endpoint: str, response_type: ResponseContentType = 'json') ->
requests.HTTPError: If the DELETE request fails
"""

response = requests.delete(f"{self.base_url}/{endpoint}", headers=self._get_headers())
response = requests.delete(f"{self.base_url}/{endpoint}",
headers=self._get_headers(),
timeout=self._timeout
)
response.raise_for_status()
return self._process_response(response, response_type)


def get_auth_header(keystone_url: str, auth_header: dict) -> dict:
""" Generate an authorization header to be used for accessing information from keystone"""

response = requests.post(f"{keystone_url}/authentication/new/", json=auth_header)
response = requests.post(f"{keystone_url}/authentication/new/", json=auth_header, timeout=10)
response.raise_for_status()
tokens = response.json()
return {"Authorization": f"Bearer {tokens['access']}"}
Expand All @@ -204,7 +228,10 @@ def get_auth_header(keystone_url: str, auth_header: dict) -> dict:
def get_request_allocations(keystone_url: str, request_pk: int, auth_header: dict) -> dict:
"""Get All Allocation information from keystone for a given request"""

response = requests.get(f"{keystone_url}/allocations/allocations/?request={request_pk}", headers=auth_header)
response = requests.get(f"{keystone_url}/allocations/allocations/?request={request_pk}",
headers=auth_header,
timeout=10
)
response.raise_for_status()
return response.json()

Expand All @@ -215,15 +242,20 @@ def get_active_requests(keystone_url: str, group_pk: int, auth_header: dict) ->
today = date.today().isoformat()
response = requests.get(
f"{keystone_url}/allocations/requests/?group={group_pk}&status=AP&active__lte={today}&expire__gt={today}",
headers=auth_header)
headers=auth_header,
timeout=10
)
response.raise_for_status()
return [request for request in response.json()]


def get_researchgroup_id(keystone_url: str, account_name: str, auth_header: dict) -> int:
"""Get the Researchgroup ID from keystone for the specified Slurm account"""

response = requests.get(f"{keystone_url}/users/researchgroups/?name={account_name}", headers=auth_header)
response = requests.get(f"{keystone_url}/users/researchgroups/?name={account_name}",
headers=auth_header,
timeout=10
)
response.raise_for_status()

try:
Expand Down Expand Up @@ -255,7 +287,9 @@ def get_most_recent_expired_request(keystone_url: str, group_pk: int, auth_heade
today = date.today().isoformat()
response = requests.get(
f"{keystone_url}/allocations/requests/?ordering=-expire&group={group_pk}&status=AP&expire__lte={today}",
headers=auth_header)
headers=auth_header,
timeout=10
)
response.raise_for_status()

return [response.json()[0]]
Expand All @@ -264,7 +298,7 @@ def get_most_recent_expired_request(keystone_url: str, group_pk: int, auth_heade
def get_enabled_cluster_ids(keystone_url: str, auth_header: dict) -> dict():
"""Get the list of enabled clusters defined in Keystone along with their IDs"""

response = requests.get(f"{keystone_url}/allocations/clusters/?enabled=True", headers=auth_header)
response = requests.get(f"{keystone_url}/allocations/clusters/?enabled=True", headers=auth_header, timeout=10)
response.raise_for_status()
clusters = {}
for cluster in response.json():
Expand Down

0 comments on commit 12cc6a6

Please sign in to comment.