Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(header-params): Add support for header parameter extraction #29

Merged
merged 2 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ internal data class QueryParamSpec(
val isRequired: Boolean = false
) : ParamSpec

internal data class HeaderParamSpec(
override val name: String,
override val description: String? = null,
val isRequired: Boolean = false
) : ParamSpec


internal data class KtorRouteSpec(
val path: String,
val parameters: List<ParamSpec>?,
Expand Down
25 changes: 19 additions & 6 deletions create-plugin/src/main/kotlin/io/github/tabilzad/ktor/Utils.kt
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ internal fun List<KtorRouteSpec>.convertToSpec(): Map<String, Map<String, OpenAp
summary = it.summary,
description = it.description,
tags = it.tags?.toList()?.sorted(),
parameters = addPathParams(it) merge addQueryParams(it),
parameters = mapPathParams(it) merge mapQueryParams(it) merge mapHeaderParams(it),
requestBody = addPostBody(it),
responses = it.responses
)
Expand All @@ -108,7 +108,7 @@ infix fun <T> List<T>?.merge(params: List<T>?): List<T>? = this?.plus(params ?:

infix fun <T> Set<T>?.merge(params: Set<T>?): Set<T>? = this?.plus(params ?: emptyList()) ?: params

private fun addPathParams(spec: KtorRouteSpec): List<OpenApiSpec.PathParam>? {
private fun mapPathParams(spec: KtorRouteSpec): List<OpenApiSpec.Parameter>? {
val params = "\\{([^}]*)}".toRegex().findAll(spec.path).toList()
return if (params.isNotEmpty()) {
params.mapNotNull {
Expand All @@ -118,7 +118,7 @@ private fun addPathParams(spec: KtorRouteSpec): List<OpenApiSpec.PathParam>? {
?.any { it.name == pathParamName } == true
) {
spec.parameters?.find { it.name == pathParamName }?.let {
OpenApiSpec.PathParam(
OpenApiSpec.Parameter(
name = it.name,
`in` = "path",
required = pathParamName?.contains("?") != true,
Expand All @@ -127,7 +127,7 @@ private fun addPathParams(spec: KtorRouteSpec): List<OpenApiSpec.PathParam>? {
)
}
} else {
OpenApiSpec.PathParam(
OpenApiSpec.Parameter(
name = pathParamName.replace("?", ""),
`in` = "path",
required = !pathParamName.contains("?"),
Expand All @@ -140,9 +140,9 @@ private fun addPathParams(spec: KtorRouteSpec): List<OpenApiSpec.PathParam>? {
}
}

private fun addQueryParams(it: KtorRouteSpec): List<OpenApiSpec.PathParam>? {
private fun mapQueryParams(it: KtorRouteSpec): List<OpenApiSpec.Parameter>? {
return it.parameters?.filterIsInstance<QueryParamSpec>()?.map {
OpenApiSpec.PathParam(
OpenApiSpec.Parameter(
name = it.name,
`in` = "query",
required = it.isRequired,
Expand All @@ -152,6 +152,19 @@ private fun addQueryParams(it: KtorRouteSpec): List<OpenApiSpec.PathParam>? {
}
}

private fun mapHeaderParams(it: KtorRouteSpec): List<OpenApiSpec.Parameter>? {
return it.parameters?.filterIsInstance<HeaderParamSpec>()?.map {
OpenApiSpec.Parameter(
name = it.name,
`in` = "header",
required = it.isRequired,
schema = OpenApiSpec.SchemaType("string"),
description = it.description
)
}
}


private fun addPostBody(it: KtorRouteSpec): OpenApiSpec.RequestBody? {
return if (it.method != "get" && it.body.contentBodyRef != null) {
OpenApiSpec.RequestBody(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ object ClassIds {
val KTOR_QUERY_PARAM = FqName("io.ktor.server.request.ApplicationRequest.queryParameters")
val KTOR_RAW_QUERY_PARAM = FqName("io.ktor.server.request.ApplicationRequest.rawQueryParameters")

val KTOR_HEADER_PARAM = FqName("io.ktor.server.request.ApplicationRequest.headers")

val KTOR_TAGS_ANNOTATION = ClassId(FqName("io.github.tabilzad.ktor.annotations"), FqName("Tag"), false)
val KTOR_GENERATE_ANNOTATION = ClassId(FqName("io.github.tabilzad.ktor.annotations"), FqName("GenerateOpenApi"), false)
val KTOR_DOCS_ANNOTATION = ClassId(FqName("io.github.tabilzad.ktor.annotations"), FqName("KtorDocs"), false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ internal class ExpressionsVisitorK2(
if (parent is EndPoint && parent.body == null) {

val receiveCall = block.statements.findReceiveCallExpression()

val queryParam = block.statements.findQueryParameterExpression()
if (queryParam.isNotEmpty()) parent.parameters = parent.parameters merge queryParam.toSet()

if (queryParam.isNotEmpty()) {
parent.parameters = parent.parameters merge queryParam.toSet()
}
val headerParam = block.statements.findHeaderParameterExpression()
if (headerParam.isNotEmpty()) parent.parameters = parent.parameters merge headerParam.toSet()

if (receiveCall != null) {
val kotlinType = receiveCall.resolvedType
Expand Down Expand Up @@ -141,6 +142,13 @@ internal class ExpressionsVisitorK2(
return queryParams.map { QueryParamSpec(it) }
}

private fun List<FirStatement>.findHeaderParameterExpression(): List<ParamSpec> {
val headerParams = mutableListOf<String>()
flatMap { it.allChildren }.filterIsInstance<FirFunctionCall>()
.forEach { it.accept(HeaderParamsVisitor(session), headerParams) }
return headerParams.map { HeaderParamSpec(it) }
}

private fun List<FirStatement>.findReceiveCallExpression(): FirFunctionCall? {

val receiveFunctionCall = filterIsInstance<FirFunctionCall>()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import org.jetbrains.kotlin.fir.declarations.FirDeclaration
import org.jetbrains.kotlin.fir.declarations.FirFunction
import org.jetbrains.kotlin.fir.declarations.FirProperty
import org.jetbrains.kotlin.fir.declarations.hasAnnotation
import org.jetbrains.kotlin.fir.declarations.utils.isEnumClass
import org.jetbrains.kotlin.fir.declarations.utils.visibility
import org.jetbrains.kotlin.fir.expressions.FirAnnotation
import org.jetbrains.kotlin.fir.expressions.FirPropertyAccessExpression
import org.jetbrains.kotlin.fir.expressions.FirStatement
import org.jetbrains.kotlin.fir.expressions.toResolvedCallableSymbol
import org.jetbrains.kotlin.fir.render
import org.jetbrains.kotlin.fir.resolve.fqName
import org.jetbrains.kotlin.fir.resolve.toClassSymbol
Expand Down Expand Up @@ -211,5 +214,10 @@ fun FirRegularClassSymbol.findEnumParamValue(value: String): List<String> {
return declarationSymbols.filterIsInstance<FirEnumEntrySymbol>().map { it.name.asString() }
}

fun FirPropertyAccessExpression.isEnum(session: FirSession): Boolean = this.dispatchReceiver
?.toResolvedCallableSymbol(session)
?.resolvedReturnType
?.toRegularClassSymbol(session)?.isEnumClass == true



Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package io.github.tabilzad.ktor.k2.visitors

import io.github.tabilzad.ktor.k2.ClassIds
import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.FirProperty
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.symbols.SymbolInternals
import org.jetbrains.kotlin.fir.visitors.FirDefaultVisitor

class HeaderParamsVisitor(private val session: FirSession) : FirDefaultVisitor<Unit, MutableList<String>>() {

override fun visitElement(element: FirElement, data: MutableList<String>) {
// no-op
}

override fun visitStringConcatenationCall(
stringConcatenationCall: FirStringConcatenationCall,
data: MutableList<String>
) {
data.add(stringConcatenationCall.argumentList.arguments.flatMap { acc ->
buildList {
acc.accept(this@HeaderParamsVisitor, this)
}
}.joinToString(""))
}

override fun visitFunctionCall(functionCall: FirFunctionCall, data: MutableList<String>) {
val functionFqName = functionCall.dispatchReceiver?.toResolvedCallableSymbol(session)?.callableId?.asSingleFqName()
if (functionFqName == ClassIds.KTOR_HEADER_PARAM) {
functionCall.acceptChildren(this, data)
} else {
// skip
}
}

override fun visitLiteralExpression(literalExpression: FirLiteralExpression, data: MutableList<String>) {
val element = literalExpression.value
element?.let { data.add(it.toString()) }
}

@OptIn(SymbolInternals::class)
override fun visitResolvedNamedReference(
resolvedNamedReference: FirResolvedNamedReference,
data: MutableList<String>
) {
val fir = resolvedNamedReference.resolvedSymbol.fir
if (fir is FirProperty) {
val init = fir.initializer

if (init is FirLiteralExpression) {
init.accept(this, data)
}
}
}

@OptIn(PrivateConstantEvaluatorAPI::class)
override fun visitArgumentList(argumentList: FirArgumentList, data: MutableList<String>) {
visitArgumentList(session, argumentList, data)
}

@OptIn(SymbolInternals::class)
override fun visitPropertyAccessExpression(
propertyAccessExpression: FirPropertyAccessExpression,
data: MutableList<String>
) {
visitPropertyAccessExpression(session, propertyAccessExpression, data)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.github.tabilzad.ktor.k2.visitors

import io.github.tabilzad.ktor.k2.ClassIds
import io.github.tabilzad.ktor.k2.isEnum
import org.jetbrains.kotlin.fir.FirElement
import org.jetbrains.kotlin.fir.FirEvaluatorResult
import org.jetbrains.kotlin.fir.FirSession
Expand All @@ -9,15 +10,13 @@ import org.jetbrains.kotlin.fir.declarations.EnumValueArgumentInfo
import org.jetbrains.kotlin.fir.declarations.FirProperty
import org.jetbrains.kotlin.fir.declarations.collectEnumEntries
import org.jetbrains.kotlin.fir.declarations.extractEnumValueArgumentInfo
import org.jetbrains.kotlin.fir.declarations.utils.isEnumClass
import org.jetbrains.kotlin.fir.expressions.*
import org.jetbrains.kotlin.fir.expressions.impl.FirResolvedArgumentList
import org.jetbrains.kotlin.fir.references.FirResolvedNamedReference
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import org.jetbrains.kotlin.fir.resolve.toClassSymbol
import org.jetbrains.kotlin.fir.symbols.SymbolInternals
import org.jetbrains.kotlin.fir.types.toLookupTag
import org.jetbrains.kotlin.fir.types.toRegularClassSymbol
import org.jetbrains.kotlin.fir.visitors.FirDefaultVisitor

class QueryParamsVisitor(private val session: FirSession) : FirDefaultVisitor<Unit, MutableList<String>>() {
Expand Down Expand Up @@ -71,66 +70,75 @@ class QueryParamsVisitor(private val session: FirSession) : FirDefaultVisitor<Un
@OptIn(PrivateConstantEvaluatorAPI::class)
// TODO(Look into evaluatePropertyInitializer instead of evaluateExpression)
override fun visitArgumentList(argumentList: FirArgumentList, data: MutableList<String>) {

if (argumentList is FirResolvedArgumentList) {
val g = argumentList.mapping.keys
.filterIsInstance<FirFunctionCall>()
.map {
FirExpressionEvaluator.evaluateExpression(it, session)
}.filterIsInstance<FirEvaluatorResult.Evaluated>().map {
it.result
}.filterIsInstance<FirLiteralExpression>()

g.forEach { it.accept(this, data) }

}

argumentList.acceptChildren(this, data)
visitArgumentList(session, argumentList, data)
}

@OptIn(SymbolInternals::class)
override fun visitPropertyAccessExpression(
propertyAccessExpression: FirPropertyAccessExpression,
data: MutableList<String>
) {
visitPropertyAccessExpression(session, propertyAccessExpression, data)
}
}

val isEnum =
propertyAccessExpression.dispatchReceiver?.toResolvedCallableSymbol(session)?.resolvedReturnType?.toRegularClassSymbol(
session
)?.isEnumClass == true
val enumInfo: EnumValueArgumentInfo? = propertyAccessExpression.dispatchReceiver?.extractEnumValueArgumentInfo()
val enumEntryAccessor = propertyAccessExpression.calleeReference.toResolvedCallableSymbol()?.name

if (isEnum) {

val entries = enumInfo?.enumClassId?.toLookupTag()?.toClassSymbol(session)?.collectEnumEntries()
val v = entries?.find { it.name.asString() == enumInfo.enumEntryName.asString() }
?.initializerObjectSymbol
?.primaryConstructorSymbol(session)
?.fir?.delegatedConstructor
@OptIn(PrivateConstantEvaluatorAPI::class)
fun FirDefaultVisitor<Unit, MutableList<String>>.visitArgumentList(
session: FirSession,
argumentList: FirArgumentList,
data: MutableList<String>,
) {
if (argumentList is FirResolvedArgumentList) {
val g = argumentList.mapping.keys
.filterIsInstance<FirFunctionCall>()
.map {
FirExpressionEvaluator.evaluateExpression(it, session)
}.filterIsInstance<FirEvaluatorResult.Evaluated>().map {
it.result
}.filterIsInstance<FirLiteralExpression>()

g.forEach { it.accept(this, data) }
}

val paramName =
v?.resolvedArgumentMapping?.values?.find { it.name.asString() == enumEntryAccessor?.asString() }
val paramLiteral = v?.resolvedArgumentMapping?.entries?.find { it.value == paramName }?.key
argumentList.acceptChildren(this, data)
}

val queryParam = (paramLiteral as? FirLiteralExpression)?.value
queryParam?.let {
data.add(queryParam.toString())
}
} else {
val calleeReference = propertyAccessExpression.calleeReference
if (calleeReference is FirResolvedNamedReference) {
val fir = calleeReference.resolvedSymbol.fir
if (fir is FirProperty) {
val init = fir.initializer

if (init is FirLiteralExpression) {
init.accept(this, data)
}
@OptIn(SymbolInternals::class)
fun FirDefaultVisitor<Unit, MutableList<String>>.visitPropertyAccessExpression(
session: FirSession,
propertyAccessExpression: FirPropertyAccessExpression,
data: MutableList<String>
) {
val enumInfo: EnumValueArgumentInfo? = propertyAccessExpression.dispatchReceiver?.extractEnumValueArgumentInfo()
val enumEntryAccessor = propertyAccessExpression.calleeReference.toResolvedCallableSymbol()?.name

if (propertyAccessExpression.isEnum(session)) {

val entries = enumInfo?.enumClassId?.toLookupTag()?.toClassSymbol(session)?.collectEnumEntries()
val v = entries?.find { it.name.asString() == enumInfo.enumEntryName.asString() }
?.initializerObjectSymbol
?.primaryConstructorSymbol(session)
?.fir?.delegatedConstructor

val paramName =
v?.resolvedArgumentMapping?.values?.find { it.name.asString() == enumEntryAccessor?.asString() }
val paramLiteral = v?.resolvedArgumentMapping?.entries?.find { it.value == paramName }?.key

val queryParam = (paramLiteral as? FirLiteralExpression)?.value
queryParam?.let {
data.add(queryParam.toString())
}
} else {
val calleeReference = propertyAccessExpression.calleeReference
if (calleeReference is FirResolvedNamedReference) {
val fir = calleeReference.resolvedSymbol.fir
if (fir is FirProperty) {
val init = fir.initializer

if (init is FirLiteralExpression) {
init.accept(this, data)
}
}

}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ data class OpenApiSpec(
val description: String? = null,
val tags: List<String>? = null,
val responses: Map<String, ResponseDetails>? = null,
val parameters: List<PathParam>? = null,
val parameters: List<Parameter>? = null,
val requestBody: RequestBody? = null
)

Expand Down Expand Up @@ -58,7 +58,7 @@ data class OpenApiSpec(
var required: MutableList<String>? = null
) : NamedObject

data class PathParam(
data class Parameter(
override val name: String,
override val `in`: String,
override val required: Boolean = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,14 @@ class K2StabilityTest {
result.assertWith(expected)
}

@Test
fun `should generate correct swagger definitions for endpoint with header parameters `() {
val (source, expected) = loadSourceAndExpected("HeaderParameters")
generateCompilerTest(testFile, source)
val result = testFile.readText()
result.assertWith(expected)
}

@Test
fun `should include private fields or ones annotated with @Transient`() {
val (source, expected) = loadSourceAndExpected("PrivateFieldsNegation")
Expand Down
Loading
Loading