Skip to content

Commit

Permalink
Rework token invalidation
Browse files Browse the repository at this point in the history
  • Loading branch information
Кирилл committed Sep 22, 2023
1 parent 85880b4 commit a702eaa
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 104 deletions.
22 changes: 12 additions & 10 deletions src/main/kotlin/io/emeraldpay/dshackle/auth/AuthContext.kt
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package io.emeraldpay.dshackle.auth

import com.github.benmanes.caffeine.cache.Caffeine
import org.springframework.stereotype.Component
import java.time.Duration
import java.time.Instant
import java.util.concurrent.ConcurrentHashMap

@Component
class AuthContext {
private val sessions = Caffeine.newBuilder()
.expireAfterAccess(Duration.ofDays(1))
.build<String, Boolean>()

companion object {
val sessions = ConcurrentHashMap<String, TokenWrapper>()

fun putTokenInContext(tokenWrapper: TokenWrapper) {
sessions[tokenWrapper.sessionId] = tokenWrapper
}
fun putSessionInContext(tokenWrapper: TokenWrapper) {
sessions.put(tokenWrapper.sessionId, true)
}

fun removeToken(sessionId: String) {
sessions.remove(sessionId)
}
fun containsSession(sessionId: String): Boolean {
return sessions.asMap()[sessionId] != null
}

data class TokenWrapper(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ const val AUTH_METHOD_NAME = "emerald.Auth/Authenticate"
const val REFLECT_METHOD_NAME = "grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo"

@Component
class AuthInterceptor : ServerInterceptor {
class AuthInterceptor(
private val authContext: AuthContext
) : ServerInterceptor {
private val specialMethods = setOf(AUTH_METHOD_NAME, REFLECT_METHOD_NAME)

override fun <ReqT : Any, RespT : Any> interceptCall(
Expand All @@ -26,7 +28,7 @@ class AuthInterceptor : ServerInterceptor {
)
val isOrdinaryMethod = !specialMethods.contains(call.methodDescriptor.fullMethodName)

if (isOrdinaryMethod && (sessionId == null || !AuthContext.sessions.containsKey(sessionId))) {
if (isOrdinaryMethod && (sessionId == null || !authContext.containsSession(sessionId))) {
val cause = if (sessionId == null) "sessionId is not passed" else "Session $sessionId does not exist"
throw Status.UNAUTHENTICATED
.withDescription(cause)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io.emeraldpay.dshackle.auth.processor

import com.auth0.jwt.JWT
import com.auth0.jwt.JWTVerifier
import com.auth0.jwt.RegisteredClaims
import com.auth0.jwt.algorithms.Algorithm
import io.emeraldpay.dshackle.auth.AuthContext
import io.emeraldpay.dshackle.auth.service.KeyReader
Expand All @@ -13,6 +14,7 @@ import java.security.PublicKey
import java.security.interfaces.RSAPrivateKey
import java.security.interfaces.RSAPublicKey
import java.time.Instant
import java.time.temporal.ChronoUnit
import java.util.UUID

const val SESSION_ID = "sessionId"
Expand All @@ -37,6 +39,9 @@ abstract class AuthProcessor(
try {
val verifier: JWTVerifier = JWT.require(verifyingAlgorithm(keys.externalPublicKey))
.withIssuer(authorizationConfig.publicKeyOwner)
.withClaim(RegisteredClaims.ISSUED_AT) { claim, _ ->
claim.asInstant().plus(1, ChronoUnit.MINUTES).isAfter(Instant.now())
}
.build()
verifier.verify(token)
} catch (e: Exception) {
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import org.springframework.stereotype.Service
class AuthService(
private val authorizationConfig: AuthorizationConfig,
private val rsaKeyReader: KeyReader,
private val authProcessorResolver: AuthProcessorResolver
private val authProcessorResolver: AuthProcessorResolver,
private val authContext: AuthContext
) {

fun authenticate(token: String): String {
Expand All @@ -31,7 +32,7 @@ class AuthService(
.getAuthProcessor(decodedJwt)
.process(keys, token)
.run {
AuthContext.putTokenInContext(this)
authContext.putSessionInContext(this)
this.token
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ import java.io.StringReader
import java.nio.file.Files
import java.nio.file.Paths
import java.security.KeyFactory
import java.security.PrivateKey
import java.security.PublicKey
import java.security.interfaces.RSAPrivateKey
import java.security.interfaces.RSAPublicKey
import java.security.spec.PKCS8EncodedKeySpec
import java.security.spec.X509EncodedKeySpec
import java.time.Instant

class AuthProcessorV1Test {
private val processor = AuthProcessorV1(
Expand All @@ -31,12 +35,13 @@ class AuthProcessorV1Test {
)
private val rsaKeyReader = RsaKeyReader()
private val privProviderPath = ResourceUtils.getFile("classpath:keys/priv.p8.key").path
private val privDrpcPath = ResourceUtils.getFile("classpath:keys/priv-drpc.p8.key").path
private val publicDrpcPath = ResourceUtils.getFile("classpath:keys/public-drpc.pem").path
private val token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkcnBjIiwiaWF0IjoxNjkyMTg1OTMxLCJ2ZXJzaW9uI" +
"joiVjEifQ.BZILN0GQ7JzXGFz-GZIbFTT9E5L-miB4Nga0v4o_cQThk8gbDelBRzEfdsqxCq_ppPr3v_Own8M-vR9yQElx5nEdlI4xe5QAMdIvr3g" +
"12fMckydX9IsW4sVQ1kJJY8RrHb-WL-uI0WSWqoMSwf-Psb-UyiEHAjc3oK7fA72lBaGT4waPHOxRBPvezwg7N934vCZvZMAftFfVgmeEtbCeD7bF" +
"umEr0uEmkIKPTg4QwP-VMvqoLBYpMiJVzP_Ipg_wRHJ7fUN0BGEPjjMvhQ_6TWByiQUBz1kTMd0Ebf_kEuXFQeiwA-FXHJpWczzh66CbbmmWAWsi" +
"ehKw3KPZeBj0oQ"
private val token = JWT.create()
.withIssuedAt(Instant.now())
.withIssuer("drpc")
.withClaim(VERSION, AuthVersion.V1.toString())
.sign(Algorithm.RSA256(generatePrivateKey(privDrpcPath) as RSAPrivateKey))
private val keyPair = rsaKeyReader.getKeyPair(privProviderPath, publicDrpcPath)

@Test
Expand Down Expand Up @@ -88,4 +93,14 @@ class AuthProcessorV1Test {

return KeyFactory.getInstance("RSA").generatePublic(publicKeySpec)
}

private fun generatePrivateKey(path: String): PrivateKey {
val privateKeyReader = StringReader(Files.readString(Paths.get(path)))

val privatePem = PEMParser(privateKeyReader).readPemObject()

val privateKeySpec = PKCS8EncodedKeySpec(privatePem.content)

return KeyFactory.getInstance("RSA").generatePrivate(privateKeySpec)
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.CompletableFuture
class AuthServiceTest {
private val rsaKeyReader = mock(KeyReader::class.java)
private val mockV1Processor = mock(AuthProcessor::class.java)
private val authContext = AuthContext()
private val factory = AuthProcessorResolver(mockV1Processor)

private val token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJkcnBjIiwiaWF0IjoxNjkyMTg1OTMxLCJ2ZXJzaW9uI" +
Expand All @@ -31,7 +32,7 @@ class AuthServiceTest {

@Test
fun `unimplemented error if auth is disabled`() {
val authService = AuthService(AuthorizationConfig.default(), rsaKeyReader, factory)
val authService = AuthService(AuthorizationConfig.default(), rsaKeyReader, factory, authContext)

val e = assertThrows(StatusException::class.java) { authService.authenticate("") }
assertEquals("UNIMPLEMENTED: Authentication process is not enabled", e.message)
Expand All @@ -48,7 +49,7 @@ class AuthServiceTest {
AuthorizationConfig.ServerConfig("privPath", "pubPath"),
AuthorizationConfig.ClientConfig.default()
),
rsaKeyReader, factory
rsaKeyReader, factory, authContext
)
val pair = KeyReader.Keys(mock(PrivateKey::class.java), mock(PublicKey::class.java))

Expand All @@ -59,7 +60,7 @@ class AuthServiceTest {
authService.authenticate(token)
verify(rsaKeyReader).getKeyPair("privPath", "pubPath")
verify(mockV1Processor).process(pair, token)
assertTrue(AuthContext.sessions.containsKey(tokenWrapper.sessionId))
assertTrue(authContext.containsSession(tokenWrapper.sessionId))
}

@Test
Expand All @@ -77,7 +78,7 @@ class AuthServiceTest {
AuthorizationConfig.ServerConfig("privPath", "pubPath"),
AuthorizationConfig.ClientConfig.default()
),
rsaKeyReader, factory
rsaKeyReader, factory, authContext
)

`when`(rsaKeyReader.getKeyPair("privPath", "pubPath")).thenReturn(pair)
Expand All @@ -93,7 +94,7 @@ class AuthServiceTest {

verify(rsaKeyReader, times(2)).getKeyPair("privPath", "pubPath")
verify(mockV1Processor, times(2)).process(pair, token)
assertTrue(AuthContext.sessions.containsKey(tokenWrapper.sessionId))
assertTrue(AuthContext.sessions.containsKey(tokenWrapper1.sessionId))
assertTrue(authContext.containsSession(tokenWrapper.sessionId))
assertTrue(authContext.containsSession(tokenWrapper1.sessionId))
}
}

0 comments on commit a702eaa

Please sign in to comment.