Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds correct docstring to factory generated methods #32

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 37 additions & 34 deletions keystone_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from __future__ import annotations

from functools import cached_property, partial
from functools import cached_property
from typing import Literal, Union
from urllib.parse import urljoin

Expand Down Expand Up @@ -245,46 +245,49 @@ def __new__(cls, *args, **kwargs) -> KeystoneClient:

new: KeystoneClient = super().__new__(cls)

new.retrieve_allocations = partial(new._retrieve_records, cls.schema.data.allocations)
new.retrieve_requests = partial(new._retrieve_records, cls.schema.data.requests)
new.retrieve_research_groups = partial(new._retrieve_records, cls.schema.data.research_groups)
new.retrieve_users = partial(new._retrieve_records, cls.schema.data.users)
new.retrieve_allocations = new._create_retrieve_method(cls.schema.data.allocations)
new.retrieve_requests = new._create_retrieve_method(cls.schema.data.requests)
new.retrieve_research_groups = new._create_retrieve_method(cls.schema.data.research_groups)
new.retrieve_users = new._create_retrieve_method(cls.schema.data.users)

return new

def _retrieve_records(
self,
_endpoint: Endpoint,
pk: int | None = None,
filters: dict | None = None,
timeout=DEFAULT_TIMEOUT
) -> QueryResult:
"""Retrieve data from the specified endpoint with optional primary key and filters
def _create_retrieve_method(self, endpoint: Endpoint) -> callable:
"""Factory function for creating retrieve methods"""

A single record is returned when specifying a primary key, otherwise the returned
object is a list of records. In either case, the return value is `None` when no data
is available for the query.
def retrieve_records(
pk: int | None = None,
filters: dict | None = None,
timeout=DEFAULT_TIMEOUT
) -> QueryResult:
"""Retrieve data from the API endpoint with optional primary key and filters

Args:
pk: Optional primary key to fetch a specific record
filters: Optional query parameters to include in the request
timeout: Seconds before the request times out
A single record is returned when specifying a primary key, otherwise the returned
object is a list of records. In either case, the return value is `None` when no data
is available for the query.

Returns:
The response from the API in JSON format
"""
Args:
pk: Optional primary key to fetch a specific record
filters: Optional query parameters to include in the request
timeout: Seconds before the request times out

Returns:
The response from the API in JSON format
"""

url = endpoint.join_url(self.url)
if pk is not None:
url = urljoin(url, str(pk))

url = _endpoint.join_url(self.url)
if pk is not None:
url = urljoin(url, str(pk))
try:
response = self.http_get(url, params=filters, timeout=timeout)
response.raise_for_status()
return response.json()

try:
response = self.http_get(url, params=filters, timeout=timeout)
response.raise_for_status()
return response.json()
except requests.HTTPError as exception:
if exception.response.status_code == 404:
return None

except requests.HTTPError as exception:
if exception.response.status_code == 404:
return None
raise

raise
return retrieve_records
2 changes: 1 addition & 1 deletion tests/client/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from keystone_client.client import HTTPClient


class TestUrl(TestCase):
class Url(TestCase):
"""Tests for the `url` property"""

def test_trailing_slash_removed(self):
Expand Down
Loading