Skip to content

Commit

Permalink
Merge pull request #77 from Dwolla/hidden-recipients
Browse files Browse the repository at this point in the history
support decryption of messages with a hidden recipient
  • Loading branch information
bpholt authored Jun 6, 2022
2 parents 86706cc + e8a321d commit e5e75ae
Show file tree
Hide file tree
Showing 7 changed files with 469 additions and 54 deletions.
9 changes: 9 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,15 @@ lazy val `fs2-pgp` = (project in file("core"))
)
},
unusedCompileDependenciesFilter -= moduleFilter("org.scala-lang.modules", "scala-collection-compat"),
mimaBinaryIssueFilters ++= {
import com.typesafe.tools.mima.core._
Seq(
// the CanCreateDecryptorFactory filters ignore a class and companion object that should have been package private
// and have been replaced with an alternative implementation that is package private
ProblemFilters.exclude[MissingClassProblem]("com.dwolla.security.crypto.CanCreateDecryptorFactory"),
ProblemFilters.exclude[MissingClassProblem]("com.dwolla.security.crypto.CanCreateDecryptorFactory$"),
)
},
)
.settings(commonSettings: _*)

Expand Down

This file was deleted.

15 changes: 10 additions & 5 deletions core/src/main/scala/com/dwolla/security/crypto/CryptoAlg.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import cats.effect._
import cats.effect.syntax.all._
import cats.syntax.all._
import com.dwolla.security.crypto.Compression._
import com.dwolla.security.crypto.DecryptToInputStream._
import com.dwolla.security.crypto.Encryption._
import com.dwolla.security.crypto.PgpLiteralDataPacketFormat._
import eu.timepit.refined.auto._
Expand Down Expand Up @@ -153,8 +154,8 @@ object CryptoAlg extends CryptoAlgPlatform {

private val objectIteratorChunkSize: ChunkSize = tagChunkSize(1)

private def pgpInputStreamToByteStream[A : CanCreateDecryptorFactory[F, *]](keylike: A,
chunkSize: ChunkSize): InputStream => Stream[F, Byte] = {
private def pgpInputStreamToByteStream[A: DecryptToInputStream[F, *]](keylike: A,
chunkSize: ChunkSize): InputStream => Stream[F, Byte] = {
def pgpCompressedDataToBytes(pcd: PGPCompressedData): Stream[F, Byte] =
Logger[Stream[F, *]].trace("Found compressed data") >>
pgpInputStreamToByteStream(keylike, chunkSize).apply(pcd.getDataStream)
Expand All @@ -172,13 +173,17 @@ object CryptoAlg extends CryptoAlgPlatform {
Stream.fromBlockingIterator[F](pedl.iterator().asScala, objectIteratorChunkSize)
.evalMap {
case pbe: PGPPublicKeyEncryptedData =>
CanCreateDecryptorFactory[F, A]
.publicKeyDataDecryptorFactory(keylike, pbe.getKeyID)
.flatMap(factory => Sync[F].blocking(pbe.getDataStream(factory)))
// a key ID of 0L indicates a "hidden" recipient,
// and we can't use that key ID to lookup the key
val recipientKeyId = Option(pbe.getKeyID).filterNot(_ == 0)

pbe.decryptToInputStream(keylike, recipientKeyId)

case other =>
Logger[F].error(EncryptionTypeError)(s"found wrong type of encrypted data: $other") >>
EncryptionTypeError.raiseError[F, InputStream]
}
.head // TODO what happens if pedl contains multiple recipients?
.flatMap(pgpInputStreamToByteStream(keylike, chunkSize))
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package com.dwolla.security.crypto

import cats._
import cats.effect._
import cats.syntax.all._
import fs2._
import org.bouncycastle.openpgp._
import org.bouncycastle.openpgp.operator.bc.{BcPBESecretKeyDecryptorBuilder, BcPGPDigestCalculatorProvider, BcPublicKeyDataDecryptorFactory}

import java.io.InputStream
import scala.jdk.CollectionConverters._

private[crypto] sealed trait DecryptToInputStream[F[_], A] {
def decryptToInputStream(input: A, maybeKeyId: Option[Long])
(pbed: PGPPublicKeyEncryptedData): F[InputStream]
}

private[crypto] object DecryptToInputStream {
@inline final def apply[F[_], A](implicit DTIS: DecryptToInputStream[F, A]): DTIS.type = DTIS

private def attemptDecrypt[F[_] : Sync](pbed: PGPPublicKeyEncryptedData,
factory: BcPublicKeyDataDecryptorFactory): F[InputStream] =
Sync[F].blocking {
pbed.getDataStream(factory)
}

/**
* Tries to decrypt the `PGPPublicKeyEncryptedData` using each secret key in the given list of keys.
* The first key that doesn't throw an exception when doing `pbed.getDataStream` will be the one
* whose `InputStream` is read downstream. Once it finds a key that works, it will stop trying
* any subsequent keys.
*/
private def decryptWithKeys[F[_] : Sync](input: List[PGPSecretKey],
passphrase: Array[Char],
pbed: PGPPublicKeyEncryptedData,
keyId: Option[Long],
): F[InputStream] =
Stream
.emits(input)
.evalMap { secretKey =>
Sync[F].blocking {
val digestCalculatorProvider = new BcPGPDigestCalculatorProvider()
val decryptor = new BcPBESecretKeyDecryptorBuilder(digestCalculatorProvider).build(passphrase)
val key = secretKey.extractPrivateKey(decryptor)
new BcPublicKeyDataDecryptorFactory(key)
}
}
.evalMap {
attemptDecrypt(pbed, _)
.map(_.some)
.handleError(_ => None) // TODO should we log these failures at the TRACE level?
}
.unNone
.head
.compile
.lastOrError
.adaptErr {
case _: NoSuchElementException => KeyRingMissingKeyException(keyId)
}

implicit def PGPSecretKeyRingCollectionInstance[F[_] : Sync]: DecryptToInputStream[F, (PGPSecretKeyRingCollection, Array[Char])] =
new DecryptToInputStream[F, (PGPSecretKeyRingCollection, Array[Char])] {
override def decryptToInputStream(input: (PGPSecretKeyRingCollection, Array[Char]),
maybeKeyId: Option[Long])
(pbed: PGPPublicKeyEncryptedData): F[InputStream] =
maybeKeyId
.toOptionT
.semiflatMap { keyId =>
ApplicativeThrow[F].catchNonFatal {
input._1.getSecretKey(keyId).pure[List]
}
}
.getOrElse(input._1.getKeyRings.asScala.toList.flatMap(_.getSecretKeys.asScala))
.flatMap(decryptWithKeys(_, input._2, pbed, maybeKeyId))
}

implicit def PGPSecretKeyRingInstance[F[_] : Sync]: DecryptToInputStream[F, (PGPSecretKeyRing, Array[Char])] =
new DecryptToInputStream[F, (PGPSecretKeyRing, Array[Char])] {
override def decryptToInputStream(input: (PGPSecretKeyRing, Array[Char]),
maybeKeyId: Option[Long])
(pbed: PGPPublicKeyEncryptedData): F[InputStream] = {
val keys = maybeKeyId.fold(input._1.getSecretKeys.asScala.toList) { keyId =>
input._1.getSecretKey(keyId).pure[List]
}

decryptWithKeys(keys, input._2, pbed, maybeKeyId)
}
}

implicit def PGPPrivateKeyInstance[F[_] : Sync]: DecryptToInputStream[F, PGPPrivateKey] =
new DecryptToInputStream[F, PGPPrivateKey] {
override def decryptToInputStream(input: PGPPrivateKey,
maybeKeyId: Option[Long])
(pbed: PGPPublicKeyEncryptedData): F[InputStream] =
if (maybeKeyId.exists(_ != input.getKeyID)) KeyMismatchException(maybeKeyId, input.getKeyID).raiseError
else
Sync[F].blocking(new BcPublicKeyDataDecryptorFactory(input))
.flatMap(attemptDecrypt(pbed, _))
}

implicit def toPGPPublicKeyEncryptedDataOps(pbed: PGPPublicKeyEncryptedData): PGPPublicKeyEncryptedDataOps = new PGPPublicKeyEncryptedDataOps(pbed)
}

class PGPPublicKeyEncryptedDataOps(val pbed: PGPPublicKeyEncryptedData) extends AnyVal {
def decryptToInputStream[F[_], A](input: A, maybeKeyId: Option[Long])
(implicit D: DecryptToInputStream[F, A]): F[InputStream] =
DecryptToInputStream[F, A].decryptToInputStream(input, maybeKeyId)(pbed)
}
106 changes: 101 additions & 5 deletions core/src/main/scala/com/dwolla/security/crypto/exceptions.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,105 @@
package com.dwolla.security.crypto

case class KeyRingMissingKeyException(expectedKeyId: Long)
extends RuntimeException(s"Cannot decrypt message with the passed keyring because it requires key $expectedKeyId, but the ring does not contain that key", null, true, false)
import cats.syntax.all._

case class KeyMismatchException(expectedKeyId: Long, actualKeyId: Long)
extends RuntimeException(s"Cannot decrypt message with key $actualKeyId because it requires key $expectedKeyId", null, true, false)
import scala.annotation.nowarn
import scala.runtime.{AbstractFunction1, AbstractFunction2}
import scala.util.control.NoStackTrace

case object EncryptionTypeError extends RuntimeException("encrypted data was not PGPPublicKeyEncryptedData", null, true, false)
class KeyRingMissingKeyException(expectedKeyId: Option[Long])
extends RuntimeException(s"Cannot decrypt message with the passed keyring because ${expectedKeyId.fold("it does not contain a compatible key and the message recipient is hidden")(id => s"it requires key $id, but the ring does not contain that key")}")
with NoStackTrace
with Product
with Equals
with Serializable {
@deprecated("only maintained for bincompat reasons", "0.4.0")
def this(keyId: Long) = this(keyId.some)

@deprecated("only maintained for bincompat reasons", "0.4.0")
override def productArity: Int = 1

@deprecated("only maintained for bincompat reasons", "0.4.0")
override def productElement(n: Int): Any = if (n == 0) expectedKeyId.getOrElse(0) else throw new IndexOutOfBoundsException()

@deprecated("only maintained for bincompat reasons", "0.4.0")
override def canEqual(that: Any): Boolean = that.isInstanceOf[KeyRingMissingKeyException]

@deprecated("only maintained for bincompat reasons", "0.4.0")
def expectedKeyId(): Long = expectedKeyId.getOrElse(0)

@nowarn
@deprecated("only maintained for bincompat reasons", "0.4.0")
def copy(keyId: Long = expectedKeyId()): KeyRingMissingKeyException = KeyRingMissingKeyException(keyId)
}

object KeyRingMissingKeyException extends AbstractFunction1[Long, KeyRingMissingKeyException] {
def apply(expectedKeyId: Option[Long]) = new KeyRingMissingKeyException(expectedKeyId)
@deprecated("only maintained for bincompat reasons", "0.4.0")
override def apply(keyId: Long): KeyRingMissingKeyException = KeyRingMissingKeyException(keyId.some)

@deprecated("only maintained for bincompat reasons", "0.4.0")
def unapply(arg: KeyRingMissingKeyException): Option[Long] =
arg.expectedKeyId().some
}

class KeyMismatchException(expectedKeyId: Option[Long], val actualKeyId: Long)
extends RuntimeException(s"Cannot decrypt message with key $actualKeyId${expectedKeyId.fold(". (The message recipient is hidden.)")(id => s" because it requires key $id")}")
with NoStackTrace
with Product
with Equals
with Serializable {
@deprecated("only maintained for bincompat reasons", "0.4.0")
def this(expectedKeyId: Long, actualKeyId: Long) = this(Option(expectedKeyId).filter(_ == 0), actualKeyId)

@deprecated("only maintained for bincompat reasons", "0.4.0")
override def productArity: Int = 2

@deprecated("only maintained for bincompat reasons", "0.4.0")
override def productElement(n: Int): Any = n match {
case 0 => expectedKeyId.getOrElse(0)
case 1 => actualKeyId
case _ => throw new IndexOutOfBoundsException()
}

@deprecated("only maintained for bincompat reasons", "0.4.0")
override def canEqual(that: Any): Boolean = that.isInstanceOf[KeyMismatchException]

@deprecated("only maintained for bincompat reasons", "0.4.0")
def expectedKeyId(): Long = expectedKeyId.getOrElse(0)

@nowarn
@deprecated("only maintained for bincompat reasons", "0.4.0")
def copy(expectedKeyId: Long = this.expectedKeyId(), actualKeyId: Long = actualKeyId) =
new KeyMismatchException(Option(expectedKeyId).filter(_ == 0), actualKeyId)
}

object KeyMismatchException extends AbstractFunction2[Long, Long, KeyMismatchException] {
@deprecated("only maintained for bincompat reasons", "0.4.0")
def unapply(arg: KeyMismatchException): Option[(Long, Long)] =
(arg.expectedKeyId(), arg.actualKeyId).some

def apply(expectedKeyId: Option[Long], actualKeyId: Long) = new KeyMismatchException(expectedKeyId, actualKeyId)

override def apply(v1: Long, v2: Long): KeyMismatchException =
KeyMismatchException(v1.some, v2)

}

object EncryptionTypeError
extends RuntimeException("encrypted data was not PGPPublicKeyEncryptedData")
with NoStackTrace
with Product
with Equals
with Serializable {

@deprecated("only maintained for bincompat reasons", "0.4.0")
override def productArity: Int = 0

@deprecated("only maintained for bincompat reasons", "0.4.0")
override def productElement(n: Int): Any = throw new IndexOutOfBoundsException()

@deprecated("only maintained for bincompat reasons", "0.4.0")
override def canEqual(that: Any): Boolean = this == EncryptionTypeError

override def hashCode(): Int = super.hashCode()
}
Loading

0 comments on commit e5e75ae

Please sign in to comment.