Skip to content

Commit

Permalink
feat: Improved variable handling for config
Browse files Browse the repository at this point in the history
This allows an environment variable in any position for headers as long
as it is in format `$VAR` or `${VAR}` allowing alphanumeric characters
and underscore `_`. This is now also applied to `remove_schema_url` in
addition to headers.

Fixes mirumee#328
Relates to mirumee#231
  • Loading branch information
bombsimon committed Dec 26, 2024
1 parent 11bfe35 commit cfb870c
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 8 deletions.
21 changes: 16 additions & 5 deletions ariadne_codegen/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass, field
from keyword import iskeyword
from pathlib import Path
import re
from textwrap import dedent
from typing import Dict, List

Expand Down Expand Up @@ -50,6 +51,7 @@ def __post_init__(self):
assert_path_exists(self.schema_path)

self.remote_schema_headers = resolve_headers(self.remote_schema_headers)
self.remote_schema_url = resolve_schema(self.remote_schema_url)


@dataclass
Expand Down Expand Up @@ -276,20 +278,29 @@ def assert_string_is_valid_python_identifier(name: str):
def resolve_headers(headers: Dict) -> Dict:
return {key: get_header_value(value) for key, value in headers.items()}

def resolve_schema(value: str) -> str:
return _replace_env_vars(value)


def get_header_value(value: str) -> str:
env_var_prefix = "$"
if value.startswith(env_var_prefix):
env_var_name = value.lstrip(env_var_prefix)
return _replace_env_vars(value)


def _replace_env_vars(value: str) -> str:
pattern = re.compile(r"\${?([\w_]+)}?")

def replacer(match):
env_var_name = match.group(1)
var_value = os.environ.get(env_var_name)

if not var_value:
raise InvalidConfiguration(
f"Environment variable {env_var_name} not found."
)
return var_value

return value
return var_value

return pattern.sub(replacer, value)

def assert_class_is_defined_in_file(file_path: Path, class_name: str):
file_content = file_path.read_text()
Expand Down
54 changes: 51 additions & 3 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,24 @@ def test_client_settings_without_schema_path_or_remote_schema_url_raises_excepti
ClientSettings(queries_path=queries_path)


@pytest.mark.parametrize(
"configured_header, expected_header",
[
("$TEST_VAR", "test_value"),
("Bearer: $TEST_VAR", "Bearer: test_value"),
("Bearer: ${TEST_VAR}", "Bearer: test_value"),
pytest.param(
"$NOT_SET_VAR",
"",
marks=pytest.mark.xfail(raises=InvalidConfiguration),
),
],
)
def test_client_settings_resolves_env_variable_for_remote_schema_header_with_prefix(
tmp_path, mocker
tmp_path,
mocker,
configured_header,
expected_header,
):
queries_path = tmp_path / "queries.graphql"
queries_path.touch()
Expand All @@ -143,10 +159,42 @@ def test_client_settings_resolves_env_variable_for_remote_schema_header_with_pre
settings = ClientSettings(
queries_path=queries_path,
remote_schema_url="https://test",
remote_schema_headers={"Authorization": "$TEST_VAR"},
remote_schema_headers={"Authorization": configured_header},
)

assert settings.remote_schema_headers["Authorization"] == expected_header


@pytest.mark.parametrize(
"configured_url, expected_url",
[
("$TEST_VAR", "test_value"),
("https://${TEST_VAR}/graphql", "https://test_value/graphql"),
("https://$TEST_VAR/graphql", "https://test_value/graphql"),
("https://TEST_VAR/graphql", "https://TEST_VAR/graphql"),
pytest.param(
"https://${NOT_SET_VAR}/graphql",
"",
marks=pytest.mark.xfail(raises=InvalidConfiguration),
),
],
)
def test_client_settings_resolves_env_variable_for_remote_schema(
tmp_path,
mocker,
configured_url,
expected_url,
):
queries_path = tmp_path / "queries.graphql"
queries_path.touch()
mocker.patch.dict(os.environ, {"TEST_VAR": "test_value"})

settings = ClientSettings(
queries_path=queries_path,
remote_schema_url=configured_url,
)

assert settings.remote_schema_headers["Authorization"] == "test_value"
assert settings.remote_schema_url == expected_url


def test_client_settings_doesnt_resolve_remote_schema_header_without_prefix(tmp_path):
Expand Down

0 comments on commit cfb870c

Please sign in to comment.