diff --git a/CHANGELOG.md b/CHANGELOG.md index 641b536c..82607772 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - new interface for transactions: split transaction into different kinds, e.g. descriptor_transaction, metric_state_transaction, etc. - `SdcConsumer` provides a dictionary with the current connection status of each subscription it is subscribed to [#271](https://github.com/Draegerwerk/sdc11073/issues/271) +- added `force_ssl_connect` parameter to constructor of consumer. ## [2.0.0a7] - 2024-01-04 diff --git a/src/sdc11073/consumer/consumerimpl.py b/src/sdc11073/consumer/consumerimpl.py index e39e5996..b8103a97 100644 --- a/src/sdc11073/consumer/consumerimpl.py +++ b/src/sdc11073/consumer/consumerimpl.py @@ -10,7 +10,7 @@ import uuid from dataclasses import dataclass from typing import TYPE_CHECKING, Any -from urllib.parse import urlparse, urlsplit +from urllib.parse import urlparse from lxml import etree as etree_ @@ -197,10 +197,13 @@ def __init__(self, device_location: str, # noqa: PLR0913 default_components: SdcConsumerComponents | None = None, specific_components: SdcConsumerComponents | None = None, request_chunk_size: int = 0, - socket_timeout: int = 5): + socket_timeout: int = 5, + force_ssl_connect: bool = False, + ): """Construct a SdcConsumer. - :param device_location: the XAddr location for meta data, e.g. http://10.52.219.67:62616/72c08f50-74cc-11e0-8092-027599143341 + :param device_location: the XAddr location for meta data, + e.g. http://10.52.219.67:62616/72c08f50-74cc-11e0-8092-027599143341 :param sdc_definitions: a class derived from BaseDefinitions :param epr: the path of this client in http server :param ssl_context_container: used for ssl connection to device and for own HTTP Server (notifications receiver) @@ -209,9 +212,24 @@ def __init__(self, device_location: str, # noqa: PLR0913 :param specific_components: a SdcConsumerComponents instance or None :param request_chunk_size: if value > 0, message is split into chunks of this size :param socket_timeout: timeout for connections to provider + :param force_ssl_connect: True: only accept ssl connections (requires a ssl_context_container) + False: if ssl_context_container is provided, consumer first tries an + encrypted connection, and if this raises an SSLError, + it tries an unencrypted connection """ if not device_location.startswith('http'): raise ValueError('Invalid device_location, it must be match http(s):// syntax') + self.is_ssl_connection: bool | None + if force_ssl_connect: + if ssl_context_container is None: + raise ValueError( + 'Invalid combination of ssl_connect (True) and ssl_context_container (None) parameters') + self.is_ssl_connection = True + elif ssl_context_container is None: + self.is_ssl_connection = False + else: + self.is_ssl_connection = None # options allow both, needs to be decided when connecting + self._device_location = device_location self.sdc_definitions = sdc_definitions if default_components is None: @@ -219,8 +237,6 @@ def __init__(self, device_location: str, # noqa: PLR0913 self._components = copy.deepcopy(default_components) if specific_components is not None: self._components.merge(specific_components) - splitted = urlsplit(self._device_location) - self._device_uses_https = splitted.scheme.lower() == 'https' self.subscription_status: dict[str, bool] = {} @@ -321,6 +337,8 @@ def base_url(self) -> str: Replace servers ip address with own ip address (server might have 0.0.0.0). """ + if self._http_server is None: + return '' p = urlparse(self._http_server.base_url) tmp = f'{p.scheme}://{self._network_adapter.ip}:{p.port}{p.path}' sep = '' if tmp.endswith('/') else '/' @@ -344,6 +362,7 @@ def update_subscription_status(subscription_filter: str, status: bool): subscription_status = dict(self.subscription_status) subscription_status[subscription_filter] = status self.subscription_status = subscription_status # trigger observable if status has changed + properties.strongbind(subscription, is_subscribed=functools.partial(update_subscription_status, filter_type.text)) @@ -453,9 +472,10 @@ def start_all(self, not_subscribed_actions: Iterable[str] | None = None, which is the minimal requirement for a sdc provider. :return: None """ - if self.host_description is None: - self._logger.debug('reading meta data from {}', self._device_location) # noqa: PLE1205 - self.host_description = self._get_metadata() + self._logger.debug('connecting to %s', self._device_location) + self._connect() + self._logger.debug('reading meta data from %s', self._device_location) + self.host_description = self._get_metadata() # now query also metadata of hosted services self._mk_hosted_services(self.host_description) @@ -525,6 +545,7 @@ def start_all(self, not_subscribed_actions: Iterable[str] | None = None, def _update_is_connected(subscription_status: dict[str, bool]): self.is_connected = all(subscription_status.values()) and any(subscription_status) + properties.strongbind(self, subscription_status=_update_is_connected) _update_is_connected(self.subscription_status) @@ -552,24 +573,38 @@ def set_used_compression(self, *compression_methods: str): del self._compression_methods[:] self._compression_methods.extend(compression_methods) - def _get_metadata(self) -> mex_types.Metadata: - _url = urlparse(self._device_location) - wsc = self.get_soap_client(self._device_location) - - if self._ssl_context_container is not None and _url.scheme == 'https': - if wsc.is_closed(): - wsc.connect() - sock = wsc.sock + def _connect(self): + soap_client = self.get_soap_client(self._device_location) + if self.is_ssl_connection is not None: + # decision was already made in constructor + soap_client.connect() + else: + try: + soap_client.connect() + self.is_ssl_connection = True + except ssl.SSLError: + # could not connect with ssl, try without it + soap_client.close() + self._forget_soap_client(soap_client) + self.is_ssl_connection = False + soap_client = self.get_soap_client(self._device_location) + # if this also fails, something else is wrong and error needs handling on application level. + soap_client.connect() + if self.is_ssl_connection: + sock = soap_client.sock self.peer_certificate = sock.getpeercert(binary_form=False) self.binary_peer_certificate = sock.getpeercert(binary_form=True) # in case the application needs it... - self._logger.info('Peer Certificate: {}', self.peer_certificate) # noqa: PLE1205 + + def _get_metadata(self) -> mex_types.Metadata: + _url = urlparse(self._device_location) + soap_client = self.get_soap_client(self._device_location) nsh = self.sdc_definitions.data_model.ns_helper inf = HeaderInformationBlock(action=f'{nsh.WXF.namespace}/Get', addr_to=self._device_location) message = self.msg_factory.mk_soap_message_etree_payload(inf, payload_element=None) - received_message_data = wsc.post_message_to(_url.path, message, msg='getMetadata') + received_message_data = soap_client.post_message_to(_url.path, message, msg='getMetadata') return mex_types.Metadata.from_node(received_message_data.p_msg.body_node) def send_probe(self) -> ProbeMatchesType: @@ -587,20 +622,25 @@ def send_probe(self) -> ProbeMatchesType: def get_soap_client(self, address: str) -> SoapClientProtocol: """Return the soap client for address. - Method creates a new soap client if needed. + Method creates a new soap client if needed and considers self.is_ssl_connection value. """ _url = urlparse(address) - key = (_url.scheme, _url.netloc) + use_ssl = self.is_ssl_connection is not False # if is_ssl_connection is still None, default to use_ssl = True + key = (use_ssl, _url.netloc) soap_client = self._soap_clients.get(key) if soap_client is None: - soap_client = self._mk_soap_client(_url.scheme, _url.netloc) + soap_client = self._mk_soap_client(use_ssl, _url.netloc) self._soap_clients[key] = soap_client return soap_client - def _mk_soap_client(self, scheme: str, - netloc: str) -> SoapClientProtocol: - _ssl_context = \ - self._ssl_context_container.client_context if scheme == "https" and self._ssl_context_container else None + def _forget_soap_client(self, soap_client: SoapClientProtocol): + for key, value in self._soap_clients.items(): + if value is soap_client: + del self._soap_clients[key] + return + + def _mk_soap_client(self, use_ssl: bool, netloc: str) -> SoapClientProtocol: + _ssl_context = self._ssl_context_container.client_context if use_ssl else None cls = self._components.soap_client_class return cls(netloc, self._socket_timeout, @@ -641,7 +681,7 @@ def _mk_hosted_service_client(self, port_type: str, def _start_event_sink(self, shared_http_server: Any): if shared_http_server is None: self._is_internal_http_server = True - ssl_context_container = self._ssl_context_container if self._device_uses_https else None + ssl_context_container = self._ssl_context_container if self.is_ssl_connection else None logger = loghelper.get_logger_adapter('sdc.client.notif_dispatch', self.log_prefix) self._http_server = HttpServerThreadBase( str(self._network_adapter.ip), @@ -652,7 +692,7 @@ def _start_event_sink(self, shared_http_server: Any): self._http_server.start() self._http_server.started_evt.wait(timeout=5) # it sometimes still happens that http server is not completely started without waiting. - #TODO: find better solution, see issue #320 + # TODO: find better solution, see issue #320 time.sleep(1) self._logger.info('serving EventSink on {}', self._http_server.base_url) # noqa: PLE1205 else: diff --git a/tests/certificates/readme.txt b/tests/certificates/readme.txt new file mode 100644 index 00000000..9bda7fb5 --- /dev/null +++ b/tests/certificates/readme.txt @@ -0,0 +1 @@ +The certificates in this folder are self-signed and only serve the purpose of testing ssl related topics. \ No newline at end of file diff --git a/tests/certificates/test_certificate.pem b/tests/certificates/test_certificate.pem new file mode 100644 index 00000000..d15d9104 --- /dev/null +++ b/tests/certificates/test_certificate.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDVzCCAj8CFBYK3jJbSBIQews2OEgNn2b5Gq4YMA0GCSqGSIb3DQEBCwUAMGcx +CzAJBgNVBAYTAkRFMRAwDgYDVQQIDAdHZXJtYW55MRAwDgYDVQQHDAdMdWViZWNr +MRUwEwYDVQQKDAxUZXN0IENvbXBhbnkxDTALBgNVBAsMBFVuaXQxDjAMBgNVBAMM +BUNOYW1lMCAXDTI0MDEyNjEzNTU1OVoYDzIwNTEwNjEzMTM1NTU5WjBnMQswCQYD +VQQGEwJERTEQMA4GA1UECAwHR2VybWFueTEQMA4GA1UEBwwHTHVlYmVjazEVMBMG +A1UECgwMVGVzdCBDb21wYW55MQ0wCwYDVQQLDARVbml0MQ4wDAYDVQQDDAVDTmFt +ZTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAN4nEXeOyPnsveySKxuq +ZQd42U0hGlSD8acUxcqj7xktI9HWj88NvGkbx30R9tOwPN9EH67j9nD6+dbqca+2 +Yillg9bXuA5m44q0j/h8HdWxdsACKRMp5Y3YUrAF18cdI4GiBZr8JUpmTcmwfFHN +tXCM8KPMF1s26zZbWb+cW6UoPCInZ6khJS3RScMyXN8YNRiyibf0VwthZfI7p7uq +K0ZvbqBnfRtpASHdp25WOYN42o/8qUQcOSNy5v9SOV1tXzbyYgTJpgDN5vLAlEDl +b5T3w4BgjeIK0hoDS+JLaZfwOAvpfxBXHoDhdOaxpcRlFl8uqi2afp1pusA21GPa +UNkCAwEAATANBgkqhkiG9w0BAQsFAAOCAQEAF4N7K6Kof8Vn8lIeUoAwtK0cx9uE +L4wzatPQWYPCseO7PVKJtBH3hdnGURDUgI1G5cv/xh0RVebnpKapRGA30gxQYrvC +4YNVz6Ofi4QY34SrxyO+25iS6CjnUFchsCVhOq3nbTiQQRDTpzuVOc9fCnJSVrfO +PZI9bhUkX/peIPn+oMS3KWSCPabsIy1ygiHyItcfVAQN3+CmJlfsih9w7xDWrVj6 +XCzQ6z6837d+I36nlJEledxuDkAvM8yTBhVxyIfH0ANJhsxwTKUCkNvLBAXqEfq5 +xgyBTGj65XK/n2cc+4Zg3SDGUE3WnB3Kbn5dfqc0ziCucbQJ+xDyKjD5bQ== +-----END CERTIFICATE----- diff --git a/tests/certificates/test_private_key.pem b/tests/certificates/test_private_key.pem new file mode 100644 index 00000000..8b3701c0 --- /dev/null +++ b/tests/certificates/test_private_key.pem @@ -0,0 +1,30 @@ +-----BEGIN RSA PRIVATE KEY----- +Proc-Type: 4,ENCRYPTED +DEK-Info: AES-256-CBC,9F8DD9ED304414AE99BB98A7E0A2F157 + +/puk2rdoPrNlyQDlq1P4NT2Da+5NSkJ0JOQnbO9GTNlP1K2ANhLcg8XtkQ2rWfTU +AbFldIMCOdto46GCC+mGoR6bA/30s3atr6LxxaMaeU3gaIEQeauhtO3i5WeiT/ns +vPUpGOTgfr5HoxWvb9PIcy6GkRjS302sp8WCXi06wplmNB2B3a5QqlflETl5PHvV +PgHEa4YB705Bm8U7X2OGlsG4SD4rzhEyYc9mBXs2gnMweVaxDD33GmrZW4Zw7woE +Tbqtrmfwv9IE3kVhB679VwJgVwp/Dy8UZru1+nuOEn5aabkf7557tV9Tx/7tWfb/ +ywMIU1fSS6aYRccKHiqQaMNgyrhQsmuAQ5WG5chtJhqrjr7QKAC7/bLUMtbKUWnK +6djy1rO6bnaPaIahgmhKXU9Ngwy1tKUc2FePW5zNgcJXkKlpmSfCNLJ87Y/PhIR9 +oDYNDz76a906S0FTr9cr/nXk/R4/AfqlnspiCl2vn1UHNY0Y1GKxtvMHkA3f//xA +jPhoMK5syeAfcA+H3CUWIjwtTazOnBsgLwoI8GSSy+7DmCYexrTasFwqTJhwL15k +nMZG+ERonl0ATnPF0Lh//XMK13Nr/3cG4uKC6fbbolhGiHIGBnBhqla+8L4J+IhQ ++2DPXPaSfix9kx9Cnh+cCBulQ8pLyS78ebS5HTA7w1aj+SJRuM2gMTih/oUPN0nX +be7LbtDJ5braer6vk3j2CTHeuN7hEpcVG/pJvgAkFTKGZcSK9yuDtwcwxG16EWvc +gvk5NtdQCaBZnWqW3pYISYYNWmGWL/t5j0AaNykvzdsYW7Bl3AlEFLiOEBn6LC6c +cOHttMc+yPDUw3zhhloo0YwBZqEeCERnwn3UwFGmn1694joFI8t4DP+tcWXss3H1 +P8PK2/ioCVFXiebTPSlwFL11MPjflu2N+Fwe798mAab3hrBReC3ayjg+/gUEhieY +mzUlpB4Cgh7LJkd8EpUASDlSuGJzwv/Zff1/egbYCKfkFyLBMoSjVhA1OsRZdI8U +zJyMBrDzT04bJq4K1OqBI7dzmvpjl46wK40hdog1obSUTHuVOYqNQOwGY2pGLdFV +si0NDb9HCdGntMpGtBrjo62F0XuzJwsPP8k/bxntKJX8C41xy5OOFiZYlS6cu6X8 +0EsUJL5z79pgmOba78NtuBZ+ti+e1G/pNbJKl8m5IztJ8IeA4Uv9//8ZGqrUVphP +CDJ/BwLfGMQjxDT6QMdkFGcrKY8+ZcEJqyiwf9hkvCcTfmQ+NjJF1t2bOaHaCvSs +q9ZZonEtDYit2SVCEHjNBgJuwf00RtuuMARvwJnn3cH0fw4XtC25TkT0LXPaNOf+ +rXh8yzViNALKu6MZ+McHUUnTFODFSo0kLBu44ypjw5ESgs6lOxDf5gjsB12/cH0Y +RA3cHeHEmVcDquE49xuTU4Cmg1aqE+q1oyzCgcvt8zKXfMOD8CViN/y4TK1gOkdL +7YDE8D3TtEYibOn1rM+Ke3SFWWdpc3zgOXelUVSWnvZewp3eWuQFbmri0prxTH75 +QT33ca0tfy4gum55qQ/4fS70RFVFdYaPy6mQGNFBEW39S9sUbPwDGLf54rhwbpLg +-----END RSA PRIVATE KEY----- diff --git a/tests/test_client_device.py b/tests/test_client_device.py index 97ca4467..b5bbeec0 100644 --- a/tests/test_client_device.py +++ b/tests/test_client_device.py @@ -12,7 +12,7 @@ import uuid from decimal import Decimal from itertools import product - +from http.client import NotConnected from lxml import etree as etree_ import sdc11073.certloader @@ -1602,3 +1602,113 @@ def test_realtime_samples_sync(self): def test_metric_report_sync(self): runtest_metric_reports(self, self.sdc_device, self.sdc_client, self.logger) + + +class TestEncryptionCombinations(unittest.TestCase): + """Check combinations of encrypted and unencrypted connections.""" + + def setUp(self): + basic_logging_setup() + self.logger = get_logger_adapter('sdc.test') + self.logger.info('############### start setUp %s ##############', self._testMethodName) + + # test uses a simple self signed certificate, certificate verify would fail + client_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + server_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + client_ssl_context.check_hostname = False + client_ssl_context.verify_mode = ssl.CERT_NONE + server_ssl_context.verify_mode = ssl.CERT_NONE + + ca_folder = pathlib.Path(__file__).parent.joinpath('certificates') + client_ssl_context.load_cert_chain(certfile=ca_folder.joinpath('test_certificate.pem'), + keyfile=ca_folder.joinpath('test_private_key.pem'), + password='password') + server_ssl_context.load_cert_chain(certfile=ca_folder.joinpath('test_certificate.pem'), + keyfile=ca_folder.joinpath('test_private_key.pem'), + password='password') + + self.ssl_context_container = sdc11073.certloader.SSLContextContainer(client_context=client_ssl_context, + server_context=server_ssl_context) + + self.wsd = WSDiscovery('127.0.0.1') + self.wsd.start() + self.sdc_device = SomeDevice.from_mdib_file(self.wsd, None, mdib_70041, + default_components=default_sdc_provider_components_async, + max_subscription_duration=10) # shorter duration for faster tests + self.sdc_device_ssl = SomeDevice.from_mdib_file(self.wsd, None, mdib_70041, + default_components=default_sdc_provider_components_async, + max_subscription_duration=10, # shorter duration for faster tests + ssl_context_container=self.ssl_context_container) + + self.sdc_device.start_all() + self._loc_validators = [pm_types.InstanceIdentifier('Validator', extension_string='System')] + self.sdc_device.set_location(utils.random_location(), self._loc_validators) + + self.sdc_device_ssl.start_all() + self._loc_validators = [pm_types.InstanceIdentifier('Validator', extension_string='System')] + self.sdc_device_ssl.set_location(utils.random_location(), self._loc_validators) + + time.sleep(0.5) # allow init of devices to complete + self.logger.info('############### setUp done %s ##############', self._testMethodName) + time.sleep(0.5) + self.log_watcher = loghelper.LogWatcher(logging.getLogger('sdc'), level=logging.ERROR) + + def tearDown(self): + self.logger.info('############### tearDown %s ... ##############\n', self._testMethodName) + self.log_watcher.setPaused(True) + self.sdc_device.stop_all() + self.sdc_device_ssl.stop_all() + try: + self.log_watcher.check() + except loghelper.LogWatchError as ex: + sys.stderr.write(repr(ex)) + raise + self.logger.info('############### tearDown %s done ##############\n', self._testMethodName) + + def test_basic_connect(self): + """Verify correct behavior of different combinations (un)encrypted provider and (un)encrypted consumer.""" + # test connect to unencrypted provider + x_addr = self.sdc_device.get_xaddrs()[0] + + # verify that invalid combination of arguments raises a ValueError + self.assertRaises(ValueError, SdcConsumer, x_addr, self.sdc_device.mdib.sdc_definitions, + ssl_context_container=None, force_ssl_connect=True) + + # verify that a forced ssl connect to an unencrypted provider raises an SSL Error + consumer = SdcConsumer(x_addr, + self.sdc_device.mdib.sdc_definitions, + self.ssl_context_container, + force_ssl_connect=True) + self.assertRaises(ssl.SSLError, consumer.start_all) + + # verify that an unforced ssl connect to an unencrypted provider is successful + consumer = SdcConsumer(x_addr, + self.sdc_device.mdib.sdc_definitions, + self.ssl_context_container, + force_ssl_connect=False) + try: + consumer.start_all() + self.assertTrue(consumer.is_connected) + self.assertFalse(consumer.is_ssl_connection) + finally: + consumer.stop_all(unsubscribe=False) + + # test connection to encrypted provider + x_addr = self.sdc_device_ssl.get_xaddrs()[0] + + # verify that connect without certificates raises an error + consumer = SdcConsumer(x_addr, + self.sdc_device.mdib.sdc_definitions, + ssl_context_container=None, + force_ssl_connect=False) + self.assertRaises(NotConnected, consumer.start_all) + + + # verify that connect with certificates works + consumer = SdcConsumer(x_addr, + self.sdc_device.mdib.sdc_definitions, + self.ssl_context_container, + force_ssl_connect=False) + consumer.start_all() + self.assertTrue(consumer.is_connected) + self.assertTrue(consumer.is_ssl_connection)