Skip to content

Commit

Permalink
Map @param annotation to androidx.benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
kx412764776 committed Sep 13, 2024
1 parent 8502fd6 commit 5874eee
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private fun Project.createSetupAndroidProjectTask(target: AndroidBenchmarkTarget
val unpackedDir = getUnpackAarDir(compilation)
val newText = it.readText().replace(
"<<BENCHMARK_CLASSES_JAR_PATH>>",
unpackedDir.resolve("classes.jar").absolutePath
unpackedDir.resolve("classes.jar").absolutePath.replace("\\", "/")
)
it.writeText(newText)
}
Expand Down
249 changes: 193 additions & 56 deletions plugin/main/src/kotlinx/benchmark/gradle/AndroidSourceGenerator.kt
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
@file:OptIn(RequiresKotlinCompilerEmbeddable::class)

package kotlinx.benchmark.gradle

import com.squareup.kotlinpoet.*
import com.squareup.kotlinpoet.ParameterizedTypeName.Companion.parameterizedBy
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.measureAnnotationFQN
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.paramAnnotationFQN
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.setupAnnotationFQN
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.teardownAnnotationFQN
import kotlinx.benchmark.gradle.SuiteSourceGenerator.Companion.warmupAnnotationFQN
import kotlinx.benchmark.gradle.internal.generator.RequiresKotlinCompilerEmbeddable
import java.io.File
import java.util.*

Expand All @@ -10,7 +19,11 @@ internal fun generateBenchmarkSourceFiles(
) {
classDescriptors.forEach { descriptor ->
if (descriptor.visibility == Visibility.PUBLIC && !descriptor.isAbstract) {
generateDescriptorFile(descriptor, targetDir)
if (descriptor.getSpecificField(paramAnnotationFQN).isNotEmpty()) {
generateParameterizedDescriptorFile(descriptor, targetDir)
} else {
generateDescriptorFile(descriptor, targetDir)
}
}
}
}
Expand All @@ -27,6 +40,12 @@ private fun generateDescriptorFile(descriptor: ClassAnnotationsDescriptor, andro
.addImport("androidx.benchmark", "BenchmarkState")
.addImport("androidx.benchmark", "ExperimentalBenchmarkStateApi")

if (descriptor.hasSetupOrTeardownMethods()) {
fileSpecBuilder
.addImport("org.junit", "Before")
.addImport("org.junit", "After")
}

val typeSpecBuilder = TypeSpec.classBuilder(descriptorName)
.addAnnotation(
AnnotationSpec.builder(ClassName("org.junit.runner", "RunWith"))
Expand All @@ -40,7 +59,122 @@ private fun generateDescriptorFile(descriptor: ClassAnnotationsDescriptor, andro
fileSpecBuilder.build().writeTo(androidTestDir)
}

private fun addBenchmarkMethods(typeSpecBuilder: TypeSpec.Builder, descriptor: ClassAnnotationsDescriptor) {
private fun generateParameterizedDescriptorFile(descriptor: ClassAnnotationsDescriptor, androidTestDir: File) {
val descriptorName = "${descriptor.name}_Descriptor"
val packageName = descriptor.packageName
val fileSpecBuilder = FileSpec.builder(packageName, descriptorName)
.addImport("org.junit.runner", "RunWith")
.addImport("org.junit.runners", "Parameterized")
.addImport("androidx.benchmark", "BenchmarkState")
.addImport("androidx.benchmark", "ExperimentalBenchmarkStateApi")
.addImport("org.junit", "Test")

if (descriptor.hasSetupOrTeardownMethods()) {
fileSpecBuilder
.addImport("org.junit", "Before")
.addImport("org.junit", "After")
}

fileSpecBuilder.addAnnotation(
AnnotationSpec.builder(ClassName("org.junit.runner", "RunWith"))
.addMember("%T::class", ClassName("org.junit.runners", "Parameterized"))
.build()
)

// Generate constructor
val constructorSpec = FunSpec.constructorBuilder()
val paramFields = descriptor.getSpecificField(paramAnnotationFQN)
paramFields.forEach { param ->
constructorSpec.addParameter(param.name, getTypeName(param.type))
}

val typeSpecBuilder = TypeSpec.classBuilder(descriptorName)
.primaryConstructor(constructorSpec.build())
.addProperties(paramFields.map { param ->
PropertySpec.builder(param.name, getTypeName(param.type))
.initializer(param.name)
.addModifiers(KModifier.PRIVATE)
.build()
})

addBenchmarkMethods(typeSpecBuilder, descriptor, true)

// Generate companion object with parameters
val companionSpec = TypeSpec.companionObjectBuilder()
.addFunction(generateParametersFunction(paramFields))
.build()

typeSpecBuilder.addType(companionSpec)

fileSpecBuilder.addType(typeSpecBuilder.build())
fileSpecBuilder.build().writeTo(androidTestDir)
}

private fun generateParametersFunction(paramFields: List<FieldAnnotationsDescriptor>): FunSpec {
val dataFunctionBuilder = FunSpec.builder("data")
.addAnnotation(JvmStatic::class)
.returns(
ClassName("java.util", "Collection")
.parameterizedBy(
ClassName("kotlin", "Array")
.parameterizedBy(ANY)
)
)

val paramNameAndIndex = paramFields.mapIndexed { index, param ->
"${param.name}={${index}}"
}.joinToString(", ")

val paramAnnotationValue = "{index}: $paramNameAndIndex"

dataFunctionBuilder.addAnnotation(
AnnotationSpec.builder(ClassName("org.junit.runners", "Parameterized.Parameters"))
.addMember("name = \"%L\"", paramAnnotationValue)
.build()
)

val paramValueLists = paramFields.map { param ->
val values = param.annotations
.find { it.name == paramAnnotationFQN }
?.parameters?.get("value") as List<*>

values.map { value ->
if (param.type == "java.lang.String") {
"\"\"\"$value\"\"\""
} else {
value.toString()
}
}
}

val cartesianProduct = cartesianProduct(paramValueLists as List<List<Any>>)

val returnStatement = StringBuilder("return listOf(\n")
cartesianProduct.forEachIndexed { index, combination ->
val arrayContent = combination.joinToString(", ")
returnStatement.append(" arrayOf($arrayContent)")
if (index != cartesianProduct.size - 1) {
returnStatement.append(",\n")
}
}
returnStatement.append("\n)")
dataFunctionBuilder.addStatement(returnStatement.toString())

return dataFunctionBuilder.build()
}

private fun cartesianProduct(lists: List<List<Any>>): List<List<Any>> {
if (lists.isEmpty()) return emptyList()
return lists.fold(listOf(listOf<Any>())) { acc, list ->
acc.flatMap { prefix -> list.map { value -> prefix + value } }
}
}

private fun addBenchmarkMethods(
typeSpecBuilder: TypeSpec.Builder,
descriptor: ClassAnnotationsDescriptor,
isParameterized: Boolean = false
) {
val className = "${descriptor.packageName}.${descriptor.name}"
val propertyName = descriptor.name.decapitalize(Locale.getDefault())

Expand All @@ -55,70 +189,106 @@ private fun addBenchmarkMethods(typeSpecBuilder: TypeSpec.Builder, descriptor: C
descriptor.methods
.filter { it.visibility == Visibility.PUBLIC && it.parameters.isEmpty() }
.filterNot { method ->
method.annotations.any { annotation -> annotation.name == "kotlinx.benchmark.Param" }
method.annotations.any { annotation -> annotation.name == paramAnnotationFQN }
}
.forEach { method ->
when {
method.annotations.any { it.name == "kotlinx.benchmark.Setup" || it.name == "kotlinx.benchmark.TearDown" } -> {
method.annotations.any { it.name == setupAnnotationFQN || it.name == teardownAnnotationFQN } -> {
generateNonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
}

isParameterized && descriptor.getSpecificField(paramAnnotationFQN).isNotEmpty() -> {
generateParameterizedMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
}

else -> {
generateMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder)
}
}
}
}

private fun generateMeasurableMethod(
private fun generateCommonMeasurableMethod(
descriptor: ClassAnnotationsDescriptor,
method: MethodAnnotationsDescriptor,
propertyName: String,
typeSpecBuilder: TypeSpec.Builder
typeSpecBuilder: TypeSpec.Builder,
isParameterized: Boolean
) {
val measurementIterations = descriptor.annotations
.find { it.name == "kotlinx.benchmark.Measurement" }
.find { it.name == measureAnnotationFQN }
?.parameters?.get("iterations") as? Int ?: 5
val warmupIterations = descriptor.annotations
.find { it.name == "kotlinx.benchmark.Warmup" }
.find { it.name == warmupAnnotationFQN }
?.parameters?.get("iterations") as? Int ?: 5

val methodName = "${descriptor.packageName}.${descriptor.name}.${method.name}"

val methodSpecBuilder = FunSpec.builder("benchmark_${descriptor.name}_${method.name}")
.addAnnotation(ClassName("org.junit", "Test"))
.addAnnotation(
AnnotationSpec.builder(ClassName("kotlin", "OptIn"))
.addMember("%T::class", ClassName("androidx.benchmark", "ExperimentalBenchmarkStateApi"))
.build()
)
// TODO: Add warmupCount and repeatCount parameters

if (isParameterized) {
descriptor.getSpecificField(paramAnnotationFQN).forEach { field ->
methodSpecBuilder.addStatement("$propertyName.${field.name} = ${field.name}")
}
}

methodSpecBuilder
.addStatement(
"val state = %T(warmupCount = $warmupIterations, repeatCount = $measurementIterations)",
ClassName("androidx.benchmark", "BenchmarkState")
)
.addStatement("println(\"Android: $methodName\")")
.beginControlFlow("while (state.keepRunning())")
.addStatement("$propertyName.${method.name}()")
.endControlFlow()
.addStatement("val measurementResult = state.getMeasurementTimeNs()")
.beginControlFlow("measurementResult.forEachIndexed { index, time ->")
.addStatement("println(\"Iteration \${index + 1}: \$time ns\")")
.endControlFlow()

typeSpecBuilder.addFunction(methodSpecBuilder.build())
}

private fun generateParameterizedMeasurableMethod(
descriptor: ClassAnnotationsDescriptor,
method: MethodAnnotationsDescriptor,
propertyName: String,
typeSpecBuilder: TypeSpec.Builder
) {
generateCommonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder, isParameterized = true)
}

private fun generateMeasurableMethod(
descriptor: ClassAnnotationsDescriptor,
method: MethodAnnotationsDescriptor,
propertyName: String,
typeSpecBuilder: TypeSpec.Builder
) {
generateCommonMeasurableMethod(descriptor, method, propertyName, typeSpecBuilder, isParameterized = false)
}


private fun generateNonMeasurableMethod(
descriptor: ClassAnnotationsDescriptor,
method: MethodAnnotationsDescriptor,
propertyName: String,
typeSpecBuilder: TypeSpec.Builder
) {
when (method.annotations.first().name) {
"kotlinx.benchmark.Setup" -> {
setupAnnotationFQN -> {
val methodSpecBuilder = FunSpec.builder("benchmark_${descriptor.name}_setUp")
.addAnnotation(ClassName("org.junit", "Before"))
.addStatement("$propertyName.${method.name}()")
typeSpecBuilder.addFunction(methodSpecBuilder.build())
}

"kotlinx.benchmark.TearDown" -> {
teardownAnnotationFQN -> {
val methodSpecBuilder = FunSpec.builder("benchmark_${descriptor.name}_tearDown")
.addAnnotation(ClassName("org.junit", "After"))
.addStatement("$propertyName.${method.name}()")
Expand All @@ -127,49 +297,16 @@ private fun generateNonMeasurableMethod(
}
}

private fun updateAndroidDependencies(buildGradleFile: File, dependencies: List<Pair<String, String?>>) {
if (buildGradleFile.exists()) {
val buildGradleContent = buildGradleFile.readText()

if (buildGradleContent.contains("android {")) {
val androidBlockStart = buildGradleContent.indexOf("android {")
val androidBlockEnd = buildGradleContent.lastIndexOf("}") + 1
val androidBlockContent = buildGradleContent.substring(androidBlockStart, androidBlockEnd)

val newDependencies = dependencies.filterNot { (dependency, version) ->
val dependencyString = version?.let { """$dependency:$version""" } ?: dependency
androidBlockContent.contains(dependencyString)
}
if (newDependencies.isNotEmpty()) {
val updatedAndroidBlockContent = if (androidBlockContent.contains("dependencies {")) {
val dependenciesBlockStart = androidBlockContent.indexOf("dependencies {")
val dependenciesBlockEnd = androidBlockContent.indexOf("}", dependenciesBlockStart) + 1
val dependenciesBlockContent =
androidBlockContent.substring(dependenciesBlockStart, dependenciesBlockEnd)

val newDependenciesString = newDependencies.joinToString("\n ") { (dependency, version) ->
version?.let { """androidTestImplementation("$dependency:$version")""" }
?: """androidTestImplementation(files("$dependency"))"""
}
androidBlockContent.replace(
dependenciesBlockContent,
dependenciesBlockContent.replace(
"dependencies {",
"dependencies {\n $newDependenciesString"
)
)
} else {
val newDependenciesString = newDependencies.joinToString("\n ") { (dependency, version) ->
version?.let { """androidTestImplementation("$dependency:$version")""" }
?: """androidTestImplementation(files("$dependency"))"""
}
androidBlockContent.replace("{", "{\n dependencies {\n $newDependenciesString\n }\n")
}

val updatedBuildGradleContent =
buildGradleContent.replace(androidBlockContent, updatedAndroidBlockContent)
buildGradleFile.writeText(updatedBuildGradleContent)
}
}
private fun getTypeName(type: String): TypeName {
return when (type) {
"int" -> Int::class.asTypeName()
"long" -> Long::class.asTypeName()
"boolean" -> Boolean::class.asTypeName()
"float" -> Float::class.asTypeName()
"double" -> Double::class.asTypeName()
"char" -> Char::class.asTypeName()
"byte" -> Byte::class.asTypeName()
"short" -> Short::class.asTypeName()
else -> ClassName.bestGuess(type)
}
}
}
Loading

0 comments on commit 5874eee

Please sign in to comment.