diff --git a/pgcrypto/apps.py b/pgcrypto/apps.py new file mode 100644 index 0000000..92bdd0b --- /dev/null +++ b/pgcrypto/apps.py @@ -0,0 +1,7 @@ +from django.apps import AppConfig + +from . import checks # NOQA + + +class PgcryptoConfig(AppConfig): + name = "pgcrypto" diff --git a/pgcrypto/checks.py b/pgcrypto/checks.py new file mode 100644 index 0000000..274d8a4 --- /dev/null +++ b/pgcrypto/checks.py @@ -0,0 +1,74 @@ +import django.apps +from django.conf import settings +from django.core.checks import Error, register + +from .mixins import PGPPublicKeyFieldMixin, PGPSymmetricKeyFieldMixin + + +@register() +def check_required_settings_exist(app_configs, **kwargs): + """Make sure PGCRYPTO_KEY/PUBLIC_PGP_KEY/PRIVATE_PGP_KEY are set.""" + all_models = django.apps.apps.get_models() + errors = [] + + has_pgp_symmetric_field = _contains_pgp_symmetric_field(all_models) + if has_pgp_symmetric_field: + found_keys = [ + getattr(settings, "PGCRYPTO_KEY", None), + ] + + db_settings = [*settings.DATABASES.values()] + found_keys = found_keys + [x.get("PGCRYPTO_KEY", None) for x in db_settings] + found_keys = list(filter(bool, found_keys)) + + if len(found_keys) == 0: + errors = [ + *errors, + Error( + "Missing PGCRYPTO_KEY setting", + id="pgcrypto.E001", + ), + ] + + has_pgp_public_field = _contains_pgp_public_key_field(all_models) + if has_pgp_public_field: + found_keys = [ + ( + getattr(settings, "PUBLIC_PGP_KEY", None), + getattr(settings, "PRIVATE_PGP_KEY", None), + ) + ] + + db_settings = [*settings.DATABASES.values()] + found_keys = found_keys + [ + (x.get("PUBLIC_PGP_KEY", None), x.get("PRIVATE_PGP_KEY", None)) + for x in db_settings + ] + found_keys = list(filter(lambda x: bool(x[0]) and bool(x[1]), found_keys)) + + if len(found_keys) == 0: + errors = [ + *errors, + Error( + "Missing PGCRYPTO_KEY setting", + id="pgcrypto.E001", + ), + ] + + return errors + + +def _contains_pgp_symmetric_field(models): + for model in models: + for field in model._meta.fields: + if isinstance(field, PGPSymmetricKeyFieldMixin): + return True + return False + + +def _contains_pgp_public_key_field(models): + for model in models: + for field in model._meta.fields: + if isinstance(field, PGPPublicKeyFieldMixin): + return True + return False diff --git a/tests/test_check.py b/tests/test_check.py new file mode 100644 index 0000000..4f4a2a6 --- /dev/null +++ b/tests/test_check.py @@ -0,0 +1,72 @@ +# flake8: noqa +from django.conf import settings +from django.test import override_settings, TestCase + +from pgcrypto.checks import check_required_settings_exist + + +class TestChecks(TestCase): + # noqa: D103 + def test_pgcrypto_key_exist(self): + errors = check_required_settings_exist(None) + self.assertEqual(len(errors), 0) + + @override_settings() + def test_missing_pgcrypto_key_raises_error(self): + del settings.PGCRYPTO_KEY + + key_value = settings.DATABASES["diff_keys"]["PGCRYPTO_KEY"] + del settings.DATABASES["diff_keys"]["PGCRYPTO_KEY"] + + errors = check_required_settings_exist(None) + self.assertEqual(errors[0].id, "pgcrypto.E001") + + settings.DATABASES["diff_keys"]["PGCRYPTO_KEY"] = key_value + + @override_settings() + def test_error_not_raised_if_key_is_in_db_settings(self): + del settings.PGCRYPTO_KEY + + errors = check_required_settings_exist(None) + self.assertEqual(len(errors), 0) + + @override_settings() + def test_empty_pgcrypto_raises_error(self): + settings.PGCRYPTO_KEY = None + key_value = settings.DATABASES["diff_keys"]["PGCRYPTO_KEY"] + del settings.DATABASES["diff_keys"]["PGCRYPTO_KEY"] + + errors = check_required_settings_exist(None) + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "pgcrypto.E001") + + settings.DATABASES["diff_keys"]["PGCRYPTO_KEY"] = key_value + + def test_public_key_exist(self): + errors = check_required_settings_exist(None) + self.assertEqual(len(errors), 0) + + @override_settings() + def test_missing_public_pgp_key_raises_error(self): + del settings.PUBLIC_PGP_KEY + + key_value = settings.DATABASES["diff_keys"]["PUBLIC_PGP_KEY"] + del settings.DATABASES["diff_keys"]["PUBLIC_PGP_KEY"] + + errors = check_required_settings_exist(None) + self.assertEqual(errors[0].id, "pgcrypto.E001") + + settings.DATABASES["diff_keys"]["PUBLIC_PGP_KEY"] = key_value + + @override_settings() + def test_missing_private_pgp_key_raises_error(self): + del settings.PRIVATE_PGP_KEY + + key_value = settings.DATABASES["diff_keys"]["PRIVATE_PGP_KEY"] + del settings.DATABASES["diff_keys"]["PRIVATE_PGP_KEY"] + + errors = check_required_settings_exist(None) + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "pgcrypto.E001") + + settings.DATABASES["diff_keys"]["PRIVATE_PGP_KEY"] = key_value