Skip to content

Commit

Permalink
Implemented structure classes for expression tree
Browse files Browse the repository at this point in the history
  • Loading branch information
laurenzlevi committed Mar 21, 2024
1 parent a6f02dd commit c00a7dd
Show file tree
Hide file tree
Showing 14 changed files with 561 additions and 660 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ open class FunctionDecl<S : Sort>(
): Expression<S> {
bindParametersToExpressions(args, functionIndices)

return NAryExpression(name, sort, args)
return UserDefinedExpression(name, sort, args)
}

open fun bindParametersToExpressions(args: List<Expression<*>>, indices: Set<NumeralIndex>) =
Expand Down
153 changes: 122 additions & 31 deletions src/main/kotlin/tools/aqua/konstraints/smt/Expression.kt
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,15 @@ package tools.aqua.konstraints.smt

import tools.aqua.konstraints.util.reduceOrDefault

abstract class Expression<T : Sort> {
sealed interface Expression<T : Sort> {
abstract val symbol: Symbol
abstract val sort: T

override fun toString() = symbol.toSMTString()

open val subexpressions = emptyList<Expression<*>>()

/** Recursive all implementation */
fun all(predicate: (Expression<*>) -> Boolean): Boolean {
return predicate(this) and
subexpressions.map { it.all(predicate) }.reduceOrDefault(true) { t1, t2 -> t1 and t2 }
}
/**
* Recursive all implementation fun all(predicate: (Expression<*>) -> Boolean): Boolean { return
* predicate(this) and subexpressions.map { it.all(predicate) }.reduceOrDefault(true) { t1, t2 ->
* t1 and t2 } }
*/

// TODO implement more operations like filter, filterIsInstance, filterIsSort, forEach, onEach
// etc.
Expand All @@ -48,40 +44,135 @@ abstract class Expression<T : Sort> {

@Suppress("UNCHECKED_CAST") return this as Expression<S>
}

fun all(predicate: (Expression<*>) -> Boolean): Boolean =
when (this) {
is UnaryExpression<*, *> -> predicate(this) and this.inner().all(predicate)
is BasicExpression -> predicate(this)
is BinaryExpression<*, *, *> ->
predicate(this) and this.lhs().all(predicate) and this.rhs().all(predicate)
is HomogenousExpression<*, *> ->
predicate(this) and
this.subexpressions()
.map { it.all(predicate) }
.reduceOrDefault(true) { t1, t2 -> t1 and t2 }
is Ite -> TODO()
is Literal -> TODO()
is NAryExpression ->
predicate(this) and
this.subexpressions()
.map { it.all(predicate) }
.reduceOrDefault(true) { t1, t2 -> t1 and t2 }
is TernaryExpression<*, *, *, *> -> TODO()
}

val subexpressions: List<Expression<*>>
}

class BasicExpression<T : Sort>(override val symbol: Symbol, override val sort: T) :
Expression<T>() {
// TODO this should be variable
class BasicExpression<T : Sort>(override val symbol: Symbol, override val sort: T) : Expression<T> {
override val subexpressions: List<Expression<*>> = emptyList()

override fun toString() = "$symbol"
}

class UnaryExpression<T : Sort>(symbol: Symbol, override val sort: T, val other: Expression<T>) :
Expression<T>() {
override val symbol: Symbol = symbol
open class Literal<T : Sort>(override val symbol: Symbol, override val sort: T) : Expression<T> {
override val subexpressions: List<Expression<*>> = emptyList()

override fun toString() = "($symbol ${other})"
override fun toString() = "$symbol"
}

class BinaryExpression<T : Sort>(
symbol: Symbol,
override val sort: T,
val left: Expression<T>,
val right: Expression<T>
) : Expression<T>() {
override val symbol: Symbol = symbol
abstract class UnaryExpression<T : Sort, S : Sort>(
override val symbol: Symbol,
override val sort: T
) : Expression<T> {

abstract fun inner(): Expression<S>

override val subexpressions: List<Expression<*>>
get() = listOf(inner())

override fun toString() = "($symbol ${left} ${right})"
override fun toString() = "($symbol ${inner()})"
}

class NAryExpression<T : Sort>(
symbol: Symbol,
override val sort: T,
val tokens: List<Expression<*>>
) : Expression<T>() {
override val symbol: Symbol = symbol
abstract class BinaryExpression<T : Sort, S1 : Sort, S2 : Sort>(
override val symbol: Symbol,
override val sort: T
) : Expression<T> {

abstract fun lhs(): Expression<S1>

abstract fun rhs(): Expression<S2>

override val subexpressions: List<Expression<*>>
get() = listOf(lhs(), rhs())

override fun toString() = "($symbol ${lhs()} ${rhs()})"
}

abstract class TernaryExpression<T : Sort, S1 : Sort, S2 : Sort, S3 : Sort>(
override val symbol: Symbol,
override val sort: T
) : Expression<T> {
abstract fun lhs(): Expression<S1>

abstract fun mid(): Expression<S2>

abstract fun rhs(): Expression<S3>

override val subexpressions: List<Expression<*>>
get() = listOf(lhs(), mid(), rhs())

override fun toString() = "($symbol ${lhs()} ${mid()} ${rhs()})"
}

abstract class HomogenousExpression<T : Sort, S : Sort>(
override val symbol: Symbol,
override val sort: T
) : Expression<T> {
abstract fun subexpressions(): List<Expression<S>>

override val subexpressions: List<Expression<*>>
get() = subexpressions()

override fun toString() =
if (subexpressions().isNotEmpty()) "($symbol ${subexpressions().joinToString(" ")})"
else symbol.toSMTString()
}

/**
* Implements ite according to Core theory (par (A) (ite Bool A A A))
*
* @param statement indicates whether [then] or [els] should be returned
* @param then value to be returned if [statement] is true
* @param els value to be returned if [statement] is false
*/
class Ite(val statement: Expression<BoolSort>, val then: Expression<*>, val els: Expression<*>) :
Expression<Sort> {
override val sort: BoolSort = BoolSort
override val symbol: Symbol = "ite".symbol()

override val subexpressions: List<Expression<*>> = listOf(statement, then, els)

override fun toString(): String = "(ite $statement $then $els)"
}

abstract class NAryExpression<T : Sort>(override val symbol: Symbol, override val sort: T) :
Expression<T> {

abstract fun subexpressions(): List<Expression<*>>

override val subexpressions: List<Expression<*>>
get() = subexpressions()

override fun toString() =
if (tokens.isNotEmpty()) "($symbol ${tokens.joinToString(" ")})" else symbol.toSMTString()
if (subexpressions().isNotEmpty()) "($symbol ${subexpressions().joinToString(" ")})"
else symbol.toSMTString()
}

class UserDefinedExpression<T : Sort>(name: Symbol, sort: T, val args: List<Expression<*>>) :
NAryExpression<T>(name, sort) {
override fun subexpressions(): List<Expression<*>> = args
}

class ExpressionCastException(from: Sort, to: String) :
Expand Down
12 changes: 6 additions & 6 deletions src/main/kotlin/tools/aqua/konstraints/smt/SMTProgram.kt
Original file line number Diff line number Diff line change
Expand Up @@ -78,18 +78,18 @@ class MutableSMTProgram(commands: List<Command>, context: Context) : SMTProgram(
* Inserts [command] at the end of the program Checks if [command] is legal w.r.t. the [context]
*/
fun add(command: Command) {
if (command is Assert) {
require(command.expression.all { context.contains(it) })
}

updateContext(command)
_commands.add(command)
add(command, _commands.size)
}

/**
* Inserts [command] at [index] into the program Checks if [command] is legal w.r.t. the [context]
*/
fun add(command: Command, index: Int) {
if (command is Assert) {
require(command.expression.all { context.contains(it) })
}

updateContext(command)
_commands.add(index, command)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ fun Z3Sort.aquaify(): Sort =
@JvmName("aquaifyAny")
fun Expr<*>.aquaify(): Expression<*> =
when (this.sort) {
is Z3BoolSort -> (this as Expr<Z3BoolSort>).aquaify()
is Z3IntSort -> (this as Expr<Z3IntSort>).aquaify()
is BitVecSort -> (this as Expr<BitVecSort>).aquaify()
is Z3BoolSort -> (this as Expr<Z3BoolSort>).aquaify() as Expression<Sort>
is Z3IntSort -> (this as Expr<Z3IntSort>).aquaify() as Expression<Sort>
is BitVecSort -> (this as Expr<BitVecSort>).aquaify() as Expression<Sort>
else -> throw RuntimeException("Unknown or unsupported Z3 sort ${this.sort}")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ fun Expression<IntSort>.z3ify(context: Z3Context): Expr<Z3IntSort> =
require(this is NAryExpression)
context.context.mkApp(
context.functions[this.symbol.toString()]!!,
*this.tokens.map { it.z3ify(context) }.toTypedArray()) as Expr<Z3IntSort>
*this.subexpressions().map { it.z3ify(context) }.toTypedArray()) as Expr<Z3IntSort>
} else {
throw IllegalArgumentException("Z3 can not visit expression $this!")
}
Expand Down
17 changes: 9 additions & 8 deletions src/main/kotlin/tools/aqua/konstraints/theories/ArraysEx.kt
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ internal object ArraySortDecl :
ArraySort(bindings[SortParameter("X")], bindings[SortParameter("Y")])
}

class ArraySelect(val array: Expression<ArraySort>, val index: Expression<*>) : Expression<Sort>() {
override val symbol: Symbol = "select".symbol()
override val sort: Sort = array.sort.y

class ArraySelect(val array: Expression<ArraySort>, val index: Expression<*>) :
BinaryExpression<Sort, ArraySort, Sort>("select".symbol(), array.sort.y) {
init {
require(array.sort.x == index.sort)
}

override fun lhs(): Expression<ArraySort> = array

override fun rhs(): Expression<Sort> = index as Expression<Sort>
}

object ArraySelectDecl :
Expand All @@ -69,14 +71,13 @@ class ArrayStore(
val array: Expression<ArraySort>,
val index: Expression<*>,
val value: Expression<*>
) : Expression<ArraySort>() {
override val symbol: Symbol = "store".symbol()
override val sort: ArraySort = array.sort

) : NAryExpression<ArraySort>("store".symbol(), array.sort) {
init {
require(array.sort.x == index.sort)
require(array.sort.y == value.sort)
}

override fun subexpressions(): List<Expression<*>> = listOf(array, index, value)
}

object ArrayStoreDecl :
Expand Down
Loading

0 comments on commit c00a7dd

Please sign in to comment.