diff --git a/lib/charms/tls_certificates_interface/v2/tls_certificates.py b/lib/charms/tls_certificates_interface/v2/tls_certificates.py index 68bf929af6..d7177c3292 100644 --- a/lib/charms/tls_certificates_interface/v2/tls_certificates.py +++ b/lib/charms/tls_certificates_interface/v2/tls_certificates.py @@ -307,7 +307,7 @@ def _on_all_certificates_invalidated(self, event: AllCertificatesInvalidatedEven # Increment this PATCH version before using `charmcraft publish-lib` or reset # to 0 if you are raising the major API version -LIBPATCH = 26 +LIBPATCH = 27 PYDEPS = ["cryptography", "jsonschema"] @@ -440,7 +440,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -449,7 +449,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] @@ -473,11 +473,11 @@ def __init__(self, handle, certificate: str, expiry: str): self.expiry = expiry def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return {"certificate": self.certificate, "expiry": self.expiry} def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.expiry = snapshot["expiry"] @@ -502,7 +502,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "reason": self.reason, "certificate_signing_request": self.certificate_signing_request, @@ -512,7 +512,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.reason = snapshot["reason"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.certificate = snapshot["certificate"] @@ -527,11 +527,11 @@ def __init__(self, handle: Handle): super().__init__(handle) def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return {} def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" pass @@ -551,7 +551,7 @@ def __init__( self.is_ca = is_ca def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate_signing_request": self.certificate_signing_request, "relation_id": self.relation_id, @@ -559,7 +559,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate_signing_request = snapshot["certificate_signing_request"] self.relation_id = snapshot["relation_id"] self.is_ca = snapshot["is_ca"] @@ -583,7 +583,7 @@ def __init__( self.chain = chain def snapshot(self) -> dict: - """Returns snapshot.""" + """Return snapshot.""" return { "certificate": self.certificate, "certificate_signing_request": self.certificate_signing_request, @@ -592,7 +592,7 @@ def snapshot(self) -> dict: } def restore(self, snapshot: dict): - """Restores snapshot.""" + """Restore snapshot.""" self.certificate = snapshot["certificate"] self.certificate_signing_request = snapshot["certificate_signing_request"] self.ca = snapshot["ca"] @@ -600,7 +600,7 @@ def restore(self, snapshot: dict): def _load_relation_data(relation_data_content: RelationDataContent) -> dict: - """Loads relation data from the relation data bag. + """Load relation data from the relation data bag. Json loads all data. @@ -610,7 +610,7 @@ def _load_relation_data(relation_data_content: RelationDataContent) -> dict: Returns: dict: Relation data in dict format. """ - certificate_data = dict() + certificate_data = {} try: for key in relation_data_content: try: @@ -663,7 +663,7 @@ def generate_ca( validity: int = 365, country: str = "US", ) -> bytes: - """Generates a CA Certificate. + """Generate a CA Certificate. Args: private_key (bytes): Private key @@ -732,7 +732,7 @@ def get_certificate_extensions( alt_names: Optional[List[str]], is_ca: bool, ) -> List[x509.Extension]: - """Generates a list of certificate extensions from a CSR and other known information. + """Generate a list of certificate extensions from a CSR and other known information. Args: authority_key_identifier (bytes): Authority key identifier @@ -834,7 +834,7 @@ def generate_certificate( alt_names: Optional[List[str]] = None, is_ca: bool = False, ) -> bytes: - """Generates a TLS certificate based on a CSR. + """Generate a TLS certificate based on a CSR. Args: csr (bytes): CSR @@ -890,7 +890,7 @@ def generate_pfx_package( package_password: str, private_key_password: Optional[bytes] = None, ) -> bytes: - """Generates a PFX package to contain the TLS certificate and private key. + """Generate a PFX package to contain the TLS certificate and private key. Args: certificate (bytes): TLS certificate @@ -921,7 +921,7 @@ def generate_private_key( key_size: int = 2048, public_exponent: int = 65537, ) -> bytes: - """Generates a private key. + """Generate a private key. Args: password (bytes): Password for decrypting the private key @@ -947,7 +947,7 @@ def generate_private_key( return key_bytes -def generate_csr( +def generate_csr( # noqa: C901 private_key: bytes, subject: str, add_unique_id_to_subject_name: bool = True, @@ -961,7 +961,7 @@ def generate_csr( sans_dns: Optional[List[str]] = None, additional_critical_extensions: Optional[List] = None, ) -> bytes: - """Generates a CSR using private key and subject. + """Generate a CSR using private key and subject. Args: private_key (bytes): Private key @@ -1081,12 +1081,12 @@ def __init__(self, charm: CharmBase, relationship_name: str): self.relationship_name = relationship_name def _load_app_relation_data(self, relation: Relation) -> dict: - """Loads relation data from the application relation data bag. + """Load relation data from the application relation data bag. Json loads all data. Args: - relation_object: Relation data from the application databag + relation: Relation data from the application databag Returns: dict: Relation data in dict format. @@ -1104,7 +1104,7 @@ def _add_certificate( ca: str, chain: List[str], ) -> None: - """Adds certificate to relation data. + """Add certificate to relation data. Args: relation_id (int): Relation id @@ -1145,7 +1145,7 @@ def _remove_certificate( certificate: Optional[str] = None, certificate_signing_request: Optional[str] = None, ) -> None: - """Removes certificate from a given relation based on user provided certificate or csr. + """Remove certificate from a given relation based on user provided certificate or csr. Args: relation_id (int): Relation id @@ -1178,7 +1178,7 @@ def _remove_certificate( @staticmethod def _relation_data_is_valid(certificates_data: dict) -> bool: - """Uses JSON schema validator to validate relation data content. + """Use JSON schema validator to validate relation data content. Args: certificates_data (dict): Certificate data dictionary as retrieved from relation data. @@ -1193,7 +1193,7 @@ def _relation_data_is_valid(certificates_data: dict) -> bool: return False def revoke_all_certificates(self) -> None: - """Revokes all certificates of this provider. + """Revoke all certificates of this provider. This method is meant to be used when the Root CA has changed. """ @@ -1212,7 +1212,7 @@ def set_relation_certificate( chain: List[str], relation_id: int, ) -> None: - """Adds certificates to relation data. + """Add certificates to relation data. Args: certificate (str): Certificate @@ -1244,7 +1244,7 @@ def set_relation_certificate( ) def remove_certificate(self, certificate: str) -> None: - """Removes a given certificate from relation data. + """Remove a given certificate from relation data. Args: certificate (str): TLS Certificate @@ -1261,7 +1261,7 @@ def remove_certificate(self, certificate: str) -> None: def get_issued_certificates( self, relation_id: Optional[int] = None ) -> Dict[str, List[Dict[str, str]]]: - """Returns a dictionary of issued certificates. + """Return a dictionary of issued certificates. It returns certificates from all relations if relation_id is not specified. Certificates are returned per application name and CSR. @@ -1296,7 +1296,7 @@ def get_issued_certificates( return certificates def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggered on relation changed event. + """Handle relation changed event. Looks at the relation data and either emits: - certificate request event: If the unit relation data contains a CSR for which @@ -1343,7 +1343,7 @@ def _on_relation_changed(self, event: RelationChangedEvent) -> None: self._revoke_certificates_for_which_no_csr_exists(relation_id=event.relation.id) def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None: - """Revokes certificates for which no unit has a CSR. + """Revoke certificates for which no unit has a CSR. Goes through all generated certificates and compare against the list of CSRs for all units of a given relationship. @@ -1379,7 +1379,7 @@ def _revoke_certificates_for_which_no_csr_exists(self, relation_id: int) -> None def get_outstanding_certificate_requests( self, relation_id: Optional[int] = None ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Returns CSR's for which no certificate has been issued. + """Return CSR's for which no certificate has been issued. Example return: [ { @@ -1421,7 +1421,7 @@ def get_outstanding_certificate_requests( def get_requirer_csrs( self, relation_id: Optional[int] = None ) -> List[Dict[str, Union[int, str, List[Dict[str, str]]]]]: - """Returns a list of requirers' CSRs grouped by unit. + """Return a list of requirers' CSRs grouped by unit. It returns CSRs from all relations if relation_id is not specified. CSRs are returned per relation id, application name and unit name. @@ -1460,7 +1460,7 @@ def get_requirer_csrs( def certificate_issued_for_csr( self, app_name: str, csr: str, relation_id: Optional[int] ) -> bool: - """Checks whether a certificate has been issued for a given CSR. + """Check whether a certificate has been issued for a given CSR. Args: app_name (str): Application name that the CSR belongs to. @@ -1489,7 +1489,7 @@ def __init__( relationship_name: str, expiry_notification_time: int = 168, ): - """Generates/use private key and observes relation changed event. + """Generate/use private key and observes relation changed event. Args: charm: Charm object @@ -1514,7 +1514,7 @@ def __init__( @property def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: - """Returns list of requirer's CSRs from relation unit data. + """Return list of requirer's CSRs from relation unit data. Example: [ @@ -1532,7 +1532,7 @@ def _requirer_csrs(self) -> List[Dict[str, Union[bool, str]]]: @property def _provider_certificates(self) -> List[Dict[str, str]]: - """Returns list of certificates from the provider's relation data.""" + """Return list of certificates from the provider's relation data.""" relation = self.model.get_relation(self.relationship_name) if not relation: logger.debug("No relation: %s", self.relationship_name) @@ -1547,7 +1547,7 @@ def _provider_certificates(self) -> List[Dict[str, str]]: return provider_relation_data.get("certificates", []) def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: - """Adds CSR to relation data. + """Add CSR to relation data. Args: csr (str): Certificate Signing Request @@ -1574,7 +1574,7 @@ def _add_requirer_csr(self, csr: str, is_ca: bool) -> None: relation.data[self.model.unit]["certificate_signing_requests"] = json.dumps(requirer_csrs) def _remove_requirer_csr(self, csr: str) -> None: - """Removes CSR from relation data. + """Remove CSR from relation data. Args: csr (str): Certificate signing request @@ -1619,7 +1619,7 @@ def request_certificate_creation( logger.info("Certificate request sent to provider") def request_certificate_revocation(self, certificate_signing_request: bytes) -> None: - """Removes CSR from relation data. + """Remove CSR from relation data. The provider of this relation is then expected to remove certificates associated to this CSR from the relation data as well and emit a request_certificate_revocation event for the @@ -1637,7 +1637,7 @@ def request_certificate_revocation(self, certificate_signing_request: bytes) -> def request_certificate_renewal( self, old_certificate_signing_request: bytes, new_certificate_signing_request: bytes ) -> None: - """Renews certificate. + """Renew certificate. Removes old CSR from relation data and adds new one. @@ -1717,7 +1717,7 @@ def get_certificate_signing_requests( fulfilled_only: bool = False, unfulfilled_only: bool = False, ) -> List[Dict[str, Union[bool, str]]]: - """Gets the list of CSR's that were sent to the provider. + """Get the list of CSR's that were sent to the provider. You can choose to get only the CSR's that have a certificate assigned or only the CSR's that don't. @@ -1747,7 +1747,7 @@ def get_certificate_signing_requests( @staticmethod def _relation_data_is_valid(certificates_data: dict) -> bool: - """Checks whether relation data is valid based on json schema. + """Check whether relation data is valid based on json schema. Args: certificates_data: Certificate data in dict format. @@ -1762,7 +1762,7 @@ def _relation_data_is_valid(certificates_data: dict) -> bool: return False def _on_relation_changed(self, event: RelationChangedEvent) -> None: - """Handler triggered on relation changed events. + """Handle relation changed event. Goes through all providers certificates that match a requested CSR. @@ -1847,7 +1847,7 @@ def _get_next_secret_expiry_time(self, certificate: str) -> Optional[datetime]: return _get_closest_future_time(expiry_notification_time, expiry_time) def _on_relation_broken(self, event: RelationBrokenEvent) -> None: - """Handler triggered on relation broken event. + """Handle relation broken event. Emitting `all_certificates_invalidated` from `relation-broken` rather than `relation-departed` since certs are stored in app data. @@ -1861,7 +1861,7 @@ def _on_relation_broken(self, event: RelationBrokenEvent) -> None: self.on.all_certificates_invalidated.emit() def _on_secret_expired(self, event: SecretExpiredEvent) -> None: - """Triggered when a certificate is set to expire. + """Handle secret expired event. Loads the certificate from the secret, and will emit 1 of 2 events. @@ -1913,7 +1913,7 @@ def _on_secret_expired(self, event: SecretExpiredEvent) -> None: event.secret.remove_all_revisions() def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any]]: - """Returns the certificate that match the given CSR.""" + """Return the certificate that match the given CSR.""" for certificate_dict in self._provider_certificates: if certificate_dict["certificate_signing_request"] != csr: continue @@ -1921,7 +1921,7 @@ def _find_certificate_in_relation_data(self, csr: str) -> Optional[Dict[str, Any return None def _on_update_status(self, event: UpdateStatusEvent) -> None: - """Triggered on update status event. + """Handle update status event. Goes through each certificate in the "certificates" relation and checks their expiry date. If they are close to expire (<7 days), emits a CertificateExpiringEvent event and if diff --git a/pyproject.toml b/pyproject.toml index 533cdd3af8..f5ba078651 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,10 +104,12 @@ target-version = ["py38"] [tool.ruff] # preview and explicit preview are enabled for CPY001 preview = true -explicit-preview-rules = true target-version = "py38" src = ["src", "."] line-length = 99 + +[tool.ruff.lint] +explicit-preview-rules = true select = ["A", "E", "W", "F", "C", "N", "D", "I001", "CPY001"] extend-ignore = [ "D203", @@ -126,16 +128,16 @@ extend-ignore = [ # Ignore D107 Missing docstring in __init__ ignore = ["E501", "D107"] -[tool.ruff.per-file-ignores] +[tool.ruff.lint.per-file-ignores] "tests/*" = ["D100", "D101", "D102", "D103", "D104"] -[tool.ruff.flake8-copyright] +[tool.ruff.lint.flake8-copyright] # Check for properly formatted copyright header in each file author = "Canonical Ltd." notice-rgx = "Copyright\\s\\d{4}([-,]\\d{4})*\\s+" -[tool.ruff.mccabe] +[tool.ruff.lint.mccabe] max-complexity = 10 -[tool.ruff.pydocstyle] +[tool.ruff.lint.pydocstyle] convention = "google" diff --git a/tests/integration/ha_tests/helpers.py b/tests/integration/ha_tests/helpers.py index db6c9d4589..ef55b6c27b 100644 --- a/tests/integration/ha_tests/helpers.py +++ b/tests/integration/ha_tests/helpers.py @@ -21,7 +21,13 @@ wait_fixed, ) -from ..helpers import APPLICATION_NAME, db_connect, get_unit_address, run_command_on_unit +from ..helpers import ( + APPLICATION_NAME, + db_connect, + get_patroni_cluster, + get_unit_address, + run_command_on_unit, +) logger = logging.getLogger(__name__) @@ -109,11 +115,6 @@ async def app_name(ops_test: OpsTest, application_name: str = "postgresql") -> O return None -def get_patroni_cluster(unit_ip: str) -> Dict[str, str]: - resp = requests.get(f"http://{unit_ip}:8008/cluster") - return resp.json() - - async def change_patroni_setting( ops_test: OpsTest, setting: str, value: int, use_random_unit: bool = False ) -> None: diff --git a/tests/integration/ha_tests/test_restore_cluster.py b/tests/integration/ha_tests/test_restore_cluster.py index e55f967337..d6af07e251 100644 --- a/tests/integration/ha_tests/test_restore_cluster.py +++ b/tests/integration/ha_tests/test_restore_cluster.py @@ -10,13 +10,13 @@ CHARM_SERIES, db_connect, get_password, + get_patroni_cluster, get_primary, get_unit_address, set_password, ) from .helpers import ( add_unit_with_storage, - get_patroni_cluster, reused_full_cluster_recovery_storage, storage_id, ) diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index 0bbecae6c0..3c63afafc9 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -103,6 +103,22 @@ def change_primary_start_timeout( ) +def get_patroni_cluster(unit_ip: str) -> Dict[str, str]: + resp = requests.get(f"http://{unit_ip}:8008/cluster") + return resp.json() + + +def assert_sync_standbys(unit_ip: str, standbys: int) -> None: + for attempt in Retrying(stop=stop_after_delay(60), wait=wait_fixed(3), reraise=True): + with attempt: + cluster = get_patroni_cluster(unit_ip) + cluster_standbys = 0 + for member in cluster["members"]: + if member["role"] == "sync_standby": + cluster_standbys += 1 + assert cluster_standbys >= standbys, "Less than expected standbys" + + async def check_database_users_existence( ops_test: OpsTest, users_that_should_exist: List[str], diff --git a/tests/integration/new_relations/test_new_relations.py b/tests/integration/new_relations/test_new_relations.py index f982e0547d..ec276e17b6 100644 --- a/tests/integration/new_relations/test_new_relations.py +++ b/tests/integration/new_relations/test_new_relations.py @@ -12,7 +12,7 @@ import yaml from pytest_operator.plugin import OpsTest -from ..helpers import CHARM_SERIES, scale_application +from ..helpers import CHARM_SERIES, assert_sync_standbys, scale_application from ..juju_ import juju_major_version from .helpers import ( build_connection_string, @@ -344,6 +344,7 @@ async def test_an_application_can_request_multiple_databases(ops_test: OpsTest): @pytest.mark.group(1) +@pytest.mark.abort_on_fail async def test_relation_data_is_updated_correctly_when_scaling(ops_test: OpsTest): """Test that relation data, like connection data, is updated correctly when scaling.""" # Retrieve the list of current database unit names. @@ -356,10 +357,14 @@ async def test_relation_data_is_updated_correctly_when_scaling(ops_test: OpsTest apps=[DATABASE_APP_NAME], status="active", timeout=1500, wait_for_exact_units=4 ) + assert_sync_standbys( + ops_test.model.applications[DATABASE_APP_NAME].units[0].public_address, 2 + ) + # Remove the original units. await ops_test.model.applications[DATABASE_APP_NAME].destroy_units(*units_to_remove) await ops_test.model.wait_for_idle( - apps=[DATABASE_APP_NAME], status="active", timeout=3000, wait_for_exact_units=2 + apps=[DATABASE_APP_NAME], status="active", timeout=1500, wait_for_exact_units=2 ) # Get the updated connection data and assert it can be used @@ -540,7 +545,7 @@ async def test_invalid_extra_user_roles(ops_test: OpsTest): for app in data_integrator_apps_names: await ops_test.model.add_relation(f"{app}:postgresql", f"{DATABASE_APP_NAME}:database") await ops_test.model.wait_for_idle(apps=[DATABASE_APP_NAME]) - ops_test.model.block_until( + await ops_test.model.block_until( lambda: any( unit.workload_status_message == INVALID_EXTRA_USER_ROLE_BLOCKING_MESSAGE for unit in ops_test.model.applications[DATABASE_APP_NAME].units diff --git a/tests/integration/test_db.py b/tests/integration/test_db.py index 4d10cd24ef..e7d6c877f7 100644 --- a/tests/integration/test_db.py +++ b/tests/integration/test_db.py @@ -16,6 +16,7 @@ APPLICATION_NAME, CHARM_SERIES, DATABASE_APP_NAME, + assert_sync_standbys, build_connection_string, check_database_users_existence, check_databases_creation, @@ -40,6 +41,7 @@ @pytest.mark.group(1) +@pytest.mark.abort_on_fail async def test_mailman3_core_db(ops_test: OpsTest, charm: str) -> None: """Deploy Mailman3 Core to test the 'db' relation.""" async with ops_test.fast_forward(): @@ -106,6 +108,7 @@ async def test_mailman3_core_db(ops_test: OpsTest, charm: str) -> None: @pytest.mark.group(1) +@pytest.mark.abort_on_fail async def test_relation_data_is_updated_correctly_when_scaling(ops_test: OpsTest): """Test that relation data, like connection data, is updated correctly when scaling.""" # Retrieve the list of current database unit names. @@ -118,10 +121,14 @@ async def test_relation_data_is_updated_correctly_when_scaling(ops_test: OpsTest apps=[DATABASE_APP_NAME], status="active", timeout=1500, wait_for_exact_units=4 ) + assert_sync_standbys( + ops_test.model.applications[DATABASE_APP_NAME].units[0].public_address, 2 + ) + # Remove the original units. await ops_test.model.applications[DATABASE_APP_NAME].destroy_units(*units_to_remove) await ops_test.model.wait_for_idle( - apps=[DATABASE_APP_NAME], status="active", timeout=3000, wait_for_exact_units=2 + apps=[DATABASE_APP_NAME], status="active", timeout=1500, wait_for_exact_units=2 ) # Get the updated connection data and assert it can be used