From 576210bf2cd8fe2ee58e0ca5d95523fcae1f5af9 Mon Sep 17 00:00:00 2001 From: Daniel Perrefort Date: Fri, 26 Jul 2024 11:01:14 -0400 Subject: [PATCH] Extend API endpoint URL joining (#36) --- keystone_client/client.py | 4 +-- keystone_client/schema.py | 13 +++++++-- tests/schema/test_endpoint.py | 54 +++++++++++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/keystone_client/client.py b/keystone_client/client.py index 757e42e..4b35ec4 100644 --- a/keystone_client/client.py +++ b/keystone_client/client.py @@ -275,9 +275,7 @@ def retrieve_records( The response from the API in JSON format """ - url = endpoint.join_url(self.url) - if pk is not None: - url = urljoin(url, str(pk)) + url = endpoint.join_url(self.url, pk) try: response = self.http_get(url, params=filters, timeout=timeout) diff --git a/keystone_client/schema.py b/keystone_client/schema.py index 64a06e7..e8bb3f6 100644 --- a/keystone_client/schema.py +++ b/keystone_client/schema.py @@ -2,24 +2,31 @@ from dataclasses import dataclass, field from urllib.parse import urljoin +from os import path class Endpoint(str): + """API endpoint agnostic th to baseAPI URL""" - def join_url(self, url: str) -> str: + def join_url(self, base: str, *append) -> str: """Join the endpoint with a base URL This method returns URLs in a format that avoids trailing slash redirects from the Keystone API. Args: - url: The base URL + base: The base URL + *append: Partial paths to append onto the url Returns: The base URL join with the endpoint """ - return urljoin(url, self).rstrip('/') + '/' + url = urljoin(base, self) + for partial_path in filter(lambda x: x is not None, append): + url = path.join(url, str(partial_path)) + + return url.rstrip('/') + '/' @dataclass diff --git a/tests/schema/test_endpoint.py b/tests/schema/test_endpoint.py index 17f54df..264a7a2 100644 --- a/tests/schema/test_endpoint.py +++ b/tests/schema/test_endpoint.py @@ -9,7 +9,7 @@ class JoinUrl(TestCase): """Tests for the `join_url` method""" def test_with_trailing_slash(self) -> None: - """Test join_url with a base URL that has a trailing slash""" + """Test `join_url` with a base URL that has a trailing slash""" endpoint = Endpoint("authentication/new") base_url = "https://api.example.com/" @@ -17,7 +17,7 @@ def test_with_trailing_slash(self) -> None: self.assertEqual(expected_result, endpoint.join_url(base_url)) def test_without_trailing_slash(self) -> None: - """Test join_url with a base URL that does not have a trailing slash""" + """Test `join_url` with a base URL that does not have a trailing slash""" endpoint = Endpoint("authentication/new") base_url = "https://api.example.com" @@ -25,7 +25,7 @@ def test_without_trailing_slash(self) -> None: self.assertEqual(expected_result, endpoint.join_url(base_url)) def test_with_endpoint_trailing_slash(self) -> None: - """Test join_url with an endpoint that has a trailing slash""" + """Test `join_url` with an endpoint that has a trailing slash""" endpoint = Endpoint("authentication/new/") base_url = "https://api.example.com" @@ -33,17 +33,59 @@ def test_with_endpoint_trailing_slash(self) -> None: self.assertEqual(expected_result, endpoint.join_url(base_url)) def test_without_endpoint_trailing_slash(self) -> None: - """Test join_url with an endpoint that does not have a trailing slash""" + """Test `join_url` with an endpoint that does not have a trailing slash""" endpoint = Endpoint("authentication/new") base_url = "https://api.example.com" expected_result = "https://api.example.com/authentication/new/" self.assertEqual(expected_result, endpoint.join_url(base_url)) - def test_with_complete_url_as_endpoint(self) -> None: - """Test join_url when the endpoint is a complete URL""" + def test_with_append_trailing_slash(self) -> None: + endpoint = Endpoint("authentication") + base_url = "https://api.example.com" + append_path = "new/" + expected_result = "https://api.example.com/authentication/new/" + self.assertEqual(expected_result, endpoint.join_url(base_url, append_path)) + + def test_without_append_trailing_slash(self) -> None: + endpoint = Endpoint("authentication") + base_url = "https://api.example.com" + append_path = "new" + expected_result = "https://api.example.com/authentication/new/" + self.assertEqual(expected_result, endpoint.join_url(base_url, append_path)) + + def test_with_mixed_trailing_slash_in_append(self) -> None: + """Test `join_url` with mixed trailing slashes in append arguments""" + + endpoint = Endpoint("authentication") + base_url = "https://api.example.com" + append_path1 = "new/" + append_path2 = "extra" + expected_result = "https://api.example.com/authentication/new/extra/" + self.assertEqual(expected_result, endpoint.join_url(base_url, append_path1, append_path2)) + + def test_complete_url_as_endpoint(self) -> None: + """Test `join_url` when the endpoint is a complete URL""" endpoint = Endpoint("https://anotherapi.com/authentication/new") base_url = "https://api.example.com" expected_result = "https://anotherapi.com/authentication/new/" self.assertEqual(expected_result, endpoint.join_url(base_url)) + + def test_int_append_argument(self) -> None: + """Test `join_url` with an `int` append argument""" + + endpoint = Endpoint("authentication") + base_url = "https://api.example.com" + append_path = 123 + expected_result = "https://api.example.com/authentication/123/" + self.assertEqual(expected_result, endpoint.join_url(base_url, str(append_path))) + + def test_none_append_argument(self) -> None: + """Test `join_url` with a `None` append argument""" + + endpoint = Endpoint("authentication") + base_url = "https://api.example.com" + append_path = None + expected_result = "https://api.example.com/authentication/" + self.assertEqual(expected_result, endpoint.join_url(base_url, append_path))