diff --git a/datagateway_api/src/common/config.py b/datagateway_api/src/common/config.py index f713c211..40cfe6ea 100644 --- a/datagateway_api/src/common/config.py +++ b/datagateway_api/src/common/config.py @@ -27,10 +27,11 @@ def validate_extension(extension): """ extension = extension.strip() - if not extension.startswith("/"): - raise ValueError("must start with '/'") - if extension.endswith("/"): - raise ValueError("must not end with '/'") + if extension: + if not extension.startswith("/"): + raise ValueError("must start with '/'") + if extension.endswith("/") and len(extension) != 1: + raise ValueError("must not end with '/'") return extension diff --git a/test/test_config.py b/test/test_config.py index 86c52750..8b01bbbc 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -3,7 +3,7 @@ import pytest -from datagateway_api.src.common.config import APIConfig +from datagateway_api.src.common.config import APIConfig, validate_extension @pytest.fixture() @@ -103,3 +103,39 @@ def test_set_backend_type(self, test_config): test_config.datagateway_api.set_backend_type("backend_name_changed") assert test_config.datagateway_api.backend == "backend_name_changed" + + @pytest.mark.parametrize( + "input_extension, expected_extension", + [ + pytest.param("/", "/", id="Slash"), + pytest.param("", "", id="Empty string, implied slash"), + pytest.param("/datagateway-api", "/datagateway-api", id="DataGateway API"), + pytest.param( + " /datagateway-api ", + "/datagateway-api", + id="DataGateway API with trailing and leading spaces", + ), + pytest.param("/search-api", "/search-api", id="Search API"), + pytest.param( + " /search-api ", + "/search-api", + id="Search API with trailing and leading spaces", + ), + ], + ) + def test_valid_extension_validation(self, input_extension, expected_extension): + test_extension = validate_extension(input_extension) + + assert test_extension == expected_extension + + @pytest.mark.parametrize( + "input_extension", + [ + pytest.param("datagateway-api", id="DataGateway API with no leading slash"), + pytest.param("search-api", id="Search API with no leading slash"), + pytest.param("/my-extension/", id="Extension with trailing slash"), + ], + ) + def test_invalid_extension_validation(self, input_extension): + with pytest.raises(ValueError): + validate_extension(input_extension)