Skip to content

Commit

Permalink
Implemented new context and let (#4)
Browse files Browse the repository at this point in the history
* Implemented stack

* Implemented let

* Updated let implementation

---------

Co-authored-by: laurenzlevi <[email protected]>
  • Loading branch information
laurenzlevi and laurenzlevi authored Apr 17, 2024
1 parent ce55571 commit 82d537e
Show file tree
Hide file tree
Showing 98 changed files with 17,837 additions and 76,535 deletions.
175 changes: 124 additions & 51 deletions src/main/kotlin/tools/aqua/konstraints/parser/Context.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package tools.aqua.konstraints.parser

import tools.aqua.konstraints.smt.*
import tools.aqua.konstraints.theories.CoreContext
import tools.aqua.konstraints.util.Stack

abstract class SortDecl<T : Sort>(
val name: Symbol,
Expand All @@ -42,26 +44,43 @@ abstract class SortDecl<T : Sort>(
abstract fun getSort(bindings: Bindings): T
}

class Context {
// store the sort of numeral expressions either (NUMERAL Int) or (NUMERAL Real) depending on the
// loaded logic
var numeralSort: Sort? = null
/**
* Context class manages the currently loaded logic/theory and all the Assertion-Levels (including
* global eventually but this option is currently not supported)
*/
class Context(theory: Theory) {
// theory setter is private to disallow changing the theory manually
// this should only be changed when (set-logic) is used or when reset is called
var theory: Theory = theory
private set

val assertionLevels = Stack<Subcontext>()

init {
// core theory is always loaded QF_UF and UF are the only logics that load core "manually"
// as it is the only theory they rely on, for all other logics core is loaded before
// the other theory is loaded
if (theory != CoreContext) {
assertionLevels.push(CoreContext)
}

fun contains(expression: Expression<*>): Boolean =
getFunction(expression.symbol.toString(), expression.subexpressions) != null
assertionLevels.push(theory)
assertionLevels.push(AssertionLevel())
}

fun registerTheory(other: TheoryContext) {
other.functions.forEach { func ->
if (func.name.toString() in functionLookup) {
functionLookup[func.name.toString()]?.add(func)
} else {
functionLookup[func.name.toString()] = mutableListOf(func)
}
}
var numeralSort: Sort? = null

fun let(varBindings: List<VarBinding>, block: (Context) -> Expression<*>): Expression<Sort> {
assertionLevels.push(LetLevel(varBindings))
val result = block(this)
assertionLevels.pop()

other.sorts.forEach { registerSort(it.value) }
return result as Expression<Sort>
}

fun contains(expression: Expression<*>): Boolean =
getFunction(expression.symbol.toString(), expression.subexpressions) != null

fun registerFunction(function: DeclareConst) {
registerFunction(
FunctionDecl(
Expand All @@ -87,26 +106,13 @@ class Context {
}

fun registerFunction(function: FunctionDecl<*>) {
val conflicts = functionLookup[function.name.toString()]

if (conflicts != null) {
val conflictParams = conflicts.filter { it.accepts(function.params, emptySet()) }

if (conflictParams.isNotEmpty()) {
val conflictReturns =
conflictParams.filter { it.signature.bindReturnOrNull(function.sort) != null }

if (conflictReturns.isNotEmpty()) {
throw FunctionAlreadyDeclaredException(function)
} else {
conflicts.add(function)
}
} else {
conflicts.add(function)
}
} else {
functionLookup[function.name.toString()] = mutableListOf(function)
if (theory?.contains(function) == true) {
throw IllegalFunctionOverloadException(
function.name.toString(), "Can not overload theory symbols")
}

// TODO enforce all overloading/shadowing rules
assertionLevels.peek().add(function)
}

internal fun registerFunction(const: ProtoDeclareConst, sort: Sort) {
Expand Down Expand Up @@ -140,16 +146,18 @@ class Context {
}

fun registerSort(sort: SortDecl<*>) {
if (sorts.containsKey(sort.name.toString()))
throw SortAlreadyDeclaredException(sort.name, sort.signature.sortParameter.size)
if (theory?.contains(sort) == true) {
throw SortAlreadyDeclaredException(sort.name, sort.signature.sortParameter.size)
}

sorts[sort.name.toString()] = sort
// TODO enforce all overloading/shadowing rules
assertionLevels.peek().add(sort)
}

fun registerSort(name: Symbol, arity: Int) {
if (sorts.containsKey(name.toString())) throw SortAlreadyDeclaredException(name, arity)
val sort = UserDefinedSortDecl(name, arity)

sorts[name.toString()] = UserDefinedSortDecl(name, arity)
registerSort(sort)
}

/**
Expand All @@ -167,33 +175,98 @@ class Context {
* @throws IllegalArgumentException if the function specified by name and args is ambiguous
*/
fun getFunction(name: String, args: List<Expression<*>>): FunctionDecl<*>? {
return functionLookup[name]?.single { func -> func.accepts(args.map { it.sort }, emptySet()) }
return assertionLevels.find { it.contains(name, args) }?.get(name, args)
}

internal fun getSort(protoSort: ProtoSort): Sort {
// build all sort parameters first
val parameters = protoSort.sorts.map { getSort(it) }
val sort =
assertionLevels.find { it.containsSort(protoSort.name) }?.sorts?.get(protoSort.name)
?: throw NoSuchElementException()

return sorts[protoSort.name]?.buildSort(protoSort.identifier, parameters)
?: throw Exception("Unknown sort ${protoSort.identifier.symbol}")
return sort.buildSort(protoSort.identifier, parameters)
}
}

private val sorts: MutableMap<String, SortDecl<*>> = mutableMapOf()
/**
* Parent class of all assertion levels (this includes the default assertion levels and binder
* assertion levels, as well as theory objects)
*/
interface Subcontext {
fun contains(function: FunctionDecl<*>) = functions.contains(function)

/*
* Lookup for all simple functions
* excludes indexed functions of the form e.g. ((_ extract i j) (_ BitVec m) (_ BitVec n))
*/
val functionLookup: MutableMap<String, MutableList<FunctionDecl<*>>> = mutableMapOf()
}
fun contains(function: String, args: List<Expression<*>>) = get(function, args) != null

interface TheoryContext {
val functions: HashSet<FunctionDecl<*>>
fun get(function: String, args: List<Expression<*>>) =
functions.find { it.name.toString() == function && it.acceptsExpressions(args, emptySet()) }

fun contains(sort: SortDecl<*>) = sorts.containsKey(sort.name.toString())

fun contains(sort: Sort) = sorts.containsKey(sort.name.toString())

fun containsSort(sort: String) = sorts.containsKey(sort)

fun add(function: FunctionDecl<*>): Boolean

fun add(sort: SortDecl<*>): SortDecl<*>?

val functions: List<FunctionDecl<*>>
val sorts: Map<String, SortDecl<*>>
}

/** Represents a single assertion level */
class AssertionLevel : Subcontext {
override fun add(function: FunctionDecl<*>) = functions.add(function)

override fun add(sort: SortDecl<*>) = sorts.put(sort.name.toString(), sort)

override val functions: MutableList<FunctionDecl<*>> = mutableListOf()
override val sorts: MutableMap<String, SortDecl<*>> = mutableMapOf()
}

class VarBinding(symbol: Symbol, val term: Expression<Sort>) :
FunctionDecl0<Sort>(symbol, emptySet(), emptySet(), term.sort) {
override fun buildExpression(bindings: Bindings): Expression<Sort> =
LocalExpression(name, sort, term)
}

class LetLevel(varBindings: List<VarBinding>) : Subcontext {
override fun add(function: FunctionDecl<*>): Boolean =
throw IllegalOperationException(
"LetLevel.add", "Can not add new functions to let assertion level")

override fun add(sort: SortDecl<*>): SortDecl<*> =
throw IllegalOperationException(
"LetLevel.add", "Can not add new sorts to let assertion level")

override val functions: List<FunctionDecl<*>> = varBindings
override val sorts: Map<String, SortDecl<*>> = emptyMap()
}

interface Theory : Subcontext {
override fun add(function: FunctionDecl<*>) =
throw IllegalOperationException("Theory.add", "Can not add new functions to SMT theories")

override fun add(sort: SortDecl<*>) =
throw IllegalOperationException("Theory.add", "Can not add new sorts to SMT theories")

override val functions: List<FunctionDecl<*>>
override val sorts: Map<String, SortDecl<*>>
}

class IllegalFunctionOverloadException(func: String, msg: String) :
RuntimeException("Illegal overload of $func: $msg.")

class FunctionAlreadyDeclaredException(func: FunctionDecl<*>) :
RuntimeException("Function $func has already been declared")

class SortAlreadyDeclaredException(sort: Symbol, arity: Int) :
RuntimeException("Sort ($sort $arity) has already been declared")

class TheoryAlreadySetException :
RuntimeException(
"Theory has already been set, use the smt-command (reset) before using a new logic or theory")

class IllegalOperationException(operation: String, reason: String) :
RuntimeException("Illegal Operation $operation: $reason.")
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ open class FunctionDecl<S : Sort>(
override fun toString() = "($name (${params.joinToString(" ")}) $sort)"
}

// TODO are indices necessary here (dont think so)
abstract class FunctionDecl0<S : Sort>(
name: Symbol,
parametricSorts: Set<Sort>,
Expand Down Expand Up @@ -242,7 +241,7 @@ abstract class FunctionDecl4<P1 : Sort, P2 : Sort, P3 : Sort, P4 : Sort, S : Sor
args: List<Expression<*>>,
functionIndices: Set<NumeralIndex>
): Expression<S> {
require(args.size == 4)
require(args.size == 4) { "$name expected 4 arguments but got ${args.size}: $args" }
val bindings = bindParametersToExpressions(args, functionIndices)

// TODO suppress unchecked cast warning
Expand Down
54 changes: 27 additions & 27 deletions src/main/kotlin/tools/aqua/konstraints/parser/ParseTreeVisitor.kt
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,12 @@ import jdk.jshell.spi.ExecutionControl.NotImplementedException
import tools.aqua.konstraints.smt.*
import tools.aqua.konstraints.theories.*
import tools.aqua.konstraints.theories.BitVectorExpressionContext
import tools.aqua.konstraints.theories.CoreContext
import tools.aqua.konstraints.theories.IntsContext

internal class ParseTreeVisitor :
ProtoCommandVisitor, ProtoTermVisitor, ProtoSortVisitor, SpecConstantVisitor {

val context = Context()

init {
// always load core theory
context.registerTheory(CoreContext)
}
var context: Context? = null

override fun visit(protoAssert: ProtoAssert): Assert {
val term = visit(protoAssert.term)
Expand All @@ -47,7 +41,7 @@ internal class ParseTreeVisitor :
override fun visit(protoDeclareConst: ProtoDeclareConst): DeclareConst {
val sort = visit(protoDeclareConst.sort)

context.registerFunction(protoDeclareConst, sort)
context?.registerFunction(protoDeclareConst, sort)

return DeclareConst(Symbol(protoDeclareConst.name), sort)
}
Expand All @@ -56,30 +50,31 @@ internal class ParseTreeVisitor :
val sort = visit(protoDeclareFun.sort)
val parameters = protoDeclareFun.parameters.map { visit(it) }

context.registerFunction(protoDeclareFun, parameters, sort)
context?.registerFunction(protoDeclareFun, parameters, sort)

return DeclareFun(protoDeclareFun.name.symbol(), parameters, sort)
}

override fun visit(protoSetLogic: ProtoSetLogic): SetLogic {
when (protoSetLogic.logic) {
QF_BV -> context.registerTheory(BitVectorExpressionContext)
QF_BV -> context = Context(BitVectorExpressionContext)
QF_IDL -> {
context.registerTheory(IntsContext)
context.numeralSort = IntSort
context = Context(IntsContext)
context?.numeralSort = IntSort
}
QF_RDL -> {
context.registerTheory(RealsContext)
context.numeralSort = RealSort
context = Context(RealsContext)
context?.numeralSort = RealSort
}
QF_FP -> context.registerTheory(FloatingPointContext)
QF_FP -> context = Context(FloatingPointContext)
// QF_AX uses only ArrayEx with free function and sort symbols, as free sorts are not yet
// supported
// load int theory as well for testing purposes
QF_AX -> {
context.registerTheory(ArrayExContext)
context.registerTheory(IntsContext)
context.numeralSort = IntSort
context = Context(ArrayExContext)
}
QF_UF -> {
context = Context(CoreContext)
}
else -> throw NotImplementedException("${protoSetLogic.logic} not yet supported")
}
Expand All @@ -88,7 +83,7 @@ internal class ParseTreeVisitor :
}

override fun visit(protoDeclareSort: ProtoDeclareSort): DeclareSort {
context.registerSort(protoDeclareSort.symbol, protoDeclareSort.arity)
context?.registerSort(protoDeclareSort.symbol, protoDeclareSort.arity)

return DeclareSort(protoDeclareSort.symbol, protoDeclareSort.arity)
}
Expand All @@ -113,7 +108,7 @@ internal class ParseTreeVisitor :
SortedVar(protoSortedVar.symbol, visit(protoSortedVar.sort))

override fun visit(simpleQualIdentifier: SimpleQualIdentifier): Expression<*> {
val op = context.getFunction(simpleQualIdentifier.identifier, listOf())
val op = context?.getFunction(simpleQualIdentifier.identifier, listOf())

if (op != null) {
return op.buildExpression(listOf(), emptySet())
Expand All @@ -135,7 +130,7 @@ internal class ParseTreeVisitor :
val terms = bracketedProtoTerm.terms.map { visit(it) }

val op =
context.getFunction(bracketedProtoTerm.qualIdentifier.identifier.symbol.toString(), terms)
context?.getFunction(bracketedProtoTerm.qualIdentifier.identifier.symbol.toString(), terms)

val functionIndices =
if (bracketedProtoTerm.qualIdentifier.identifier is IndexedIdentifier) {
Expand All @@ -157,7 +152,12 @@ internal class ParseTreeVisitor :
}

override fun visit(protoLet: ProtoLet): Expression<*> {
TODO("Implement visit ProtoLet")
val bindings =
protoLet.bindings.map { VarBinding(it.symbol, visit(it.term) as Expression<Sort>) }

val inner = context?.let(bindings) { visit(protoLet.term) }!!

return LetExpression("xyz".symbol(), inner.sort, bindings, inner)
}

override fun visit(protoForAll: ProtoForAll): Expression<*> {
Expand All @@ -173,22 +173,22 @@ internal class ParseTreeVisitor :
}

override fun visit(protoAnnotation: ProtoAnnotation): Expression<*> {
TODO("Implement visit ProtoExclamation")
TODO("Implement visit ProtoAnnotation")
}

override fun visit(protoSort: ProtoSort): Sort {
return context.getSort(protoSort)
return context!!.getSort(protoSort)
}

override fun visit(stringConstant: StringConstant): Expression<*> {
TODO("Not yet implemented")
}

override fun visit(numeralConstant: NumeralConstant): Expression<*> {
if (context.numeralSort == IntSort) return IntLiteral(numeralConstant.numeral)
else if (context.numeralSort == RealSort)
if (context?.numeralSort == IntSort) return IntLiteral(numeralConstant.numeral)
else if (context?.numeralSort == RealSort)
return RealLiteral(BigDecimal(numeralConstant.numeral))
else throw RuntimeException("Unsupported numeral literal sort ${context.numeralSort}")
else throw RuntimeException("Unsupported numeral literal sort ${context?.numeralSort}")
}

override fun visit(binaryConstant: BinaryConstant): Expression<*> {
Expand Down
Loading

0 comments on commit 82d537e

Please sign in to comment.