From f5104a7a1df47923cea1855eb5a1cf181877e69d Mon Sep 17 00:00:00 2001 From: Simon Sawert Date: Thu, 26 Dec 2024 15:41:50 +0100 Subject: [PATCH] feat: Improved variable handling for config This allows an environment variable in any position for headers as long as it is in format `$VAR` or `${VAR}` allowing alphanumeric characters and underscore `_`. This is now also applied to `remove_schema_url` in addition to headers. Fixes #328 Relates to #231 --- ariadne_codegen/settings.py | 21 +++++++++++---- tests/test_settings.py | 54 ++++++++++++++++++++++++++++++++++--- 2 files changed, 67 insertions(+), 8 deletions(-) diff --git a/ariadne_codegen/settings.py b/ariadne_codegen/settings.py index 808397ba..3a61a0e5 100644 --- a/ariadne_codegen/settings.py +++ b/ariadne_codegen/settings.py @@ -1,5 +1,6 @@ import enum import os +import re from dataclasses import dataclass, field from keyword import iskeyword from pathlib import Path @@ -50,6 +51,7 @@ def __post_init__(self): assert_path_exists(self.schema_path) self.remote_schema_headers = resolve_headers(self.remote_schema_headers) + self.remote_schema_url = resolve_schema(self.remote_schema_url) @dataclass @@ -276,20 +278,29 @@ def assert_string_is_valid_python_identifier(name: str): def resolve_headers(headers: Dict) -> Dict: return {key: get_header_value(value) for key, value in headers.items()} +def resolve_schema(value: str) -> str: + return _replace_env_vars(value) + def get_header_value(value: str) -> str: - env_var_prefix = "$" - if value.startswith(env_var_prefix): - env_var_name = value.lstrip(env_var_prefix) + return _replace_env_vars(value) + + +def _replace_env_vars(value: str) -> str: + pattern = re.compile(r"\${?([\w_]+)}?") + + def replacer(match): + env_var_name = match.group(1) var_value = os.environ.get(env_var_name) + if not var_value: raise InvalidConfiguration( f"Environment variable {env_var_name} not found." ) - return var_value - return value + return var_value + return pattern.sub(replacer, value) def assert_class_is_defined_in_file(file_path: Path, class_name: str): file_content = file_path.read_text() diff --git a/tests/test_settings.py b/tests/test_settings.py index 11d03523..13c076a9 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -133,8 +133,24 @@ def test_client_settings_without_schema_path_or_remote_schema_url_raises_excepti ClientSettings(queries_path=queries_path) +@pytest.mark.parametrize( + "configured_header, expected_header", + [ + ("$TEST_VAR", "test_value"), + ("Bearer: $TEST_VAR", "Bearer: test_value"), + ("Bearer: ${TEST_VAR}", "Bearer: test_value"), + pytest.param( + "$NOT_SET_VAR", + "", + marks=pytest.mark.xfail(raises=InvalidConfiguration), + ), + ], +) def test_client_settings_resolves_env_variable_for_remote_schema_header_with_prefix( - tmp_path, mocker + tmp_path, + mocker, + configured_header, + expected_header, ): queries_path = tmp_path / "queries.graphql" queries_path.touch() @@ -143,10 +159,42 @@ def test_client_settings_resolves_env_variable_for_remote_schema_header_with_pre settings = ClientSettings( queries_path=queries_path, remote_schema_url="https://test", - remote_schema_headers={"Authorization": "$TEST_VAR"}, + remote_schema_headers={"Authorization": configured_header}, + ) + + assert settings.remote_schema_headers["Authorization"] == expected_header + + +@pytest.mark.parametrize( + "configured_url, expected_url", + [ + ("$TEST_VAR", "test_value"), + ("https://${TEST_VAR}/graphql", "https://test_value/graphql"), + ("https://$TEST_VAR/graphql", "https://test_value/graphql"), + ("https://TEST_VAR/graphql", "https://TEST_VAR/graphql"), + pytest.param( + "https://${NOT_SET_VAR}/graphql", + "", + marks=pytest.mark.xfail(raises=InvalidConfiguration), + ), + ], +) +def test_client_settings_resolves_env_variable_for_remote_schema( + tmp_path, + mocker, + configured_url, + expected_url, +): + queries_path = tmp_path / "queries.graphql" + queries_path.touch() + mocker.patch.dict(os.environ, {"TEST_VAR": "test_value"}) + + settings = ClientSettings( + queries_path=queries_path, + remote_schema_url=configured_url, ) - assert settings.remote_schema_headers["Authorization"] == "test_value" + assert settings.remote_schema_url == expected_url def test_client_settings_doesnt_resolve_remote_schema_header_without_prefix(tmp_path):