From bc3af372f0f4d5cc135db205ee23c313e335347e Mon Sep 17 00:00:00 2001 From: Ilia Vologin Date: Fri, 1 Nov 2024 15:46:12 +0100 Subject: [PATCH] [js] Fix JS tests --- gradle/libs.versions.toml | 4 +++- .../operators/tensor/Squeeze.kt | 2 +- .../tfjs/operators/tensor/SqueezeTest.kt | 8 ++++---- utils/utils-testing/build.gradle.kts | 1 + .../kotlin/io.kinference/utils/TestRunner.kt | 19 ++++++++++++++++--- .../kotlin/io/kinference/utils/TestRunner.kt | 13 ------------- .../kotlin/io/kinference/utils/TestRunner.kt | 11 ----------- 7 files changed, 25 insertions(+), 33 deletions(-) delete mode 100644 utils/utils-testing/src/jsMain/kotlin/io/kinference/utils/TestRunner.kt delete mode 100644 utils/utils-testing/src/jvmMain/kotlin/io/kinference/utils/TestRunner.kt diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 14051b424..5c2b56b6d 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -5,7 +5,7 @@ apache-commons = "4.0-beta1" aws = "1.12.761" kmath = "0.3.1" kotlin-atomicfu = "0.23.2" -kotlin-coroutines = "1.8.0-RC" +kotlin-coroutines = "1.9.0" multik = "0.2.2" okio = "3.6.0" onnxruntime = "1.17.0.patched-1" @@ -40,3 +40,5 @@ slf4j-api = { module = "org.slf4j:slf4j-api", version.ref = "slf4j" } slf4j-simple = { module = "org.slf4j:slf4j-simple", version.ref = "slf4j" } wire-runtime = { module = "com.squareup.wire:wire-runtime", version.ref = "wire" } fastutil-core = { module = "it.unimi.dsi:fastutil-core", version.ref = "fastutil" } + +kotlinx-coroutines-test = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-test", version.ref = "kotlin-coroutines" } diff --git a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Squeeze.kt b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Squeeze.kt index 345374716..521b89f55 100644 --- a/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Squeeze.kt +++ b/inference/inference-tfjs/src/jsMain/kotlin/io.kinference.tfjs/operators/tensor/Squeeze.kt @@ -16,7 +16,7 @@ sealed class Squeeze(name: String, info: OperatorInfo, attributes: Map>, inputs: List, outputs: List) = when (version ?: DEFAULT_VERSION.sinceVersion) { in SqueezeVer1.VERSION.asRange() -> SqueezeVer1(name, attributes, inputs, outputs) - else -> error("Unsupported version of Constant operator: $version") + else -> error("Unsupported version of Squeeze operator: $version") } } } diff --git a/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/tensor/SqueezeTest.kt b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/tensor/SqueezeTest.kt index 57a431159..d9ddd6f07 100644 --- a/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/tensor/SqueezeTest.kt +++ b/inference/inference-tfjs/src/jsTest/kotlin/io/kinference/tfjs/operators/tensor/SqueezeTest.kt @@ -8,12 +8,12 @@ class SqueezeTest { private fun getTargetPath(dirName: String) = "squeeze/$dirName/" @Test - fun test_squeeze() = TestRunner.runTest { - TFJSAccuracyRunner.runFromResources(getTargetPath("test_squeeze")) + fun test_v1_squeeze() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("v1/test_squeeze")) } @Test - fun test_squeeze_with_negative_axes() = TestRunner.runTest { - TFJSAccuracyRunner.runFromResources(getTargetPath("test_squeeze_negative_axes")) + fun test_v1_squeeze_with_negative_axes() = TestRunner.runTest { + TFJSAccuracyRunner.runFromResources(getTargetPath("v1/test_squeeze_negative_axes")) } } diff --git a/utils/utils-testing/build.gradle.kts b/utils/utils-testing/build.gradle.kts index 8bf0c480a..7d9fa7b76 100644 --- a/utils/utils-testing/build.gradle.kts +++ b/utils/utils-testing/build.gradle.kts @@ -28,6 +28,7 @@ kotlin { api(kotlin("test-annotations-common")) api(libs.kinference.primitives.annotations) + api(libs.kotlinx.coroutines.test) } } diff --git a/utils/utils-testing/src/commonMain/kotlin/io.kinference/utils/TestRunner.kt b/utils/utils-testing/src/commonMain/kotlin/io.kinference/utils/TestRunner.kt index bafff7325..8613adf6f 100644 --- a/utils/utils-testing/src/commonMain/kotlin/io.kinference/utils/TestRunner.kt +++ b/utils/utils-testing/src/commonMain/kotlin/io.kinference/utils/TestRunner.kt @@ -1,7 +1,20 @@ package io.kinference.utils -import kotlinx.coroutines.CoroutineScope +import io.kinference.utils.time.Timer +import kotlinx.coroutines.* +import kotlinx.coroutines.test.TestResult +import kotlin.time.Duration +import kotlin.time.Duration.Companion.minutes -expect object TestRunner { - fun runTest(platform: Platform? = null, block: suspend CoroutineScope.() -> Unit) +object TestRunner { + private val logger = LoggerFactory.create("io.kinference.utils.TestRunner") + + fun runTest(platform: Platform? = null, timeout: Duration = 5.minutes, block: suspend CoroutineScope.() -> Unit): TestResult { + if (platform != null && platform != PlatformUtils.platform) return kotlinx.coroutines.test.runTest {} + + val mark = Timer.start() + val res = kotlinx.coroutines.test.runTest(timeout = timeout, testBody = block) + logger.info { "[${PlatformUtils.platform}] Test took ${mark.elapsed().millis}ms" } + return res + } } diff --git a/utils/utils-testing/src/jsMain/kotlin/io/kinference/utils/TestRunner.kt b/utils/utils-testing/src/jsMain/kotlin/io/kinference/utils/TestRunner.kt deleted file mode 100644 index fd910645c..000000000 --- a/utils/utils-testing/src/jsMain/kotlin/io/kinference/utils/TestRunner.kt +++ /dev/null @@ -1,13 +0,0 @@ -package io.kinference.utils - -import kotlinx.coroutines.* - -actual object TestRunner { - actual fun runTest(platform: Platform?, block: suspend CoroutineScope.() -> Unit) { - if (platform == null || platform == Platform.JS) { - GlobalScope.promise { - block() - } - } - } -} diff --git a/utils/utils-testing/src/jvmMain/kotlin/io/kinference/utils/TestRunner.kt b/utils/utils-testing/src/jvmMain/kotlin/io/kinference/utils/TestRunner.kt deleted file mode 100644 index 52041516e..000000000 --- a/utils/utils-testing/src/jvmMain/kotlin/io/kinference/utils/TestRunner.kt +++ /dev/null @@ -1,11 +0,0 @@ -package io.kinference.utils - -import kotlinx.coroutines.* - -actual object TestRunner { - actual fun runTest(platform: Platform?, block: suspend CoroutineScope.() -> Unit) { - if (platform == null || platform == Platform.JVM) { - runBlocking(Dispatchers.Default) { block() } - } - } -}