Skip to content

Commit

Permalink
Data/Model: Support flexible tensors in filter description
Browse files Browse the repository at this point in the history
This patch supports flexible tensors in filter description.

Signed-off-by: Yelin Jeong <[email protected]>
  • Loading branch information
niley7464 authored and wooksong committed Sep 19, 2024
1 parent a1c9c09 commit 3aa41d3
Showing 1 changed file with 56 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ data class Model(
uid: Int,
name: String,
jsonObject: JSONObject = JSONObject(),
optionalJsonObject: JSONObject = JSONObject()
optionalJsonObject: JSONObject = JSONObject(),
) : this(uid, name) {
// TODO: The following is the parser for "single", which is a legacy conf
runCatching {
Expand Down Expand Up @@ -66,7 +66,8 @@ data class Model(
runCatching {
val infoMap = mutableMapOf<String, MutableList<String>>(
"type" to mutableListOf(),
"dimension" to mutableListOf()
"dimension" to mutableListOf(),
"format" to mutableListOf()
)

val info = jsonObject.getJSONArray(prop)
Expand Down Expand Up @@ -113,6 +114,44 @@ data class Model(

var optionalInfo: Map<String, String> = mapOf()

private fun getFormat(list: List<String>): String {
return list.let {
if (it.isNotEmpty()) {
it.joinToString(",")
} else {
"static"
}
}
}

private fun getTensors(list: List<String>?, format: String): String {
return if (format == "static") {
val numTensors = list?.size ?: 1
",num_tensors=$numTensors"
} else {
""
}
}

private fun getType(list: List<String>): String {
val filtered = list.filterNot{ it.isEmpty() or it.isBlank() }
return if (filtered.isEmpty()) {
""
} else {
",types=${filtered.joinToString(",")}"
}
}

private fun getDimension(list: List<String>): String {
return list.let {
if (it.isNotEmpty()) {
",dimensions=(string)${it.joinToString(",")}"
} else {
""
}
}
}

/**
* Get NNS filter description string for the given model. This is used to create a new NNS pipeline instance.
* @return NNS filter description string. It includes paths of the model files and other information.
Expand All @@ -125,22 +164,23 @@ data class Model(
transform = { basePath?.resolve(it).toString() }
)
}
val inTypes = inputInfo["type"]?.let {
"types=${inputInfo["type"]?.joinToString(",")}"
} ?: ""
val inDims = inputInfo["dimension"]?.let {
"dimensions=(string)${inputInfo["dimension"]?.joinToString(",")}"
} ?: ""
val outTypes = outputInfo["type"]?.let {
"types=${outputInfo["type"]?.joinToString(",")}"
} ?: ""
val outDims = outputInfo["dimension"]?.let {
"dimensions=(string)${outputInfo["dimension"]?.joinToString(",")}"
} ?: ""

val inFormat = inputInfo["format"]?.let { getFormat(it) } ?: "static"
val outFormat = outputInfo["format"]?.let { getFormat(it) } ?: "static"

val inTensors = getTensors(inputInfo["type"], inFormat)
val outTensors = getTensors(outputInfo["type"], inFormat)

val inTypes = inputInfo["type"]?.let { getType(it) } ?: ""
val outTypes = outputInfo["type"]?.let { getType(it) } ?: ""

val inDims = inputInfo["dimension"]?.let { getDimension(it) } ?: ""
val outDims = outputInfo["dimension"]?.let { getDimension(it) } ?: ""

val filter =
"other/tensors,num_tensors=${inputInfo["type"]?.size ?: 1},format=static,${inDims},${inTypes},framerate=0/1 ! " +
"other/tensors,format=${inFormat}${inTensors}${inDims}${inTypes},framerate=0/1 ! " +
"tensor_filter framework=tensorflow-lite model=${modelPaths} ! " +
"other/tensors,num_tensors=${outputInfo["type"]?.size ?: 1},format=static,${outDims},${outTypes},framerate=0/1"
"other/tensors,format=${outFormat}${outTensors}${outDims}${outTypes},framerate=0/1"

return filter
}
Expand Down

0 comments on commit 3aa41d3

Please sign in to comment.