diff --git a/datagateway_api/config.yaml.example b/datagateway_api/config.yaml.example index 5224f796..0634e1ac 100644 --- a/datagateway_api/config.yaml.example +++ b/datagateway_api/config.yaml.example @@ -30,3 +30,4 @@ host: "127.0.0.1" port: "5000" test_user_credentials: { username: "root", password: "pw" } test_mechanism: "simple" +url_prefix: "/" diff --git a/datagateway_api/src/api_start_utils.py b/datagateway_api/src/api_start_utils.py index ced74576..eb5d13b7 100644 --- a/datagateway_api/src/api_start_utils.py +++ b/datagateway_api/src/api_start_utils.py @@ -69,8 +69,8 @@ def handle_error(self, e): def configure_datagateway_api_swaggerui_blueprint(flask_app): swaggerui_blueprint = get_swaggerui_blueprint( - base_url=Config.config.datagateway_api.extension, - api_url="/datagateway-api/openapi.json", + base_url=f"{Config.config.url_prefix}{Config.config.datagateway_api.extension}", + api_url=f"{Config.config.url_prefix}/datagateway-api/openapi.json", config={"app_name": "DataGateway API OpenAPI Spec"}, blueprint_name="DataGateway API Swagger UI", ) @@ -83,8 +83,8 @@ def configure_datagateway_api_swaggerui_blueprint(flask_app): def configure_search_api_swaggerui_blueprint(flask_app): swaggerui_blueprint = get_swaggerui_blueprint( - base_url=Config.config.search_api.extension, - api_url="/search-api/openapi.json", + base_url=f"{Config.config.url_prefix}{Config.config.search_api.extension}", + api_url=f"{Config.config.url_prefix}/search-api/openapi.json", config={"app_name": "Search API OpenAPI Spec"}, blueprint_name="Search API Swagger UI", ) diff --git a/datagateway_api/src/common/config.py b/datagateway_api/src/common/config.py index 9f190bb7..245ffc68 100644 --- a/datagateway_api/src/common/config.py +++ b/datagateway_api/src/common/config.py @@ -32,6 +32,8 @@ def validate_extension(extension): raise ValueError("must start with '/'") if extension.endswith("/") and len(extension) != 1: raise ValueError("must not end with '/'") + if extension == "/": + extension = "" return extension @@ -55,6 +57,9 @@ class DataGatewayAPI(BaseModel): _validate_extension = validator("extension", allow_reuse=True)(validate_extension) + def __getitem__(self, item): + return getattr(self, item) + @validator("db_url", always=True) def require_db_config_value(cls, value, values): # noqa: B902, N805 """ @@ -145,6 +150,9 @@ class SearchAPI(BaseModel): _validate_extension = validator("extension", allow_reuse=True)(validate_extension) + def __getitem__(self, item): + return getattr(self, item) + class TestUserCredentials(BaseModel): username: StrictStr @@ -183,6 +191,12 @@ class APIConfig(BaseModel): search_api: Optional[SearchAPI] test_mechanism: Optional[StrictStr] test_user_credentials: Optional[TestUserCredentials] + url_prefix: StrictStr + + _validate_extension = validator("url_prefix", allow_reuse=True)(validate_extension) + + def __getitem__(self, item): + return getattr(self, item) @classmethod def load(cls, path=Path(__file__).parent.parent.parent / "config.yaml"): diff --git a/datagateway_api/src/main.py b/datagateway_api/src/main.py index fd2c04dd..9dffc5f2 100644 --- a/datagateway_api/src/main.py +++ b/datagateway_api/src/main.py @@ -1,6 +1,8 @@ import logging from flask import Flask +from werkzeug.middleware.dispatcher import DispatcherMiddleware +from werkzeug.wrappers import Response from datagateway_api.src.api_start_utils import ( create_api_endpoints, @@ -10,6 +12,7 @@ from datagateway_api.src.common.config import Config from datagateway_api.src.common.logger_setup import setup_logger + setup_logger() log = logging.getLogger() log.info("Logging now setup") @@ -18,8 +21,12 @@ api, specs = create_app_infrastructure(app) create_api_endpoints(app, api, specs) create_openapi_endpoints(app, specs) +app.config["APPLICATION_ROOT"] = Config.config.url_prefix if __name__ == "__main__": + app.wsgi_app = DispatcherMiddleware( + Response("Not Found", status=404), {Config.config.url_prefix: app.wsgi_app}, + ) app.run( host=Config.config.host, port=Config.config.port, diff --git a/datagateway_api/src/swagger/apispec_flask_restful.py b/datagateway_api/src/swagger/apispec_flask_restful.py index b43b8559..e1be334a 100644 --- a/datagateway_api/src/swagger/apispec_flask_restful.py +++ b/datagateway_api/src/swagger/apispec_flask_restful.py @@ -9,6 +9,8 @@ from apispec.exceptions import APISpecError import yaml +from datagateway_api.src.common.config import Config + def deduce_path(resource, **kwargs): """Find resource path using provided API or path itself""" @@ -73,6 +75,7 @@ def path_helper(self, path=None, operations=None, parameters=None, **kwargs): resource = kwargs.pop("resource") path = deduce_path(resource, **kwargs) path = re.sub(r"<(?:[^:<>]+:)?([^<>]+)>", r"{\1}", path) + path = f"{Config.config.url_prefix}{path}" return path except Exception as exc: logging.getLogger(__name__).exception( diff --git a/test/conftest.py b/test/conftest.py index 31b6f12f..b82e9d2a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -59,6 +59,7 @@ def test_config_data(): "port": "5000", "test_user_credentials": {"username": "root", "password": "pw"}, "test_mechanism": "simple", + "url_prefix": "", } diff --git a/test/integration/datagateway_api/db/endpoints/conftest.py b/test/integration/datagateway_api/db/endpoints/conftest.py index 667a9506..ee58e0a0 100644 --- a/test/integration/datagateway_api/db/endpoints/conftest.py +++ b/test/integration/datagateway_api/db/endpoints/conftest.py @@ -47,4 +47,5 @@ def test_config_data(): "port": "5000", "test_user_credentials": {"username": "root", "password": "pw"}, "test_mechanism": "simple", + "url_prefix": "", } diff --git a/test/integration/datagateway_api/test_swagger_ui.py b/test/integration/datagateway_api/test_swagger_ui.py new file mode 100644 index 00000000..cbcaed35 --- /dev/null +++ b/test/integration/datagateway_api/test_swagger_ui.py @@ -0,0 +1,69 @@ +import json +from unittest.mock import mock_open, patch + +from flask import Flask +import pytest +from werkzeug.middleware.dispatcher import DispatcherMiddleware +from werkzeug.wrappers import Response + +from datagateway_api.src.api_start_utils import ( + create_api_endpoints, + create_app_infrastructure, + create_openapi_endpoints, +) +from datagateway_api.src.common.config import APIConfig + + +@pytest.fixture(params=["", "/url-prefix"], ids=["No URL prefix", "Given a URL prefix"]) +def test_config_swagger(test_config_data, request): + test_config_data["url_prefix"] = request.param + test_config_data["datagateway_api"]["extension"] = ( + "" if request.param == "" else "/datagateway-api" + ) + test_config_data["search_api"]["extension"] = ( + "/search-api" if request.param == "" else "" + ) + + with patch("builtins.open", mock_open(read_data=json.dumps(test_config_data))): + return APIConfig.load("test/path") + + +class TestSwaggerUI: + @pytest.mark.parametrize( + "api_type", + [ + pytest.param("datagateway_api", id="DataGateway API"), + pytest.param("search_api", id="Search API"), + ], + ) + def test_swagger_ui(self, test_config_swagger, api_type, request): + # derived from the param IDs set above, used to assert the page title + api_name = request.node.callspec.id.split("-")[1] + with patch( + "datagateway_api.src.common.config.Config.config", test_config_swagger, + ): + test_app = Flask(__name__) + api, spec = create_app_infrastructure(test_app) + create_api_endpoints(test_app, api, spec) + create_openapi_endpoints(test_app, spec) + test_app.wsgi_app = DispatcherMiddleware( + Response("Not Found", status=404), + {test_config_swagger.url_prefix: test_app.wsgi_app}, + ) + test_client = test_app.test_client() + + test_response = test_client.get( + f"{test_config_swagger.url_prefix}{test_config_swagger[api_type].extension}", # noqa: B950 + ) + + test_response_string = test_response.get_data(as_text=True) + + assert f"{api_name} OpenAPI Spec" in test_response_string + assert ( + f"{test_config_swagger.url_prefix}{test_config_swagger[api_type].extension}/swagger-ui" # noqa: B950 + in test_response_string + ) + assert ( + f"{test_config_swagger.url_prefix}/{api_type.replace('_', '-')}/openapi.json" # noqa: B950 + in test_response_string + ) diff --git a/test/unit/test_config.py b/test/unit/test_config.py index f9646bf8..771ddfe9 100644 --- a/test/unit/test_config.py +++ b/test/unit/test_config.py @@ -60,7 +60,7 @@ def test_load_with_same_api_extensions(self, test_config_data): @pytest.mark.parametrize( "input_extension, expected_extension", [ - pytest.param("/", "/", id="Slash"), + pytest.param("/", "", id="Slash"), pytest.param("", "", id="Empty string, implied slash"), pytest.param("/datagateway-api", "/datagateway-api", id="DataGateway API"), pytest.param(