diff --git a/modules/terraform-aws-ca-lambda/unittests/test_validate_sans.py b/modules/terraform-aws-ca-lambda/unittests/test_validate_sans.py new file mode 100644 index 00000000..bc39ede3 --- /dev/null +++ b/modules/terraform-aws-ca-lambda/unittests/test_validate_sans.py @@ -0,0 +1,41 @@ +from utils.certs.types import filter_and_validate_sans + + +def test_filter_and_validate_sans(): + sans = ["example.com", "example.org", "example.net"] + output = filter_and_validate_sans("example.com", sans) + expected = ["example.com", "example.org", "example.net"] + + assert output == expected + + +def test_filter_and_validate_sans_invalid_domain(): + sans = ["example.com", "example.org", "net"] + output = filter_and_validate_sans("example.com", sans) + expected = ["example.com", "example.org"] + + assert output == expected + + +def test_filter_and_validate_sans_wildcard_allowed(): + sans = ["example.com", "example.org", "*.example.net"] + output = filter_and_validate_sans("example.com", sans) + expected = ["example.com", "example.org", "*.example.net"] + + assert output == expected + + +def test_filter_and_validate_sans_wildcard_disallowed_if_base_domain_invalid(): + sans = ["example.com", "example.org", "*.net"] + output = filter_and_validate_sans("example.com", sans) + expected = ["example.com", "example.org"] + + assert output == expected + + +def test_filter_and_validate_sans_mixed_domains(): + sans = ["example.com", "example.org", "*.example.net", "*.net", "Invalid DNS name"] + output = filter_and_validate_sans("example.com", sans) + expected = ["example.com", "example.org", "*.example.net"] + + assert output == expected diff --git a/modules/terraform-aws-ca-lambda/utils/certs/types.py b/modules/terraform-aws-ca-lambda/utils/certs/types.py index 019d0427..52872990 100644 --- a/modules/terraform-aws-ca-lambda/utils/certs/types.py +++ b/modules/terraform-aws-ca-lambda/utils/certs/types.py @@ -81,8 +81,17 @@ def filter_and_validate_sans(common_name: str, sans: list[str]) -> list[str]: if (_sans is None or _sans == []) and valid_common_name: _sans = [common_name] + # log invalid SANs + for san in _sans: + # allow wildcard SANs provided base domain is valid + if san.split(".")[0] == "*" and domain_validator(san[2:]): + continue + # log invalid SANs + if not domain_validator(san): + print(f"Invalid domain {san} excluded from SANs") + # remove invalid SANs - _sans = [s for s in _sans if domain_validator(s)] + _sans = [s for s in _sans if domain_validator(s) or s.split(".")[0] == "*" and domain_validator(s[2:])] return _sans diff --git a/tests/test_issued_certs.py b/tests/test_issued_certs.py index 6b6b6ebe..9e69526a 100644 --- a/tests/test_issued_certs.py +++ b/tests/test_issued_certs.py @@ -19,7 +19,6 @@ ) from utils.modules.aws.kms import get_kms_details -from utils.modules.aws.lambdas import get_lambda_name, invoke_lambda from utils.modules.aws.s3 import delete_s3_object, get_s3_bucket, list_s3_object_keys, put_s3_object from .helper import ( helper_create_csr_info, @@ -148,8 +147,8 @@ def test_issued_cert_includes_correct_dns_names(): Test issued certificate contains correct DNS names in Subject Alternative Name extension """ common_name = "pipeline-test-dn-csr-no-passphrase.example.com" - sans = ["test1.example.com", "test2.example.com", "invalid DNS name"] - expected_result = ["test1.example.com", "test2.example.com"] + sans = ["test1.example.com", "test2.example.com", "*.example.com", "*.com", "invalid DNS name"] + expected_result = ["test1.example.com", "test2.example.com", "*.example.com"] purposes = ["server_auth"]