Skip to content

Commit

Permalink
feat(pyannote): Android x all platforms working, min sdk 24
Browse files Browse the repository at this point in the history
  • Loading branch information
jpohhhh committed Dec 9, 2024
1 parent e86a43c commit f862f83
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 18 deletions.
4 changes: 2 additions & 2 deletions android/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ group 'com.telosnex.fonnx'
version '1.0-SNAPSHOT'

buildscript {
ext.kotlin_version = '1.7.10'
ext.kotlin_version = '1.9.21'
repositories {
google()
mavenCentral()
Expand Down Expand Up @@ -47,7 +47,7 @@ android {

defaultConfig {
// Upgraded to 21 for ONNX ORT 1.16
minSdkVersion 21
minSdkVersion 24
}

dependencies {
Expand Down
44 changes: 44 additions & 0 deletions android/src/main/kotlin/com/telosnex/fonnx/FonnxPlugin.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class FonnxPlugin : FlutterPlugin, MethodCallHandler {
var cachedWhisper: OrtWhisper? = null
var cachedSileroVadPath: String? = null
var cachedSileroVad: OrtVad? = null
var cachedPyannotePath: String? = null
var cachedPyannote: OrtPyannote? = null

private lateinit var channel: MethodChannel

Expand Down Expand Up @@ -115,6 +117,48 @@ class FonnxPlugin : FlutterPlugin, MethodCallHandler {
} else {
result.error("SileroVad", "Could not instantiate model", null)
}
} else if (call.method == "pyannote") {
val list = call.arguments as List<Any>
val modelPath = list[0] as String
val audioBytes = list[1] as FloatArray

if (cachedPyannotePath != modelPath) {
cachedPyannote = OrtPyannote(modelPath)
cachedPyannotePath = modelPath
}
val pyannote = cachedPyannote

if (pyannote != null) {
launch(Dispatchers.Default) {
try {
// Process the audio data
val diarization = pyannote.process(audioBytes)
if (diarization != null) {
launch(Dispatchers.Main) {
result.success(diarization)
}
} else {
launch(Dispatchers.Main) {
result.error(
"Pyannote",
"Processing failed",
null
)
}
}
} catch (e: Exception) {
launch(Dispatchers.Main) {
result.error(
"Pyannote",
"Error processing audio: ${e.message}",
null
)
}
}
}
} else {
result.error("Pyannote", "Could not instantiate model", null)
}
} else {
result.notImplemented()
}
Expand Down
35 changes: 21 additions & 14 deletions android/src/main/kotlin/com/telosnex/fonnx/OrtPyannote.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class OrtPyannote(private val modelPath: String) {
session?.let {
for ((windowIndex, window) in windows.withIndex()) {
val (windowSize, windowData) = window

// Log.d("[OrtPyannote.kt]", "Processing window $windowIndex of ${windows.size}; this window has ${windowData.size} samples")
// Prepare input tensor
val inputs = mutableMapOf<String, OnnxTensor>()
inputs["input_values"] = OnnxTensor.createTensor(
Expand All @@ -48,11 +48,14 @@ class OrtPyannote(private val modelPath: String) {
val startTime = System.currentTimeMillis()
val rawResult = it.run(inputs, setOf("logits"))
val endTime = System.currentTimeMillis()
Log.d("[OrtPyannote.kt]", "Inference time: ${endTime - startTime} ms")
// Log.d("[OrtPyannote.kt]", "Inference time: ${endTime - startTime} ms")

// Process output
val logits = rawResult.get("logits").get().value as Array<Array<FloatArray>>
var frameOutputs = processOutputData(logits[0])
// Log.d("[OrtPyannote.kt]", "Logits shape: ${logits.size} x ${logits[0].size} x ${logits[0][0].size}")
// Log.d("[OrtPyannote.kt]", "Logits: ${logits[0][0].joinToString(", ")}")
var frameOutputs = processOutputData(logits)
// Log.d("[OrtPyannote.kt]", "Frame outputs: ${frameOutputs.size}")

// Handle overlapping
if (windowIndex > 0) {
Expand Down Expand Up @@ -102,7 +105,6 @@ class OrtPyannote(private val modelPath: String) {
}

return results

} catch (e: Exception) {
Log.e("[OrtPyannote.kt]", "Error in process: ${e.message}")
e.printStackTrace()
Expand Down Expand Up @@ -151,22 +153,27 @@ class OrtPyannote(private val modelPath: String) {
return windows
}

private fun processOutputData(logits: FloatArray): List<DoubleArray> {
private fun processOutputData(logits: Array<Array<FloatArray>>): List<DoubleArray> {
val frameOutputs = mutableListOf<DoubleArray>()
val numCompleteFrames = logits.size / 7

for (frame in 0 until numCompleteFrames) {
val i = frame * 7
val probs = logits.slice(i until i + 7).map { exp(it.toDouble()) }
val batchLogits = logits[0] // First batch

// Process each frame (589 frames)
for (frame in batchLogits.indices) {
val frameLogits = batchLogits[frame] // Get the 7 logits for this frame
val probs = frameLogits.map { exp(it.toDouble()) }

val speakerProbs = DoubleArray(NUM_SPEAKERS)
speakerProbs[0] = probs[1] + probs[4] + probs[5] // spk1
speakerProbs[1] = probs[2] + probs[4] + probs[6] // spk2
speakerProbs[2] = probs[3] + probs[5] + probs[6] // spk3
// Combine probabilities for each speaker
// spk1: solo (1) + with_spk2 (4) + with_spk3 (5)
speakerProbs[0] = probs[1] + probs[4] + probs[5]
// spk2: solo (2) + with_spk1 (4) + with_spk3 (6)
speakerProbs[1] = probs[2] + probs[4] + probs[6]
// spk3: solo (3) + with_spk1 (5) + with_spk2 (6)
speakerProbs[2] = probs[3] + probs[5] + probs[6]

frameOutputs.add(speakerProbs)
}

return frameOutputs
}

Expand Down
2 changes: 1 addition & 1 deletion example/android/app/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ android {
applicationId "com.telosnex.fonnx.example"
// You can update the following values to match your application needs.
// For more information, see: https://docs.flutter.dev/deployment/android#reviewing-the-gradle-build-configuration.
minSdkVersion 23 // needed 21 for ONNX ORT 1.16; in June 2024 need 23 for record_android
minSdkVersion 24 // needed 21 for ONNX ORT 1.16; in June 2024 need 23 for record_android
targetSdkVersion flutter.targetSdkVersion
versionCode flutterVersionCode.toInteger()
versionName flutterVersionName
Expand Down
2 changes: 1 addition & 1 deletion example/lib/pyannote_widget.dart
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class _PyannoteWidgetState extends State<PyannoteWidget> {
lastStop = stop;
}

final golden = kIsWeb
final golden = kIsWeb || Platform.isAndroid
? [
{"speaker": 1, "start": 0.8044375, "stop": 4.4494375}
]
Expand Down

0 comments on commit f862f83

Please sign in to comment.