From 76fd02b0a3e9fe7fd49a9b1955a56ca25dbed2ac Mon Sep 17 00:00:00 2001 From: breandan Date: Fri, 17 Jan 2020 17:19:20 -0500 Subject: [PATCH] initial support for heterogeneous input types #5 --- .../kotlingrad/experimental/ToyExample.kt | 232 ++++++++++-------- .../experimental/ToyMatrixExample.kt | 110 +++++---- .../experimental/ToyVectorExample.kt | 164 ++++++------- .../experimental/VariableCapture.kt | 38 +-- .../calculus/ExpressionGenerator.kt | 18 +- .../kotlingrad/evaluation/TestSymbolic.kt | 4 +- .../samples/MultilayerPerceptron.kt | 8 +- .../kotlingrad/samples/VisualizeDFG.kt | 4 +- .../samples/physics/DoublePendulum.kt | 20 +- 9 files changed, 314 insertions(+), 284 deletions(-) diff --git a/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyExample.kt b/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyExample.kt index f007c7af..d9d02797 100644 --- a/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyExample.kt +++ b/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyExample.kt @@ -57,63 +57,82 @@ interface Field> : Group { fun ln(): X } +interface Fun> { + val inputs: Inputs +} + +/** + * TODO: Figure out how to merge with + * @see Bindings + */ + +data class Inputs>( + val sVars: Set> = emptySet(), + val vVars: Set> = emptySet(), + val mVars: Set> = emptySet() +) { + constructor(inputs: List>): this( + inputs.flatMap { it.sVars }.toSet(), + inputs.flatMap { it.vVars }.toSet(), + inputs.flatMap { it.mVars }.toSet() + ) + + constructor(vararg funs: Fun): this(funs.map { it.inputs }) +} + /** * Scalar function. */ -sealed class Fun>(open val sVars: Set> = emptySet()) : Field>, (Bindings) -> Fun { - constructor(fn: Fun) : this(fn.sVars) - constructor(vararg fns: Fun) : this(fns.flatMap { it.sVars }.toSet()) +sealed class SFun>(override val inputs: Inputs): Fun, Field>, (Bindings) -> SFun { + constructor(vararg funs: Fun) : this(Inputs(*funs)) - override operator fun plus(addend: Fun): Fun = Sum(this, addend) - override operator fun times(multiplicand: Fun): Fun = Prod(this, multiplicand) - override operator fun div(divisor: Fun): Fun = this * divisor.pow(-One()) + override operator fun plus(addend: SFun): SFun = Sum(this, addend) + override operator fun times(multiplicand: SFun): SFun = Prod(this, multiplicand) + override operator fun div(divisor: SFun): SFun = this * divisor.pow(-One()) open operator fun times(multiplicand: VFun): VFun = SVProd(this, multiplicand) open operator fun times(multiplicand: MFun): MFun = SMProd(this, multiplicand) - override operator fun invoke(bnds: Bindings): Fun = + override operator fun invoke(bnds: Bindings): SFun = Composition(this, bnds).run { if (bnds.isReassignmentFree) evaluate else this } - open operator fun invoke(): Fun = invoke(Bindings()) + open operator fun invoke(): SFun = invoke(Bindings()) - open fun d(v1: Var): Fun = Derivative(this, v1) + open fun d(v1: Var): SFun = Derivative(this, v1) open fun d(v1: Var, v2: Var): Vec = Vec(Derivative(this, v1), Derivative(this, v2)) open fun d(v1: Var, v2: Var, v3: Var): Vec = Vec(Derivative(this, v1), Derivative(this, v2), Derivative(this, v3)) - open fun d(vararg vars: Var): Map, Fun> = vars.map { it to Derivative(this, it) }.toMap() + open fun d(vararg vars: Var): Map, SFun> = vars.map { it to Derivative(this, it) }.toMap() open fun d(vVar: VVar): VFun = Gradient(this, vVar) open fun d(mVar: MVar): MFun = TODO() - open fun d(mv: Mat) = d(*mv.sVars.toTypedArray()).values.foldIndexed(mutableListOf()) { index, acc: MutableList>>, p -> + open fun d(mv: Mat) = d(*mv.inputs.sVars.toTypedArray()).values.foldIndexed(mutableListOf()) { index, acc: MutableList>>, p -> if (index % mv.numCols == 0) acc.add(mutableListOf(p)) else acc.last().add(p) acc }.map { Vec(it) }.let { Mat(it) } - open fun grad(): Map, Fun> = sVars.map { it to Derivative(this, it) }.toMap() - - override fun ln(): Fun = Log(this) - override fun pow(exp: Fun): Fun = Power(this, exp) - override fun unaryMinus(): Fun = Negative(this) - open fun sqrt(): Fun = this pow (One() / (Two())) - - override fun toString(): String = when { - this is Log -> "ln($logarithmand)" - this is Negative -> "-($value)" - this is Power -> "($base) pow ($exponent)" -// this is Prod && right is Sum -> "$left * ($right)" -// this is Prod && left is Sum -> "($left) * $right" - this is Prod -> "($left) * ($right)" - this is Sum && right is Negative -> "$left - ${right.value}" - this is Sum -> "$left + $right" - this is Var -> name - this is Derivative -> "d($fn) / d($vrb)" - this is Zero -> "0" //"\uD835\uDFD8" // 𝟘 - this is One -> "1" //"\uD835\uDFD9" // 𝟙 - this is Two -> "2" //"\uD835\uDFDA" // 𝟚 - this is E -> "\u2147" // ⅇ - this is VMagnitude -> "|$value|" - this is DProd -> "($left) dot ($right)" - this is Composition -> "($fn) comp $bindings" + open fun grad(): Map, SFun> = inputs.sVars.map { it to Derivative(this, it) }.toMap() + + override fun ln(): SFun = Log(this) + override fun pow(exp: SFun): SFun = Power(this, exp) + override fun unaryMinus(): SFun = Negative(this) + open fun sqrt(): SFun = this pow (One() / Two()) + + override fun toString(): String = when (this) { + is Log -> "ln($logarithmand)" + is Negative -> "- ($value)" + is Power -> "($base) pow ($exponent)" + is Prod -> "($left) * ($right)" + is Sum -> if(right is Negative) "$left $right" else "$left + $right" + is Var -> name + is Derivative -> "d($fn) / d($vrb)" + is Zero -> "0" //"\uD835\uDFD8" // 𝟘 + is One -> "1" //"\uD835\uDFD9" // 𝟙 + is Two -> "2" //"\uD835\uDFDA" // 𝟚 + is E -> "E()" //"\u2147" // ⅇ + is VMagnitude -> "|$value|" + is DProd -> "($left) dot ($right)" + is Composition -> "($fn)($bindings)" else -> super.toString() } @@ -128,7 +147,7 @@ sealed class Fun>(open val sVars: Set> = emptySet()) : Field name is Negative -> { value.toGraph() - this; add(Label.of("neg")) } is Derivative -> { fn.toGraph() - this; mutNode("$this").apply { add(Label.of(vrb.toString())) } - this; add(Label.of("d")) } @@ -137,7 +156,7 @@ sealed class Fun>(open val sVars: Set> = emptySet()) : Field add(Label.of("one")) is Zero -> add(Label.of("zero")) is Composition -> { bindings.sMap.entries.map { entry -> mutNode(entry.hashCode().toString()).also { compNode -> entry.key.toGraph() - compNode; entry.value.toGraph() - compNode; compNode.add(Label.of("comp")) } }.map { it - this; add(Label.of("bindings")) } } - else -> TODO(this@Fun.javaClass.toString()) + else -> TODO(this@SFun.javaClass.toString()) } } } @@ -146,16 +165,16 @@ sealed class Fun>(open val sVars: Set> = emptySet()) : Field>(val left: Fun, val right: Fun): Fun(left, right) +sealed class BiFun>(val left: SFun, val right: SFun): SFun(left, right) -class Sum>(val addend: Fun, val augend: Fun): BiFun(addend, augend) +class Sum>(val addend: SFun, val augend: SFun): BiFun(addend, augend) -class Negative>(val value: Fun) : Fun(value) -class Prod>(val multiplicand: Fun, val multiplicator: Fun): BiFun(multiplicand, multiplicator) -class Power> internal constructor(val base: Fun, val exponent: Fun) : BiFun(base, exponent) -class Log> internal constructor(val logarithmand: Fun, val base: Fun = E()) : BiFun(logarithmand, base) -class Derivative>(val fn: Fun, val vrb: Var) : Fun(fn, vrb) { - fun Fun.df(): Fun = when (this) { +class Negative>(val value: SFun) : SFun(value) +class Prod>(val multiplicand: SFun, val multiplicator: SFun): BiFun(multiplicand, multiplicator) +class Power> internal constructor(val base: SFun, val exponent: SFun) : BiFun(base, exponent) +class Log> internal constructor(val logarithmand: SFun, val base: SFun = E()) : BiFun(logarithmand, base) +class Derivative>(val fn: SFun, val vrb: Var) : SFun(fn, vrb) { + fun SFun.df(): SFun = when (this@df) { is Var -> if (this == vrb) One() else Zero() is SConst -> Zero() is Sum -> left.df() + right.df() @@ -166,15 +185,15 @@ class Derivative>(val fn: Fun, val vrb: Var) : Fun(fn, vrb) is Derivative -> fn.df() is DProd -> this().df() is VMagnitude -> this().df() - is Composition -> bindings.curried().fold(One()) { acc: Fun, binding -> + is Composition -> bindings.curried().fold(One()) { acc: SFun, binding -> acc * fn.df()(binding) * binding.sMap.entries.first().value.df() } } } -class Composition>(val fn: Fun, val bindings: Bindings) : Fun() { - val evaluate by lazy { apply() } - override val sVars: Set> by lazy { evaluate.sVars } +class Composition>(val fn: SFun, val bindings: Bindings) : SFun(fn) { + val evaluate by lazy { call() } + override val inputs: Inputs by lazy { evaluate.inputs } // private fun calculateFixpoint(): Fun { // var result = apply() @@ -187,64 +206,63 @@ class Composition>(val fn: Fun, val bindings: Bindings) : Fun.apply(): Fun = - bindings.sMap.getOrElse(this) { - when (this) { - is Zero -> bindings.zero - is One -> bindings.one - is Two -> bindings.two - is E -> bindings.e - is Var -> this - is SConst -> this - is Prod -> left.apply() * right.apply() - is Sum -> left.apply() + right.apply() - is Power -> base.apply() pow exponent.apply() - is Negative -> -value.apply() - is Log -> logarithmand.apply().ln() - is Derivative -> df().apply() - is DProd -> left(bindings) as Vec dot right(bindings) as Vec - is VMagnitude -> value(bindings).magnitude() - is Composition -> fn.apply().apply() - } - } + fun SFun.call(): SFun = bindings.sMap.getOrElse(this@call) { bind() } + + fun SFun.bind() = when (this@bind) { + is Zero -> bindings.zero + is One -> bindings.one + is Two -> bindings.two + is E -> bindings.e + is Var -> this + is SConst -> this + is Prod -> left.call() * right.call() + is Sum -> left.call() + right.call() + is Power -> base.call() pow exponent.call() + is Negative -> -value.call() + is Log -> logarithmand.call().ln() + is Derivative -> df().call() + is DProd -> left(bindings) as Vec dot right(bindings) as Vec + is VMagnitude -> value(bindings).magnitude() + is Composition -> fn.call().call() + } } -data class Bindings>( - val sMap: Map, Fun> = mapOf(), - val zero: Fun = Zero(), - val one: Fun = One(), - val two: Fun = Two(), - val e: Fun = E()) { +data class Bindings>( + val sMap: Map, SFun> = mapOf(), + val zero: SFun = Zero(), + val one: SFun = One(), + val two: SFun = Two(), + val e: SFun = E()) { // constructor(sMap: Map, Fun>, // vMap: Map, VFun>, // zero: Fun, // one: Fun, // two: Fun, // E: Fun): this(sMap, zero, one, two, E) - val isReassignmentFree = sMap.values.all { it.sVars.isEmpty() } - fun determines(fn: Fun) = fn.sVars.all { it in sMap } + val isReassignmentFree = sMap.values.all { it.inputs.sVars.isEmpty() } + fun fullyDetermines(fn: SFun) = fn.inputs.sVars.all { it in sMap } override fun toString() = sMap.toString() operator fun contains(v: Var) = v in sMap fun curried() = sMap.entries.map { Bindings(mapOf(it.key to it.value), zero, one, two, e) } } -class DProd>(val left: VFun, val right: VFun): Fun(left.sVars + right.sVars)//, left.vVars + right.vVars) +class DProd>(val left: VFun, val right: VFun): SFun(left, right) -class VMagnitude>(val value: VFun): Fun(value.sVars)//, value.vVars) +class VMagnitude>(val value: VFun): SFun(value)//, value.vVars) interface Variable { val name: String } -class Var>(override val name: String = "") : Variable, Fun() { - override val sVars: Set> = setOf(this) +class Var>(override val name: String = "") : Variable, SFun() { + override val inputs: Inputs = Inputs(setOf(this)) } -open class SConst> : Fun() -class Zero> : SConst() -class One> : SConst() -class Two> : SConst() -class E> : SConst() +open class SConst> : SFun() +class Zero> : SConst() +class One> : SConst() +class Two> : SConst() +class E> : SConst() -abstract class RealNumber>(open val value: Number) : SConst() +abstract class RealNumber>(open val value: Number) : SConst() class DReal(override val value: Double) : RealNumber(value) { override fun unaryMinus() = DReal(-value) @@ -255,17 +273,17 @@ class DReal(override val value: Double) : RealNumber(value) { * Constant propagation. */ - override fun plus(addend: Fun) = when (addend) { + override fun plus(addend: SFun) = when (addend) { is DReal -> DReal(value + addend.value) else -> super.plus(addend) } - override fun times(multiplicand: Fun) = when (multiplicand) { + override fun times(multiplicand: SFun) = when (multiplicand) { is DReal -> DReal(value * multiplicand.value) else -> super.times(multiplicand) } - override fun pow(exp: Fun) = when (exp) { + override fun pow(exp: SFun) = when (exp) { is DReal -> DReal(value.pow(exp.value)) else -> super.pow(exp) } @@ -290,27 +308,27 @@ class DReal(override val value: Double) : RealNumber(value) { */ sealed class Protocol> { - class IndVar> constructor(val fn: Fun) + class IndVar> constructor(val fn: SFun) - class Differential>(private val fx: Fun) { - // TODO: make sure this notation works for arbitrary nested functions using the Chain rule - infix operator fun div(arg: Differential) = fx.d(arg.fx.sVars.first()) + class Differential>(private val fx: SFun) { + // TODO: ensure correctness for arbitrary nested functions using the Chain rule + infix operator fun div(arg: Differential) = fx.d(arg.fx.inputs.sVars.first()) } - fun > d(fn: Fun) = Differential(fn) + fun > d(fn: SFun) = Differential(fn) abstract fun wrap(default: Number): X - operator fun Number.times(multiplicand: Fun) = multiplicand * wrap(this) - operator fun Fun.times(multiplicand: Number) = wrap(multiplicand) * this + operator fun Number.times(multiplicand: SFun) = multiplicand * wrap(this) + operator fun SFun.times(multiplicand: Number) = wrap(multiplicand) * this - operator fun Number.plus(addend: Fun) = addend + wrap(this) - operator fun Fun.plus(addend: Number) = wrap(addend) + this + operator fun Number.plus(addend: SFun) = addend + wrap(this) + operator fun SFun.plus(addend: Number) = wrap(addend) + this - operator fun Number.minus(subtrahend: Fun) = -subtrahend + wrap(this) - operator fun Fun.minus(subtrahend: Number) = -wrap(subtrahend) + this + operator fun Number.minus(subtrahend: SFun) = -subtrahend + wrap(this) + operator fun SFun.minus(subtrahend: Number) = -wrap(subtrahend) + this - fun Number.pow(exp: Fun) = wrap(this) pow exp - infix fun Fun.pow(exp: Number) = this pow wrap(exp) + fun Number.pow(exp: SFun) = wrap(this) pow exp + infix fun SFun.pow(exp: Number) = this pow wrap(exp) } object DoublePrecision : Protocol() { @@ -323,9 +341,9 @@ object DoublePrecision : Protocol() { fun vrb(name: String) = Var(name) - @JvmName("ValBnd") operator fun Fun.invoke(vararg pairs: Pair, Number>) = + @JvmName("ValBnd") operator fun SFun.invoke(vararg pairs: Pair, Number>) = this(Bindings(pairs.map { (it.first to wrap(it.second)) }.toMap(), zero, one, two, e)) - @JvmName("FunBnd") operator fun Fun.invoke(vararg pairs: Pair, Fun>) = + @JvmName("FunBnd") operator fun SFun.invoke(vararg pairs: Pair, SFun>) = this(Bindings(pairs.map { (it.first to it.second) }.toMap(), zero, one, two, e)) operator fun VFun.invoke(vararg sPairs: Pair, Number>) = @@ -334,7 +352,7 @@ object DoublePrecision : Protocol() { operator fun MFun.invoke(vararg sPairs: Pair, Number>) = this(Bindings(sPairs.map { (it.first to wrap(it.second)) }.toMap(), zero, one, two, e)) - fun Fun.asDouble() = (this as DReal).value + fun SFun.asDouble() = (this as DReal).value val x = vrb("X") val y = vrb("Y") diff --git a/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyMatrixExample.kt b/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyMatrixExample.kt index 16acd951..7897af6b 100644 --- a/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyMatrixExample.kt +++ b/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyMatrixExample.kt @@ -31,7 +31,6 @@ fun main() { y * y, x * y) - val mf2 = Mat1x2(vf2) val qr = mf2 * Vec(x, y) @@ -57,11 +56,8 @@ fun main() { * Matrix function. */ -open class MFun, R: D1, C: D1>( - open val sVars: Set> = emptySet() -): (Bindings) -> MFun { - constructor(left: MFun, right: MFun): this(left.sVars + right.sVars) - constructor(mFun: MFun): this(mFun.sVars) +open class MFun, R: D1, C: D1>(override val inputs: Inputs): Fun, (Bindings) -> MFun { + constructor(vararg funs: Fun): this(Inputs(*funs)) open val ᵀ: MFun by lazy { MTranspose(this) } @@ -76,7 +72,7 @@ open class MFun, R: D1, C: D1>( is SMProd -> left(bnds) * right(bnds) is MConst -> MZero() is Mat -> Mat(rows.map { it(bnds) as Vec }) - else -> throw IllegalArgumentException("Type ${this::class.java.name} unknown") + else -> TODO(this::class.java.name) } // Materializes the concrete matrix from the dataflow graph @@ -84,7 +80,7 @@ open class MFun, R: D1, C: D1>( open operator fun unaryMinus(): MFun = MNegative(this) open operator fun plus(addend: MFun): MFun = MSum(this, addend) - open operator fun times(multiplicand: Fun): MFun = MSProd(this, multiplicand) + open operator fun times(multiplicand: SFun): MFun = MSProd(this, multiplicand) open operator fun times(multiplicand: VFun): VFun = MVProd(this, multiplicand) // The Hadamard product @@ -100,24 +96,24 @@ open class MFun, R: D1, C: D1>( is HProd -> "$left ʘ $right" is MSProd -> "$left * $right" is SMProd -> "$left * $right" - is MConst -> "TODO()" + is MConst -> "${javaClass.name}()" is Mat -> "Mat${numRows}x$numCols(${rows.joinToString(", ") { it.contents.joinToString(", ") }})" is MDerivative -> "d($mFun) / d($v1)" - else -> throw IllegalArgumentException("Type ${this::class.java.name} unknown") + else -> TODO(this::class.java.name) } } -class MNegative, R: D1, C: D1>(val value: MFun): MFun(value) -class MTranspose, R: D1, C: D1>(val value: MFun): MFun(value.sVars) -class MSum, R: D1, C: D1>(val left: MFun, val right: MFun): MFun(left, right) -class MMProd, R: D1, C1: D1, C2: D1>(val left: MFun, val right: MFun): MFun(left, right) -class HProd, R: D1, C: D1>(val left: MFun, val right: MFun): MFun(left, right) -class MSProd, R: D1, C: D1>(val left: MFun, val right: Fun): MFun(left) -class SMProd, R: D1, C: D1>(val left: Fun, val right: MFun): MFun(right) +class MNegative, R: D1, C: D1>(val value: MFun): MFun(value) +class MTranspose, R: D1, C: D1>(val value: MFun): MFun(value) +class MSum, R: D1, C: D1>(val left: MFun, val right: MFun): MFun(left, right) +class MMProd, R: D1, C1: D1, C2: D1>(val left: MFun, val right: MFun): MFun(left, right) +class HProd, R: D1, C: D1>(val left: MFun, val right: MFun): MFun(left, right) +class MSProd, R: D1, C: D1>(val left: MFun, val right: SFun): MFun(left) +class SMProd, R: D1, C: D1>(val left: SFun, val right: MFun): MFun(right) // TODO: Generalize tensor derivatives? https://en.wikipedia.org/wiki/Tensor_derivative_(continuum_mechanics) -class MDerivative, R: D1, C: D1> internal constructor(val mFun: VFun, numCols: Nat, val v1: Var): MFun(mFun.sVars) { - fun MFun.df(): MFun = when (this) { +class MDerivative, R: D1, C: D1> internal constructor(val mFun: VFun, numCols: Nat, val v1: Var): MFun(mFun) { + fun MFun.df(): MFun = when (this@df) { is MConst -> MZero() is MVar -> MZero() is MNegative -> -value.df() @@ -129,22 +125,38 @@ class MDerivative, R: D1, C: D1> internal constructor(val mFun: VFun left.d(v1) * right + left * right.df() is HProd -> left.df() ʘ right + left ʘ right.df() is Mat -> Mat(rows.map { it.d(v1)() }) - else -> throw IllegalArgumentException("Unable to differentiate expression of type ${this::class.java.name}") + else -> TODO(this@df::class.java.name) + } +} + +class MGradient, R: D1, C: D1>(val fn: SFun, val mVar: MVar): MFun(fn) { + fun df() = fn.df() + fun SFun.df(): MFun = when (this@df) { + is MVar<*, *, *> -> if (this == mVar) MOne() else MZero() + is Var -> MZero() + is SConst -> MZero() + is Sum -> left.df() + right.df() + is Prod -> left.df() * right + left * right.df() + is Power -> this * (exponent * Log(base)).df() + is Negative -> -value.df() + is Log -> (logarithmand pow -One()) * logarithmand.df() +// is Derivative -> fn.df() + is DProd -> this().df() + is VMagnitude -> this().df() + else -> TODO(this@df::class.java.name) } } -class MVar, R: D1, C: D1>(override val name: String = ""): Variable, MFun() -open class MConst, R: D1, C: D1>: MFun() +class MVar, R: D1, C: D1>(override val name: String = ""): Variable, MFun() +open class MConst, R: D1, C: D1>: MFun() -class MZero, R: D1, C: D1>: MConst() -class MOne, R: D1, C: D1>: MConst() +class MZero, R: D1, C: D1>: MConst() +class MOne, R: D1, C: D1>: MConst() -open class Mat, R: D1, C: D1>(override val sVars: Set> = emptySet(), - val rows: List>): MFun() { - constructor(rows: List>): this(rows.flatMap { it.sVars }.toSet(), rows) - constructor(vararg rows: Vec): this(rows.flatMap { it.sVars }.toSet(), rows.asList()) +open class Mat, R: D1, C: D1>(val rows: List>): MFun(*rows.toTypedArray()) { + constructor(vararg rows: Vec): this(rows.asList()) - val flatContents: List> by lazy { rows.flatMap { it.contents } } + val flatContents: List> by lazy { rows.flatMap { it.contents } } val indices = rows.indices val cols by lazy { indices.map { i -> Vec(rows.map { it[i] }) } } @@ -169,7 +181,7 @@ open class Mat, R: D1, C: D1>(override val sVars: Set> = emptyS operator fun get(i: Int): VFun = rows[i] - override operator fun times(multiplicand: Fun): Mat = Mat(rows.map { it * multiplicand }) + override operator fun times(multiplicand: SFun): Mat = Mat(rows.map { it * multiplicand }) override operator fun times(multiplicand: VFun): VFun = when (multiplicand) { @@ -188,22 +200,22 @@ open class Mat, R: D1, C: D1>(override val sVars: Set> = emptyS } } -fun > Mat1x1(v0: Vec): Mat = Mat(v0) -fun > Mat2x1(v0: Vec, v1: Vec): Mat = Mat(v0, v1) -fun > Mat3x1(v0: Vec, v1: Vec, v2: Vec): Mat = Mat(v0, v1, v2) -fun > Mat1x2(v0: Vec): Mat = Mat(v0) -fun > Mat2x2(v0: Vec, v1: Vec): Mat = Mat(v0, v1) -fun > Mat3x2(v0: Vec, v1: Vec, v2: Vec): Mat = Mat(v0, v1, v2) -fun > Mat1x3(v0: Vec): Mat = Mat(v0) -fun > Mat2x3(v0: Vec, v1: Vec): Mat = Mat(v0, v1) -fun > Mat3x3(v0: Vec, v1: Vec, v2: Vec): Mat = Mat(v0, v1, v2) - -fun > Mat1x1(d0: Fun): Mat = Mat(Vec(d0)) -fun > Mat1x2(d0: Fun, d1: Fun): Mat = Mat(Vec(d0, d1)) -fun > Mat1x3(d0: Fun, d1: Fun, d2: Fun): Mat = Mat(Vec(d0, d1, d2)) -fun > Mat2x1(d0: Fun, d1: Fun): Mat = Mat(Vec(d0), Vec(d1)) -fun > Mat2x2(d0: Fun, d1: Fun, d2: Fun, d3: Fun): Mat = Mat(Vec(d0, d1), Vec(d2, d3)) -fun > Mat2x3(d0: Fun, d1: Fun, d2: Fun, d3: Fun, d4: Fun, d5: Fun): Mat = Mat(Vec(d0, d1, d2), Vec(d3, d4, d5)) -fun > Mat3x1(d0: Fun, d1: Fun, d2: Fun): Mat = Mat(Vec(d0), Vec(d1), Vec(d2)) -fun > Mat3x2(d0: Fun, d1: Fun, d2: Fun, d3: Fun, d4: Fun, d5: Fun): Mat = Mat(Vec(d0, d1), Vec(d2, d3), Vec(d4, d5)) -fun > Mat3x3(d0: Fun, d1: Fun, d2: Fun, d3: Fun, d4: Fun, d5: Fun, d6: Fun, d7: Fun, d8: Fun): Mat = Mat(Vec(d0, d1, d2), Vec(d3, d4, d5), Vec(d6, d7, d8)) \ No newline at end of file +fun > Mat1x1(v0: Vec): Mat = Mat(v0) +fun > Mat2x1(v0: Vec, v1: Vec): Mat = Mat(v0, v1) +fun > Mat3x1(v0: Vec, v1: Vec, v2: Vec): Mat = Mat(v0, v1, v2) +fun > Mat1x2(v0: Vec): Mat = Mat(v0) +fun > Mat2x2(v0: Vec, v1: Vec): Mat = Mat(v0, v1) +fun > Mat3x2(v0: Vec, v1: Vec, v2: Vec): Mat = Mat(v0, v1, v2) +fun > Mat1x3(v0: Vec): Mat = Mat(v0) +fun > Mat2x3(v0: Vec, v1: Vec): Mat = Mat(v0, v1) +fun > Mat3x3(v0: Vec, v1: Vec, v2: Vec): Mat = Mat(v0, v1, v2) + +fun > Mat1x1(d0: SFun): Mat = Mat(Vec(d0)) +fun > Mat1x2(d0: SFun, d1: SFun): Mat = Mat(Vec(d0, d1)) +fun > Mat1x3(d0: SFun, d1: SFun, d2: SFun): Mat = Mat(Vec(d0, d1, d2)) +fun > Mat2x1(d0: SFun, d1: SFun): Mat = Mat(Vec(d0), Vec(d1)) +fun > Mat2x2(d0: SFun, d1: SFun, d2: SFun, d3: SFun): Mat = Mat(Vec(d0, d1), Vec(d2, d3)) +fun > Mat2x3(d0: SFun, d1: SFun, d2: SFun, d3: SFun, d4: SFun, d5: SFun): Mat = Mat(Vec(d0, d1, d2), Vec(d3, d4, d5)) +fun > Mat3x1(d0: SFun, d1: SFun, d2: SFun): Mat = Mat(Vec(d0), Vec(d1), Vec(d2)) +fun > Mat3x2(d0: SFun, d1: SFun, d2: SFun, d3: SFun, d4: SFun, d5: SFun): Mat = Mat(Vec(d0, d1), Vec(d2, d3), Vec(d4, d5)) +fun > Mat3x3(d0: SFun, d1: SFun, d2: SFun, d3: SFun, d4: SFun, d5: SFun, d6: SFun, d7: SFun, d8: SFun): Mat = Mat(Vec(d0, d1, d2), Vec(d3, d4, d5), Vec(d6, d7, d8)) \ No newline at end of file diff --git a/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyVectorExample.kt b/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyVectorExample.kt index cf1c78cf..90d42528 100644 --- a/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyVectorExample.kt +++ b/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyVectorExample.kt @@ -39,9 +39,10 @@ fun main() { * Vector function. */ -sealed class VFun, E: D1>( - override val sVars: Set> = emptySet()): MFun(sVars) { - constructor(vararg vFns: VFun): this(vFns.flatMap { it.sVars }.toSet()) + + +sealed class VFun, E: D1>(override val inputs: Inputs): Fun, MFun(inputs) { + constructor(vararg funs: Fun): this(Inputs(*funs)) override operator fun invoke(bnds: Bindings): VFun = when (this) { @@ -52,16 +53,17 @@ sealed class VFun, E: D1>( is SVProd -> left(bnds) * right(bnds) is VSProd -> left(bnds) * right(bnds) is VDerivative -> df()(bnds) + is Gradient -> df()(bnds) is MVProd -> left(bnds) as Mat * (right as Vec)(bnds) is VMProd -> (left as Vec)(bnds) * (right as Mat)(bnds) is VMap -> value(bnds).map(ef) - else -> throw IllegalArgumentException("Type ${this::class.java.name} unknown") + else -> TODO(this::class.java.name) } // Materializes the concrete vector from the dataflow graph operator fun invoke(): Vec = invoke(Bindings()) as Vec - open fun map(ef: (Fun) -> Fun): VFun = VMap(this, ef) + open fun map(ef: (SFun) -> SFun): VFun = VMap(this, ef) fun d(v1: Var) = VDerivative(this, v1) fun d(v1: Var, v2: Var) = Jacobian(this, v1, v2) @@ -74,70 +76,71 @@ sealed class VFun, E: D1>( fun d(v1: Var, v2: Var, v3: Var, v4: Var, v5: Var, v6: Var, v7: Var, v8: Var, v9: Var) = Jacobian(this, v1, v2, v3, v4, v5, v6, v7, v8, v9) //... fun d(vararg vars: Var): Map, VFun> = vars.map { it to VDerivative(this, it) }.toMap() - fun grad(): Map, VFun> = sVars.map { it to VDerivative(this, it) }.toMap() + fun grad(): Map, VFun> = inputs.sVars.map { it to VDerivative(this, it) }.toMap() override operator fun unaryMinus(): VFun = VNegative(this) open operator fun plus(addend: VFun): VFun = VSum(this, addend) open operator fun minus(subtrahend: VFun): VFun = VSum(this, -subtrahend) open infix fun ʘ(multiplicand: VFun): VFun = VVProd(this, multiplicand) - override operator fun times(multiplicand: Fun): VFun = VSProd(this, multiplicand) + override operator fun times(multiplicand: SFun): VFun = VSProd(this, multiplicand) open operator fun times(multiplicand: MFun): VFun = VMProd(this, multiplicand)//(expand * multiplicand).rows.first() - open infix fun dot(multiplicand: VFun): Fun = DProd(this, multiplicand) - - open fun magnitude(): Fun = VMagnitude(this) - - override fun toString() = - when (this) { - is Vec -> contents.joinToString(", ", "[", "]") - is VSum -> "$left + $right" - is VVProd -> "$left ʘ $right" - is SVProd -> "$left * $right" - is VSProd -> "$left * $right" - is VNegative -> "-($value)" - is VDerivative -> "d($vFun) / d($v1)"//d(${v1.joinToString(", ")})" - is MVProd -> "$left * $right" - is VMProd -> "$left * $right" - is Gradient -> "($fn).d(${vrbs.joinToString(", ")})" - is VMap -> "$value.map { $ef }" - is VVar -> "VVar($name)" - } + open infix fun dot(multiplicand: VFun): SFun = DProd(this, multiplicand) + + open fun magnitude(): SFun = VMagnitude(this) + + override fun toString() = when (this) { + is Vec -> contents.joinToString(", ", "[", "]") + is VSum -> "$left + $right" + is VVProd -> "$left ʘ $right" + is SVProd -> "$left * $right" + is VSProd -> "$left * $right" + is VNegative -> "-($value)" + is VDerivative -> "d($vFun) / d($v1)"//d(${v1.joinToString(", ")})" + is MVProd -> "$left * $right" + is VMProd -> "$left * $right" + is Gradient -> "($fn).d($vVar)" + is VMap -> "$value.map { $ef }" + is VVar -> "VVar($name)" + } } -class VNegative, E: D1>(val value: VFun): VFun(value) -class VMap, E: D1>(val value: VFun, val ef: (Fun) -> Fun): VFun(value) - -class VSum, E: D1>(val left: VFun, val right: VFun): VFun(left, right) - -class VVProd, E: D1>(val left: VFun, val right: VFun): VFun(left, right) -class SVProd, E: D1>(val left: Fun, val right: VFun): VFun(right.sVars + right.sVars) -class VSProd, E: D1>(val left: VFun, val right: Fun): VFun(left.sVars + right.sVars) -class MVProd, R: D1, C: D1>(val left: MFun, val right: VFun): VFun(left.sVars + right.sVars) -class VMProd, R: D1, C: D1>(val left: VFun, val right: MFun): VFun(left.sVars + right.sVars) - -class Gradient, E: D1>: VFun { - - val fn: Fun - val vrbs: Array> - - constructor(fn: Fun, vararg vrbs: Var): super(fn.sVars) { - this.fn = fn - this.vrbs = vrbs +class VNegative, E: D1>(val value: VFun): VFun(value) +class VMap, E: D1>(val value: VFun, val ef: (SFun) -> SFun): VFun(value) + +class VSum, E: D1>(val left: VFun, val right: VFun): VFun(left, right) + +class VVProd, E: D1>(val left: VFun, val right: VFun): VFun(left, right) +class SVProd, E: D1>(val left: SFun, val right: VFun): VFun(left, right) +class VSProd, E: D1>(val left: VFun, val right: SFun): VFun(left, right) +class MVProd, R: D1, C: D1>(val left: MFun, val right: VFun): VFun(left, right) +class VMProd, R: D1, C: D1>(val left: VFun, val right: MFun): VFun(left, right) + +class Gradient, E: D1>(val fn: SFun, val vVar: VVar): VFun(fn) { + fun df() = fn.df() + fun SFun.df(): VFun = when (this@df) { + is VVar<*, *> -> if (this == vVar) VOne() else VZero() + is Var -> VZero() + is SConst -> VZero() + is Sum -> left.df() + right.df() + is Prod -> left.df() * right + left * right.df() + is Power -> this * (exponent * Log(base)).df() + is Negative -> -value.df() + is Log -> (logarithmand pow -One()) * logarithmand.df() +// is Derivative -> fn.df() + is DProd -> this().df() + is VMagnitude -> this().df() + else -> TODO(this@df::class.java.name) } - - // constructor(fn: Fun, vVar: VVar): super(fn.sVars, vVar) -override fun invoke(bnds: Bindings) = Vec(vrbs.map { Derivative(fn, it)() })(bnds) - // override fun invoke(bnds: Bindings) = - // if(Vec(vrbs.map { Derivative(fn, it)() })(bnds) } -class VVar, E: D1>(override val name: String = ""): Variable, VFun() -class Jacobian, R: D1, C: D1>(val vfn: VFun, vararg val vrbs: Var): MFun(vfn.sVars) { +class VVar, E: D1>(override val name: String = ""): Variable, VFun() +class Jacobian, R: D1, C: D1>(val vfn: VFun, vararg val vrbs: Var): MFun(vfn) { override fun invoke(bnds: Bindings) = Mat(vrbs.map { VDerivative(vfn, it)() }).ᵀ(bnds) } -class VDerivative, E: D1> internal constructor(val vFun: VFun, val v1: Var) : VFun(vFun) { - fun VFun.df(): VFun = when (this) { +class VDerivative, E: D1> internal constructor(val vFun: VFun, val v1: Var) : VFun(vFun) { + fun VFun.df(): VFun = when (this@df) { is VConst -> VZero() is VVar -> VZero() is VSum -> left.df() + right.df() @@ -154,15 +157,12 @@ class VDerivative, E: D1> internal constructor(val vFun: VFun, } } -open class VConst, E: D1>(vararg contents: SConst): Vec(emptySet(), contents.asList()) - -class VZero, E: D1>: VConst() -class VOne, E: D1>: VConst() +open class VConst, E: D1>(vararg contents: SConst): Vec(contents.asList()) -open class Vec, E: D1>(override val sVars: Set> = emptySet(), - val contents: List>): VFun() { - constructor(contents: List>): this(contents.flatMap { it.sVars }.toSet(), contents) +class VZero, E: D1>: VConst() +class VOne, E: D1>: VConst() +open class Vec, E: D1>(val contents: List>): VFun(*contents.toTypedArray()) { val size = contents.size val indices = contents.indices @@ -185,39 +185,39 @@ open class Vec, E: D1>(override val sVars: Set> = emptySet(), else -> super.ʘ(multiplicand) } - override fun times(multiplicand: Fun): Vec = Vec(contents.map { it * multiplicand }) + override fun times(multiplicand: SFun): Vec = Vec(contents.map { it * multiplicand }) override fun dot(multiplicand: VFun) = when(multiplicand) { is Vec -> contents.reduceIndexed { index, acc, element -> acc + element * multiplicand[index] } else -> super.dot(multiplicand) } - override fun map(ef: (Fun) -> Fun): Vec = Vec(contents.map { ef(it) }) + override fun map(ef: (SFun) -> SFun): Vec = Vec(contents.map { ef(it) }) override fun magnitude() = contents.map { it * it }.reduce { acc, p -> acc + p }.sqrt() override fun unaryMinus(): Vec = Vec(contents.map { -it }) companion object { - operator fun > invoke(s0: SConst): VConst = VConst(s0) - operator fun > invoke(s0: SConst, s1: SConst): VConst = VConst(s0, s1) - operator fun > invoke(s0: SConst, s1: SConst, s2: SConst): VConst = VConst(s0, s1, s2) - operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst): VConst = VConst(s0, s1, s2, s3) - operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst): VConst = VConst(s0, s1, s2, s3, s4) - operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst, s5: SConst): VConst = VConst(s0, s1, s2, s3, s4, s5) - operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst, s5: SConst, s6: SConst): VConst = VConst(s0, s1, s2, s3, s4, s5, s6) - operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst, s5: SConst, s6: SConst, s7: SConst): VConst = VConst(s0, s1, s2, s3, s4, s5, s6, s7) - operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst, s5: SConst, s6: SConst, s7: SConst, s8: SConst): VConst = VConst(s0, s1, s2, s3, s4, s5, s6, s7, s8) - - operator fun > invoke(t0: Fun): Vec = Vec(arrayListOf(t0)) - operator fun > invoke(t0: Fun, t1: Fun): Vec = Vec(arrayListOf(t0, t1)) - operator fun > invoke(t0: Fun, t1: Fun, t2: Fun): Vec = Vec(arrayListOf(t0, t1, t2)) - operator fun > invoke(t0: Fun, t1: Fun, t2: Fun, t3: Fun): Vec = Vec(arrayListOf(t0, t1, t2, t3)) - operator fun > invoke(t0: Fun, t1: Fun, t2: Fun, t3: Fun, t4: Fun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4)) - operator fun > invoke(t0: Fun, t1: Fun, t2: Fun, t3: Fun, t4: Fun, t5: Fun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4, t5)) - operator fun > invoke(t0: Fun, t1: Fun, t2: Fun, t3: Fun, t4: Fun, t5: Fun, t6: Fun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4, t5, t6)) - operator fun > invoke(t0: Fun, t1: Fun, t2: Fun, t3: Fun, t4: Fun, t5: Fun, t6: Fun, t7: Fun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4, t5, t6, t7)) - operator fun > invoke(t0: Fun, t1: Fun, t2: Fun, t3: Fun, t4: Fun, t5: Fun, t6: Fun, t7: Fun, t8: Fun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4, t5, t6, t7, t8)) + operator fun > invoke(s0: SConst): VConst = VConst(s0) + operator fun > invoke(s0: SConst, s1: SConst): VConst = VConst(s0, s1) + operator fun > invoke(s0: SConst, s1: SConst, s2: SConst): VConst = VConst(s0, s1, s2) + operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst): VConst = VConst(s0, s1, s2, s3) + operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst): VConst = VConst(s0, s1, s2, s3, s4) + operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst, s5: SConst): VConst = VConst(s0, s1, s2, s3, s4, s5) + operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst, s5: SConst, s6: SConst): VConst = VConst(s0, s1, s2, s3, s4, s5, s6) + operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst, s5: SConst, s6: SConst, s7: SConst): VConst = VConst(s0, s1, s2, s3, s4, s5, s6, s7) + operator fun > invoke(s0: SConst, s1: SConst, s2: SConst, s3: SConst, s4: SConst, s5: SConst, s6: SConst, s7: SConst, s8: SConst): VConst = VConst(s0, s1, s2, s3, s4, s5, s6, s7, s8) + + operator fun > invoke(t0: SFun): Vec = Vec(arrayListOf(t0)) + operator fun > invoke(t0: SFun, t1: SFun): Vec = Vec(arrayListOf(t0, t1)) + operator fun > invoke(t0: SFun, t1: SFun, t2: SFun): Vec = Vec(arrayListOf(t0, t1, t2)) + operator fun > invoke(t0: SFun, t1: SFun, t2: SFun, t3: SFun): Vec = Vec(arrayListOf(t0, t1, t2, t3)) + operator fun > invoke(t0: SFun, t1: SFun, t2: SFun, t3: SFun, t4: SFun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4)) + operator fun > invoke(t0: SFun, t1: SFun, t2: SFun, t3: SFun, t4: SFun, t5: SFun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4, t5)) + operator fun > invoke(t0: SFun, t1: SFun, t2: SFun, t3: SFun, t4: SFun, t5: SFun, t6: SFun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4, t5, t6)) + operator fun > invoke(t0: SFun, t1: SFun, t2: SFun, t3: SFun, t4: SFun, t5: SFun, t6: SFun, t7: SFun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4, t5, t6, t7)) + operator fun > invoke(t0: SFun, t1: SFun, t2: SFun, t3: SFun, t4: SFun, t5: SFun, t6: SFun, t7: SFun, t8: SFun): Vec = Vec(arrayListOf(t0, t1, t2, t3, t4, t5, t6, t7, t8)) } } diff --git a/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/VariableCapture.kt b/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/VariableCapture.kt index 273a0d36..cf033665 100644 --- a/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/VariableCapture.kt +++ b/core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/VariableCapture.kt @@ -77,7 +77,7 @@ open class X>(override val left: BiFn<*>, op(left(X), right(X)) private operator fun BiFn<*>.invoke(X: XBnd

): P = - when (this) { + when (this@invoke) { is Const<*, *> -> this as P is X<*> -> (this as X

)(X) else -> throw IllegalStateException(toString()) @@ -128,7 +128,7 @@ open class Y>(override val left: BiFn<*>, op(left(Y), right(Y)) private operator fun BiFn<*>.invoke(Y: YBnd

): P = - when (this) { + when (this@invoke) { is Const<*, *> -> this as P is Y<*> -> (this as Y

)(Y) else -> throw IllegalStateException(toString()) @@ -178,7 +178,7 @@ open class Z>(override val left: BiFn<*>, open operator fun invoke(Z: ZBnd

): P = op(left(Z), right(Z)) private operator fun BiFn<*>.invoke(Z: ZBnd

): P = - when (this) { + when (this@invoke) { is Const<*, *> -> this as P is Z<*> -> (this as Z

)(Z) else -> throw IllegalStateException(toString()) @@ -226,7 +226,7 @@ class XY>(override val left: BiFn<*>, operator fun times(that: XYZ

): XYZ

= XYZ(this, that, mul) private operator fun BiFn<*>.invoke(X: XBnd

, Y: YBnd

): P = - when (this) { + when (this@invoke) { is Const<*, *> -> this as P is XY<*> -> (this as XY

)(X, Y) is X<*> -> (this as X

)(X) @@ -237,7 +237,7 @@ class XY>(override val left: BiFn<*>, operator fun invoke(X: XBnd

, Y: YBnd

): P = op(left(X, Y), right(X, Y)) private operator fun BiFn<*>.invoke(X: XBnd

): Y

= - when (this) { + when (this@invoke) { is Const<*, *> -> Y(this as P) is XY<*> -> (this as XY

)(X) is X<*> -> Y((this as X

)(X)) @@ -246,7 +246,7 @@ class XY>(override val left: BiFn<*>, } private operator fun BiFn<*>.invoke(Y: YBnd

): X

= - when (this) { + when (this@invoke) { is Const<*, *> -> X(this as P) is XY<*> -> (this as XY

)(Y) is Y<*> -> X((this as Y

)(Y)) @@ -299,7 +299,7 @@ class XZ>(override val left: BiFn<*>, operator fun times(that: XYZ

): XYZ

= XYZ(this, that, mul) private operator fun BiFn<*>.invoke(X: XBnd

, Z: ZBnd

): P = - when (this) { + when (this@invoke) { is Const<*, *> -> this as P is XZ<*> -> (this as XZ

)(X, Z) is X<*> -> (this as X

)(X) @@ -310,7 +310,7 @@ class XZ>(override val left: BiFn<*>, operator fun invoke(X: XBnd

, Z: ZBnd

): P = op(left(X, Z), right(X, Z)) private operator fun BiFn<*>.invoke(X: XBnd

): Z

= - when (this) { + when (this@invoke) { is Const<*, *> -> Z(this as P) is XZ<*> -> (this as XZ

)(X) is X<*> -> Z((this as X

)(X)) @@ -319,7 +319,7 @@ class XZ>(override val left: BiFn<*>, } private operator fun BiFn<*>.invoke(Z: ZBnd

): X

= - when (this) { + when (this@invoke) { is Const<*, *> -> X(this as P) is XZ<*> -> (this as XZ

)(Z) is Z<*> -> X((this as Z

)(Z)) @@ -372,7 +372,7 @@ class YZ>(override val left: BiFn<*>, operator fun times(that: XYZ

): XYZ

= XYZ(this, that, mul) private operator fun BiFn<*>.invoke(Y: YBnd

, Z: ZBnd

): P = - when (this) { + when (this@invoke) { is Const<*, *> -> this as P is YZ<*> -> (this as YZ

)(Y, Z) is Y<*> -> (this as Y

)(Y) @@ -384,7 +384,7 @@ class YZ>(override val left: BiFn<*>, op(left(Y, Z), right(Y, Z)) private operator fun BiFn<*>.invoke(Y: YBnd

): Z

= - when (this) { + when (this@invoke) { is Const<*, *> -> Z(this as P) is YZ<*> -> (this as YZ

)(Y) is Y<*> -> Z((this as Y

)(Y)) @@ -393,7 +393,7 @@ class YZ>(override val left: BiFn<*>, } private operator fun BiFn<*>.invoke(Z: ZBnd

): Y

= - when (this) { + when (this@invoke) { is Const<*, *> -> Y(this as P) is YZ<*> -> (this as YZ

)(Z) is Z<*> -> Y((this as Z

)(Z)) @@ -444,7 +444,7 @@ class XYZ>(override val left: BiFn<*>, operator fun times(that: XYZ

): XYZ

= XYZ(this, that, mul) private operator fun BiFn<*>.invoke(X: XBnd

, Y: YBnd

, Z: ZBnd

): P = - when (this) { + when (this@invoke) { is Const<*, *> -> this as P is XYZ<*> -> (this as XYZ

)(X, Y, Z) is XY<*> -> (this as XY

)(X, Y) @@ -460,7 +460,7 @@ class XYZ>(override val left: BiFn<*>, op(left(X, Y, Z), right(X, Y, Z)) private operator fun BiFn<*>.invoke(X: XBnd

, Z: ZBnd

): Y

= - when (this) { + when (this@invoke) { is Const<*, *> -> Y(this as P) is XYZ<*> -> (this as XYZ

)(X, Z) is XY<*> -> (this as XY

)(X) @@ -476,7 +476,7 @@ class XYZ>(override val left: BiFn<*>, Y(left(X, Z), right(X, Z), op) private operator fun BiFn<*>.invoke(X: XBnd

, Y: YBnd

): Z

= - when (this) { + when (this@invoke) { is Const<*, *> -> Z(this as P) is XYZ<*> -> (this as XYZ

)(X, Y) is XY<*> -> Z((this as XY

)(X, Y)) @@ -492,7 +492,7 @@ class XYZ>(override val left: BiFn<*>, Z(left(X, Y), right(X, Y), op) private operator fun BiFn<*>.invoke(Y: YBnd

, Z: ZBnd

): X

= - when (this) { + when (this@invoke) { is Const<*, *> -> X(this as P) is XYZ<*> -> (this as XYZ

)(Y, Z) is XY<*> -> (this as XY

)(Y) @@ -508,7 +508,7 @@ class XYZ>(override val left: BiFn<*>, X(left(Y, Z), right(Y, Z), op) private operator fun BiFn<*>.invoke(X: XBnd

): YZ

= - when (this) { + when (this@invoke) { is Const<*, *> -> YZ(this as P) is XYZ<*> -> (this as XYZ

)(X) is XY<*> -> YZ((this as XY

)(X)) @@ -523,7 +523,7 @@ class XYZ>(override val left: BiFn<*>, operator fun invoke(X: XBnd

): YZ

= YZ(left(X), right(X), op) private operator fun BiFn<*>.invoke(Y: YBnd

): XZ

= - when (this) { + when (this@invoke) { is Const<*, *> -> XZ(this as P) is XYZ<*> -> (this as XYZ

)(Y) is XY<*> -> XZ((this as XY

)(Y)) @@ -538,7 +538,7 @@ class XYZ>(override val left: BiFn<*>, operator fun invoke(Y: YBnd

): XZ

= XZ(left(Y), right(Y), op) private operator fun BiFn<*>.invoke(Z: ZBnd

): XY

= - when (this) { + when (this@invoke) { is Const<*, *> -> XY(this as P) is XYZ<*> -> (this as XYZ

)(Z) is XY<*> -> this as XY

diff --git a/core/src/test/kotlin/edu/umontreal/kotlingrad/calculus/ExpressionGenerator.kt b/core/src/test/kotlin/edu/umontreal/kotlingrad/calculus/ExpressionGenerator.kt index 984e025c..5ac62460 100644 --- a/core/src/test/kotlin/edu/umontreal/kotlingrad/calculus/ExpressionGenerator.kt +++ b/core/src/test/kotlin/edu/umontreal/kotlingrad/calculus/ExpressionGenerator.kt @@ -7,24 +7,24 @@ import edu.umontreal.kotlingrad.experimental.DoublePrecision.z import io.kotlintest.properties.Gen import io.kotlintest.properties.shrinking.Shrinker -abstract class ExpressionGenerator>: Gen> { +abstract class ExpressionGenerator>: Gen> { companion object: ExpressionGenerator() { override val variables: List> = listOf(x, y, z) } - val sum = { x: Fun, y: Fun -> Sum(x, y) } - val mul = { x: Fun, y: Fun -> Prod(x, y) } + val sum = { x: SFun, y: SFun -> Sum(x, y) } + val mul = { x: SFun, y: SFun -> Prod(x, y) } - val operators: List<(Fun, Fun) -> Fun> = listOf(sum, mul) + val operators: List<(SFun, SFun) -> SFun> = listOf(sum, mul) val constants: List> = listOf(Zero(), One(), Two()) open val variables: List> = listOf(Var("x"), Var("y"), Var("z")) - override fun constants(): Iterable> = constants + override fun constants(): Iterable> = constants - override fun random(): Sequence> = generateSequence { randomBiTree() } + override fun random(): Sequence> = generateSequence { randomBiTree() } - override fun shrinker() = object: Shrinker> { - override fun shrink(failure: Fun): List> = + override fun shrinker() = object: Shrinker> { + override fun shrink(failure: SFun): List> = when(failure) { is Sum -> listOf(failure.left, failure.right) is Prod -> listOf(failure.left, failure.right) @@ -32,7 +32,7 @@ abstract class ExpressionGenerator>: Gen> { } } - private fun randomBiTree(level: Int = 1): Fun = + private fun randomBiTree(level: Int = 1): SFun = if(5 < level) if(Math.random() < 0.5) constants.random() else variables.random() else operators.random()(randomBiTree(level + 1), randomBiTree(level + 1)) diff --git a/core/src/test/kotlin/edu/umontreal/kotlingrad/evaluation/TestSymbolic.kt b/core/src/test/kotlin/edu/umontreal/kotlingrad/evaluation/TestSymbolic.kt index 5953b461..917c046e 100644 --- a/core/src/test/kotlin/edu/umontreal/kotlingrad/evaluation/TestSymbolic.kt +++ b/core/src/test/kotlin/edu/umontreal/kotlingrad/evaluation/TestSymbolic.kt @@ -14,7 +14,7 @@ import javax.script.SimpleBindings class TestSymbolic : StringSpec({ val engine = ScriptEngineManager().getEngineByExtension("kts") - fun ktEval(f: Fun, vararg kgBnds: Pair, Number>) = + fun ktEval(f: SFun, vararg kgBnds: Pair, Number>) = engine.run { val bindings = kgBnds.map { it.first.name to it.second.toDouble() }.toMap() setBindings(SimpleBindings(bindings), ScriptContext.ENGINE_SCOPE) @@ -23,7 +23,7 @@ class TestSymbolic : StringSpec({ with(DoublePrecision) { "test symbolic evaluation" { - ExpressionGenerator.assertAll(10) { f: Fun -> + ExpressionGenerator.assertAll(10) { f: SFun -> try { DoubleGenerator.assertAll(10) { ẋ, ẏ, ż -> f(x to ẋ, y to ẏ, z to ż) shouldBeAbout ktEval(f, x to ẋ, y to ẏ, z to ż) diff --git a/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/MultilayerPerceptron.kt b/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/MultilayerPerceptron.kt index b1ad29c6..f747e402 100644 --- a/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/MultilayerPerceptron.kt +++ b/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/MultilayerPerceptron.kt @@ -4,7 +4,7 @@ import edu.umontreal.kotlingrad.experimental.* import kotlin.random.Random @Suppress("NonAsciiCharacters") -class MultilayerPerceptron>( +class MultilayerPerceptron>( val x: Var = Var(), val y: Var = Var(), val p1v: VVar = VVar(), @@ -12,12 +12,12 @@ class MultilayerPerceptron>( val p3v: VVar = VVar() ): (VFun, MFun, - VFun) -> Fun { + VFun) -> SFun { override operator fun invoke( p1: VFun, p2: MFun, p3: VFun - ): Fun { + ): SFun { val layer1 = layer(p1 * x) val layer2 = layer(p2 * layer1) val output = layer2 dot p3 @@ -25,7 +25,7 @@ class MultilayerPerceptron>( return lossFun } - private fun sigmoid(x: Fun) = One() / (One() + E().pow(-x)) + private fun sigmoid(x: SFun) = One() / (One() + E().pow(-x)) private fun layer(x: VFun): VFun = x.map { sigmoid(it) } diff --git a/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/VisualizeDFG.kt b/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/VisualizeDFG.kt index bd8d298c..0620ae3e 100644 --- a/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/VisualizeDFG.kt +++ b/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/VisualizeDFG.kt @@ -1,7 +1,7 @@ package edu.umontreal.kotlingrad.samples import edu.umontreal.kotlingrad.experimental.DoublePrecision -import edu.umontreal.kotlingrad.experimental.Fun +import edu.umontreal.kotlingrad.experimental.SFun import guru.nidi.graphviz.* import guru.nidi.graphviz.attribute.* import guru.nidi.graphviz.attribute.Rank.RankDir.LEFT_TO_RIGHT @@ -20,7 +20,7 @@ fun main() { const val DARKMODE = false const val THICKNESS = 2 -fun Fun<*>.render(filename: String? = null) { +fun SFun<*>.render(filename: String? = null) { val image = graph(directed = true) { val color = if (DARKMODE) Color.WHITE else Color.BLACK diff --git a/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/physics/DoublePendulum.kt b/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/physics/DoublePendulum.kt index 077f44f6..199a1fc8 100644 --- a/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/physics/DoublePendulum.kt +++ b/samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/physics/DoublePendulum.kt @@ -19,19 +19,19 @@ import kotlin.math.* @Suppress("NonAsciiCharacters", "LocalVariableName") class DoublePendulum(private val len: Double = 900.0) : Application(), EventHandler { - var ω1: Fun = DoublePrecision.wrap(0.0) // Angular velocities - var ω2: Fun = DoublePrecision.wrap(0.0) + var ω1: SFun = DoublePrecision.wrap(0.0) // Angular velocities + var ω2: SFun = DoublePrecision.wrap(0.0) val m1 = 2.0 // Masses val m2 = 2.0 - var G: Fun = DoublePrecision.wrap(9.81) // Gravity - var µ: Fun = DoublePrecision.wrap(0.01) // Friction + var G: SFun = DoublePrecision.wrap(9.81) // Gravity + var µ: SFun = DoublePrecision.wrap(0.01) // Friction val Gp = 0.01 // Simulate measurement error val µp = -0.01 var r1 = DoublePrecision.Vec(1.0, 0.0) // Polar vector var r2 = DoublePrecision.Vec(1.0, 0.0) val observationSteps = 30 var priorVal = 5.0 - fun step(obs: Fun? = null, groundTruth: Pair, VConst>? = null) = with(DoublePrecision) { + fun step(obs: SFun? = null, groundTruth: Pair, VConst>? = null) = with(DoublePrecision) { val isObserving = false // val priorVal = if(G is Var) G.asDouble() if (isObserving) { @@ -69,10 +69,10 @@ class DoublePendulum(private val len: Double = 900.0) : Application(), EventHand } val r1a = (r1.angle + (ω1 * dt + .5 * α1 * dt * dt)).run { - if(G is Var) this(G as Var to priorVal).asDouble() else try { asDouble() } catch(e: Exception) {println(this); this(sVars.first() to priorVal).asDouble() } + if(G is Var) this(G as Var to priorVal).asDouble() else try { asDouble() } catch(e: Exception) {println(this); this(inputs.sVars.first() to priorVal).asDouble() } } val r2a = (r2.angle + - (ω2 * dt + .5 * α2 * dt * dt)).run { - if(G is Var) this(G as Var to priorVal).asDouble() else try { asDouble() } catch(e: Exception) {println(this); this(sVars.first() to priorVal).asDouble() } + if(G is Var) this(G as Var to priorVal).asDouble() else try { asDouble() } catch(e: Exception) {println(this); this(inputs.sVars.first() to priorVal).asDouble() } } if(G is Var) { @@ -90,11 +90,11 @@ class DoublePendulum(private val len: Double = 900.0) : Application(), EventHand } } - fun Fun.descend(steps: Int, vinit: Double, gamma: Double, α: Double = 0.1, map: Pair, DReal>): Fun { + fun SFun.descend(steps: Int, vinit: Double, gamma: Double, α: Double = 0.1, map: Pair, DReal>): SFun { with(DoublePrecision) { val d_dg = this@descend.d(map.first) - var G1P: Fun = map.second - var velocity: Fun = wrap(vinit) + var G1P: SFun = map.second + var velocity: SFun = wrap(vinit) var i = 0 do { velocity = gamma * velocity + d_dg(map.first to G1P) * α