From f862f83ff8fedc3b7b63814391c4626d79008adf Mon Sep 17 00:00:00 2001 From: James O'Leary <65884233+jpohhhh@users.noreply.github.com> Date: Sun, 8 Dec 2024 21:22:13 -0500 Subject: [PATCH] feat(pyannote): Android x all platforms working, min sdk 24 --- android/build.gradle | 4 +- .../kotlin/com/telosnex/fonnx/FonnxPlugin.kt | 44 +++++++++++++++++++ .../kotlin/com/telosnex/fonnx/OrtPyannote.kt | 35 +++++++++------ example/android/app/build.gradle | 2 +- example/lib/pyannote_widget.dart | 2 +- 5 files changed, 69 insertions(+), 18 deletions(-) diff --git a/android/build.gradle b/android/build.gradle index 37d7fc4..656f85c 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -2,7 +2,7 @@ group 'com.telosnex.fonnx' version '1.0-SNAPSHOT' buildscript { - ext.kotlin_version = '1.7.10' + ext.kotlin_version = '1.9.21' repositories { google() mavenCentral() @@ -47,7 +47,7 @@ android { defaultConfig { // Upgraded to 21 for ONNX ORT 1.16 - minSdkVersion 21 + minSdkVersion 24 } dependencies { diff --git a/android/src/main/kotlin/com/telosnex/fonnx/FonnxPlugin.kt b/android/src/main/kotlin/com/telosnex/fonnx/FonnxPlugin.kt index cb2ea41..eb26cf0 100644 --- a/android/src/main/kotlin/com/telosnex/fonnx/FonnxPlugin.kt +++ b/android/src/main/kotlin/com/telosnex/fonnx/FonnxPlugin.kt @@ -19,6 +19,8 @@ class FonnxPlugin : FlutterPlugin, MethodCallHandler { var cachedWhisper: OrtWhisper? = null var cachedSileroVadPath: String? = null var cachedSileroVad: OrtVad? = null + var cachedPyannotePath: String? = null + var cachedPyannote: OrtPyannote? = null private lateinit var channel: MethodChannel @@ -115,6 +117,48 @@ class FonnxPlugin : FlutterPlugin, MethodCallHandler { } else { result.error("SileroVad", "Could not instantiate model", null) } + } else if (call.method == "pyannote") { + val list = call.arguments as List + val modelPath = list[0] as String + val audioBytes = list[1] as FloatArray + + if (cachedPyannotePath != modelPath) { + cachedPyannote = OrtPyannote(modelPath) + cachedPyannotePath = modelPath + } + val pyannote = cachedPyannote + + if (pyannote != null) { + launch(Dispatchers.Default) { + try { + // Process the audio data + val diarization = pyannote.process(audioBytes) + if (diarization != null) { + launch(Dispatchers.Main) { + result.success(diarization) + } + } else { + launch(Dispatchers.Main) { + result.error( + "Pyannote", + "Processing failed", + null + ) + } + } + } catch (e: Exception) { + launch(Dispatchers.Main) { + result.error( + "Pyannote", + "Error processing audio: ${e.message}", + null + ) + } + } + } + } else { + result.error("Pyannote", "Could not instantiate model", null) + } } else { result.notImplemented() } diff --git a/android/src/main/kotlin/com/telosnex/fonnx/OrtPyannote.kt b/android/src/main/kotlin/com/telosnex/fonnx/OrtPyannote.kt index 522818d..31d3f1d 100644 --- a/android/src/main/kotlin/com/telosnex/fonnx/OrtPyannote.kt +++ b/android/src/main/kotlin/com/telosnex/fonnx/OrtPyannote.kt @@ -37,7 +37,7 @@ class OrtPyannote(private val modelPath: String) { session?.let { for ((windowIndex, window) in windows.withIndex()) { val (windowSize, windowData) = window - + // Log.d("[OrtPyannote.kt]", "Processing window $windowIndex of ${windows.size}; this window has ${windowData.size} samples") // Prepare input tensor val inputs = mutableMapOf() inputs["input_values"] = OnnxTensor.createTensor( @@ -48,11 +48,14 @@ class OrtPyannote(private val modelPath: String) { val startTime = System.currentTimeMillis() val rawResult = it.run(inputs, setOf("logits")) val endTime = System.currentTimeMillis() - Log.d("[OrtPyannote.kt]", "Inference time: ${endTime - startTime} ms") + // Log.d("[OrtPyannote.kt]", "Inference time: ${endTime - startTime} ms") // Process output val logits = rawResult.get("logits").get().value as Array> - var frameOutputs = processOutputData(logits[0]) + // Log.d("[OrtPyannote.kt]", "Logits shape: ${logits.size} x ${logits[0].size} x ${logits[0][0].size}") + // Log.d("[OrtPyannote.kt]", "Logits: ${logits[0][0].joinToString(", ")}") + var frameOutputs = processOutputData(logits) + // Log.d("[OrtPyannote.kt]", "Frame outputs: ${frameOutputs.size}") // Handle overlapping if (windowIndex > 0) { @@ -102,7 +105,6 @@ class OrtPyannote(private val modelPath: String) { } return results - } catch (e: Exception) { Log.e("[OrtPyannote.kt]", "Error in process: ${e.message}") e.printStackTrace() @@ -151,22 +153,27 @@ class OrtPyannote(private val modelPath: String) { return windows } - private fun processOutputData(logits: FloatArray): List { + private fun processOutputData(logits: Array>): List { val frameOutputs = mutableListOf() - val numCompleteFrames = logits.size / 7 - - for (frame in 0 until numCompleteFrames) { - val i = frame * 7 - val probs = logits.slice(i until i + 7).map { exp(it.toDouble()) } + val batchLogits = logits[0] // First batch + + // Process each frame (589 frames) + for (frame in batchLogits.indices) { + val frameLogits = batchLogits[frame] // Get the 7 logits for this frame + val probs = frameLogits.map { exp(it.toDouble()) } val speakerProbs = DoubleArray(NUM_SPEAKERS) - speakerProbs[0] = probs[1] + probs[4] + probs[5] // spk1 - speakerProbs[1] = probs[2] + probs[4] + probs[6] // spk2 - speakerProbs[2] = probs[3] + probs[5] + probs[6] // spk3 + // Combine probabilities for each speaker + // spk1: solo (1) + with_spk2 (4) + with_spk3 (5) + speakerProbs[0] = probs[1] + probs[4] + probs[5] + // spk2: solo (2) + with_spk1 (4) + with_spk3 (6) + speakerProbs[1] = probs[2] + probs[4] + probs[6] + // spk3: solo (3) + with_spk1 (5) + with_spk2 (6) + speakerProbs[2] = probs[3] + probs[5] + probs[6] frameOutputs.add(speakerProbs) } - + return frameOutputs } diff --git a/example/android/app/build.gradle b/example/android/app/build.gradle index 9193675..6704507 100644 --- a/example/android/app/build.gradle +++ b/example/android/app/build.gradle @@ -45,7 +45,7 @@ android { applicationId "com.telosnex.fonnx.example" // You can update the following values to match your application needs. // For more information, see: https://docs.flutter.dev/deployment/android#reviewing-the-gradle-build-configuration. - minSdkVersion 23 // needed 21 for ONNX ORT 1.16; in June 2024 need 23 for record_android + minSdkVersion 24 // needed 21 for ONNX ORT 1.16; in June 2024 need 23 for record_android targetSdkVersion flutter.targetSdkVersion versionCode flutterVersionCode.toInteger() versionName flutterVersionName diff --git a/example/lib/pyannote_widget.dart b/example/lib/pyannote_widget.dart index 81e39ba..1e18e91 100644 --- a/example/lib/pyannote_widget.dart +++ b/example/lib/pyannote_widget.dart @@ -109,7 +109,7 @@ class _PyannoteWidgetState extends State { lastStop = stop; } - final golden = kIsWeb + final golden = kIsWeb || Platform.isAndroid ? [ {"speaker": 1, "start": 0.8044375, "stop": 4.4494375} ]