Skip to content

Commit

Permalink
initial support for heterogeneous input types #5
Browse files Browse the repository at this point in the history
  • Loading branch information
breandan committed Jan 17, 2020
1 parent 05f3c7b commit 76fd02b
Show file tree
Hide file tree
Showing 9 changed files with 314 additions and 284 deletions.
232 changes: 125 additions & 107 deletions core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyExample.kt

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ fun main() {
y * y,
x * y)


val mf2 = Mat1x2(vf2)

val qr = mf2 * Vec(x, y)
Expand All @@ -57,11 +56,8 @@ fun main() {
* Matrix function.
*/

open class MFun<X: Fun<X>, R: D1, C: D1>(
open val sVars: Set<Var<X>> = emptySet()
): (Bindings<X>) -> MFun<X, R, C> {
constructor(left: MFun<X, R, *>, right: MFun<X, *, C>): this(left.sVars + right.sVars)
constructor(mFun: MFun<X, R, C>): this(mFun.sVars)
open class MFun<X: SFun<X>, R: D1, C: D1>(override val inputs: Inputs<X>): Fun<X>, (Bindings<X>) -> MFun<X, R, C> {
constructor(vararg funs: Fun<X>): this(Inputs(*funs))

open val: MFun<X, C, R> by lazy { MTranspose(this) }

Expand All @@ -76,15 +72,15 @@ open class MFun<X: Fun<X>, R: D1, C: D1>(
is SMProd -> left(bnds) * right(bnds)
is MConst -> MZero()
is Mat -> Mat(rows.map { it(bnds) as Vec<X, C> })
else -> throw IllegalArgumentException("Type ${this::class.java.name} unknown")
else -> TODO(this::class.java.name)
}

// Materializes the concrete matrix from the dataflow graph
fun coalesce(): Mat<X, R, C> = this(Bindings()) as Mat<X, R, C>

open operator fun unaryMinus(): MFun<X, R, C> = MNegative(this)
open operator fun plus(addend: MFun<X, R, C>): MFun<X, R, C> = MSum(this, addend)
open operator fun times(multiplicand: Fun<X>): MFun<X, R, C> = MSProd(this, multiplicand)
open operator fun times(multiplicand: SFun<X>): MFun<X, R, C> = MSProd(this, multiplicand)
open operator fun times(multiplicand: VFun<X, C>): VFun<X, R> = MVProd(this, multiplicand)

// The Hadamard product
Expand All @@ -100,24 +96,24 @@ open class MFun<X: Fun<X>, R: D1, C: D1>(
is HProd -> "$left ʘ $right"
is MSProd -> "$left * $right"
is SMProd -> "$left * $right"
is MConst -> "TODO()"
is MConst -> "${javaClass.name}()"
is Mat -> "Mat${numRows}x$numCols(${rows.joinToString(", ") { it.contents.joinToString(", ") }})"
is MDerivative -> "d($mFun) / d($v1)"
else -> throw IllegalArgumentException("Type ${this::class.java.name} unknown")
else -> TODO(this::class.java.name)
}
}

class MNegative<X: Fun<X>, R: D1, C: D1>(val value: MFun<X, R, C>): MFun<X, R, C>(value)
class MTranspose<X: Fun<X>, R: D1, C: D1>(val value: MFun<X, R, C>): MFun<X, C, R>(value.sVars)
class MSum<X: Fun<X>, R: D1, C: D1>(val left: MFun<X, R, C>, val right: MFun<X, R, C>): MFun<X, R, C>(left, right)
class MMProd<X: Fun<X>, R: D1, C1: D1, C2: D1>(val left: MFun<X, R, C1>, val right: MFun<X, C1, C2>): MFun<X, R, C2>(left, right)
class HProd<X: Fun<X>, R: D1, C: D1>(val left: MFun<X, R, C>, val right: MFun<X, R, C>): MFun<X, R, C>(left, right)
class MSProd<X: Fun<X>, R: D1, C: D1>(val left: MFun<X, R, C>, val right: Fun<X>): MFun<X, R, C>(left)
class SMProd<X: Fun<X>, R: D1, C: D1>(val left: Fun<X>, val right: MFun<X, R, C>): MFun<X, R, C>(right)
class MNegative<X: SFun<X>, R: D1, C: D1>(val value: MFun<X, R, C>): MFun<X, R, C>(value)
class MTranspose<X: SFun<X>, R: D1, C: D1>(val value: MFun<X, R, C>): MFun<X, C, R>(value)
class MSum<X: SFun<X>, R: D1, C: D1>(val left: MFun<X, R, C>, val right: MFun<X, R, C>): MFun<X, R, C>(left, right)
class MMProd<X: SFun<X>, R: D1, C1: D1, C2: D1>(val left: MFun<X, R, C1>, val right: MFun<X, C1, C2>): MFun<X, R, C2>(left, right)
class HProd<X: SFun<X>, R: D1, C: D1>(val left: MFun<X, R, C>, val right: MFun<X, R, C>): MFun<X, R, C>(left, right)
class MSProd<X: SFun<X>, R: D1, C: D1>(val left: MFun<X, R, C>, val right: SFun<X>): MFun<X, R, C>(left)
class SMProd<X: SFun<X>, R: D1, C: D1>(val left: SFun<X>, val right: MFun<X, R, C>): MFun<X, R, C>(right)

// TODO: Generalize tensor derivatives? https://en.wikipedia.org/wiki/Tensor_derivative_(continuum_mechanics)
class MDerivative<X: Fun<X>, R: D1, C: D1> internal constructor(val mFun: VFun<X, R>, numCols: Nat<C>, val v1: Var<X>): MFun<X, R, C>(mFun.sVars) {
fun MFun<X, R, C>.df(): MFun<X, R, C> = when (this) {
class MDerivative<X: SFun<X>, R: D1, C: D1> internal constructor(val mFun: VFun<X, R>, numCols: Nat<C>, val v1: Var<X>): MFun<X, R, C>(mFun) {
fun MFun<X, R, C>.df(): MFun<X, R, C> = when (this@df) {
is MConst -> MZero()
is MVar -> MZero()
is MNegative -> -value.df()
Expand All @@ -129,22 +125,38 @@ class MDerivative<X: Fun<X>, R: D1, C: D1> internal constructor(val mFun: VFun<X
is SMProd -> left.d(v1) * right + left * right.df()
is HProd -> left.df() ʘ right + left ʘ right.df()
is Mat -> Mat(rows.map { it.d(v1)() })
else -> throw IllegalArgumentException("Unable to differentiate expression of type ${this::class.java.name}")
else -> TODO(this@df::class.java.name)
}
}

class MGradient<X : SFun<X>, R: D1, C: D1>(val fn: SFun<X>, val mVar: MVar<X, R, C>): MFun<X, R, C>(fn) {
fun df() = fn.df()
fun SFun<X>.df(): MFun<X, R, C> = when (this@df) {
is MVar<*, *, *> -> if (this == mVar) MOne() else MZero()
is Var -> MZero()
is SConst -> MZero()
is Sum -> left.df() + right.df()
is Prod -> left.df() * right + left * right.df()
is Power -> this * (exponent * Log(base)).df()
is Negative -> -value.df()
is Log -> (logarithmand pow -One<X>()) * logarithmand.df()
// is Derivative -> fn.df()
is DProd -> this().df()
is VMagnitude -> this().df()
else -> TODO(this@df::class.java.name)
}
}

class MVar<X: Fun<X>, R: D1, C: D1>(override val name: String = ""): Variable, MFun<X, R, C>()
open class MConst<X: Fun<X>, R: D1, C: D1>: MFun<X, R, C>()
class MVar<X: SFun<X>, R: D1, C: D1>(override val name: String = ""): Variable, MFun<X, R, C>()
open class MConst<X: SFun<X>, R: D1, C: D1>: MFun<X, R, C>()

class MZero<X: Fun<X>, R: D1, C: D1>: MConst<X, R, C>()
class MOne<X: Fun<X>, R: D1, C: D1>: MConst<X, R, C>()
class MZero<X: SFun<X>, R: D1, C: D1>: MConst<X, R, C>()
class MOne<X: SFun<X>, R: D1, C: D1>: MConst<X, R, C>()

open class Mat<X: Fun<X>, R: D1, C: D1>(override val sVars: Set<Var<X>> = emptySet(),
val rows: List<Vec<X, C>>): MFun<X, R, C>() {
constructor(rows: List<Vec<X, C>>): this(rows.flatMap { it.sVars }.toSet(), rows)
constructor(vararg rows: Vec<X, C>): this(rows.flatMap { it.sVars }.toSet(), rows.asList())
open class Mat<X: SFun<X>, R: D1, C: D1>(val rows: List<Vec<X, C>>): MFun<X, R, C>(*rows.toTypedArray()) {
constructor(vararg rows: Vec<X, C>): this(rows.asList())

val flatContents: List<Fun<X>> by lazy { rows.flatMap { it.contents } }
val flatContents: List<SFun<X>> by lazy { rows.flatMap { it.contents } }

val indices = rows.indices
val cols by lazy { indices.map { i -> Vec<X, R>(rows.map { it[i] }) } }
Expand All @@ -169,7 +181,7 @@ open class Mat<X: Fun<X>, R: D1, C: D1>(override val sVars: Set<Var<X>> = emptyS

operator fun get(i: Int): VFun<X, C> = rows[i]

override operator fun times(multiplicand: Fun<X>): Mat<X, R, C> = Mat(rows.map { it * multiplicand })
override operator fun times(multiplicand: SFun<X>): Mat<X, R, C> = Mat(rows.map { it * multiplicand })

override operator fun times(multiplicand: VFun<X, C>): VFun<X, R> =
when (multiplicand) {
Expand All @@ -188,22 +200,22 @@ open class Mat<X: Fun<X>, R: D1, C: D1>(override val sVars: Set<Var<X>> = emptyS
}
}

fun <X: Fun<X>> Mat1x1(v0: Vec<X, D1>): Mat<X, D1, D1> = Mat(v0)
fun <X: Fun<X>> Mat2x1(v0: Vec<X, D1>, v1: Vec<X, D1>): Mat<X, D2, D1> = Mat(v0, v1)
fun <X: Fun<X>> Mat3x1(v0: Vec<X, D1>, v1: Vec<X, D1>, v2: Vec<X, D1>): Mat<X, D3, D1> = Mat(v0, v1, v2)
fun <X: Fun<X>> Mat1x2(v0: Vec<X, D2>): Mat<X, D1, D2> = Mat(v0)
fun <X: Fun<X>> Mat2x2(v0: Vec<X, D2>, v1: Vec<X, D2>): Mat<X, D2, D2> = Mat(v0, v1)
fun <X: Fun<X>> Mat3x2(v0: Vec<X, D2>, v1: Vec<X, D2>, v2: Vec<X, D2>): Mat<X, D3, D2> = Mat(v0, v1, v2)
fun <X: Fun<X>> Mat1x3(v0: Vec<X, D3>): Mat<X, D1, D3> = Mat(v0)
fun <X: Fun<X>> Mat2x3(v0: Vec<X, D3>, v1: Vec<X, D3>): Mat<X, D2, D3> = Mat(v0, v1)
fun <X: Fun<X>> Mat3x3(v0: Vec<X, D3>, v1: Vec<X, D3>, v2: Vec<X, D3>): Mat<X, D3, D3> = Mat(v0, v1, v2)

fun <X: Fun<X>> Mat1x1(d0: Fun<X>): Mat<X, D1, D1> = Mat(Vec(d0))
fun <X: Fun<X>> Mat1x2(d0: Fun<X>, d1: Fun<X>): Mat<X, D1, D2> = Mat(Vec(d0, d1))
fun <X: Fun<X>> Mat1x3(d0: Fun<X>, d1: Fun<X>, d2: Fun<X>): Mat<X, D1, D3> = Mat(Vec(d0, d1, d2))
fun <X: Fun<X>> Mat2x1(d0: Fun<X>, d1: Fun<X>): Mat<X, D2, D1> = Mat(Vec(d0), Vec(d1))
fun <X: Fun<X>> Mat2x2(d0: Fun<X>, d1: Fun<X>, d2: Fun<X>, d3: Fun<X>): Mat<X, D2, D2> = Mat(Vec(d0, d1), Vec(d2, d3))
fun <X: Fun<X>> Mat2x3(d0: Fun<X>, d1: Fun<X>, d2: Fun<X>, d3: Fun<X>, d4: Fun<X>, d5: Fun<X>): Mat<X, D2, D3> = Mat(Vec(d0, d1, d2), Vec(d3, d4, d5))
fun <X: Fun<X>> Mat3x1(d0: Fun<X>, d1: Fun<X>, d2: Fun<X>): Mat<X, D3, D1> = Mat(Vec(d0), Vec(d1), Vec(d2))
fun <X: Fun<X>> Mat3x2(d0: Fun<X>, d1: Fun<X>, d2: Fun<X>, d3: Fun<X>, d4: Fun<X>, d5: Fun<X>): Mat<X, D3, D2> = Mat(Vec(d0, d1), Vec(d2, d3), Vec(d4, d5))
fun <X: Fun<X>> Mat3x3(d0: Fun<X>, d1: Fun<X>, d2: Fun<X>, d3: Fun<X>, d4: Fun<X>, d5: Fun<X>, d6: Fun<X>, d7: Fun<X>, d8: Fun<X>): Mat<X, D3, D3> = Mat(Vec(d0, d1, d2), Vec(d3, d4, d5), Vec(d6, d7, d8))
fun <X: SFun<X>> Mat1x1(v0: Vec<X, D1>): Mat<X, D1, D1> = Mat(v0)
fun <X: SFun<X>> Mat2x1(v0: Vec<X, D1>, v1: Vec<X, D1>): Mat<X, D2, D1> = Mat(v0, v1)
fun <X: SFun<X>> Mat3x1(v0: Vec<X, D1>, v1: Vec<X, D1>, v2: Vec<X, D1>): Mat<X, D3, D1> = Mat(v0, v1, v2)
fun <X: SFun<X>> Mat1x2(v0: Vec<X, D2>): Mat<X, D1, D2> = Mat(v0)
fun <X: SFun<X>> Mat2x2(v0: Vec<X, D2>, v1: Vec<X, D2>): Mat<X, D2, D2> = Mat(v0, v1)
fun <X: SFun<X>> Mat3x2(v0: Vec<X, D2>, v1: Vec<X, D2>, v2: Vec<X, D2>): Mat<X, D3, D2> = Mat(v0, v1, v2)
fun <X: SFun<X>> Mat1x3(v0: Vec<X, D3>): Mat<X, D1, D3> = Mat(v0)
fun <X: SFun<X>> Mat2x3(v0: Vec<X, D3>, v1: Vec<X, D3>): Mat<X, D2, D3> = Mat(v0, v1)
fun <X: SFun<X>> Mat3x3(v0: Vec<X, D3>, v1: Vec<X, D3>, v2: Vec<X, D3>): Mat<X, D3, D3> = Mat(v0, v1, v2)

fun <X: SFun<X>> Mat1x1(d0: SFun<X>): Mat<X, D1, D1> = Mat(Vec(d0))
fun <X: SFun<X>> Mat1x2(d0: SFun<X>, d1: SFun<X>): Mat<X, D1, D2> = Mat(Vec(d0, d1))
fun <X: SFun<X>> Mat1x3(d0: SFun<X>, d1: SFun<X>, d2: SFun<X>): Mat<X, D1, D3> = Mat(Vec(d0, d1, d2))
fun <X: SFun<X>> Mat2x1(d0: SFun<X>, d1: SFun<X>): Mat<X, D2, D1> = Mat(Vec(d0), Vec(d1))
fun <X: SFun<X>> Mat2x2(d0: SFun<X>, d1: SFun<X>, d2: SFun<X>, d3: SFun<X>): Mat<X, D2, D2> = Mat(Vec(d0, d1), Vec(d2, d3))
fun <X: SFun<X>> Mat2x3(d0: SFun<X>, d1: SFun<X>, d2: SFun<X>, d3: SFun<X>, d4: SFun<X>, d5: SFun<X>): Mat<X, D2, D3> = Mat(Vec(d0, d1, d2), Vec(d3, d4, d5))
fun <X: SFun<X>> Mat3x1(d0: SFun<X>, d1: SFun<X>, d2: SFun<X>): Mat<X, D3, D1> = Mat(Vec(d0), Vec(d1), Vec(d2))
fun <X: SFun<X>> Mat3x2(d0: SFun<X>, d1: SFun<X>, d2: SFun<X>, d3: SFun<X>, d4: SFun<X>, d5: SFun<X>): Mat<X, D3, D2> = Mat(Vec(d0, d1), Vec(d2, d3), Vec(d4, d5))
fun <X: SFun<X>> Mat3x3(d0: SFun<X>, d1: SFun<X>, d2: SFun<X>, d3: SFun<X>, d4: SFun<X>, d5: SFun<X>, d6: SFun<X>, d7: SFun<X>, d8: SFun<X>): Mat<X, D3, D3> = Mat(Vec(d0, d1, d2), Vec(d3, d4, d5), Vec(d6, d7, d8))
Loading

0 comments on commit 76fd02b

Please sign in to comment.