Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DTLS raw keys. #2242

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ import java.security.KeyPair
data class CertificateInfo(
val keyPair: KeyPair,
val certificate: org.bouncycastle.tls.Certificate,
val rawKeyCertificate: org.bouncycastle.tls.Certificate,
val localFingerprintHashFunction: String,
val localFingerprint: String,
val localRawKeyFingerprint: String,
val creationTimestampMs: Long
)
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class DtlsConfig {
}
}

val negotiateRawKeyFingerprints: Boolean by config {
"jmt.dtls.negotiate-raw-key-fingerprints".from(JitsiConfig.newConfig)
}

companion object {
val config = DtlsConfig()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,19 @@ class DtlsStack(
val localFingerprint: String
get() = certificateInfo.localFingerprint

val localRawKeyFingerprint: String
get() = certificateInfo.localRawKeyFingerprint

/**
* The remote fingerprints sent to us over the signaling path.
*/
var remoteFingerprints: Map<String, String> = HashMap()

/**
* The remote raw key fingerprints.
*/
var remoteRawKeyFingerprints: Map<String, String> = HashMap()

/**
* A handler which will be invoked when DTLS application data is received
*/
Expand Down Expand Up @@ -174,7 +182,7 @@ class DtlsStack(
*/
private fun verifyAndValidateRemoteCertificate(remoteCertificate: Certificate?) {
remoteCertificate?.let {
DtlsUtils.verifyAndValidateCertificate(it, remoteFingerprints)
DtlsUtils.verifyAndValidateCertificate(it, remoteFingerprints, remoteRawKeyFingerprints)
// The above throws an exception if the checks fail.
logger.cdebug { "Fingerprints verified." }
} ?: run {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,27 @@
package org.jitsi.nlj.dtls

import org.bouncycastle.asn1.ASN1Encoding
import org.bouncycastle.asn1.ASN1Object
import org.bouncycastle.asn1.x500.X500Name
import org.bouncycastle.asn1.x500.X500NameBuilder
import org.bouncycastle.asn1.x500.style.BCStyle
import org.bouncycastle.asn1.x509.Certificate
import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo
import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder
import org.bouncycastle.jce.ECNamedCurveTable
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.operator.DefaultDigestAlgorithmIdentifierFinder
import org.bouncycastle.operator.bc.BcDefaultDigestProvider
import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder
import org.bouncycastle.tls.AlertDescription
import org.bouncycastle.tls.CertificateEntry
import org.bouncycastle.tls.CertificateType
import org.bouncycastle.tls.TlsContext
import org.bouncycastle.tls.TlsUtils
import org.bouncycastle.tls.crypto.TlsSecret
import org.bouncycastle.tls.crypto.impl.bc.BcTlsCertificate
import org.bouncycastle.tls.crypto.impl.bc.BcTlsCrypto
import org.bouncycastle.tls.crypto.impl.bc.BcTlsRawKeyCertificate
import org.jitsi.utils.logging2.Logger
import org.jitsi.utils.logging2.cdebug
import org.jitsi.utils.logging2.cerror
Expand Down Expand Up @@ -68,14 +73,24 @@ class DtlsUtils {
val localFingerprintHashFunction = x509certificate.getHashFunction()
val localFingerprint = x509certificate.getFingerprint(localFingerprintHashFunction)

val sPKI = x509certificate.subjectPublicKeyInfo
val rawKeyFingerprint = sPKI.getFingerprint(localFingerprintHashFunction)

val certificate = org.bouncycastle.tls.Certificate(
arrayOf(BcTlsCertificate(BC_TLS_CRYPTO, x509certificate))
)
val rawKeyCertificate = org.bouncycastle.tls.Certificate(
CertificateType.RawPublicKey,
null,
arrayOf(CertificateEntry(BcTlsRawKeyCertificate(BC_TLS_CRYPTO, sPKI), null))
)
return CertificateInfo(
keyPair,
certificate,
rawKeyCertificate,
localFingerprintHashFunction,
localFingerprint,
rawKeyFingerprint,
System.currentTimeMillis()
)
}
Expand Down Expand Up @@ -158,14 +173,26 @@ class DtlsUtils {
*/
fun verifyAndValidateCertificate(
certificateInfo: org.bouncycastle.tls.Certificate,
remoteFingerprints: Map<String, String>
remoteFingerprints: Map<String, String>,
remoteRawKeyFingerprints: Map<String, String>
) {
if (certificateInfo.certificateList.isEmpty()) {
throw DtlsException("No remote fingerprints.")
}
val type = certificateInfo.certificateType
for (currCertificate in certificateInfo.certificateList) {
val x509Cert = Certificate.getInstance(currCertificate.encoded)
verifyAndValidateCertificate(x509Cert, remoteFingerprints)
when (type) {
CertificateType.X509 -> {
val x509Cert = Certificate.getInstance(currCertificate.encoded)
verifyAndValidateCertificate(x509Cert, remoteFingerprints)
}
CertificateType.RawPublicKey -> {
val sPKI = SubjectPublicKeyInfo.getInstance(currCertificate.encoded)
verifyAndValidateRawPublicKey(sPKI, remoteRawKeyFingerprints)
}
else ->
throw DtlsException("Invalid certificate type")
}
}
}

Expand Down Expand Up @@ -229,6 +256,27 @@ class DtlsUtils {
}
}

private fun verifyAndValidateRawPublicKey(
sPKI: SubjectPublicKeyInfo,
remoteRawKeyFingerprints: Map<String, String>
) {
if (!remoteRawKeyFingerprints.any { (hash, fingerprint) ->
try {
sPKI.getFingerprint(hash) == fingerprint
} catch (e: Exception) {
// Swallow exception for unknown hash functions
false
}
}
) {
val expected = sPKI.getFingerprint("sha-256")
throw DtlsException(
"No remote raw key fingerprint matches SubjectPublicKeyInfo. " +
"Expected: sha-256 $expected. Seen: ${remoteRawKeyFingerprints.entries.joinToString()}"
)
}
}

/**
* Determine and return the hash function (as a [String]) used by this certificate
*/
Expand All @@ -242,11 +290,12 @@ class DtlsUtils {
}

/**
* Computes the fingerprint of a [org.bouncycastle.asn1.x509.Certificate] using [hashFunction] and returns it
* as a [String]
* Computes the fingerprint of a [ASN1Object] (e.g. a [org.bouncycastle.asn1.x509.Certificate])
* using [hashFunction] and returns it as a [String]
*/
private fun Certificate.getFingerprint(hashFunction: String): String {
private fun ASN1Object.getFingerprint(hashFunction: String): String {
val digAlgId = DefaultDigestAlgorithmIdentifierFinder().find(hashFunction.uppercase())
?: throw DtlsException("digest algorithm $hashFunction unknown")
val digest = BcDefaultDigestProvider.INSTANCE.get(digAlgId)
val input: ByteArray = getEncoded(ASN1Encoding.DER)
val output = ByteArray(digest.digestSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ package org.jitsi.nlj.dtls

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import org.bouncycastle.crypto.util.PrivateKeyFactory
import org.bouncycastle.tls.AlertDescription
import org.bouncycastle.tls.Certificate
import org.bouncycastle.tls.CertificateRequest
import org.bouncycastle.tls.CertificateType
import org.bouncycastle.tls.DefaultTlsClient
import org.bouncycastle.tls.ExporterLabel
import org.bouncycastle.tls.ExtensionType
Expand All @@ -28,6 +30,7 @@ import org.bouncycastle.tls.SignatureAlgorithm
import org.bouncycastle.tls.SignatureAndHashAlgorithm
import org.bouncycastle.tls.TlsAuthentication
import org.bouncycastle.tls.TlsCredentials
import org.bouncycastle.tls.TlsFatalAlert
import org.bouncycastle.tls.TlsSRTPUtils
import org.bouncycastle.tls.TlsServerCertificate
import org.bouncycastle.tls.TlsSession
Expand Down Expand Up @@ -81,12 +84,20 @@ class TlsClientImpl(
return object : TlsAuthentication {
override fun getClientCredentials(certificateRequest: CertificateRequest): TlsCredentials {
// NOTE: can't set clientCredentials when it is declared because 'context' won't be set yet
val cert = when (context.securityParametersHandshake.clientCertificateType) {
CertificateType.RawPublicKey ->
certificateInfo.rawKeyCertificate
CertificateType.X509 ->
certificateInfo.certificate
else ->
throw TlsFatalAlert(AlertDescription.internal_error)
}
if (clientCredentials == null) {
clientCredentials = BcDefaultTlsCredentialedSigner(
TlsCryptoParameters(context),
(context.crypto as BcTlsCrypto),
PrivateKeyFactory.createKey(certificateInfo.keyPair.private.encoded),
certificateInfo.certificate,
cert,
if (TlsUtils.isSignatureAlgorithmsExtensionAllowed(context.serverVersion)) {
SignatureAndHashAlgorithm(
HashAlgorithm.sha256,
Expand Down Expand Up @@ -123,6 +134,20 @@ class TlsClientImpl(
return clientExtensions
}

override fun getAllowedClientCertificateTypes(): ShortArray? {
if (DtlsConfig.config.negotiateRawKeyFingerprints) {
return shortArrayOf(CertificateType.X509, CertificateType.RawPublicKey)
}
return null
}

override fun getAllowedServerCertificateTypes(): ShortArray? {
if (DtlsConfig.config.negotiateRawKeyFingerprints) {
return shortArrayOf(CertificateType.X509, CertificateType.RawPublicKey)
}
return null
}

override fun processServerExtensions(serverExtensions: Hashtable<*, *>?) {
// TODO: a few cases we should be throwing alerts for in here. see old TlsClientImpl
val useSRTPData = TlsSRTPUtils.getUseSRTPExtension(serverExtensions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import org.bouncycastle.crypto.util.PrivateKeyFactory
import org.bouncycastle.tls.Certificate
import org.bouncycastle.tls.CertificateRequest
import org.bouncycastle.tls.CertificateType
import org.bouncycastle.tls.ClientCertificateType
import org.bouncycastle.tls.DefaultTlsServer
import org.bouncycastle.tls.ExporterLabel
Expand All @@ -29,6 +30,7 @@ import org.bouncycastle.tls.SignatureAlgorithm
import org.bouncycastle.tls.SignatureAndHashAlgorithm
import org.bouncycastle.tls.TlsCredentialedDecryptor
import org.bouncycastle.tls.TlsCredentialedSigner
import org.bouncycastle.tls.TlsExtensionsUtils
import org.bouncycastle.tls.TlsSRTPUtils
import org.bouncycastle.tls.TlsSession
import org.bouncycastle.tls.TlsUtils
Expand Down Expand Up @@ -64,6 +66,8 @@ class TlsServerImpl(

private var session: TlsSession? = null

private var useRawKeys: Boolean = false

/**
* Only set after a handshake has completed
*/
Expand Down Expand Up @@ -96,6 +100,17 @@ class TlsServerImpl(
val protectionProfiles = useSRTPData.protectionProfiles
chosenSrtpProtectionProfile =
DtlsUtils.chooseSrtpProtectionProfile(SrtpConfig.protectionProfiles, protectionProfiles.asIterable())

if (DtlsConfig.config.negotiateRawKeyFingerprints) {
val remoteServerCertTypes = TlsExtensionsUtils.getServerCertificateTypeExtensionClient(clientExtensions)
val remoteClientCertTypes = TlsExtensionsUtils.getClientCertificateTypeExtensionClient(clientExtensions)

if (remoteServerCertTypes?.contains(CertificateType.RawPublicKey) == true &&
remoteClientCertTypes?.contains(CertificateType.RawPublicKey) == true
) {
useRawKeys = true
}
}
}

override fun getCipherSuites() = DtlsConfig.config.cipherSuites.toIntArray()
Expand All @@ -109,15 +124,27 @@ class TlsServerImpl(
}

override fun getECDSASignerCredentials(): TlsCredentialedSigner {
val cert = if (useRawKeys) {
certificateInfo.rawKeyCertificate
} else {
certificateInfo.certificate
}
return BcDefaultTlsCredentialedSigner(
TlsCryptoParameters(context),
(context.crypto as BcTlsCrypto),
PrivateKeyFactory.createKey(certificateInfo.keyPair.private.encoded),
certificateInfo.certificate,
cert,
SignatureAndHashAlgorithm(HashAlgorithm.sha256, SignatureAlgorithm.ecdsa)
)
}

override fun getAllowedClientCertificateTypes(): ShortArray? {
if (useRawKeys) {
return shortArrayOf(CertificateType.RawPublicKey)
}
return null
}

override fun getCertificateRequest(): CertificateRequest {
val signatureAlgorithms = Vector<SignatureAndHashAlgorithm>(1)
signatureAlgorithms.add(SignatureAndHashAlgorithm(HashAlgorithm.sha256, SignatureAlgorithm.ecdsa))
Expand Down
3 changes: 3 additions & 0 deletions jitsi-media-transform/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ jmt {
// TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
// TLS_DHE_RSA_WITH_AES_256_GCM_SHA384
]

// Whether to send and recognize fingerprints for raw keys
negotiate-raw-key-fingerprints = false
}
srtp {
// The maximum number of packets that can be discarded early (without going through the SRTP stack for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import kotlin.concurrent.thread
class DtlsTest : ShouldSpec() {
override fun isolationMode(): IsolationMode? = IsolationMode.InstancePerLeaf
private val debugEnabled = true
private val pcapEnabled = false
private val pcapEnabled = true
private val logger = StdoutLogger(_level = Level.OFF)

fun debug(s: String) {
Expand All @@ -50,9 +50,15 @@ class DtlsTest : ShouldSpec() {
dtlsClient.remoteFingerprints = mapOf(
dtlsServer.localFingerprintHashFunction to dtlsServer.localFingerprint
)
dtlsClient.remoteRawKeyFingerprints = mapOf(
dtlsServer.localFingerprintHashFunction to dtlsServer.localRawKeyFingerprint
)
dtlsServer.remoteFingerprints = mapOf(
dtlsClient.localFingerprintHashFunction to dtlsClient.localFingerprint
)
dtlsServer.remoteRawKeyFingerprints = mapOf(
dtlsClient.localFingerprintHashFunction to dtlsClient.localRawKeyFingerprint
)

// The DTLS server's send is wired directly to the DTLS client's receive
dtlsServer.outgoingDataHandler = object : DtlsStack.OutgoingDataHandler {
Expand Down
14 changes: 14 additions & 0 deletions jvb/src/main/kotlin/org/jitsi/videobridge/Endpoint.kt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ import org.jitsi.videobridge.util.looksLikeDtls
import org.jitsi.videobridge.websocket.colibriWebSocketServiceSupplier
import org.jitsi.xmpp.extensions.colibri.WebSocketPacketExtension
import org.jitsi.xmpp.extensions.jingle.DtlsFingerprintPacketExtension
import org.jitsi.xmpp.extensions.jingle.DtlsRawKeyFingerprintPacketExtension
import org.jitsi.xmpp.extensions.jingle.IceUdpTransportPacketExtension
import org.jitsi.xmpp.util.XmlStringBuilderUtil.Companion.toStringOpt
import org.jitsi_modified.sctp4j.SctpDataCallback
Expand Down Expand Up @@ -803,6 +804,19 @@ class Endpoint @JvmOverloads constructor(
val setup = fingerprintExtensions.first().setup
dtlsTransport.setSetupAttribute(setup)
}

val remoteRawKeyFingerprints = mutableMapOf<String, String>()
val rawKeyFingerprintExtensions =
transportInfo.getChildExtensionsOfType(DtlsRawKeyFingerprintPacketExtension::class.java)
rawKeyFingerprintExtensions.forEach { rawKeyFingerprintExtension ->
if (rawKeyFingerprintExtension.hash != null && rawKeyFingerprintExtension.fingerprint != null) {
remoteRawKeyFingerprints[rawKeyFingerprintExtension.hash] = rawKeyFingerprintExtension.fingerprint
} else {
logger.info("Ignoring empty DtlsRawKeyFingerprint extension: ${transportInfo.toStringOpt()}")
}
}
dtlsTransport.setRemoteRawKeyFingerprints(remoteRawKeyFingerprints)

iceTransport.startConnectivityEstablishment(transportInfo)
}

Expand Down
Loading
Loading