Skip to content

Commit

Permalink
Merge pull request #63 from davidhiendl/allow-explicit-class-references
Browse files Browse the repository at this point in the history
Allow explicit class references when building routes
  • Loading branch information
Wicpar authored Jul 22, 2020
2 parents e4fe15b + ca1f46d commit cb716fd
Show file tree
Hide file tree
Showing 19 changed files with 409 additions and 109 deletions.
15 changes: 6 additions & 9 deletions src/main/kotlin/com/papsign/ktor/openapigen/KTypeUtil.kt
Original file line number Diff line number Diff line change
@@ -1,27 +1,24 @@
package com.papsign.ktor.openapigen

import com.papsign.ktor.openapigen.annotations.mapping.openAPIName
import java.lang.reflect.Field
import kotlin.reflect.*
import kotlin.reflect.full.createType
import kotlin.reflect.full.declaredMemberProperties
import kotlin.reflect.full.memberProperties
import kotlin.reflect.jvm.javaField
import kotlin.reflect.jvm.jvmErasure

val unitKType = getKType<Unit>()

inline fun <reified T> isNullable(): Boolean {
internal inline fun <reified T> isNullable(): Boolean {
return null is T
}

inline fun <reified T> getKType() = typeOf<T>()
@PublishedApi
internal inline fun <reified T> getKType() = typeOf<T>()

fun KType.strip(nullable: Boolean = isMarkedNullable): KType {
internal fun KType.strip(nullable: Boolean = isMarkedNullable): KType {
return jvmErasure.createType(arguments, nullable)
}

fun KType.deepStrip(nullable: Boolean = isMarkedNullable): KType {
internal fun KType.deepStrip(nullable: Boolean = isMarkedNullable): KType {
return jvmErasure.createType(arguments.map { it.copy(type = it.type?.deepStrip()) }, nullable)
}

Expand All @@ -44,4 +41,4 @@ val KType.memberProperties: List<KTypeProperty>
}
}

val KClass<*>.isInterface get() = java.isInterface
internal val KClass<*>.isInterface get() = java.isInterface
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ package com.papsign.ktor.openapigen.content.type
import io.ktor.application.ApplicationCall
import io.ktor.http.ContentType
import io.ktor.util.pipeline.PipelineContext
import kotlin.reflect.KClass
import kotlin.reflect.KType

interface BodyParser: ContentTypeProvider {
fun <T: Any> getParseableContentTypes(clazz: KClass<T>): List<ContentType>
fun <T: Any> getParseableContentTypes(type: KType): List<ContentType>
suspend fun <T: Any> parseBody(clazz: KType, request: PipelineContext<Unit, ApplicationCall>): T
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import io.ktor.application.ApplicationCall
import io.ktor.http.ContentType
import io.ktor.http.HttpStatusCode
import io.ktor.util.pipeline.PipelineContext
import kotlin.reflect.KClass
import kotlin.reflect.KType

interface ResponseSerializer: ContentTypeProvider {
/**
* used to determine which registered response serializer is used, based on the accept header
*/
fun <T: Any> getSerializableContentTypes(clazz: KClass<T>): List<ContentType>
fun <T: Any> getSerializableContentTypes(type: KType): List<ContentType>
suspend fun <T: Any> respond(response: T, request: PipelineContext<Unit, ApplicationCall>, contentType: ContentType)
suspend fun <T: Any> respond(statusCode: HttpStatusCode, response: T, request: PipelineContext<Unit, ApplicationCall>, contentType: ContentType)
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ object BinaryContentTypeParser: BodyParser, ResponseSerializer, OpenAPIGenModule
return constructors.first { it.parameters.size == 1 && acceptedTypes.contains(it.parameters[0].type) }
}

override fun <T : Any> getParseableContentTypes(clazz: KClass<T>): List<ContentType> {
return clazz.findAnnotation<BinaryRequest>()?.contentTypes?.map(ContentType.Companion::parse) ?: listOf()
override fun <T : Any> getParseableContentTypes(type: KType): List<ContentType> {
return type.jvmErasure.findAnnotation<BinaryRequest>()?.contentTypes?.map(ContentType.Companion::parse) ?: listOf()
}

override fun <T: Any> getSerializableContentTypes(clazz: KClass<T>): List<ContentType> {
return clazz.findAnnotation<BinaryResponse>()?.contentTypes?.map(ContentType.Companion::parse) ?: listOf()
override fun <T: Any> getSerializableContentTypes(type: KType): List<ContentType> {
return type.jvmErasure.findAnnotation<BinaryResponse>()?.contentTypes?.map(ContentType.Companion::parse) ?: listOf()
}

override suspend fun <T : Any> respond(response: T, request: PipelineContext<Unit, ApplicationCall>, contentType: ContentType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import io.ktor.http.HttpStatusCode
import io.ktor.request.receive
import io.ktor.response.respond
import io.ktor.util.pipeline.PipelineContext
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.jvm.jvmErasure
Expand Down Expand Up @@ -63,15 +62,15 @@ object KtorContentProvider : ContentTypeProvider, BodyParser, ResponseSerializer
return contentTypes.associateWith { media.copy() }
}

override fun <T : Any> getParseableContentTypes(clazz: KClass<T>): List<ContentType> {
override fun <T : Any> getParseableContentTypes(type: KType): List<ContentType> {
return contentTypes!!.toList()
}

override suspend fun <T: Any> parseBody(clazz: KType, request: PipelineContext<Unit, ApplicationCall>): T {
return request.call.receive(clazz)
}

override fun <T: Any> getSerializableContentTypes(clazz: KClass<T>): List<ContentType> {
override fun <T: Any> getSerializableContentTypes(type: KType): List<ContentType> {
return contentTypes!!.toList()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import kotlin.reflect.jvm.jvmErasure

object MultipartFormDataContentProvider : BodyParser, OpenAPIGenModuleExtension {

override fun <T : Any> getParseableContentTypes(clazz: KClass<T>): List<ContentType> {
override fun <T : Any> getParseableContentTypes(type: KType): List<ContentType> {
return listOf(ContentType.MultiPart.FormData)
}

Expand Down Expand Up @@ -68,7 +68,7 @@ object MultipartFormDataContentProvider : BodyParser, OpenAPIGenModuleExtension
private val typeContentTypes = HashMap<KType, Map<String, MediaTypeEncodingModel>>()


override suspend fun <T : Any> parseBody(clazz: KType, request: PipelineContext<Unit, ApplicationCall>): T {
override suspend fun <T : Any> parseBody(type: KType, request: PipelineContext<Unit, ApplicationCall>): T {
val objectMap = HashMap<String, Any>()
request.context.receiveMultipart().forEachPart {
val name = it.name
Expand All @@ -86,7 +86,8 @@ object MultipartFormDataContentProvider : BodyParser, OpenAPIGenModuleExtension
}
}
}
val ctor = (clazz.classifier as KClass<T>).primaryConstructor!!
@Suppress("UNCHECKED_CAST")
val ctor = (type.classifier as KClass<T>).primaryConstructor!!
return ctor.callBy(ctor.parameters.associateWith {
val raw = objectMap[it.openAPIName]
if ((raw == null || (raw !is InputStream && streamTypes.contains(it.type))) && it.type.isMarkedNullable) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.papsign.ktor.openapigen.modules.handlers

import com.papsign.ktor.openapigen.getKType
import com.papsign.ktor.openapigen.OpenAPIGen
import com.papsign.ktor.openapigen.annotations.Request
import com.papsign.ktor.openapigen.classLogger
Expand All @@ -14,12 +13,11 @@ import com.papsign.ktor.openapigen.modules.ofType
import com.papsign.ktor.openapigen.modules.openapi.OperationModule
import com.papsign.ktor.openapigen.modules.providers.ParameterProvider
import com.papsign.ktor.openapigen.modules.registerModule
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.jvm.jvmErasure

class RequestHandlerModule<T : Any>(
val requestClass: KClass<T>,
val requestType: KType,
val requestExample: T? = null
) : OperationModule {
Expand All @@ -34,7 +32,7 @@ class RequestHandlerModule<T : Any>(
mediaType.map { Pair(it.key.toString(), it.value) }
}.flatten().associate { it }

val requestMeta = requestClass.findAnnotation<Request>()
val requestMeta = requestType.jvmErasure.findAnnotation<Request>()

val parameters = provider.ofType<ParameterProvider>().flatMap { it.getParameters(apiGen, provider) }
operation.parameters = operation.parameters?.let { (it + parameters).distinct() } ?: parameters
Expand All @@ -51,7 +49,6 @@ class RequestHandlerModule<T : Any>(
}

companion object {
inline fun <reified T : Any> create(requestExample: T? = null) = RequestHandlerModule(T::class,
getKType<T>(), requestExample)
fun <T : Any> create(tType: KType, requestExample: T? = null) = RequestHandlerModule(tType, requestExample)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.papsign.ktor.openapigen.modules.handlers

import com.papsign.ktor.openapigen.getKType
import com.papsign.ktor.openapigen.OpenAPIGen
import com.papsign.ktor.openapigen.annotations.Response
import com.papsign.ktor.openapigen.classLogger
Expand Down Expand Up @@ -46,6 +45,6 @@ class ResponseHandlerModule<T>(val responseType: KType, val responseExample: T?
}

companion object {
inline fun <reified T : Any> create(responseExample: T? = null) = ResponseHandlerModule(getKType<T>(), responseExample)
fun <T : Any> create(tType: KType, responseExample: T? = null) = ResponseHandlerModule(tType, responseExample)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ object ThrowOperationHandler : OperationModule {
private val log = classLogger()
override fun configure(apiGen: OpenAPIGen, provider: ModuleProvider<*>, operation: OperationModel) {

val exceptions = provider.ofType<ThrowInfoProvider>().flatMap { it.exceptions }
exceptions.groupBy { it.status }.forEach { exceptions ->
provider
.ofType<ThrowInfoProvider>()
.flatMap { it.exceptions }
.groupBy { it.status }
.forEach { exceptions ->
val map: MutableMap<String, MediaTypeModel<*>> = exceptions.value.flatMap { ex ->
provider.ofType<ResponseSerializer>().mapNotNull {
if (ex.contentType == unitKType) return@mapNotNull null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,29 @@ import com.papsign.ktor.openapigen.parameters.handlers.ModularParameterHandler
import com.papsign.ktor.openapigen.parameters.handlers.ParameterHandler
import com.papsign.ktor.openapigen.parameters.handlers.UnitParameterHandler
import com.papsign.ktor.openapigen.parameters.parsers.builders.Builder
import kotlin.reflect.KFunction
import kotlin.reflect.KParameter
import kotlin.reflect.KType
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.full.primaryConstructor
import kotlin.reflect.jvm.jvmErasure


inline fun <reified T : Any> buildParameterHandler(): ParameterHandler<T> {
if (Unit is T) return UnitParameterHandler as ParameterHandler<T>
val t = T::class
assert(t.isData) { "API route with ${t.simpleName} must be a data class." }
val constructor = t.primaryConstructor ?: error("API routes with ${t.simpleName} must have a primary constructor.")
fun <T : Any> buildParameterHandler(tType: KType): ParameterHandler<T> {
@Suppress("UNCHECKED_CAST")
if (tType.classifier == Unit::class) return UnitParameterHandler as ParameterHandler<T>
val tClass = tType.jvmErasure
assert(tClass.isData) { "API route with ${tClass.simpleName} must be a data class." }
val constructor = tClass.primaryConstructor ?: error("API routes with ${tClass.simpleName} must have a primary constructor.")
val parsers: Map<KParameter, Builder<*>> = constructor.parameters.associateWith { param ->
val type = param.type
param.findAnnotation<HeaderParam>()?.let { a -> a.style.factory.buildBuilderForced(type, a.explode) } ?:
param.findAnnotation<PathParam>()?.let { a -> a.style.factory.buildBuilderForced(type, a.explode) } ?:
param.findAnnotation<QueryParam>()?.let { a -> a.style.factory.buildBuilderForced(type, a.explode) } ?:
error("Parameters must be annotated with @PathParam or @QueryParam")
}
@Suppress("UNCHECKED_CAST")
return ModularParameterHandler(
parsers,
constructor
constructor as KFunction<T>
)
}
45 changes: 35 additions & 10 deletions src/main/kotlin/com/papsign/ktor/openapigen/route/Functions.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,34 @@ import io.ktor.http.HttpMethod
import io.ktor.routing.HttpMethodRouteSelector
import io.ktor.routing.createRouteFromPath
import io.ktor.util.pipeline.ContextDsl
import kotlin.reflect.KType
import kotlin.reflect.full.findAnnotation
import kotlin.reflect.jvm.jvmErasure
import kotlin.reflect.typeOf

fun <T: OpenAPIRoute<T>> T.route(path: String): T {
fun <T : OpenAPIRoute<T>> T.route(path: String): T {
return child(ktorRoute.createRouteFromPath(path)).apply {
provider.registerModule(PathProviderModule(path))
}
}

@ContextDsl
inline fun <T: OpenAPIRoute<T>> T.route(path: String, crossinline fn: T.() -> Unit) {
inline fun <T : OpenAPIRoute<T>> T.route(path: String, crossinline fn: T.() -> Unit) {
route(path).fn()
}

fun <T: OpenAPIRoute<T>> T.method(method: HttpMethod): T {
fun <T : OpenAPIRoute<T>> T.method(method: HttpMethod): T {
return child(ktorRoute.createChild(HttpMethodRouteSelector(method))).apply {
provider.registerModule(HttpMethodProviderModule(method))
}
}

@ContextDsl
inline fun <T: OpenAPIRoute<T>> T.method(method: HttpMethod, crossinline fn: T.() -> Unit) {
inline fun <T : OpenAPIRoute<T>> T.method(method: HttpMethod, crossinline fn: T.() -> Unit) {
method(method).fn()
}

fun <T: OpenAPIRoute<T>> T.provider(vararg content: ContentTypeProvider): T {
fun <T : OpenAPIRoute<T>> T.provider(vararg content: ContentTypeProvider): T {
return child().apply {
content.forEach {
provider.registerModule(it)
Expand All @@ -45,38 +48,60 @@ fun <T: OpenAPIRoute<T>> T.provider(vararg content: ContentTypeProvider): T {
}

@ContextDsl
inline fun <T: OpenAPIRoute<T>> T.provider(vararg content: ContentTypeProvider, crossinline fn: T.() -> Unit) {
inline fun <T : OpenAPIRoute<T>> T.provider(vararg content: ContentTypeProvider, crossinline fn: T.() -> Unit) {
provider(*content).fn()
}


fun <T: OpenAPIRoute<T>> T.tag(tag: APITag): T {
fun <T : OpenAPIRoute<T>> T.tag(tag: APITag): T {
return child().apply {
provider.registerModule(TagModule(listOf(tag)))
}
}


@ContextDsl
inline fun <T: OpenAPIRoute<T>> T.tag(tag: APITag, crossinline fn: T.() -> Unit) {
inline fun <T : OpenAPIRoute<T>> T.tag(tag: APITag, crossinline fn: T.() -> Unit) {
tag(tag).fn()
}

inline fun <reified P : Any, reified R : Any, reified B : Any, T: OpenAPIRoute<T>> T.preHandle(
inline fun <reified P : Any, reified R : Any, reified B : Any, T : OpenAPIRoute<T>> T.preHandle(
exampleResponse: R? = null,
exampleRequest: B? = null,
noinline handle: T.() -> Unit
) {
preHandle<P, R, B, T>(
typeOf<P>(),
typeOf<R>(),
typeOf<B>(),
exampleResponse,
exampleRequest,
handle
)
}

// hide this function from public api as it can be "misused" easily but make it accessible to inlined functions from this package
@PublishedApi
internal fun <P : Any, R : Any, B : Any, T : OpenAPIRoute<T>> T.preHandle(
pType: KType,
rType: KType,
bType: KType,
exampleResponse: R? = null,
exampleRequest: B? = null,
handle: T.() -> Unit
) {
val path = P::class.findAnnotation<Path>()
val path = pType.jvmErasure.findAnnotation<Path>()
val new = if (path != null) child(ktorRoute.createRouteFromPath(path.path)) else child()
new.apply {
provider.registerModule(
RequestHandlerModule.create(
bType,
exampleRequest
)
)
provider.registerModule(
ResponseHandlerModule.create(
rType,
exampleResponse
)
)
Expand Down
Loading

0 comments on commit cb716fd

Please sign in to comment.