diff --git a/scripts/tls.py b/scripts/tls.py index 130bc56f..7017e19d 100755 --- a/scripts/tls.py +++ b/scripts/tls.py @@ -32,6 +32,7 @@ from tlslite.api import * from tlslite.constants import CipherSuite, HashAlgorithm, SignatureAlgorithm, \ GroupName, SignatureScheme +from tlslite.handshakesettings import Keypair, VirtualHost from tlslite import __version__ from tlslite.utils.compat import b2a_hex, a2b_hex, time_stamp from tlslite.utils.dns_utils import is_valid_hostname @@ -74,13 +75,13 @@ def printUsage(s=None): print("""Commands: server - [-k KEY] [-c CERT] [-t TACK] [-v VERIFIERDB] [-d DIR] [-l LABEL] [-L LENGTH] + [-c CERT] [-k KEY] [-t TACK] [-v VERIFIERDB] [-d DIR] [-l LABEL] [-L LENGTH] [--reqcert] [--param DHFILE] [--psk PSK] [--psk-ident IDENTITY] [--psk-sha384] [--ssl3] [--max-ver VER] [--tickets COUNT] HOST:PORT client - [-k KEY] [-c CERT] [-u USER] [-p PASS] [-l LABEL] [-L LENGTH] [-a ALPN] + [-c CERT] [-k KEY] [-u USER] [-p PASS] [-l LABEL] [-L LENGTH] [-a ALPN] [--psk PSK] [--psk-ident IDENTITY] [--psk-sha384] [--resumption] [--ssl3] [--max-ver VER] HOST:PORT @@ -98,6 +99,10 @@ def printUsage(s=None): "tls1.3" --tickets COUNT - how many tickets should server send after handshake is finished + CERT, KEY - the file with key and certificates that will be used by client or + server. The server can accept multiple pairs of `-c` and `-k` options + to configure different certificates (like RSA and ECDSA) + """) sys.exit(-1) @@ -131,6 +136,8 @@ def handleArgs(argv, argString, flagsList=[]): # Default values if arg not present privateKey = None cert_chain = None + virtual_hosts = [] + v_host_cert = None username = None password = None tacks = None @@ -155,14 +162,30 @@ def handleArgs(argv, argString, flagsList=[]): if sys.version_info[0] >= 3: s = str(s, 'utf-8') # OpenSSL/m2crypto does not support RSASSA-PSS certificates - privateKey = parsePEMKey(s, private=True, - implementations=["python"]) + if not privateKey: + privateKey = parsePEMKey(s, private=True, + implementations=["python"]) + else: + if not v_host_cert: + raise ValueError("Virtual host certificate missing " + "(must be listed before key)") + p_key = parsePEMKey(s, private=True, + implementations=["python"]) + if not virtual_hosts: + virtual_hosts.append(VirtualHost()) + virtual_hosts[0].keys.append( + Keypair(p_key, v_host_cert.x509List)) + v_host_cert = None elif opt == "-c": s = open(arg, "rb").read() if sys.version_info[0] >= 3: s = str(s, 'utf-8') - cert_chain = X509CertChain() - cert_chain.parsePemList(s) + if not cert_chain: + cert_chain = X509CertChain() + cert_chain.parsePemList(s) + else: + v_host_cert = X509CertChain() + v_host_cert.parsePemList(s) elif opt == "-u": username = arg elif opt == "-p": @@ -228,6 +251,7 @@ def handleArgs(argv, argString, flagsList=[]): retList.append(privateKey) if "c" in argString: retList.append(cert_chain) + retList.append(virtual_hosts) if "u" in argString: retList.append(username) if "p" in argString: @@ -323,7 +347,8 @@ def printExporter(connection, expLabel, expLength): def clientCmd(argv): - (address, privateKey, cert_chain, username, password, expLabel, + (address, privateKey, cert_chain, virtual_hosts, username, password, + expLabel, expLength, alpn, psk, psk_ident, psk_hash, resumption, ssl3, max_ver) = \ handleArgs(argv, "kcuplLa", ["psk=", "psk-ident=", "psk-sha384", @@ -455,7 +480,8 @@ def clientCmd(argv): def serverCmd(argv): - (address, privateKey, cert_chain, tacks, verifierDB, directory, reqCert, + (address, privateKey, cert_chain, virtual_hosts, tacks, verifierDB, + directory, reqCert, expLabel, expLength, dhparam, psk, psk_ident, psk_hash, ssl3, max_ver, tickets) = \ handleArgs(argv, "kctbvdlL", @@ -502,6 +528,7 @@ def serverCmd(argv): settings.minVersion = (3, 0) if max_ver: settings.maxVersion = max_ver + settings.virtual_hosts = virtual_hosts class MySimpleHTTPHandler(SimpleHTTPRequestHandler): """Buffer the header and body of HTTP message.""" diff --git a/tests/tlstest.py b/tests/tlstest.py index 45470a78..08db530a 100755 --- a/tests/tlstest.py +++ b/tests/tlstest.py @@ -33,6 +33,7 @@ AlertDescription, HTTPTLSConnection, TLSSocketServerMixIn, \ POP3_TLS, m2cryptoLoaded, pycryptoLoaded, gmpyLoaded, tackpyLoaded, \ Checker, __version__ +from tlslite.handshakesettings import VirtualHost, Keypair from tlslite.errors import * from tlslite.utils.cryptomath import prngName, getRandomBytes @@ -360,6 +361,84 @@ def connect(): test_no += 1 + print("Test {0} - good RSA and ECDSA, TLSv1.3, rsa" + .format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 4) + settings.maxVersion = (3, 4) + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.cipherSuite in\ + constants.CipherSuite.tls13Suites + assert isinstance(connection.session.serverCertChain, X509CertChain) + assert connection.session.serverCertChain.getEndEntityPublicKey().key_type\ + == "rsa" + assert connection.version == (3, 4) + connection.close() + + test_no += 1 + + print("Test {0} - good RSA and ECDSA, TLSv1.3, ecdsa" + .format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 4) + settings.maxVersion = (3, 4) + settings.rsaSigHashes = [] + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.cipherSuite in\ + constants.CipherSuite.tls13Suites + assert isinstance(connection.session.serverCertChain, X509CertChain) + assert connection.session.serverCertChain.getEndEntityPublicKey().key_type\ + == "ecdsa" + assert connection.version == (3, 4) + connection.close() + + test_no += 1 + + print("Test {0} - good RSA and ECDSA, TLSv1.2, rsa" + .format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.cipherSuite in\ + constants.CipherSuite.ecdheCertSuites, connection.session.cipherSuite + assert isinstance(connection.session.serverCertChain, X509CertChain) + assert connection.session.serverCertChain.getEndEntityPublicKey().key_type\ + == "rsa" + assert connection.version == (3, 3) + connection.close() + + test_no += 1 + + print("Test {0} - good RSA and ECDSA, TLSv1.2, ecdsa" + .format(test_no)) + synchro.recv(1) + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 3) + settings.rsaSigHashes = [] + connection.handshakeClientCert(settings=settings) + testConnClient(connection) + assert connection.session.cipherSuite in\ + constants.CipherSuite.ecdheEcdsaSuites, connection.session.cipherSuite + assert isinstance(connection.session.serverCertChain, X509CertChain) + assert connection.session.serverCertChain.getEndEntityPublicKey().key_type\ + == "ecdsa" + assert connection.version == (3, 3) + connection.close() + + test_no += 1 + print("Test {0} - good X.509, mismatched key_share".format(test_no)) synchro.recv(1) connection = connect() @@ -1502,6 +1581,28 @@ def connect(): test_no += 1 + for prot in ["TLSv1.3", "TLSv1.2"]: + for c_type, exp_chain in (("rsa", x509Chain), + ("ecdsa", x509ecdsaChain)): + print("Test {0} - good RSA and ECDSA, {2}, {1}" + .format(test_no, c_type, prot)) + synchro.send(b'R') + connection = connect() + settings = HandshakeSettings() + settings.minVersion = (3, 3) + settings.maxVersion = (3, 4) + v_host = VirtualHost() + v_host.keys = [Keypair(x509ecdsaKey, x509ecdsaChain.x509List)] + settings.virtual_hosts = [v_host] + connection.handshakeServer(certChain=x509Chain, + privateKey=x509Key, settings=settings) + assert connection.extendedMasterSecret + assert connection.session.serverCertChain == exp_chain + testConnServer(connection) + connection.close() + + test_no += 1 + print("Test {0} - good X.509, mismatched key_share".format(test_no)) synchro.send(b'R') connection = connect() diff --git a/tlslite/handshakesettings.py b/tlslite/handshakesettings.py index 2c0c501a..b0d7e96b 100644 --- a/tlslite/handshakesettings.py +++ b/tlslite/handshakesettings.py @@ -49,6 +49,80 @@ PSK_MODES = ["psk_dhe_ke", "psk_ke"] +class Keypair(object): + """ + Key, certificate and related data. + + Stores also certificate associate data like OCSPs and transparency info. + TODO: add the above + + First certificate in certificates needs to match key, remaining should + build a trust path to a root CA. + + :vartype key: RSAKey or ECDSAKey + :ivar key: private key + + :vartype certificates: list of X509 + :ivar certificates: the certificates to send to peer if the key is selected + for use. The first one MUST include the public key of the ``key`` + """ + def __init__(self, key=None, certificates=tuple()): + self.key = key + self.certificates = certificates + + def validate(self): + """Sanity check the keypair.""" + if not self.key or not self.certificates: + raise ValueError("Key or certificate missing in Keypair") + + +class VirtualHost(object): + """ + Configuration of keys and certs for a single virual server. + + This class encapsulates keys and certificates for hosts specified by + server_name (SNI) and ALPN extensions. + + TODO: support SRP as alternative to certificates + TODO: support PSK as alternative to certificates + + :vartype keys: list of :ref:`~Keypair` + :ivar keys: List of certificates and keys to be used in this + virtual host. First keypair able to server ClientHello will be used. + + :vartype hostnames: set of bytes + :ivar hostnames: all the hostnames that server supports + please use :ref:`matches_hostname` to verify if the VirtualHost + can serve a request to a given hostname as that allows wildcard hosts + that always reply True. + + :vartype trust_anchors: list of X509 + :ivar trust_anchors: list of CA certificates supported for client + certificate authentication, sent in CertificateRequest + + :ivar app_protocols: all the application protocols that the server supports + (for ALPN) + """ + + def __init__(self): + """Set up default configuration.""" + self.keys = [] + self.hostnames = set() + self.trust_anchors = [] + self.app_protocols = [] + + def matches_hostname(self, hostname): + """Checks if the virtual host can serve hostname""" + return hostname in self.hostnames + + def validate(self): + """Sanity check the settings""" + if not self.keys: + raise ValueError("Virtual host missing keys") + for i in self.keys: + i.validate() + + class HandshakeSettings(object): """ This class encapsulates various parameters that can be used with @@ -254,6 +328,7 @@ def _init_key_settings(self): self.maxKeySize = 8193 self.rsaSigHashes = list(RSA_SIGNATURE_HASHES) self.rsaSchemes = list(RSA_SCHEMES) + self.virtual_hosts = [] # DH key settings self.eccCurves = list(CURVE_NAMES) self.dhParams = None @@ -312,6 +387,9 @@ def _sanityCheckKeySizes(other): raise ValueError("maxKeySize too large") if other.maxKeySize < other.minKeySize: raise ValueError("maxKeySize smaller than minKeySize") + # check also keys of virtual hosts + for i in other.virtual_hosts: + i.validate() @staticmethod def _not_matching(values, sieve): @@ -530,6 +608,7 @@ def _copy_extension_settings(self, other): """Copy values of settings related to extensions.""" other.useExtendedMasterSecret = self.useExtendedMasterSecret other.requireExtendedMasterSecret = self.requireExtendedMasterSecret + other.useExperimentalTackExtension = self.useExperimentalTackExtension other.sendFallbackSCSV = self.sendFallbackSCSV other.useEncryptThenMAC = self.useEncryptThenMAC other.usePaddingExtension = self.usePaddingExtension @@ -574,6 +653,7 @@ def _copy_key_settings(self, other): other.rsaSigHashes = self.rsaSigHashes other.rsaSchemes = self.rsaSchemes other.ecdsaSigHashes = self.ecdsaSigHashes + other.virtual_hosts = self.virtual_hosts # DH key params other.eccCurves = self.eccCurves other.dhParams = self.dhParams diff --git a/tlslite/tlsconnection.py b/tlslite/tlsconnection.py index 8ca80d5d..70886dfa 100644 --- a/tlslite/tlsconnection.py +++ b/tlslite/tlsconnection.py @@ -1910,11 +1910,13 @@ def handshakeServer(self, verifierDB=None, :type certChain: ~tlslite.x509certchain.X509CertChain :param certChain: The certificate chain to be used if the - client requests server certificate authentication. + client requests server certificate authentication and no virtual + host defined in HandshakeSettings matches ClientHello. :type privateKey: ~tlslite.utils.rsakey.RSAKey :param privateKey: The private key to be used if the client - requests server certificate authentication. + requests server certificate authentication and no virtual host + defined in HandshakeSettings matches ClientHello. :type reqCert: bool :param reqCert: Whether to request client certificate @@ -1941,21 +1943,23 @@ def handshakeServer(self, verifierDB=None, :type reqCAs: list of bytearray :param reqCAs: A collection of DER-encoded DistinguishedNames that - will be sent along with a certificate request. This does not affect - verification. + will be sent along with a certificate request to help client pick + a certificates. This does not affect verification. :type nextProtos: list of str :param nextProtos: A list of upper layer protocols to expose to the clients through the Next-Protocol Negotiation Extension, - if they support it. + if they support it. Deprecated, use the `virtual_hosts` in + HandshakeSettings. :type alpn: list of bytearray :param alpn: names of application layer protocols supported. Note that it will be used instead of NPN if both were advertised by - client. + client. Deprecated, use the `virtual_hosts` in HandshakeSettings. :type sni: bytearray - :param sni: expected virtual name hostname. + :param sni: expected virtual name hostname. Deprecated, use the + `virtual_hosts` in HandshakeSettings. :raises socket.error: If a socket error occurs. :raises tlslite.errors.TLSAbruptCloseError: If the socket is closed @@ -2009,8 +2013,12 @@ def _handshakeServerAsyncHelper(self, verifierDB, self._handshakeStart(client=False) + if not settings: + settings = HandshakeSettings() + settings = settings.validate() + if (not verifierDB) and (not cert_chain) and not anon and \ - not settings.pskConfigs: + not settings.pskConfigs and not settings.virtual_hosts: raise ValueError("Caller passed no authentication credentials") if cert_chain and not privateKey: raise ValueError("Caller passed a cert_chain but no privateKey") @@ -2025,21 +2033,19 @@ def _handshakeServerAsyncHelper(self, verifierDB, if tacks: if not tackpyLoaded: raise ValueError("tackpy is not loaded") - if not settings or not settings.useExperimentalTackExtension: + if not settings.useExperimentalTackExtension: raise ValueError("useExperimentalTackExtension not enabled") if alpn is not None and not alpn: raise ValueError("Empty list of ALPN protocols") - if not settings: - settings = HandshakeSettings() - settings = settings.validate() self.sock.padding_cb = settings.padding_cb # OK Start exchanging messages # ****************************** # Handle ClientHello and resumption - for result in self._serverGetClientHello(settings, cert_chain, + for result in self._serverGetClientHello(settings, privateKey, + cert_chain, verifierDB, sessionCache, anon, alpn, sni): if result in (0,1): yield result @@ -2047,7 +2053,8 @@ def _handshakeServerAsyncHelper(self, verifierDB, self._handshakeDone(resumed=True) return # Handshake was resumed, we're done else: break - (clientHello, cipherSuite, version, scheme) = result + (clientHello, version, cipherSuite, sig_scheme, privateKey, + cert_chain) = result # in TLS 1.3 the handshake is completely different # (extensions go into different messages, format of messages is @@ -2056,7 +2063,7 @@ def _handshakeServerAsyncHelper(self, verifierDB, for result in self._serverTLS13Handshake(settings, clientHello, cipherSuite, privateKey, cert_chain, - version, scheme, + version, sig_scheme, alpn, reqCert): if result in (0, 1): yield result @@ -2194,13 +2201,25 @@ def _handshakeServerAsyncHelper(self, verifierDB, if result in (0, 1): yield result else: break - premasterSecret = result + premasterSecret, privateKey, cert_chain = result # Perform a certificate-based key exchange elif (cipherSuite in CipherSuite.certSuites or cipherSuite in CipherSuite.dheCertSuites or cipherSuite in CipherSuite.ecdheCertSuites or cipherSuite in CipherSuite.ecdheEcdsaSuites): + try: + sig_hash_alg, cert_chain, privateKey = \ + self._pickServerKeyExchangeSig(settings, + clientHello, + cert_chain, + privateKey) + except TLSHandshakeFailure as alert: + for result in self._sendError( + AlertDescription.handshake_failure, + str(alert)): + yield result + if cipherSuite in CipherSuite.certSuites: keyExchange = RSAKeyExchange(cipherSuite, clientHello, @@ -2226,8 +2245,8 @@ def _handshakeServerAsyncHelper(self, verifierDB, defaultCurve) else: assert(False) - for result in self._serverCertKeyExchange(clientHello, serverHello, - cert_chain, keyExchange, + for result in self._serverCertKeyExchange(clientHello, serverHello, + sig_hash_alg, cert_chain, keyExchange, reqCert, reqCAs, cipherSuite, settings): if result in (0,1): yield result @@ -2823,7 +2842,8 @@ def _serverTLS13Handshake(self, settings, clientHello, cipherSuite, yield "finished" - def _serverGetClientHello(self, settings, cert_chain, verifierDB, + def _serverGetClientHello(self, settings, private_key, cert_chain, + verifierDB, sessionCache, anon, alpn, sni): # Tentatively set version to most-desirable version, so if an error # occurs parsing the ClientHello, this will be the version we'll use @@ -2839,6 +2859,8 @@ def _serverGetClientHello(self, settings, cert_chain, verifierDB, else: break clientHello = result + # check if the ClientHello and its extensions are well-formed + #If client's version is too low, reject it real_version = clientHello.client_version if real_version >= (3, 3): @@ -2925,12 +2947,6 @@ def _serverGetClientHello(self, settings, cert_chain, verifierDB, AlertDescription.illegal_parameter, "Host name in SNI is not valid DNS name"): yield result - # warn the client if the name didn't match the expected value - if sni and sni != name: - alert = Alert().create(AlertDescription.unrecognized_name, - AlertLevel.warning) - for result in self._sendMsg(alert): - yield result # sanity check the EMS extension emsExt = clientHello.getExtension(ExtensionType.extended_master_secret) @@ -3070,6 +3086,7 @@ def _serverGetClientHello(self, settings, cert_chain, verifierDB, self._recordLayer.max_early_data = settings.max_early_data self._recordLayer.early_data_ok = True + # negotiate the protocol version for the connection high_ver = None if ver_ext: high_ver = getFirstMatching(settings.versions, @@ -3107,19 +3124,31 @@ def _serverGetClientHello(self, settings, cert_chain, verifierDB, # TODO when TLS 1.3 is final, check the client hello random for # downgrade too - scheme = None - if version >= (3, 4): - try: - scheme = self._pickServerKeyExchangeSig(settings, - clientHello, - cert_chain, - version) - except TLSHandshakeFailure as alert: - for result in self._sendError( - AlertDescription.handshake_failure, - str(alert)): + # start negotiating the parameters of the connection + + sni_ext = clientHello.getExtension(ExtensionType.server_name) + if sni_ext: + name = sni_ext.hostNames[0].decode('ascii', 'strict') + # warn the client if the name didn't match the expected value + if sni and sni != name: + alert = Alert().create(AlertDescription.unrecognized_name, + AlertLevel.warning) + for result in self._sendMsg(alert): yield result + try: + sig_scheme, cert_chain, private_key = \ + self._pickServerKeyExchangeSig(settings, + clientHello, + cert_chain, + private_key, + version) + except TLSHandshakeFailure as alert: + for result in self._sendError( + AlertDescription.handshake_failure, + str(alert)): + yield result + #Check if there's intersection between supported curves by client and #server clientGroups = clientHello.getExtension(ExtensionType.supported_groups) @@ -3628,28 +3657,31 @@ def _serverGetClientHello(self, settings, cert_chain, verifierDB, # we have no session cache, or # the client's session_id was not found in cache: #pylint: disable = undefined-loop-variable - yield (clientHello, cipherSuite, version, scheme) + yield (clientHello, version, cipherSuite, sig_scheme, private_key, + cert_chain) #pylint: enable = undefined-loop-variable def _serverSRPKeyExchange(self, clientHello, serverHello, verifierDB, cipherSuite, privateKey, serverCertChain, settings): """Perform the server side of SRP key exchange""" - keyExchange = SRPKeyExchange(cipherSuite, - clientHello, - serverHello, - privateKey, - verifierDB) - try: - sigHash = self._pickServerKeyExchangeSig(settings, clientHello, - serverCertChain) + sigHash, serverCertChain, privateKey = \ + self._pickServerKeyExchangeSig(settings, clientHello, + serverCertChain, + privateKey) except TLSHandshakeFailure as alert: for result in self._sendError( AlertDescription.handshake_failure, str(alert)): yield result + keyExchange = SRPKeyExchange(cipherSuite, + clientHello, + serverHello, + privateKey, + verifierDB) + #Create ServerKeyExchange, signing it if necessary try: serverKeyExchange = keyExchange.makeServerKeyExchange(sigHash) @@ -3692,9 +3724,9 @@ def _serverSRPKeyExchange(self, clientHello, serverHello, verifierDB, str(alert)): yield result - yield premasterSecret + yield premasterSecret, privateKey, serverCertChain - def _serverCertKeyExchange(self, clientHello, serverHello, + def _serverCertKeyExchange(self, clientHello, serverHello, sigHashAlg, serverCertChain, keyExchange, reqCert, reqCAs, cipherSuite, settings): @@ -3707,14 +3739,6 @@ def _serverCertKeyExchange(self, clientHello, serverHello, msgs.append(serverHello) msgs.append(Certificate(CertificateType.x509).create(serverCertChain)) - try: - sigHashAlg = self._pickServerKeyExchangeSig(settings, clientHello, - serverCertChain) - except TLSHandshakeFailure as alert: - for result in self._sendError( - AlertDescription.handshake_failure, - str(alert)): - yield result try: serverKeyExchange = keyExchange.makeServerKeyExchange(sigHashAlg) except TLSInternalError as alert: @@ -4078,6 +4102,7 @@ def _handshakeWrapperAsync(self, handshaker, checker): @staticmethod def _pickServerKeyExchangeSig(settings, clientHello, certList=None, + private_key=None, version=(3, 3)): """Pick a hash that matches most closely the supported ones""" hashAndAlgsExt = clientHello.getExtension( @@ -4087,26 +4112,30 @@ def _pickServerKeyExchangeSig(settings, clientHello, certList=None, if not hashAndAlgsExt: # the error checking was done before hand, likely we're # doing PSK key exchange - return + return None, certList, private_key if hashAndAlgsExt is None or hashAndAlgsExt.sigalgs is None: # RFC 5246 states that if there are no hashes advertised, # sha1 should be picked - return "sha1" + return "sha1", certList, private_key + + alt_certs = ((X509CertChain(i.certificates), i.key) for vh in + settings.virtual_hosts for i in vh.keys) - supported = TLSConnection._sigHashesToList(settings, - certList=certList, - version=version) + for certs, key in chain([(certList, private_key)], alt_certs): + supported = TLSConnection._sigHashesToList(settings, + certList=certs, + version=version) - for schemeID in supported: - if schemeID in hashAndAlgsExt.sigalgs: - name = SignatureScheme.toRepr(schemeID) - if not name and schemeID[1] in (SignatureAlgorithm.rsa, - SignatureAlgorithm.ecdsa): - name = HashAlgorithm.toRepr(schemeID[0]) + for schemeID in supported: + if schemeID in hashAndAlgsExt.sigalgs: + name = SignatureScheme.toRepr(schemeID) + if not name and schemeID[1] in (SignatureAlgorithm.rsa, + SignatureAlgorithm.ecdsa): + name = HashAlgorithm.toRepr(schemeID[0]) - if name: - return name + if name: + return name, certs, key # if no match, we must abort per RFC 5246 raise TLSHandshakeFailure("No common signature algorithms") diff --git a/tlslite/x509.py b/tlslite/x509.py index 4de34161..f860aac9 100644 --- a/tlslite/x509.py +++ b/tlslite/x509.py @@ -42,6 +42,22 @@ def __init__(self): self.subject = None self.certAlg = None + def __hash__(self): + """Calculate hash of object.""" + return hash(bytes(self.bytes)) + + def __eq__(self, other): + """Compare other object for equality.""" + if not hasattr(other, "bytes"): + return NotImplemented + return self.bytes == other.bytes + + def __ne__(self, other): + """Compare with other object for inequality.""" + if not hasattr(other, "bytes"): + return NotImplemented + return not self == other + def parse(self, s): """ Parse a PEM-encoded X.509 certificate. diff --git a/tlslite/x509certchain.py b/tlslite/x509certchain.py index d7ca81f4..0d1232b1 100644 --- a/tlslite/x509certchain.py +++ b/tlslite/x509certchain.py @@ -30,6 +30,22 @@ def __init__(self, x509List=None): else: self.x509List = [] + def __hash__(self): + """Return hash of the object.""" + return hash(tuple(self.x509List)) + + def __eq__(self, other): + """Compare objects with each-other.""" + if not hasattr(other, "x509List"): + return NotImplemented + return self.x509List == other.x509List + + def __ne__(self, other): + """Compare object for inequality.""" + if not hasattr(other, "x509List"): + return NotImplemented + return self.x509List != other.x509List + def parsePemList(self, s): """Parse a string containing a sequence of PEM certs. diff --git a/unit_tests/test_tlslite_handshakesettings.py b/unit_tests/test_tlslite_handshakesettings.py index 995b6818..32c6b21c 100644 --- a/unit_tests/test_tlslite_handshakesettings.py +++ b/unit_tests/test_tlslite_handshakesettings.py @@ -7,8 +7,12 @@ import unittest2 as unittest except ImportError: import unittest +try: + import mock +except ImportError: + import unittest.mock as mock -from tlslite.handshakesettings import HandshakeSettings +from tlslite.handshakesettings import HandshakeSettings, Keypair, VirtualHost class TestHandshakeSettings(unittest.TestCase): def test___init__(self): @@ -470,5 +474,64 @@ def test_ticket_count_negative(self): self.assertIn("new session tickets", str(e.exception)) +class TestKeypair(unittest.TestCase): + def test___init___(self): + k_p = Keypair() + + self.assertIsInstance(k_p, Keypair) + self.assertIsNone(k_p.key) + self.assertIsInstance(k_p.certificates, tuple) + + def test_validate_with_missing_keys(self): + k_p = Keypair() + + with self.assertRaises(ValueError): + k_p.validate() + + def test_validate_with_missing_certificates(self): + k_p = Keypair() + k_p.key = mock.MagicMock() + + with self.assertRaises(ValueError): + k_p.validate() + + +class TestVirtualHost(unittest.TestCase): + def test___init__(self): + v_h = VirtualHost() + + self.assertIsInstance(v_h, VirtualHost) + self.assertEqual(v_h.keys, []) + self.assertEqual(v_h.hostnames, set()) + self.assertEqual(v_h.trust_anchors, []) + self.assertEqual(v_h.app_protocols, []) + + def test_matches_hostname_with_non_matching_name(self): + v_h = VirtualHost() + v_h.hostnames = set([b'example.com']) + + self.assertFalse(v_h.matches_hostname(b'example.org')) + + def test_matches_hostname_with_matching_name(self): + v_h = VirtualHost() + v_h.hostnames = set([b'example.com']) + + self.assertTrue(v_h.matches_hostname(b'example.com')) + + def test_validate_without_keys(self): + v_h = VirtualHost() + + with self.assertRaises(ValueError): + v_h.validate() + + def test_validate_with_keys(self): + v_h = VirtualHost() + v_h.keys = [mock.MagicMock()] + + v_h.validate() + + v_h.keys[0].validate.assert_called_once_with() + + if __name__ == '__main__': unittest.main() diff --git a/unit_tests/test_tlslite_x509.py b/unit_tests/test_tlslite_x509.py index 85733c51..8c5696c3 100644 --- a/unit_tests/test_tlslite_x509.py +++ b/unit_tests/test_tlslite_x509.py @@ -12,10 +12,12 @@ from tlslite.x509 import X509 from tlslite.utils.python_ecdsakey import Python_ECDSAKey +from tlslite.x509certchain import X509CertChain class TestX509(unittest.TestCase): - def test_pem(self): - data = ( + @classmethod + def setUpClass(cls): + cls.data = ( "-----BEGIN CERTIFICATE-----\n" "MIIBbTCCARSgAwIBAgIJAPM58cskyK+yMAkGByqGSM49BAEwFDESMBAGA1UEAwwJ\n" "bG9jYWxob3N0MB4XDTE3MTAyMzExNDI0MVoXDTE3MTEyMjExNDI0MVowFDESMBAG\n" @@ -26,8 +28,10 @@ def test_pem(self): "KoZIzj0EAQNIADBFAiA6p0YM5ZzfW+klHPRU2r13/IfKgeRfDR3dtBngmPvxUgIh\n" "APTeSDeJvYWVBLzyrKTeSerNDKKHU2Rt7sufipv76+7s\n" "-----END CERTIFICATE-----\n") + + def test_pem(self): x509 = X509() - x509.parse(data) + x509.parse(self.data) self.assertIsNotNone(x509.publicKey) self.assertIsInstance(x509.publicKey, Python_ECDSAKey) @@ -37,3 +41,42 @@ def test_pem(self): 12490546948316647166662676770106859255378658810545502161335656899238893361610) self.assertEqual(x509.publicKey.curve_name, "NIST256p") + def test_hash(self): + x509_1 = X509() + x509_1.parse(self.data) + + x509_2 = X509() + x509_2.parse(self.data) + + self.assertEqual(hash(x509_1), hash(x509_2)) + self.assertEqual(x509_1, x509_2) + + +class TestX509CertChain(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.data = ( + "-----BEGIN CERTIFICATE-----\n" + "MIIBbTCCARSgAwIBAgIJAPM58cskyK+yMAkGByqGSM49BAEwFDESMBAGA1UEAwwJ\n" + "bG9jYWxob3N0MB4XDTE3MTAyMzExNDI0MVoXDTE3MTEyMjExNDI0MVowFDESMBAG\n" + "A1UEAwwJbG9jYWxob3N0MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEyDRjEAJe\n" + "3F5T62MyZbhjoJnPLGL2nrTthLFymBupZ2IbnWYnqVWDkT/L6i8sQhf2zCLrlSjj\n" + "1kn7ERqPx/KZyqNQME4wHQYDVR0OBBYEFPfFTUg9o3t6ehLsschSnC8Te8oaMB8G\n" + "A1UdIwQYMBaAFPfFTUg9o3t6ehLsschSnC8Te8oaMAwGA1UdEwQFMAMBAf8wCQYH\n" + "KoZIzj0EAQNIADBFAiA6p0YM5ZzfW+klHPRU2r13/IfKgeRfDR3dtBngmPvxUgIh\n" + "APTeSDeJvYWVBLzyrKTeSerNDKKHU2Rt7sufipv76+7s\n" + "-----END CERTIFICATE-----\n") + + def test_pem(self): + x509cc = X509CertChain() + x509cc.parsePemList(self.data) + + def test_hash(self): + x509cc1 = X509CertChain() + x509cc1.parsePemList(self.data) + + x509cc2 = X509CertChain() + x509cc2.parsePemList(self.data) + + self.assertEqual(hash(x509cc1), hash(x509cc2)) + self.assertEqual(x509cc1, x509cc2)