Skip to content

Commit

Permalink
Allow modifying response headers for Spring Boot and Micronaut apps (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
rjaros committed Mar 2, 2024
1 parent 6b05c75 commit 903bc38
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright (c) 2017-present Robert Jaros
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/

package io.kvision.remote

import io.micronaut.context.annotation.Bean
import io.micronaut.context.annotation.Factory
import io.micronaut.http.MutableHttpResponse

/**
* A helper class for holding http response mutator function.
*/
class HttpResponseMutator {
internal var responseMutator: (MutableHttpResponse<String>.() -> Unit)? = null
fun mutate(responseMutator: MutableHttpResponse<String>.() -> Unit) {
this.responseMutator = responseMutator
}
}

internal object ResponseMutatorHolder {
val threadLocalResponseMutator = ThreadLocal<HttpResponseMutator>()
}

/**
* Helper factory for the HttpResponseMutator bean.
*/
@Factory
open class HttpResponseMutatorBeanFactory {
@Bean
open fun httpResponseMutator(): HttpResponseMutator {
return ResponseMutatorHolder.threadLocalResponseMutator.get() ?: HttpResponseMutator()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ open class KVController {
val handler = kvManagers.services.asSequence().mapNotNull {
it.routeMapRegistry.findHandler(method, "/$path")
}.firstOrNull() ?: return HttpResponse.notFound()
return handler(request, RequestHolder.threadLocalRequest, applicationContext)
return handler(
request,
RequestHolder.threadLocalRequest,
ResponseMutatorHolder.threadLocalResponseMutator,
applicationContext
)
}

private fun handleSse(path: String?, request: HttpRequest<*>): Publisher<Event<String>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ import java.nio.charset.StandardCharsets
import kotlin.reflect.KClass

typealias RequestHandler =
suspend (HttpRequest<*>, ThreadLocal<HttpRequest<*>>, ApplicationContext) -> HttpResponse<String>
suspend (HttpRequest<*>, ThreadLocal<HttpRequest<*>>, ThreadLocal<HttpResponseMutator>, ApplicationContext) -> HttpResponse<String>

typealias WebsocketHandler = suspend (
WebSocketSession, ThreadLocal<WebSocketSession>, ApplicationContext, ReceiveChannel<String>, SendChannel<String>
Expand Down Expand Up @@ -73,10 +73,13 @@ actual open class KVServiceManager<out T : Any> actual constructor(private val s
serializerFactory: () -> KSerializer<RET>
): RequestHandler {
val serializer by lazy { serializerFactory() }
return { req, tlReq, ctx ->
return { req, tlReq, tlResponseMutator, ctx ->
val httpResponseMutator = HttpResponseMutator()
tlReq.set(req)
tlResponseMutator.set(httpResponseMutator)
val service = ctx.getBean(serviceClass.java)
tlReq.remove()
tlResponseMutator.remove()
val jsonRpcRequest = if (method == HttpMethod.GET) {
val parameters = (0..<numberOfParams).map {
req.parameters["p$it"]?.let {
Expand Down Expand Up @@ -114,7 +117,7 @@ actual open class KVServiceManager<out T : Any> actual constructor(private val s
)
}
)
)
).also { mutableHttpResponse -> httpResponseMutator.responseMutator?.let { mutableHttpResponse.it() } }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import org.springframework.http.MediaType.TEXT_HTML
import org.springframework.stereotype.Component
import org.springframework.web.reactive.function.server.ServerRequest
import org.springframework.web.reactive.function.server.ServerResponse
import org.springframework.web.reactive.function.server.ServerResponse.BodyBuilder
import org.springframework.web.reactive.function.server.ServerResponse.HeadersBuilder
import org.springframework.web.reactive.function.server.buildAndAwait
import org.springframework.web.reactive.function.server.coRouter
import org.springframework.web.reactive.function.server.router
Expand Down Expand Up @@ -76,6 +78,8 @@ open class KVHandler(val services: List<KVServiceManager<*>>, val applicationCon

private val threadLocalRequest = ThreadLocal<ServerRequest>()

private val threadLocalHeadersBuilder = ThreadLocal<HeadersBuilder<BodyBuilder>>()

@PostConstruct
open fun init() {
services.forEach { it.deSerializer = kotlinxObjectDeSerializer(serializersModules) }
Expand All @@ -87,6 +91,12 @@ open class KVHandler(val services: List<KVServiceManager<*>>, val applicationCon
return threadLocalRequest.get() ?: KVServerRequest()
}

@Bean
@Scope(BeanDefinition.SCOPE_PROTOTYPE)
open fun headersBuilder(): HeadersBuilder<BodyBuilder> {
return threadLocalHeadersBuilder.get() ?: ServerResponse.ok()
}

open suspend fun handle(request: ServerRequest): ServerResponse {

fun getHandler(): RequestHandler? {
Expand All @@ -101,6 +111,7 @@ open class KVHandler(val services: List<KVServiceManager<*>>, val applicationCon
return (getHandler() ?: return ServerResponse.notFound().buildAndAwait())(
request,
threadLocalRequest,
threadLocalHeadersBuilder,
applicationContext
)
}
Expand All @@ -117,6 +128,7 @@ open class KVHandler(val services: List<KVServiceManager<*>>, val applicationCon
return (getSseHandler() ?: return ServerResponse.notFound().buildAndAwait())(
request,
threadLocalRequest,
threadLocalHeadersBuilder,
applicationContext
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ import org.springframework.http.codec.ServerSentEvent
import org.springframework.web.reactive.function.BodyInserters
import org.springframework.web.reactive.function.server.ServerRequest
import org.springframework.web.reactive.function.server.ServerResponse
import org.springframework.web.reactive.function.server.ServerResponse.BodyBuilder
import org.springframework.web.reactive.function.server.ServerResponse.HeadersBuilder
import org.springframework.web.reactive.function.server.awaitBody
import org.springframework.web.reactive.function.server.bodyValueAndAwait
import org.springframework.web.reactive.function.server.json
Expand All @@ -49,7 +51,7 @@ import java.nio.charset.StandardCharsets
import kotlin.jvm.optionals.getOrNull
import kotlin.reflect.KClass

typealias RequestHandler = suspend (ServerRequest, ThreadLocal<ServerRequest>, ApplicationContext) -> ServerResponse
typealias RequestHandler = suspend (ServerRequest, ThreadLocal<ServerRequest>, ThreadLocal<HeadersBuilder<BodyBuilder>>, ApplicationContext) -> ServerResponse
typealias WebsocketHandler = suspend (
WebSocketSession, ThreadLocal<WebSocketSession>, ApplicationContext, ReceiveChannel<String>, SendChannel<String>
) -> Unit
Expand All @@ -76,10 +78,13 @@ actual open class KVServiceManager<out T : Any> actual constructor(private val s
serializerFactory: () -> KSerializer<RET>
): RequestHandler {
val serializer by lazy { serializerFactory() }
return { req, tlReq, ctx ->
return { req, tlReq, tlHeadersBuilder, ctx ->
val bodyBuilder = ServerResponse.ok().json()
tlReq.set(req)
tlHeadersBuilder.set(bodyBuilder)
val service = ctx.getBean(serviceClass.java)
tlReq.remove()
tlHeadersBuilder.remove()
val jsonRpcRequest = if (method == HttpMethod.GET) {
val parameters = (0..<numberOfParams).map {
req.queryParam("p$it").getOrNull()?.let {
Expand All @@ -90,7 +95,7 @@ actual open class KVServiceManager<out T : Any> actual constructor(private val s
} else {
req.awaitBody()
}
ServerResponse.ok().json().bodyValueAndAwait(
bodyBuilder.bodyValueAndAwait(
deSerializer.serializeNonNull(
try {
val result = function.invoke(service, jsonRpcRequest.params)
Expand Down Expand Up @@ -148,10 +153,13 @@ actual open class KVServiceManager<out T : Any> actual constructor(private val s
serializerFactory: () -> KSerializer<PAR>
): SseHandler {
val serializer by lazy { serializerFactory() }
return { req, tlReq, ctx ->
return { req, tlReq, tlHeadersBuilder, ctx ->
val bodyBuilder = ServerResponse.ok().contentType(MediaType.TEXT_EVENT_STREAM)
tlReq.set(req)
tlHeadersBuilder.set(bodyBuilder)
val service = ctx.getBean(serviceClass.java)
tlReq.remove()
tlHeadersBuilder.remove()
val channel = Channel<String>()
val events = flux {
for (item in channel) {
Expand All @@ -174,9 +182,7 @@ actual open class KVServiceManager<out T : Any> actual constructor(private val s
function = function
)
}
ServerResponse
.ok()
.contentType(MediaType.TEXT_EVENT_STREAM)
bodyBuilder
.body(BodyInserters.fromServerSentEvents(events)).awaitSingle()
}
}
Expand Down

0 comments on commit 903bc38

Please sign in to comment.