From e69dd81564a077d4d9fb4716b08098762e026447 Mon Sep 17 00:00:00 2001 From: dmitriy sokolov Date: Mon, 24 Jul 2023 23:43:52 +0300 Subject: [PATCH 01/12] cvc5 forking solver, tests --- .../kotlin/io/ksmt/solver/KForkingSolver.kt | 14 + .../io/ksmt/solver/KForkingSolverManager.kt | 16 + .../io/ksmt/solver/cvc5/KCvc5Context.kt | 195 +++++++----- .../io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt | 126 ++++++++ .../solver/cvc5/KCvc5ForkingSolverManager.kt | 30 ++ .../kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt | 4 + .../kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt | 222 +------------- .../io/ksmt/solver/cvc5/KCvc5SolverBase.kt | 225 ++++++++++++++ .../solver/cvc5/KCvc5SolverConfiguration.kt | 23 ++ .../kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt | 121 ++++++++ .../kotlin/io/ksmt/test/KForkingSolverTest.kt | 278 ++++++++++++++++++ 11 files changed, 962 insertions(+), 292 deletions(-) create mode 100644 ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolver.kt create mode 100644 ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt create mode 100644 ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt create mode 100644 ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt create mode 100644 ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt create mode 100644 ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt create mode 100644 ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt diff --git a/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolver.kt b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolver.kt new file mode 100644 index 000000000..86a7ba498 --- /dev/null +++ b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolver.kt @@ -0,0 +1,14 @@ +package io.ksmt.solver + +/** + * A solver capable of creating forks (copies) of itself, preserving assertions and assertion scopes + * + * @see KForkingSolverManager + */ +interface KForkingSolver : KSolver { + + /** + * Creates forked solver + */ + fun fork(): KForkingSolver +} diff --git a/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt new file mode 100644 index 000000000..7df85068a --- /dev/null +++ b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt @@ -0,0 +1,16 @@ +package io.ksmt.solver + +/** + * Responsible for creation of [KForkingSolver] and managing its lifetime + * + * @see KForkingSolver + */ +interface KForkingSolverManager : AutoCloseable { + + fun mkForkingSolver(): KForkingSolver + + /** + * Closes the manager and all opened solvers ([KForkingSolver]) managed by this + */ + override fun close() +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt index 0368b6118..91069e915 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt @@ -32,19 +32,28 @@ import io.ksmt.sort.KSortVisitor import io.ksmt.sort.KUninterpretedSort import java.util.TreeMap -class KCvc5Context( +class KCvc5Context private constructor( private val solver: Solver, - private val ctx: KContext + private val ctx: KContext, + parent: KCvc5Context?, + isForking: Boolean ) : AutoCloseable { + constructor(solver: Solver, ctx: KContext, isForking: Boolean = false) : this(solver, ctx, null, isForking) + private var isClosed = false + private val isChild = parent != null private val uninterpretedSortCollector = KUninterpretedSortCollector(this) private var exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(uninterpretedSortCollector, ctx) + private val uninterpretedSorts: ScopedFrame> + private val declarations: ScopedFrame>> + + /** * We use double-scoped expression internalization cache: - * * current (before pop operation) - [currentScopeExpressions] - * * global (after pop operation) - [expressions] + * * current accumulated (before pop operation) - [currentAccumulatedScopeExpressions] + * * global (current + all previously met)- [expressions] * * Due to incremental collect of declarations and uninterpreted sorts for the model, * we collect them during internalizing. @@ -57,39 +66,79 @@ class KCvc5Context( * * **solution**: Recollect sorts / decls for each expression * that is in global cache, but whose sorts / decls have been erased after pop() - * (and put this expr to the current scope cache) + * (and put this expr to the cache of current accumulated scope) */ - private val currentScopeExpressions = HashMap, Term>() - private val expressions = HashMap, Term>() - // we can't use HashMap with Term and Sort (hashcode is not implemented) - private val cvc5Expressions = TreeMap>() - private val sorts = HashMap() - private val cvc5Sorts = TreeMap() - private val decls = HashMap, Term>() - private val cvc5Decls = TreeMap>() + private val currentAccumulatedScopeExpressions: HashMap, Term> + private val expressions: HashMap, Term> - private var currentLevelUninterpretedSorts = hashSetOf() - private val uninterpretedSorts = mutableListOf(currentLevelUninterpretedSorts) - - private var currentLevelDeclarations = hashSetOf>() - private val declarations = mutableListOf(currentLevelDeclarations) + // we can't use HashMap with Term and Sort (hashcode is not implemented) + private val cvc5Expressions: TreeMap> + private val sorts: HashMap + private val cvc5Sorts: TreeMap + private val decls: HashMap, Term> + private val cvc5Decls: TreeMap> - fun addUninterpretedSort(sort: KUninterpretedSort) { currentLevelUninterpretedSorts += sort } + private val uninterpretedSortValueDescriptors: ArrayList + private val uninterpretedSortValueInterpreter: HashMap /** - * uninterpreted sorts of active push-levels + * Uninterpreted sort values and universe are shared for whole forking hierarchy (from parent to children) + * due to shared expressions cache, + * that's why once [registerUninterpretedSortValue] and [saveUninterpretedSortValue] are called, + * each solver in hierarchy should assert newly internalized uninterpreted sort values via [assertPendingAxioms] + * + * @see KCvc5Model.uninterpretedSortUniverse */ - fun uninterpretedSorts(): List> = uninterpretedSorts + private val uninterpretedSortValues: HashMap>> + + init { + if (isForking) { + uninterpretedSorts = (parent?.uninterpretedSorts as? ScopedLinkedFrame)?.fork() + ?: ScopedLinkedFrame(::HashSet, ::HashSet) + declarations = (parent?.declarations as? ScopedLinkedFrame)?.fork() + ?: ScopedLinkedFrame(::HashSet, ::HashSet) + } else { + uninterpretedSorts = ScopedArrayFrame(::HashSet) + declarations = ScopedArrayFrame(::HashSet) + } + + if (parent != null) { + currentAccumulatedScopeExpressions = parent.currentAccumulatedScopeExpressions.toMap(HashMap()) + expressions = parent.expressions + cvc5Expressions = parent.cvc5Expressions + sorts = parent.sorts + cvc5Sorts = parent.cvc5Sorts + decls = parent.decls + cvc5Decls = parent.cvc5Decls + uninterpretedSortValueDescriptors = parent.uninterpretedSortValueDescriptors + uninterpretedSortValueInterpreter = parent.uninterpretedSortValueInterpreter + uninterpretedSortValues = parent.uninterpretedSortValues + } else { + currentAccumulatedScopeExpressions = HashMap() + expressions = HashMap() + cvc5Expressions = TreeMap() + sorts = HashMap() + cvc5Sorts = TreeMap() + decls = HashMap() + cvc5Decls = TreeMap() + uninterpretedSortValueDescriptors = arrayListOf() + uninterpretedSortValueInterpreter = hashMapOf() + uninterpretedSortValues = hashMapOf() + } + } + + fun addUninterpretedSort(sort: KUninterpretedSort) { + uninterpretedSorts.currentFrame += sort + } + + fun uninterpretedSorts(): Set = uninterpretedSorts.flatten { this += it } fun addDeclaration(decl: KDecl<*>) { - currentLevelDeclarations += decl + declarations.currentFrame += decl uninterpretedSortCollector.collect(decl) } - /** - * declarations of active push-levels - */ - fun declarations(): List>> = declarations + fun declarations(): Set> = declarations.flatten { this += it } val nativeSolver: Solver get() = solver @@ -97,40 +146,37 @@ class KCvc5Context( val isActive: Boolean get() = !isClosed + fun fork(solver: Solver): KCvc5Context = KCvc5Context(solver, ctx, this, true).also { forkCtx -> + repeat(assertedConstraintLevels.size) { + forkCtx.pushAssertionLevel() + } + } + fun push() { - currentLevelDeclarations = hashSetOf() - declarations.add(currentLevelDeclarations) - currentLevelUninterpretedSorts = hashSetOf() - uninterpretedSorts.add(currentLevelUninterpretedSorts) + declarations.push() + uninterpretedSorts.push() pushAssertionLevel() } fun pop(n: UInt) { - repeat(n.toInt()) { - declarations.removeLast() - uninterpretedSorts.removeLast() + declarations.pop(n) + uninterpretedSorts.pop(n) - popAssertionLevel() - } + repeat(n.toInt()) { popAssertionLevel() } - currentLevelDeclarations = declarations.last() - currentLevelUninterpretedSorts = uninterpretedSorts.last() - - expressions += currentScopeExpressions - currentScopeExpressions.clear() + currentAccumulatedScopeExpressions.clear() // recreate cache restorer to avoid KNonRecursiveTransformer cache exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(uninterpretedSortCollector, ctx) } - // expr - fun findInternalizedExpr(expr: KExpr<*>): Term? = currentScopeExpressions[expr] + fun findInternalizedExpr(expr: KExpr<*>): Term? = currentAccumulatedScopeExpressions[expr] ?: expressions[expr]?.also { /* - * expr is not in current scope cache, but in global cache. + * expr is not in cache of current accumulated scope, but in global cache. * Recollect declarations and uninterpreted sorts - * and add entire expression tree to the current scope cache from the global - * to avoid re-internalizing with native calls + * and add entire expression tree to the current accumulated scope cache from the global + * to avoid re-internalization */ exprCurrentLevelCacheRestorer.apply(expr) } @@ -138,17 +184,19 @@ class KCvc5Context( fun findConvertedExpr(expr: Term): KExpr<*>? = cvc5Expressions[expr] fun saveInternalizedExpr(expr: KExpr<*>, internalized: Term): Term = - internalizeAst(currentScopeExpressions, cvc5Expressions, expr) { internalized } + internalizeAst(currentAccumulatedScopeExpressions, cvc5Expressions, expr) { internalized } + .also { expressions[expr] = internalized } /** * save expr, which is in global cache, to the current scope cache */ - fun savePreviouslyInternalizedExpr(expr: KExpr<*>): Term = saveInternalizedExpr(expr, expressions[expr]!!) + fun saveInternalizedExprToCurrentAccumulatedScope(expr: KExpr<*>): Term = + currentAccumulatedScopeExpressions.getOrPut(expr) { expressions[expr]!! } fun saveConvertedExpr(expr: Term, converted: KExpr<*>): KExpr<*> = - convertAst(currentScopeExpressions, cvc5Expressions, expr) { converted } + convertAst(currentAccumulatedScopeExpressions, cvc5Expressions, expr) { converted } + .also { expressions[converted] = expr } - // sort fun findInternalizedSort(sort: KSort): Sort? = sorts[sort] fun findConvertedSort(sort: Sort): KSort? = cvc5Sorts[sort] @@ -165,7 +213,6 @@ class KCvc5Context( inline fun convertSort(sort: Sort, converter: () -> KSort): KSort = findOrSave(sort, converter, ::findConvertedSort, ::saveConvertedSort) - // decl fun findInternalizedDecl(decl: KDecl<*>): Term? = decls[decl] fun findConvertedDecl(decl: Term): KDecl<*>? = cvc5Decls[decl] @@ -211,17 +258,20 @@ class KCvc5Context( val nativeValueTerm: Term ) + /** + * Uninterpreted sort value axioms will not be lost for [KCvc5ForkingSolver] on [fork]. + * + * On child initialization, "[currentValueConstraintsLevel] = 0" + * will be pushed to [assertedConstraintLevels] for each push-level ([currentValueConstraintsLevel] times). + * At the first call of [assertPendingAxioms] each descriptor from [uninterpretedSortValueDescriptors] + * will be asserted to the child [KCvc5ForkingSolver] + */ private var currentValueConstraintsLevel = 0 private val assertedConstraintLevels = arrayListOf() - private val uninterpretedSortValueDescriptors = arrayListOf() - private val uninterpretedSortValueInterpreter = hashMapOf() - - private val uninterpretedSortValues = - hashMapOf>>() fun saveUninterpretedSortValue(nativeValue: Term, value: KUninterpretedSortValue): Term { val sortValues = uninterpretedSortValues.getOrPut(value.sort) { arrayListOf() } - sortValues.add(nativeValue to value) + sortValues += nativeValue to value return nativeValue } @@ -266,7 +316,7 @@ class KCvc5Context( uninterpretedSortValues[sort] ?: emptyList() private fun pushAssertionLevel() { - assertedConstraintLevels.add(currentValueConstraintsLevel) + assertedConstraintLevels += currentValueConstraintsLevel } private fun popAssertionLevel() { @@ -354,20 +404,18 @@ class KCvc5Context( if (isClosed) return isClosed = true - currentScopeExpressions.clear() - expressions.clear() - cvc5Expressions.clear() + currentAccumulatedScopeExpressions.clear() - uninterpretedSorts.clear() - currentLevelUninterpretedSorts.clear() + if (isChild) { + expressions.clear() + cvc5Expressions.clear() - declarations.clear() - currentLevelDeclarations.clear() + sorts.clear() + cvc5Sorts.clear() - sorts.clear() - cvc5Sorts.clear() - decls.clear() - cvc5Decls.clear() + decls.clear() + cvc5Decls.clear() + } } class KUninterpretedSortCollector(private val cvc5Ctx: KCvc5Context) : KSortVisitor { @@ -416,7 +464,8 @@ class KCvc5Context( ctx: KContext ) : KNonRecursiveTransformer(ctx) { - override fun exprTransformationRequired(expr: KExpr): Boolean = expr !in currentScopeExpressions + override fun exprTransformationRequired(expr: KExpr): Boolean = + expr !in currentAccumulatedScopeExpressions override fun transform(expr: KFunctionApp): KExpr = cacheIfNeed(expr) { this@KCvc5Context.addDeclaration(expr.decl) @@ -426,7 +475,7 @@ class KCvc5Context( override fun transform(expr: KConst): KExpr = cacheIfNeed(expr) { this@KCvc5Context.addDeclaration(expr.decl) uninterpretedSortCollector.collect(expr.decl) - this@KCvc5Context.savePreviouslyInternalizedExpr(expr) + saveInternalizedExprToCurrentAccumulatedScope(expr) } override fun , R : KSort> transform(expr: KFunctionAsArray): KExpr = @@ -446,7 +495,7 @@ class KCvc5Context( override fun transform(expr: KUniversalQuantifier): KExpr = cacheIfNeed(expr) { transformQuantifier(expr.bounds.toSet(), expr.body) - this@KCvc5Context.savePreviouslyInternalizedExpr(expr) + saveInternalizedExprToCurrentAccumulatedScope(expr) } private fun transformQuantifier(bounds: Set>, body: KExpr<*>) { @@ -455,11 +504,11 @@ class KCvc5Context( } private fun > cacheIfNeed(expr: E, transform: E.() -> Unit): KExpr { - if (expr in currentScopeExpressions) + if (expr in currentAccumulatedScopeExpressions) return expr expr.transform() - this@KCvc5Context.savePreviouslyInternalizedExpr(expr) + saveInternalizedExprToCurrentAccumulatedScope(expr) return expr } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt new file mode 100644 index 000000000..cd24c385a --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt @@ -0,0 +1,126 @@ +package io.ksmt.solver.cvc5 + +import io.github.cvc5.Term +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import java.util.TreeMap +import java.util.TreeSet +import kotlin.time.Duration + +open class KCvc5ForkingSolver internal constructor( + ctx: KContext, + private val manager: KCvc5ForkingSolverManager, + parent: KCvc5ForkingSolver? +) : KCvc5SolverBase(ctx), KForkingSolver, KSolver { + + final override val cvc5Ctx: KCvc5Context + private val isChild = parent != null + private var assertionsInitiated = !isChild + + private val _trackedAssertions: ScopedLinkedFrame>> + + override val trackedAssertions: ScopedFrame>> + get() = _trackedAssertions + + private val cvc5Assertions: ScopedLinkedFrame> + + init { + if (parent != null) { + cvc5Ctx = parent.cvc5Ctx.fork(solver) + _trackedAssertions = parent._trackedAssertions.fork() + cvc5Assertions = parent.cvc5Assertions.fork() + } else { + cvc5Ctx = KCvc5Context(solver, ctx, true) + _trackedAssertions = ScopedLinkedFrame(::TreeMap, ::TreeMap) + cvc5Assertions = ScopedLinkedFrame(::TreeSet, ::TreeSet) + } + } + + private val config: KCvc5ForkingSolverConfigurationImpl by lazy { + parent?.config?.fork(solver) ?: KCvc5ForkingSolverConfigurationImpl(solver) + } + + override fun configure(configurator: KCvc5SolverConfiguration.() -> Unit) { + config.configurator() + } + + override fun fork(): KForkingSolver = manager.mkForkingSolver(this) + + private fun ensureAssertionsInitiated() { + if (assertionsInitiated) return + + cvc5Assertions.stacked() + .zip(_trackedAssertions.stacked()) + .asReversed() + .forEachIndexed { scope, (cvc5AssertionFrame, trackedFrame) -> + if (scope > 0) solver.push() + + cvc5AssertionFrame.forEach(solver::assertFormula) + trackedFrame.forEach { (track, _) -> solver.assertFormula(track) } + } + + assertionsInitiated = true + } + + override fun assert(expr: KExpr): Unit = cvc5Try { + ctx.ensureContextMatch(expr) + ensureAssertionsInitiated() + + val cvc5Expr = with(exprInternalizer) { expr.internalizeExpr() } + solver.assertFormula(cvc5Expr) + cvc5Ctx.assertPendingAxioms(solver) + cvc5Assertions.currentFrame.add(cvc5Expr) + } + + override fun assertAndTrack(expr: KExpr) { + cvc5Try { ensureAssertionsInitiated() } + super.assertAndTrack(expr) + } + + override fun push() { + cvc5Try { ensureAssertionsInitiated() } + super.push() + cvc5Assertions.push() + } + + override fun pop(n: UInt) { + cvc5Try { ensureAssertionsInitiated() } + super.pop(n) + cvc5Assertions.pop(n) + } + + override fun check(timeout: Duration): KSolverStatus { + cvc5Try { ensureAssertionsInitiated() } + cvc5Ctx.assertPendingAxioms(solver) + return super.check(timeout) + } + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { + cvc5Try { ensureAssertionsInitiated() } + cvc5Ctx.assertPendingAxioms(solver) + return super.checkWithAssumptions(assumptions, timeout) + } + + override fun unsatCore(): List> { + val cvc5FullCore = cvc5UnsatCore() + + val unsatCore = mutableListOf>() + + cvc5FullCore.forEach { unsatCoreTerm -> + lastCvc5Assumptions?.get(unsatCoreTerm)?.also { unsatCore += it } + ?: trackedAssertions.find { trackedAssertion -> + trackedAssertion[unsatCoreTerm]?.let { unsatCore += it; true } ?: false + } + } + return unsatCore + } + + override fun close() { + manager.close(this) + super.close() + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt new file mode 100644 index 000000000..533a86f4f --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt @@ -0,0 +1,30 @@ +package io.ksmt.solver.cvc5 + +import io.ksmt.KContext +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KForkingSolverManager +import java.util.concurrent.ConcurrentHashMap + +open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { + + private val solvers: MutableSet = ConcurrentHashMap.newKeySet() + + override fun mkForkingSolver(): KForkingSolver { + return KCvc5ForkingSolver(ctx, this, null).also { solvers += it } + } + + internal fun mkForkingSolver(parent: KCvc5ForkingSolver): KForkingSolver { + return KCvc5ForkingSolver(ctx, this, parent).also { solvers += it } + } + + /** + * unregister [solver] for this manager + */ + internal fun close(solver: KCvc5ForkingSolver) { + solvers -= solver + } + + override fun close() { + solvers.forEach(KCvc5ForkingSolver::close) + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt index 306706549..9ac8273ea 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt @@ -93,6 +93,10 @@ open class KCvc5Model( KFuncInterpVarsFree(decl = decl, entries = emptyList(), default = interp) } + /** + * In case of [KCvc5ForkingSolver.model] call, uninterpreted sort universe extends values of whole forking hierarchy + * @see KCvc5Context.getRegisteredSortValues + */ override fun uninterpretedSortUniverse(sort: KUninterpretedSort): Set? = getUninterpretedSortContext(sort).getSortUniverse() diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt index 1e92797fe..9fc53507f 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt @@ -1,230 +1,14 @@ package io.ksmt.solver.cvc5 -import io.github.cvc5.CVC5ApiException -import io.github.cvc5.Result -import io.github.cvc5.Solver import io.github.cvc5.Term import io.ksmt.KContext import io.ksmt.expr.KExpr -import io.ksmt.solver.KModel import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverException -import io.ksmt.solver.KSolverStatus import io.ksmt.sort.KBoolSort -import io.ksmt.utils.NativeLibraryLoader import java.util.TreeMap -import kotlin.time.Duration -import kotlin.time.DurationUnit -open class KCvc5Solver(private val ctx: KContext) : KSolver { - private val solver = Solver() - private val cvc5Ctx = KCvc5Context(solver, ctx) +open class KCvc5Solver(ctx: KContext) : KCvc5SolverBase(ctx), KSolver { - private val exprInternalizer by lazy { createExprInternalizer(cvc5Ctx) } - - private val currentScope: UInt - get() = cvc5TrackedAssertions.lastIndex.toUInt() - - private var lastCheckStatus = KSolverStatus.UNKNOWN - private var lastReasonOfUnknown: String? = null - private var lastModel: KCvc5Model? = null - // we need TreeMap here (hashcode not implemented in Term) - private var cvc5LastAssumptions: TreeMap>? = null - - private var cvc5CurrentLevelTrackedAssertions = TreeMap>() - private val cvc5TrackedAssertions = mutableListOf(cvc5CurrentLevelTrackedAssertions) - - init { - solver.setOption("produce-models", "true") - solver.setOption("produce-unsat-cores", "true") - /** - * Allow floating-point sorts of all sizes, rather than - * only Float32 (8/24) or Float64 (11/53) (experimental in cvc5 1.0.2) - */ - solver.setOption("fp-exp", "true") - } - - open fun createExprInternalizer(cvc5Ctx: KCvc5Context): KCvc5ExprInternalizer = KCvc5ExprInternalizer(cvc5Ctx) - - override fun configure(configurator: KCvc5SolverConfiguration.() -> Unit) { - KCvc5SolverConfigurationImpl(solver).configurator() - } - - override fun assert(expr: KExpr) = cvc5Try { - ctx.ensureContextMatch(expr) - - val cvc5Expr = with(exprInternalizer) { expr.internalizeExpr() } - solver.assertFormula(cvc5Expr) - - cvc5Ctx.assertPendingAxioms(solver) - } - - override fun assertAndTrack(expr: KExpr) { - ctx.ensureContextMatch(expr) - - val trackVarApp = ctx.mkFreshConst("track", ctx.boolSort) - val cvc5TrackVar = with(exprInternalizer) { trackVarApp.internalizeExpr() } - val trackedExpr = with(ctx) { trackVarApp implies expr } - - cvc5CurrentLevelTrackedAssertions[cvc5TrackVar] = expr - - assert(trackedExpr) - solver.assertFormula(cvc5TrackVar) - } - - override fun push() = solver.push().also { - cvc5CurrentLevelTrackedAssertions = TreeMap() - cvc5TrackedAssertions.add(cvc5CurrentLevelTrackedAssertions) - - cvc5Ctx.push() - } - - override fun pop(n: UInt) { - require(n <= currentScope) { - "Can not pop $n scope levels because current scope level is $currentScope" - } - - if (n == 0u) return - - repeat(n.toInt()) { - cvc5TrackedAssertions.removeLast() - } - cvc5CurrentLevelTrackedAssertions = cvc5TrackedAssertions.last() - - cvc5Ctx.pop(n) - solver.pop(n.toInt()) - } - - override fun check(timeout: Duration): KSolverStatus = cvc5TryCheck { - solver.updateTimeout(timeout) - solver.checkSat().processCheckResult() - } - - override fun checkWithAssumptions( - assumptions: List>, - timeout: Duration - ): KSolverStatus = cvc5TryCheck { - ctx.ensureContextMatch(assumptions) - - val lastAssumptions = TreeMap>().also { cvc5LastAssumptions = it } - val cvc5Assumptions = Array(assumptions.size) { idx -> - val assumedExpr = assumptions[idx] - with(exprInternalizer) { - assumedExpr.internalizeExpr().also { - lastAssumptions[it] = assumedExpr - } - } - } - - solver.updateTimeout(timeout) - solver.checkSatAssuming(cvc5Assumptions).processCheckResult() - } - - override fun reasonOfUnknown(): String = cvc5Try { - require(lastCheckStatus == KSolverStatus.UNKNOWN) { - "Unknown reason is only available after UNKNOWN checks" - } - lastReasonOfUnknown ?: "no explanation" - } - - override fun model(): KModel = cvc5Try { - require(lastCheckStatus == KSolverStatus.SAT) { "Models are only available after SAT checks" } - val model = lastModel ?: KCvc5Model( - ctx, - cvc5Ctx, - exprInternalizer, - cvc5Ctx.declarations().flatMapTo(hashSetOf()) { it }, - cvc5Ctx.uninterpretedSorts().flatMapTo(hashSetOf()) { it }, - ) - lastModel = model - - model - } - - override fun unsatCore(): List> = cvc5Try { - require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } - - val cvc5FullCore = solver.unsatCore - - val trackedTerms = TreeMap>() - cvc5TrackedAssertions.forEach { frame -> - trackedTerms.putAll(frame) - } - cvc5LastAssumptions?.also { trackedTerms.putAll(it) } - - cvc5FullCore.mapNotNull { trackedTerms[it] } - } - - override fun close() { - cvc5CurrentLevelTrackedAssertions.clear() - cvc5TrackedAssertions.clear() - - cvc5Ctx.close() - solver.close() - } - - /* - there are no methods to interrupt cvc5, - but maybe CVC5ApiRecoverableException can be thrown in someway - */ - override fun interrupt() = Unit - - private fun Result.processCheckResult() = when { - isSat -> KSolverStatus.SAT - isUnsat -> KSolverStatus.UNSAT - isUnknown || isNull -> KSolverStatus.UNKNOWN - else -> KSolverStatus.UNKNOWN - }.also { - lastCheckStatus = it - if (it == KSolverStatus.UNKNOWN) { - lastReasonOfUnknown = this.unknownExplanation?.name - } - } - - private fun Solver.updateTimeout(timeout: Duration) { - val cvc5Timeout = if (timeout == Duration.INFINITE) 0 else timeout.toInt(DurationUnit.MILLISECONDS) - setOption("tlimit-per", cvc5Timeout.toString()) - } - - private inline fun cvc5Try(body: () -> T): T = try { - body() - } catch (ex: CVC5ApiException) { - throw KSolverException(ex) - } - - private inline fun cvc5TryCheck(body: () -> KSolverStatus): KSolverStatus = try { - invalidateSolverState() - body() - } catch (ex: CVC5ApiException) { - lastReasonOfUnknown = ex.message - KSolverStatus.UNKNOWN.also { lastCheckStatus = it } - } - - private fun invalidateSolverState() { - /** - * Cvc5 model is only valid until the next check-sat call. - * */ - lastModel?.markInvalid() - lastModel = null - - lastCheckStatus = KSolverStatus.UNKNOWN - lastReasonOfUnknown = null - - cvc5LastAssumptions = null - } - - companion object { - init { - if (System.getProperty("cvc5.skipLibraryLoad") != "true") { - NativeLibraryLoader.load { os -> - when (os) { - NativeLibraryLoader.OS.LINUX -> listOf("libcvc5", "libcvc5jni") - NativeLibraryLoader.OS.WINDOWS -> listOf("libcvc5", "libcvc5jni") - NativeLibraryLoader.OS.MACOS -> listOf("libcvc5", "libcvc5jni") - } - } - System.setProperty("cvc5.skipLibraryLoad", "true") - } - } - } + override val cvc5Ctx: KCvc5Context = KCvc5Context(solver, ctx) + override val trackedAssertions: ScopedFrame>> = ScopedArrayFrame { TreeMap() } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt new file mode 100644 index 000000000..f141969fb --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt @@ -0,0 +1,225 @@ +package io.ksmt.solver.cvc5 + +import io.github.cvc5.CVC5ApiException +import io.github.cvc5.Result +import io.github.cvc5.Solver +import io.github.cvc5.Term +import io.ksmt.KContext +import io.ksmt.expr.KApp +import io.ksmt.expr.KExpr +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverException +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import io.ksmt.utils.NativeLibraryLoader +import java.util.TreeMap +import kotlin.time.Duration +import kotlin.time.DurationUnit + +abstract class KCvc5SolverBase internal constructor( + protected val ctx: KContext +) : KSolver { + + protected abstract val trackedAssertions: ScopedFrame>> + + protected open val currentScope: UInt + get() = trackedAssertions.currentScope + + protected val solver = Solver().apply { configureInitially() } + protected abstract val cvc5Ctx: KCvc5Context + protected val exprInternalizer by lazy { createExprInternalizer(cvc5Ctx) } + + protected var lastCheckStatus = KSolverStatus.UNKNOWN + protected var lastReasonOfUnknown: String? = null + protected var lastModel: KCvc5Model? = null + + // use TreeMap for cvc5 Term (hashcode not implemented) + protected var lastCvc5Assumptions: TreeMap>? = null + + open fun createExprInternalizer(cvc5Ctx: KCvc5Context): KCvc5ExprInternalizer = KCvc5ExprInternalizer(cvc5Ctx) + + private fun Solver.configureInitially() { + setOption("produce-models", "true") + setOption("produce-unsat-cores", "true") + /** + * Allow floating-point sorts of all sizes, rather than + * only Float32 (8/24) or Float64 (11/53) + */ + setOption("fp-exp", "true") + } + + override fun configure(configurator: KCvc5SolverConfiguration.() -> Unit) { + KCvc5SolverConfigurationImpl(solver).configurator() + } + + override fun assert(expr: KExpr) = cvc5Try { + ctx.ensureContextMatch(expr) + + val cvc5Expr = with(exprInternalizer) { expr.internalizeExpr() } + solver.assertFormula(cvc5Expr) + cvc5Ctx.assertPendingAxioms(solver) + } + + override fun assertAndTrack(expr: KExpr) = cvc5Try { + ctx.ensureContextMatch(expr) + + val trackVarApp = createTrackVarApp() + val cvc5TrackVar = with(exprInternalizer) { trackVarApp.internalizeExpr() } + val trackedExpr = with(ctx) { trackVarApp implies expr } + assert(trackedExpr) + solver.assertFormula(cvc5TrackVar) + trackedAssertions.currentFrame[cvc5TrackVar] = expr + } + + override fun push() = cvc5Try { + solver.push() + cvc5Ctx.push() + trackedAssertions.push() + } + + override fun pop(n: UInt) = cvc5Try { + require(n <= currentScope) { + "Can not pop $n scope levels because current scope level is $currentScope" + } + + if (n == 0u) return + solver.pop(n.toInt()) + cvc5Ctx.pop(n) + trackedAssertions.pop(n) + } + + override fun check(timeout: Duration): KSolverStatus = cvc5TryCheck { + solver.updateTimeout(timeout) + solver.checkSat().processCheckResult() + } + + override fun checkWithAssumptions( + assumptions: List>, + timeout: Duration + ): KSolverStatus = cvc5TryCheck { + ctx.ensureContextMatch(assumptions) + + val lastAssumptions = TreeMap>().also { lastCvc5Assumptions = it } + val cvc5Assumptions = Array(assumptions.size) { idx -> + val assumedExpr = assumptions[idx] + with(exprInternalizer) { + assumedExpr.internalizeExpr().also { + lastAssumptions[it] = assumedExpr + } + } + } + + solver.updateTimeout(timeout) + solver.checkSatAssuming(cvc5Assumptions).processCheckResult() + } + + protected open fun freshModel(): KCvc5Model = KCvc5Model( + ctx, + cvc5Ctx, + exprInternalizer, + cvc5Ctx.declarations(), + cvc5Ctx.uninterpretedSorts(), + ) + + override fun model(): KModel = cvc5Try { + require(lastCheckStatus == KSolverStatus.SAT) { "Models are only available after SAT checks" } + val model = lastModel ?: freshModel() + model.also { lastModel = it } + } + + override fun reasonOfUnknown(): String = cvc5Try { + require(lastCheckStatus == KSolverStatus.UNKNOWN) { + "Unknown reason is only available after UNKNOWN checks" + } + lastReasonOfUnknown ?: "no explanation" + } + + override fun unsatCore(): List> { + val cvc5FullCore = cvc5UnsatCore() + + val unsatCore = mutableListOf>() + + cvc5FullCore.forEach { unsatCoreTerm -> + lastCvc5Assumptions?.get(unsatCoreTerm)?.also { unsatCore += it } + ?: trackedAssertions.find { trackedAssertion -> + trackedAssertion[unsatCoreTerm]?.also { unsatCore += it } != null + } + } + return unsatCore + } + + protected fun cvc5UnsatCore(): Array = cvc5Try { + require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } + solver.unsatCore + } + + override fun close() { + cvc5Ctx.close() + solver.close() + } + + /* + there is no method to interrupt cvc5 + */ + override fun interrupt() = Unit + + protected fun createTrackVarApp(): KApp = ctx.mkFreshConst("track", ctx.boolSort) + + protected fun Result.processCheckResult() = when { + isSat -> KSolverStatus.SAT + isUnsat -> KSolverStatus.UNSAT + isUnknown || isNull -> KSolverStatus.UNKNOWN + else -> KSolverStatus.UNKNOWN + }.also { + lastCheckStatus = it + if (it == KSolverStatus.UNKNOWN) { + lastReasonOfUnknown = this.unknownExplanation?.name + } + } + + protected fun Solver.updateTimeout(timeout: Duration) { + val cvc5Timeout = if (timeout == Duration.INFINITE) 0 else timeout.toInt(DurationUnit.MILLISECONDS) + setOption("tlimit-per", cvc5Timeout.toString()) + } + + protected inline fun cvc5Try(body: () -> T): T = try { + body() + } catch (ex: CVC5ApiException) { + throw KSolverException(ex) + } + + protected inline fun cvc5TryCheck(body: () -> KSolverStatus): KSolverStatus = try { + invalidateSolverState() + body() + } catch (ex: CVC5ApiException) { + lastReasonOfUnknown = ex.message + KSolverStatus.UNKNOWN.also { lastCheckStatus = it } + } + + protected fun invalidateSolverState() { + /** + * Cvc5 model is only valid until the next check-sat call. + * */ + lastModel?.markInvalid() + lastModel = null + lastCheckStatus = KSolverStatus.UNKNOWN + lastReasonOfUnknown = null + lastCvc5Assumptions = null + } + + companion object { + init { + if (System.getProperty("cvc5.skipLibraryLoad") != "true") { + NativeLibraryLoader.load { os -> + when (os) { + NativeLibraryLoader.OS.LINUX -> listOf("libcvc5", "libcvc5jni") + NativeLibraryLoader.OS.WINDOWS -> listOf("libcvc5", "libcvc5jni") + NativeLibraryLoader.OS.MACOS -> listOf("libcvc5", "libcvc5jni") + } + } + System.setProperty("cvc5.skipLibraryLoad", "true") + } + } + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverConfiguration.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverConfiguration.kt index 7ef58b0cd..ec743a154 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverConfiguration.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverConfiguration.kt @@ -45,6 +45,29 @@ class KCvc5SolverConfigurationImpl(val solver: Solver) : KCvc5SolverConfiguratio } } +class KCvc5ForkingSolverConfigurationImpl(val solver: Solver) : KCvc5SolverConfiguration { + private val options = hashMapOf() + private lateinit var logic: String + override fun setCvc5Option(option: String, value: String) { + solver.setOption(option, value) + options[option] = value + } + + override fun setCvc5Logic(value: String) { + solver.setLogic(value) + logic = value + } + + fun fork(solver: Solver): KCvc5ForkingSolverConfigurationImpl = KCvc5ForkingSolverConfigurationImpl(solver).also { + if (::logic.isInitialized) { + it.setCvc5Logic(logic) + } + options.forEach { (option, value) -> + it.setCvc5Option(option, value) + } + } +} + class KCvc5SolverUniversalConfiguration( private val builder: KSolverUniversalConfigurationBuilder ) : KCvc5SolverConfiguration { diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt new file mode 100644 index 000000000..b6039d305 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt @@ -0,0 +1,121 @@ +package io.ksmt.solver.cvc5 + +interface ScopedFrame { + val currentScope: UInt + val currentFrame: T + + fun flatten(collect: T.(T) -> Unit): T + fun find(predicate: (T) -> Boolean): T? + + fun push() + fun pop(n: UInt = 1u) +} + +internal class ScopedArrayFrame( + currentFrame: T, + private val createNewFrame: () -> T +) : ScopedFrame { + constructor(createNewFrame: () -> T) : this(createNewFrame(), createNewFrame) + + private val frames = arrayListOf(currentFrame) + + override var currentFrame = currentFrame + private set + + override val currentScope: UInt + get() = frames.size.toUInt() + + override fun flatten(collect: T.(T) -> Unit) = createNewFrame().also { newFrame -> + frames.forEach { newFrame.collect(it) } + } + + override fun find(predicate: (T) -> Boolean) = frames.find(predicate) + + override fun push() { + currentFrame = createNewFrame() + frames += currentFrame + } + + override fun pop(n: UInt) { + repeat(n.toInt()) { frames.removeLast() } + currentFrame = frames.last() + } +} + +internal class ScopedLinkedFrame private constructor( + private var current: LinkedFrame, + private val createNewFrame: () -> T, + private val copyFrame: (T) -> T +) : ScopedFrame { + constructor( + currentFrame: T, + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(LinkedFrame(currentFrame), createNewFrame, copyFrame) + + constructor( + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(createNewFrame(), createNewFrame, copyFrame) + + override val currentFrame: T + get() = current.value + + override val currentScope: UInt + get() = current.scope + + override fun flatten(collect: T.(T) -> Unit): T = createNewFrame().also { newFrame -> + forEachReversed { frame -> + newFrame.collect(frame) + } + } + + fun stacked(): ArrayDeque = ArrayDeque().also { stack -> + forEachReversed { frame -> + stack.addLast(frame) + } + } + + override fun find(predicate: (T) -> Boolean): T? { + forEachReversed { frame -> + if (predicate(frame)) return frame + } + return null + } + + override fun push() { + current = LinkedFrame(createNewFrame(), current) + } + + override fun pop(n: UInt) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + recreateTopFrame() + } + + private fun recreateTopFrame() { + val newTopFrame = copyFrame(currentFrame) + current = LinkedFrame(newTopFrame, current.previous) + } + + fun fork(): ScopedLinkedFrame = ScopedLinkedFrame( + current, + createNewFrame, + copyFrame + ).also { it.recreateTopFrame() } + + private inline fun forEachReversed(action: (T) -> Unit) { + var cur: LinkedFrame? = current + while (cur != null) { + action(cur.value) + cur = cur.previous + } + } + + private class LinkedFrame( + val value: E, + val previous: LinkedFrame? = null + ) { + val scope: UInt = previous?.scope?.plus(1u) ?: 0u + } + +} diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt new file mode 100644 index 000000000..a5cfec1f7 --- /dev/null +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt @@ -0,0 +1,278 @@ +package io.ksmt.test + +import io.ksmt.KContext +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.solver.cvc5.KCvc5ForkingSolverManager +import io.ksmt.utils.getValue +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertTrue +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test + +class KForkingSolverTest { + @Nested + inner class KForkingSolverTestCvc5 { + @Test + fun testCheckSat() = testCheckSat(::mkCvc5ForkingSolver) + + @Test + fun testModel() = testModel(::mkCvc5ForkingSolver) + + @Test + fun testUnsatCore() = testUnsatCore(::mkCvc5ForkingSolver) + + @Test + fun testUninterpretedSort() = testUninterpretedSort(::mkCvc5ForkingSolver) + + @Test + fun testScopedAssertions() = testScopedAssertions(::mkCvc5ForkingSolver) + + private fun mkCvc5ForkingSolver(ctx: KContext) = KCvc5ForkingSolverManager(ctx).mkForkingSolver() + } + + private fun testCheckSat(mkSolver: (KContext) -> KForkingSolver<*>) = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + mkSolver(ctx).use { parentSolver -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and b + val neg = !a + + parentSolver.push() + + // * check children's assertions do not change parent's state + parentSolver.assert(f) + require(parentSolver.check() == KSolverStatus.SAT) + require(parentSolver.checkWithAssumptions(listOf(neg)) == KSolverStatus.UNSAT) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork.check()) + } + + assertEquals(KSolverStatus.SAT, parentSolver.check()) + // * + + // * check parent's assertions translated into child solver + parentSolver.push() + assertEquals(KSolverStatus.UNSAT, parentSolver.fork().checkWithAssumptions(listOf(neg))) + parentSolver.assert(neg) + require(parentSolver.check() == KSolverStatus.UNSAT) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.UNSAT, fork.check()) + } + parentSolver.pop() + // * + + // * check children independence + assertEquals(KSolverStatus.SAT, parentSolver.check()) + parentSolver.fork().also { fork1 -> + val fork2 = parentSolver.fork() + fork2.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork2.check()) + assertEquals(KSolverStatus.SAT, fork1.check()) + + fork1.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork1.check()) + assertEquals(KSolverStatus.SAT, parentSolver.fork().check()) + } + assertEquals(KSolverStatus.SAT, parentSolver.check()) + // * + } + + } + } + + private fun testUnsatCore(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + mkSolver(ctx).use { parentSolver -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and b + val neg = !a + + // * check that unsat core is empty (non-tracked assertions) + parentSolver.push() + parentSolver.assert(f) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertTrue { fork.unsatCore().isEmpty() } + assertEquals(KSolverStatus.SAT, parentSolver.check()) // parent's state hasn't changed + } + parentSolver.pop() + // * + + // check tracked exprs are in unsat core + parentSolver.push() + parentSolver.assertAndTrack(f) + + parentSolver.fork().also { fork -> + fork.assertAndTrack(neg) + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertContains(fork.unsatCore(), neg) + assertContains(fork.unsatCore(), f) + assertEquals(KSolverStatus.SAT, parentSolver.check()) // parent's state hasn't changed + } + // * + + // * check unsat core saves from parent to child + parentSolver.assert(neg) + require(parentSolver.check() == KSolverStatus.UNSAT) + require(neg !in parentSolver.unsatCore()) + require(f in parentSolver.unsatCore()) // only tracked f is in unsat core + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertContains(fork.unsatCore(), f) + assertTrue { neg !in fork.unsatCore() } + } + } + } + } + + private fun testModel(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + mkSolver(ctx).use { parentSolver -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and !b + + parentSolver.assert(f) + + require(parentSolver.check() == KSolverStatus.SAT) + require(parentSolver.model().eval(a) == true.expr) + require(parentSolver.model().eval(b) == false.expr) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(true.expr, fork.model().eval(a)) + assertEquals(false.expr, fork.model().eval(b)) + } + } + } + } + + private fun testScopedAssertions(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + mkSolver(ctx).use { parent -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and b + val neg = !a + + parent.push() + + parent.assertAndTrack(f) + require(parent.check() == KSolverStatus.SAT) + parent.push() + parent.assertAndTrack(neg) + + require(parent.check() == KSolverStatus.UNSAT) + + parent.fork().also { fork -> + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertContains(fork.unsatCore(), f) + assertContains(fork.unsatCore(), neg) + + fork.pop() + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(true.expr, fork.model().eval(a)) + assertEquals(true.expr, fork.model().eval(b)) + assertEquals(KSolverStatus.UNSAT, fork.checkWithAssumptions(listOf(neg))) + assertEquals(KSolverStatus.UNSAT, parent.check()) // check parent's state hasn't changed + + fork.fork().also { ffork -> + assertEquals(KSolverStatus.SAT, ffork.check()) + assertEquals(KSolverStatus.UNSAT, ffork.checkWithAssumptions(listOf(neg))) + + ffork.push() + ffork.assertAndTrack(neg) + assertEquals(KSolverStatus.UNSAT, ffork.check()) + assertContains(ffork.unsatCore(), f) + assertContains(ffork.unsatCore(), neg) + + assertEquals(KSolverStatus.SAT, fork.check()) // check parent's state hasn't changed + assertEquals(KSolverStatus.UNSAT, parent.check()) // check parent's state hasn't changed + + ffork.pop() + assertEquals(KSolverStatus.SAT, ffork.check()) + assertEquals(KSolverStatus.UNSAT, ffork.checkWithAssumptions(listOf(neg))) + } + } + + // check child's state is detached + val fork = parent.fork() + assertEquals(KSolverStatus.UNSAT, fork.check()) + parent.pop() + + assertEquals(KSolverStatus.SAT, parent.check()) + assertEquals(KSolverStatus.UNSAT, fork.check()) + + parent.pop() + + fork.pop() + fork.pop() + + fork.assert(neg) + assertEquals(KSolverStatus.SAT, fork.check()) + } + } + } + + private fun testUninterpretedSort(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + mkSolver(ctx).use { parentSolver -> + with(ctx) { + val uSort = mkUninterpretedSort("u") + val u1 by uSort + val u2 by uSort + + val eq12 = u1 eq u2 + + parentSolver.push() + parentSolver.assert(eq12) + + require(parentSolver.check() == KSolverStatus.SAT) + val pu1v = parentSolver.model().eval(u1) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(u1 eq pu1v) + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(pu1v, fork.model().eval(u1)) + } + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(u1 eq pu1v) + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(pu1v, fork.model().eval(u1)) + + fork.fork().also { ff -> + assertEquals(KSolverStatus.SAT, ff.check()) + assertEquals(pu1v, ff.model().eval(u1)) + ff.model().uninterpretedSortUniverse(uSort)?.also { universe -> + assertContains(universe, pu1v) + } + } + } + + parentSolver.model().uninterpretedSortUniverse(uSort)?.also { universe -> + assertContains(universe, pu1v) + } + + } + } + } +} From 10e45022f426f4e4e9597966cf4fb669067311f2 Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Wed, 26 Jul 2023 17:16:20 +0300 Subject: [PATCH 02/12] cvc5 expressions lifetime fix via native mkExprSolver; added lifetime test --- .../io/ksmt/solver/cvc5/KCvc5Context.kt | 38 +++++++------ .../ksmt/solver/cvc5/KCvc5DeclInternalizer.kt | 4 +- .../io/ksmt/solver/cvc5/KCvc5ExprConverter.kt | 25 ++++++++- .../ksmt/solver/cvc5/KCvc5ExprInternalizer.kt | 6 +-- .../io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt | 53 ++++++++++--------- .../solver/cvc5/KCvc5ForkingSolverManager.kt | 46 +++++++++++++++- .../kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt | 8 ++- .../kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt | 22 +++++++- .../io/ksmt/solver/cvc5/KCvc5SolverBase.kt | 22 ++++---- .../ksmt/solver/cvc5/KCvc5SortInternalizer.kt | 2 +- .../kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt | 19 +++++-- .../kotlin/io/ksmt/test/KForkingSolverTest.kt | 25 +++++++++ 12 files changed, 194 insertions(+), 76 deletions(-) diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt index 91069e915..0223042a1 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt @@ -34,14 +34,18 @@ import java.util.TreeMap class KCvc5Context private constructor( private val solver: Solver, + val mkExprSolver: Solver, private val ctx: KContext, parent: KCvc5Context?, - isForking: Boolean + val isForking: Boolean ) : AutoCloseable { - constructor(solver: Solver, ctx: KContext, isForking: Boolean = false) : this(solver, ctx, null, isForking) + constructor(solver: Solver, mkExprSolver: Solver, ctx: KContext, isForking: Boolean = false) + : this(solver, mkExprSolver, ctx, null, isForking) + + constructor(solver: Solver, ctx: KContext, isForking: Boolean = false) + : this(solver, solver, ctx, null, isForking) private var isClosed = false - private val isChild = parent != null private val uninterpretedSortCollector = KUninterpretedSortCollector(this) private var exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(uninterpretedSortCollector, ctx) @@ -71,7 +75,12 @@ class KCvc5Context private constructor( private val currentAccumulatedScopeExpressions: HashMap, Term> private val expressions: HashMap, Term> - // we can't use HashMap with Term and Sort (hashcode is not implemented) + /** + * We can't use HashMap with Term and Sort (hashcode is not implemented) + * + * Avoid to close cache explicitly due to its sharing between forking hierarchy. + * It will be garbage collected on last solver close in forking hierarchy + */ private val cvc5Expressions: TreeMap> private val sorts: HashMap private val cvc5Sorts: TreeMap @@ -146,11 +155,12 @@ class KCvc5Context private constructor( val isActive: Boolean get() = !isClosed - fun fork(solver: Solver): KCvc5Context = KCvc5Context(solver, ctx, this, true).also { forkCtx -> - repeat(assertedConstraintLevels.size) { - forkCtx.pushAssertionLevel() + fun fork(solver: Solver, mkExprSolver: Solver): KCvc5Context = + KCvc5Context(solver, mkExprSolver, ctx, this, true).also { forkCtx -> + repeat(assertedConstraintLevels.size) { + forkCtx.pushAssertionLevel() + } } - } fun push() { declarations.push() @@ -399,23 +409,11 @@ class KCvc5Context private constructor( return converted } - override fun close() { if (isClosed) return isClosed = true currentAccumulatedScopeExpressions.clear() - - if (isChild) { - expressions.clear() - cvc5Expressions.clear() - - sorts.clear() - cvc5Sorts.clear() - - decls.clear() - cvc5Decls.clear() - } } class KUninterpretedSortCollector(private val cvc5Ctx: KCvc5Context) : KSortVisitor { diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5DeclInternalizer.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5DeclInternalizer.kt index ba9c03ce2..2cb1fefdc 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5DeclInternalizer.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5DeclInternalizer.kt @@ -18,7 +18,7 @@ open class KCvc5DeclInternalizer( val domainSorts = decl.argSorts.map { it.accept(sortInternalizer) } val rangeSort = decl.sort.accept(sortInternalizer) - cvc5Ctx.nativeSolver.declareFun( + cvc5Ctx.mkExprSolver.declareFun( decl.name, domainSorts.toTypedArray(), rangeSort @@ -29,6 +29,6 @@ open class KCvc5DeclInternalizer( cvc5Ctx.addDeclaration(decl) val sort = decl.sort.accept(sortInternalizer) - cvc5Ctx.nativeSolver.mkConst(sort, decl.name) + cvc5Ctx.mkExprSolver.mkConst(sort, decl.name) } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprConverter.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprConverter.kt index e7f69e186..b79d970eb 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprConverter.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprConverter.kt @@ -24,9 +24,10 @@ import io.ksmt.expr.KExpr import io.ksmt.expr.KFpRoundingMode import io.ksmt.expr.KInterpretedValue import io.ksmt.solver.KSolverUnsupportedFeatureException -import io.ksmt.solver.util.KExprConverterBase import io.ksmt.solver.util.ExprConversionResult +import io.ksmt.solver.util.KExprConverterBase import io.ksmt.solver.util.KExprConverterUtils.argumentsConversionRequired +import io.ksmt.solver.util.conversionLoop import io.ksmt.sort.KArithSort import io.ksmt.sort.KArray2Sort import io.ksmt.sort.KArray3Sort @@ -42,7 +43,7 @@ import io.ksmt.sort.KRealSort import io.ksmt.sort.KSort import io.ksmt.utils.asExpr import io.ksmt.utils.uncheckedCast -import java.util.* +import java.util.TreeMap @Suppress("LargeClass") open class KCvc5ExprConverter( @@ -734,6 +735,26 @@ open class KCvc5ExprConverter( } } + fun Term.convertExprWithMkExprSolver(): KExpr = convertExprWithoutCacheSave().also { + with(internalizer) { + it.internalizeExpr() + } + } + + fun Term.convertExprWithoutCacheSave(): KExpr { + val wasteCache = TreeMap>() + return conversionLoop( + stack = exprStack, + native = this, + stackPush = { stack, element -> stack.add(element) }, + stackPop = { stack -> stack.removeLast() }, + stackIsNotEmpty = { stack -> stack.isNotEmpty() }, + convertNative = { expr -> convertNativeExpr(expr) }, + findConverted = { expr -> wasteCache[expr] ?: findConvertedNative(expr) }, + saveConverted = { expr, converted -> wasteCache[expr] = converted } + ) + } + fun Term.convertExpr(): KExpr = convertFromNative() @Suppress("UNCHECKED_CAST") diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprInternalizer.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprInternalizer.kt index 7a369f3dd..c5fa7837b 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprInternalizer.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprInternalizer.kt @@ -160,11 +160,11 @@ import io.ksmt.expr.KXorExpr import io.ksmt.expr.rewrite.simplify.rewriteBvAddNoOverflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvAddNoUnderflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvDivNoOverflowExpr +import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoOverflowExpr +import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoUnderflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvNegNoOverflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvSubNoOverflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvSubNoUnderflowExpr -import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoOverflowExpr -import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoUnderflowExpr import io.ksmt.expr.rewrite.simplify.simplifyBvRotateLeftExpr import io.ksmt.expr.rewrite.simplify.simplifyBvRotateRightExpr import io.ksmt.solver.KSolverUnsupportedFeatureException @@ -203,7 +203,7 @@ class KCvc5ExprInternalizer( } private val nsolver: Solver - get() = cvc5Ctx.nativeSolver + get() = cvc5Ctx.mkExprSolver private val zeroIntValueTerm: Term by lazy { nsolver.mkInteger(0L) } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt index cd24c385a..3e8f98dd3 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt @@ -1,5 +1,6 @@ package io.ksmt.solver.cvc5 +import io.github.cvc5.Solver import io.github.cvc5.Term import io.ksmt.KContext import io.ksmt.expr.KExpr @@ -14,28 +15,29 @@ import kotlin.time.Duration open class KCvc5ForkingSolver internal constructor( ctx: KContext, private val manager: KCvc5ForkingSolverManager, - parent: KCvc5ForkingSolver? + /** store reference on Solver to separate lifetime of native expressions */ + private val mkExprSolver: Solver, + parent: KCvc5ForkingSolver? = null ) : KCvc5SolverBase(ctx), KForkingSolver, KSolver { final override val cvc5Ctx: KCvc5Context private val isChild = parent != null private var assertionsInitiated = !isChild - private val _trackedAssertions: ScopedLinkedFrame>> - - override val trackedAssertions: ScopedFrame>> - get() = _trackedAssertions - + private val trackedAssertions: ScopedLinkedFrame>> private val cvc5Assertions: ScopedLinkedFrame> + override val currentScope: UInt + get() = trackedAssertions.currentScope + init { if (parent != null) { - cvc5Ctx = parent.cvc5Ctx.fork(solver) - _trackedAssertions = parent._trackedAssertions.fork() + cvc5Ctx = parent.cvc5Ctx.fork(solver, this.mkExprSolver) + trackedAssertions = parent.trackedAssertions.fork() cvc5Assertions = parent.cvc5Assertions.fork() } else { - cvc5Ctx = KCvc5Context(solver, ctx, true) - _trackedAssertions = ScopedLinkedFrame(::TreeMap, ::TreeMap) + cvc5Ctx = KCvc5Context(solver, this.mkExprSolver, ctx, true) + trackedAssertions = ScopedLinkedFrame(::TreeMap, ::TreeMap) cvc5Assertions = ScopedLinkedFrame(::TreeSet, ::TreeSet) } } @@ -44,6 +46,10 @@ open class KCvc5ForkingSolver internal constructor( parent?.config?.fork(solver) ?: KCvc5ForkingSolverConfigurationImpl(solver) } + init { + if (isChild) config // initialize child config + } + override fun configure(configurator: KCvc5SolverConfiguration.() -> Unit) { config.configurator() } @@ -54,7 +60,7 @@ open class KCvc5ForkingSolver internal constructor( if (assertionsInitiated) return cvc5Assertions.stacked() - .zip(_trackedAssertions.stacked()) + .zip(trackedAssertions.stacked()) .asReversed() .forEachIndexed { scope, (cvc5AssertionFrame, trackedFrame) -> if (scope > 0) solver.push() @@ -66,6 +72,12 @@ open class KCvc5ForkingSolver internal constructor( assertionsInitiated = true } + override fun saveTrackedAssertion(track: Term, trackedExpr: KExpr) { + trackedAssertions.currentFrame[track] = trackedExpr + } + + override fun findTrackedExprByTrack(track: Term): KExpr? = trackedAssertions.find { it[track] } + override fun assert(expr: KExpr): Unit = cvc5Try { ctx.ensureContextMatch(expr) ensureAssertionsInitiated() @@ -84,12 +96,14 @@ open class KCvc5ForkingSolver internal constructor( override fun push() { cvc5Try { ensureAssertionsInitiated() } super.push() + trackedAssertions.push() cvc5Assertions.push() } override fun pop(n: UInt) { cvc5Try { ensureAssertionsInitiated() } super.pop(n) + trackedAssertions.pop(n) cvc5Assertions.pop(n) } @@ -105,22 +119,9 @@ open class KCvc5ForkingSolver internal constructor( return super.checkWithAssumptions(assumptions, timeout) } - override fun unsatCore(): List> { - val cvc5FullCore = cvc5UnsatCore() - - val unsatCore = mutableListOf>() - - cvc5FullCore.forEach { unsatCoreTerm -> - lastCvc5Assumptions?.get(unsatCoreTerm)?.also { unsatCore += it } - ?: trackedAssertions.find { trackedAssertion -> - trackedAssertion[unsatCoreTerm]?.let { unsatCore += it; true } ?: false - } - } - return unsatCore - } - override fun close() { manager.close(this) - super.close() + solver.close() + cvc5Ctx.close() } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt index 533a86f4f..fc28f02e0 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt @@ -1,20 +1,39 @@ package io.ksmt.solver.cvc5 +import io.github.cvc5.Solver import io.ksmt.KContext import io.ksmt.solver.KForkingSolver import io.ksmt.solver.KForkingSolverManager +import java.util.IdentityHashMap import java.util.concurrent.ConcurrentHashMap open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { private val solvers: MutableSet = ConcurrentHashMap.newKeySet() + /** + * for each parent to child hierarchy created only one mkExprSolver, + * which is responsible for native expressions lifetime + */ + private val forkingSolverToMkExprSolver = IdentityHashMap() + private val mkExprSolverReferences = IdentityHashMap() + override fun mkForkingSolver(): KForkingSolver { - return KCvc5ForkingSolver(ctx, this, null).also { solvers += it } + val mkExprSolver = Solver() + incRef(mkExprSolver) + return KCvc5ForkingSolver(ctx, this, mkExprSolver, null).also { + solvers += it + forkingSolverToMkExprSolver[it] = mkExprSolver + } } internal fun mkForkingSolver(parent: KCvc5ForkingSolver): KForkingSolver { - return KCvc5ForkingSolver(ctx, this, parent).also { solvers += it } + val mkExprSolver = forkingSolverToMkExprSolver.getValue(parent) + incRef(mkExprSolver) + return KCvc5ForkingSolver(ctx, this, mkExprSolver, parent).also { + solvers += it + forkingSolverToMkExprSolver[it] = mkExprSolver + } } /** @@ -22,9 +41,32 @@ open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolver */ internal fun close(solver: KCvc5ForkingSolver) { solvers -= solver + val mkExprSolver = forkingSolverToMkExprSolver.getValue(solver) + forkingSolverToMkExprSolver -= solver + decRef(mkExprSolver) } override fun close() { solvers.forEach(KCvc5ForkingSolver::close) } + + private fun incRef(mkExprSolver: Solver) { + mkExprSolverReferences[mkExprSolver] = mkExprSolverReferences.getOrDefault(mkExprSolver, 0) + 1 + } + + private fun decRef(mkExprSolver: Solver) { + val referencesAfterDec = mkExprSolverReferences.getValue(mkExprSolver) - 1 + if (referencesAfterDec == 0) { + mkExprSolverReferences -= mkExprSolver + mkExprSolver.close() + } else { + mkExprSolverReferences[mkExprSolver] = referencesAfterDec + } + } + + companion object { + init { + KCvc5SolverBase.ensureCvc5LibLoaded() + } + } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt index 9ac8273ea..2c2c28d5c 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt @@ -81,14 +81,18 @@ open class KCvc5Model( val cvc5InterpArgs = cvc5Interp.getChild(0).getChildren() val cvc5FreshVarsInterp = cvc5Interp.substitute(cvc5InterpArgs, cvc5Vars) - val defaultBody = cvc5FreshVarsInterp.getChild(1).convertExpr() + // in case of forking solver, save in cache mkExprSolver's terms + val defaultBody = cvc5FreshVarsInterp.getChild(1).let { + if (cvc5Ctx.isForking) it.convertExprWithMkExprSolver() else it.convertExpr() + } KFuncInterpWithVars(decl, vars.map { it.decl }, emptyList(), defaultBody) } private fun constInterp(decl: KDecl, const: Term): KFuncInterp = with(converter) { val cvc5Interp = cvc5Ctx.nativeSolver.getValue(const) - val interp = cvc5Interp.convertExpr() + // in case of forking solver, save in cache mkExprSolver's terms + val interp = if (cvc5Ctx.isForking) cvc5Interp.convertExprWithMkExprSolver() else cvc5Interp.convertExpr() KFuncInterpVarsFree(decl = decl, entries = emptyList(), default = interp) } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt index 9fc53507f..c2d64c3e9 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt @@ -8,7 +8,25 @@ import io.ksmt.sort.KBoolSort import java.util.TreeMap open class KCvc5Solver(ctx: KContext) : KCvc5SolverBase(ctx), KSolver { - override val cvc5Ctx: KCvc5Context = KCvc5Context(solver, ctx) - override val trackedAssertions: ScopedFrame>> = ScopedArrayFrame { TreeMap() } + private val trackedAssertions = ScopedArrayFrame>> { TreeMap() } + + override val currentScope: UInt + get() = trackedAssertions.currentScope + + override fun saveTrackedAssertion(track: Term, trackedExpr: KExpr) { + trackedAssertions.currentFrame[track] = trackedExpr + } + + override fun findTrackedExprByTrack(track: Term): KExpr? = trackedAssertions.find { it[track] } + + override fun push() { + super.push() + trackedAssertions.push() + } + + override fun pop(n: UInt) { + super.pop(n) + trackedAssertions.pop(n) + } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt index f141969fb..7a8ce5639 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt @@ -21,10 +21,7 @@ abstract class KCvc5SolverBase internal constructor( protected val ctx: KContext ) : KSolver { - protected abstract val trackedAssertions: ScopedFrame>> - - protected open val currentScope: UInt - get() = trackedAssertions.currentScope + protected abstract val currentScope: UInt protected val solver = Solver().apply { configureInitially() } protected abstract val cvc5Ctx: KCvc5Context @@ -61,6 +58,9 @@ abstract class KCvc5SolverBase internal constructor( cvc5Ctx.assertPendingAxioms(solver) } + protected abstract fun saveTrackedAssertion(track: Term, trackedExpr: KExpr) + protected abstract fun findTrackedExprByTrack(track: Term): KExpr? + override fun assertAndTrack(expr: KExpr) = cvc5Try { ctx.ensureContextMatch(expr) @@ -69,13 +69,12 @@ abstract class KCvc5SolverBase internal constructor( val trackedExpr = with(ctx) { trackVarApp implies expr } assert(trackedExpr) solver.assertFormula(cvc5TrackVar) - trackedAssertions.currentFrame[cvc5TrackVar] = expr + saveTrackedAssertion(cvc5TrackVar, expr) } override fun push() = cvc5Try { solver.push() cvc5Ctx.push() - trackedAssertions.push() } override fun pop(n: UInt) = cvc5Try { @@ -86,7 +85,6 @@ abstract class KCvc5SolverBase internal constructor( if (n == 0u) return solver.pop(n.toInt()) cvc5Ctx.pop(n) - trackedAssertions.pop(n) } override fun check(timeout: Duration): KSolverStatus = cvc5TryCheck { @@ -142,9 +140,7 @@ abstract class KCvc5SolverBase internal constructor( cvc5FullCore.forEach { unsatCoreTerm -> lastCvc5Assumptions?.get(unsatCoreTerm)?.also { unsatCore += it } - ?: trackedAssertions.find { trackedAssertion -> - trackedAssertion[unsatCoreTerm]?.also { unsatCore += it } != null - } + ?: findTrackedExprByTrack(unsatCoreTerm)?.also { unsatCore += it } } return unsatCore } @@ -209,7 +205,7 @@ abstract class KCvc5SolverBase internal constructor( } companion object { - init { + internal fun ensureCvc5LibLoaded() { if (System.getProperty("cvc5.skipLibraryLoad") != "true") { NativeLibraryLoader.load { os -> when (os) { @@ -221,5 +217,9 @@ abstract class KCvc5SolverBase internal constructor( System.setProperty("cvc5.skipLibraryLoad", "true") } } + + init { + ensureCvc5LibLoaded() + } } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SortInternalizer.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SortInternalizer.kt index 17bfd918b..48153212b 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SortInternalizer.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SortInternalizer.kt @@ -20,7 +20,7 @@ import io.ksmt.sort.KUninterpretedSort open class KCvc5SortInternalizer( private val cvc5Ctx: KCvc5Context ) : KSortVisitor { - private val nSolver: Solver = cvc5Ctx.nativeSolver + private val nSolver: Solver = cvc5Ctx.mkExprSolver override fun visit(sort: KBoolSort): Sort = cvc5Ctx.internalizeSort(sort) { nSolver.booleanSort diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt index b6039d305..147f60c96 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt @@ -1,11 +1,15 @@ package io.ksmt.solver.cvc5 -interface ScopedFrame { +internal interface ScopedFrame { val currentScope: UInt val currentFrame: T fun flatten(collect: T.(T) -> Unit): T - fun find(predicate: (T) -> Boolean): T? + + /** + * find value [V] in frame [T], and return it or null + */ + fun find(predicate: (T) -> V?): V? fun push() fun pop(n: UInt = 1u) @@ -29,7 +33,12 @@ internal class ScopedArrayFrame( frames.forEach { newFrame.collect(it) } } - override fun find(predicate: (T) -> Boolean) = frames.find(predicate) + override fun find(predicate: (T) -> V?): V? { + frames.forEach { frame -> + predicate(frame)?.let { return it } + } + return null + } override fun push() { currentFrame = createNewFrame() @@ -76,9 +85,9 @@ internal class ScopedLinkedFrame private constructor( } } - override fun find(predicate: (T) -> Boolean): T? { + override fun find(predicate: (T) -> V?): V? { forEachReversed { frame -> - if (predicate(frame)) return frame + predicate(frame)?.let { return it } } return null } diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt index a5cfec1f7..fd59fa00f 100644 --- a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt @@ -29,6 +29,9 @@ class KForkingSolverTest { @Test fun testScopedAssertions() = testScopedAssertions(::mkCvc5ForkingSolver) + @Test + fun testLifeTime() = testLifeTime(::mkCvc5ForkingSolver) + private fun mkCvc5ForkingSolver(ctx: KContext) = KCvc5ForkingSolverManager(ctx).mkForkingSolver() } @@ -275,4 +278,26 @@ class KForkingSolverTest { } } } + + fun testLifeTime(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + with(ctx) { + val parent = mkSolver(ctx) + val x by intSort + val f = x ge 100.expr + + parent.assert(f) + parent.check().also { require(it == KSolverStatus.SAT) } + + val xVal = parent.model().eval(x) + + val fork = parent.fork().fork().fork() + parent.close() + + fork.assert(f and (x eq xVal)) + fork.check().also { assertEquals(KSolverStatus.SAT, it) } + assertEquals(fork.model().eval(x), xVal) + } + + } } From c0e4be554af9e1889dd8d619fb563fa44bf033b8 Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Thu, 27 Jul 2023 18:41:59 +0300 Subject: [PATCH 03/12] cvc5: Extracted global cache for forking context and delegated to KCvc5ForkingSolverManager --- .../io/ksmt/solver/cvc5/KCvc5Context.kt | 77 ++++++++++------- .../io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt | 16 ++-- .../solver/cvc5/KCvc5ForkingSolverManager.kt | 86 ++++++++++++++++++- .../kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt | 9 +- .../kotlin/io/ksmt/test/KForkingSolverTest.kt | 4 +- 5 files changed, 145 insertions(+), 47 deletions(-) diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt index 0223042a1..8f72a81e7 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt @@ -32,20 +32,24 @@ import io.ksmt.sort.KSortVisitor import io.ksmt.sort.KUninterpretedSort import java.util.TreeMap -class KCvc5Context private constructor( +class KCvc5Context internal constructor( private val solver: Solver, + /** + * Used as context for expressions lifetime separation. + * Exprs which stored in [KCvc5Context], created with [mkExprSolver] + */ val mkExprSolver: Solver, private val ctx: KContext, - parent: KCvc5Context?, - val isForking: Boolean + forkingSolverManager: KCvc5ForkingSolverManager? = null ) : AutoCloseable { - constructor(solver: Solver, mkExprSolver: Solver, ctx: KContext, isForking: Boolean = false) - : this(solver, mkExprSolver, ctx, null, isForking) + constructor(solver: Solver, mkExprSolver: Solver, ctx: KContext) + : this(solver, mkExprSolver, ctx, null) - constructor(solver: Solver, ctx: KContext, isForking: Boolean = false) - : this(solver, solver, ctx, null, isForking) + constructor(solver: Solver, ctx: KContext) + : this(solver, solver, ctx, null) private var isClosed = false + val isForking = forkingSolverManager != null private val uninterpretedSortCollector = KUninterpretedSortCollector(this) private var exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(uninterpretedSortCollector, ctx) @@ -72,14 +76,11 @@ class KCvc5Context private constructor( * that is in global cache, but whose sorts / decls have been erased after pop() * (and put this expr to the cache of current accumulated scope) */ - private val currentAccumulatedScopeExpressions: HashMap, Term> + private val currentAccumulatedScopeExpressions = HashMap, Term>() private val expressions: HashMap, Term> /** * We can't use HashMap with Term and Sort (hashcode is not implemented) - * - * Avoid to close cache explicitly due to its sharing between forking hierarchy. - * It will be garbage collected on last solver close in forking hierarchy */ private val cvc5Expressions: TreeMap> private val sorts: HashMap @@ -102,28 +103,26 @@ class KCvc5Context private constructor( init { if (isForking) { - uninterpretedSorts = (parent?.uninterpretedSorts as? ScopedLinkedFrame)?.fork() - ?: ScopedLinkedFrame(::HashSet, ::HashSet) - declarations = (parent?.declarations as? ScopedLinkedFrame)?.fork() - ?: ScopedLinkedFrame(::HashSet, ::HashSet) + uninterpretedSorts = ScopedLinkedFrame(::HashSet, ::HashSet) + declarations = ScopedLinkedFrame(::HashSet, ::HashSet) } else { uninterpretedSorts = ScopedArrayFrame(::HashSet) declarations = ScopedArrayFrame(::HashSet) } - if (parent != null) { - currentAccumulatedScopeExpressions = parent.currentAccumulatedScopeExpressions.toMap(HashMap()) - expressions = parent.expressions - cvc5Expressions = parent.cvc5Expressions - sorts = parent.sorts - cvc5Sorts = parent.cvc5Sorts - decls = parent.decls - cvc5Decls = parent.cvc5Decls - uninterpretedSortValueDescriptors = parent.uninterpretedSortValueDescriptors - uninterpretedSortValueInterpreter = parent.uninterpretedSortValueInterpreter - uninterpretedSortValues = parent.uninterpretedSortValues + if (forkingSolverManager != null) { + with(forkingSolverManager) { + expressions = findExpressionsCache() + cvc5Expressions = findExpressionsReversedCache() + sorts = findSortsCache() + cvc5Sorts = findSortsReversedCache() + decls = findDeclsCache() + cvc5Decls = findDeclsReversedCache() + uninterpretedSortValueDescriptors = findUninterpretedSortsValueDescriptors() + uninterpretedSortValueInterpreter = findUninterpretedSortsValueInterpretersCache() + uninterpretedSortValues = findUninterpretedSortValues() + } } else { - currentAccumulatedScopeExpressions = HashMap() expressions = HashMap() cvc5Expressions = TreeMap() sorts = HashMap() @@ -155,12 +154,18 @@ class KCvc5Context private constructor( val isActive: Boolean get() = !isClosed - fun fork(solver: Solver, mkExprSolver: Solver): KCvc5Context = - KCvc5Context(solver, mkExprSolver, ctx, this, true).also { forkCtx -> + fun fork(solver: Solver, forkingSolverManager: KCvc5ForkingSolverManager): KCvc5Context { + require(isForking) { "Can't fork non-forking context" } + return KCvc5Context(solver, mkExprSolver, ctx, forkingSolverManager).also { forkCtx -> + forkCtx.currentAccumulatedScopeExpressions += currentAccumulatedScopeExpressions + (forkCtx.uninterpretedSorts as ScopedLinkedFrame).fork(uninterpretedSorts as ScopedLinkedFrame) + (forkCtx.declarations as ScopedLinkedFrame).fork(declarations as ScopedLinkedFrame) + repeat(assertedConstraintLevels.size) { forkCtx.pushAssertionLevel() } } + } fun push() { declarations.push() @@ -262,7 +267,7 @@ class KCvc5Context private constructor( * * todo: precise uninterpreted sort values tracking * */ - private data class UninterpretedSortValueDescriptor( + internal data class UninterpretedSortValueDescriptor( val value: KUninterpretedSortValue, val nativeUniqueValueDescriptor: Term, val nativeValueTerm: Term @@ -414,6 +419,18 @@ class KCvc5Context private constructor( isClosed = true currentAccumulatedScopeExpressions.clear() + + if (!isForking) { + expressions.clear() + cvc5Expressions.clear() + sorts.clear() + cvc5Sorts.clear() + decls.clear() + cvc5Decls.clear() + uninterpretedSortValueDescriptors.clear() + uninterpretedSortValueInterpreter.clear() + uninterpretedSortValues.clear() + } } class KUninterpretedSortCollector(private val cvc5Ctx: KCvc5Context) : KSortVisitor { diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt index 3e8f98dd3..d49f52187 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt @@ -16,7 +16,7 @@ open class KCvc5ForkingSolver internal constructor( ctx: KContext, private val manager: KCvc5ForkingSolverManager, /** store reference on Solver to separate lifetime of native expressions */ - private val mkExprSolver: Solver, + mkExprSolver: Solver, parent: KCvc5ForkingSolver? = null ) : KCvc5SolverBase(ctx), KForkingSolver, KSolver { @@ -24,21 +24,19 @@ open class KCvc5ForkingSolver internal constructor( private val isChild = parent != null private var assertionsInitiated = !isChild - private val trackedAssertions: ScopedLinkedFrame>> - private val cvc5Assertions: ScopedLinkedFrame> + private val trackedAssertions = ScopedLinkedFrame>>(::TreeMap, ::TreeMap) + private val cvc5Assertions = ScopedLinkedFrame>(::TreeSet, ::TreeSet) override val currentScope: UInt get() = trackedAssertions.currentScope init { if (parent != null) { - cvc5Ctx = parent.cvc5Ctx.fork(solver, this.mkExprSolver) - trackedAssertions = parent.trackedAssertions.fork() - cvc5Assertions = parent.cvc5Assertions.fork() + cvc5Ctx = parent.cvc5Ctx.fork(solver, manager) + trackedAssertions.fork(parent.trackedAssertions) + cvc5Assertions.fork(parent.cvc5Assertions) } else { - cvc5Ctx = KCvc5Context(solver, this.mkExprSolver, ctx, true) - trackedAssertions = ScopedLinkedFrame(::TreeMap, ::TreeMap) - cvc5Assertions = ScopedLinkedFrame(::TreeSet, ::TreeSet) + cvc5Ctx = KCvc5Context(solver, mkExprSolver, ctx, manager) } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt index fc28f02e0..303083f2c 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt @@ -1,10 +1,18 @@ package io.ksmt.solver.cvc5 import io.github.cvc5.Solver +import io.github.cvc5.Sort +import io.github.cvc5.Term import io.ksmt.KContext +import io.ksmt.decl.KDecl +import io.ksmt.expr.KExpr +import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.solver.KForkingSolver import io.ksmt.solver.KForkingSolverManager +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort import java.util.IdentityHashMap +import java.util.TreeMap import java.util.concurrent.ConcurrentHashMap open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { @@ -12,12 +20,67 @@ open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolver private val solvers: MutableSet = ConcurrentHashMap.newKeySet() /** - * for each parent to child hierarchy created only one mkExprSolver, + * for each parent-to-child hierarchy created only one mkExprSolver, * which is responsible for native expressions lifetime */ private val forkingSolverToMkExprSolver = IdentityHashMap() private val mkExprSolverReferences = IdentityHashMap() + // shared cache + private val expressionsCache = IdentityHashMap() + private val expressionsReversedCache = IdentityHashMap() + private val sortsCache = IdentityHashMap() + private val sortsReversedCache = IdentityHashMap() + private val declsCache = IdentityHashMap() + private val declsReversedCache = IdentityHashMap() + + private val uninterpretedSortValueDescriptors = IdentityHashMap() + private val uninterpretedSortValueInterpretersCache = + IdentityHashMap() + private val uninterpretedSortValues = IdentityHashMap() + + private fun Solver.ensureRegisteredAsMkExprSolver() = require(this in mkExprSolverReferences) { + "Solver is not registered by this manager" + } + + internal fun KCvc5Context.findExpressionsCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { + expressionsCache.getOrPut(mkExprSolver) { ExpressionsCache() } + } + + internal fun KCvc5Context.findExpressionsReversedCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { + expressionsReversedCache.getOrPut(mkExprSolver) { ExpressionsReversedCache() } + } + + internal fun KCvc5Context.findSortsCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { + sortsCache.getOrPut(mkExprSolver) { SortsCache() } + } + + internal fun KCvc5Context.findSortsReversedCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { + sortsReversedCache.getOrPut(mkExprSolver) { SortsReversedCache() } + } + + internal fun KCvc5Context.findDeclsCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { + declsCache.getOrPut(mkExprSolver) { DeclsCache() } + } + + internal fun KCvc5Context.findDeclsReversedCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { + declsReversedCache.getOrPut(mkExprSolver) { DeclsReversedCache() } + } + + internal fun KCvc5Context.findUninterpretedSortsValueDescriptors() = mkExprSolver.ensureRegisteredAsMkExprSolver() + .let { + uninterpretedSortValueDescriptors.getOrPut(mkExprSolver) { UninterpretedSortValueDescriptors() } + } + + internal fun KCvc5Context.findUninterpretedSortsValueInterpretersCache() = mkExprSolver + .ensureRegisteredAsMkExprSolver().let { + uninterpretedSortValueInterpretersCache.getOrPut(mkExprSolver) { UninterpretedSortValueInterpretersCache() } + } + + internal fun KCvc5Context.findUninterpretedSortValues() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { + uninterpretedSortValues.getOrPut(mkExprSolver) { UninterpretedSortValues() } + } + override fun mkForkingSolver(): KForkingSolver { val mkExprSolver = Solver() incRef(mkExprSolver) @@ -58,6 +121,16 @@ open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolver val referencesAfterDec = mkExprSolverReferences.getValue(mkExprSolver) - 1 if (referencesAfterDec == 0) { mkExprSolverReferences -= mkExprSolver + expressionsCache -= mkExprSolver + expressionsReversedCache -= mkExprSolver + sortsCache -= mkExprSolver + sortsReversedCache -= mkExprSolver + declsCache -= mkExprSolver + declsReversedCache -= mkExprSolver + uninterpretedSortValueDescriptors -= mkExprSolver + uninterpretedSortValueInterpretersCache -= mkExprSolver + uninterpretedSortValues -= mkExprSolver + mkExprSolver.close() } else { mkExprSolverReferences[mkExprSolver] = referencesAfterDec @@ -70,3 +143,14 @@ open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolver } } } + +private typealias ExpressionsCache = HashMap, Term> +private typealias ExpressionsReversedCache = TreeMap> +private typealias SortsCache = HashMap +private typealias SortsReversedCache = TreeMap +private typealias DeclsCache = HashMap, Term> +private typealias DeclsReversedCache = TreeMap> +private typealias UninterpretedSortValueDescriptors = ArrayList +private typealias UninterpretedSortValueInterpretersCache = HashMap +@Suppress("MaxLineLength") +private typealias UninterpretedSortValues = HashMap>> diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt index 147f60c96..e8fdcd351 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt @@ -106,11 +106,10 @@ internal class ScopedLinkedFrame private constructor( current = LinkedFrame(newTopFrame, current.previous) } - fun fork(): ScopedLinkedFrame = ScopedLinkedFrame( - current, - createNewFrame, - copyFrame - ).also { it.recreateTopFrame() } + fun fork(parent: ScopedLinkedFrame) { + current = parent.current + recreateTopFrame() + } private inline fun forEachReversed(action: (T) -> Unit) { var cur: LinkedFrame? = current diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt index fd59fa00f..16bc66bb1 100644 --- a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt @@ -284,7 +284,7 @@ class KForkingSolverTest { with(ctx) { val parent = mkSolver(ctx) val x by intSort - val f = x ge 100.expr + val f = x gt 100.expr parent.assert(f) parent.check().also { require(it == KSolverStatus.SAT) } @@ -296,7 +296,7 @@ class KForkingSolverTest { fork.assert(f and (x eq xVal)) fork.check().also { assertEquals(KSolverStatus.SAT, it) } - assertEquals(fork.model().eval(x), xVal) + assertEquals(xVal, fork.model().eval(x)) } } From d88500ec53bbf2718f88b5bcd6f200246cd74e26 Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Sat, 29 Jul 2023 22:17:42 +0300 Subject: [PATCH 04/12] z3: forking solver; Wrapped missing native throwable functions; Tracks of assertion clear on pop --- .../kotlin/io/ksmt/test/KForkingSolverTest.kt | 494 ++++++++++-------- .../main/kotlin/com/microsoft/z3/UnsafeApi.kt | 6 + .../ExpressionUninterpretedValuesTracker.kt | 23 +- .../kotlin/io/ksmt/solver/z3/KZ3Context.kt | 137 +++-- .../io/ksmt/solver/z3/KZ3ForkingSolver.kt | 131 +++++ .../ksmt/solver/z3/KZ3ForkingSolverManager.kt | 172 ++++++ .../kotlin/io/ksmt/solver/z3/KZ3Solver.kt | 246 +-------- .../kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt | 256 +++++++++ .../ksmt/solver/z3/KZ3SolverConfiguration.kt | 39 ++ .../kotlin/io/ksmt/solver/z3/ScopedFrame.kt | 129 +++++ 10 files changed, 1133 insertions(+), 500 deletions(-) create mode 100644 ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt create mode 100644 ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt create mode 100644 ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt create mode 100644 ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ScopedFrame.kt diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt index 16bc66bb1..3a7e46595 100644 --- a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt @@ -1,303 +1,371 @@ package io.ksmt.test import io.ksmt.KContext -import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KForkingSolverManager import io.ksmt.solver.KSolverStatus import io.ksmt.solver.cvc5.KCvc5ForkingSolverManager +import io.ksmt.solver.z3.KZ3ForkingSolverManager import io.ksmt.utils.getValue +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertDoesNotThrow import kotlin.test.assertContains import kotlin.test.assertEquals +import kotlin.test.assertNotEquals import kotlin.test.assertTrue -import org.junit.jupiter.api.Nested -import org.junit.jupiter.api.Test class KForkingSolverTest { @Nested inner class KForkingSolverTestCvc5 { @Test - fun testCheckSat() = testCheckSat(::mkCvc5ForkingSolver) + fun testCheckSat() = testCheckSat(::mkCvc5ForkingSolverManager) @Test - fun testModel() = testModel(::mkCvc5ForkingSolver) + fun testModel() = testModel(::mkCvc5ForkingSolverManager) @Test - fun testUnsatCore() = testUnsatCore(::mkCvc5ForkingSolver) + fun testUnsatCore() = testUnsatCore(::mkCvc5ForkingSolverManager) @Test - fun testUninterpretedSort() = testUninterpretedSort(::mkCvc5ForkingSolver) + fun testUninterpretedSort() = testUninterpretedSort(::mkCvc5ForkingSolverManager) @Test - fun testScopedAssertions() = testScopedAssertions(::mkCvc5ForkingSolver) + fun testScopedAssertions() = testScopedAssertions(::mkCvc5ForkingSolverManager) @Test - fun testLifeTime() = testLifeTime(::mkCvc5ForkingSolver) + fun testLifeTime() = testLifeTime(::mkCvc5ForkingSolverManager) - private fun mkCvc5ForkingSolver(ctx: KContext) = KCvc5ForkingSolverManager(ctx).mkForkingSolver() + private fun mkCvc5ForkingSolverManager(ctx: KContext) = KCvc5ForkingSolverManager(ctx) } - private fun testCheckSat(mkSolver: (KContext) -> KForkingSolver<*>) = - KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkSolver(ctx).use { parentSolver -> - with(ctx) { - val a by boolSort - val b by boolSort - val f = a and b - val neg = !a - - parentSolver.push() + @Nested + inner class KForkingSolverTestZ3 { + @Test + fun testCheckSat() = testCheckSat(::mkZ3ForkingSolverManager) - // * check children's assertions do not change parent's state - parentSolver.assert(f) - require(parentSolver.check() == KSolverStatus.SAT) - require(parentSolver.checkWithAssumptions(listOf(neg)) == KSolverStatus.UNSAT) + @Test + fun testModel() = testModel(::mkZ3ForkingSolverManager) - parentSolver.fork().also { fork -> - assertEquals(KSolverStatus.SAT, fork.check()) - fork.assert(neg) - assertEquals(KSolverStatus.UNSAT, fork.check()) - } + @Test + fun testUnsatCore() = testUnsatCore(::mkZ3ForkingSolverManager) - assertEquals(KSolverStatus.SAT, parentSolver.check()) - // * + @Test + fun testUninterpretedSort() = testUninterpretedSort(::mkZ3ForkingSolverManager) - // * check parent's assertions translated into child solver - parentSolver.push() - assertEquals(KSolverStatus.UNSAT, parentSolver.fork().checkWithAssumptions(listOf(neg))) - parentSolver.assert(neg) - require(parentSolver.check() == KSolverStatus.UNSAT) + @Test + fun testScopedAssertions() = testScopedAssertions(::mkZ3ForkingSolverManager) - parentSolver.fork().also { fork -> - assertEquals(KSolverStatus.UNSAT, fork.check()) - } - parentSolver.pop() - // * - - // * check children independence - assertEquals(KSolverStatus.SAT, parentSolver.check()) - parentSolver.fork().also { fork1 -> - val fork2 = parentSolver.fork() - fork2.assert(neg) - assertEquals(KSolverStatus.UNSAT, fork2.check()) - assertEquals(KSolverStatus.SAT, fork1.check()) - - fork1.assert(neg) - assertEquals(KSolverStatus.UNSAT, fork1.check()) - assertEquals(KSolverStatus.SAT, parentSolver.fork().check()) - } - assertEquals(KSolverStatus.SAT, parentSolver.check()) - // * - } + @Test + fun testLifeTime() = testLifeTime(::mkZ3ForkingSolverManager) - } - } + private fun mkZ3ForkingSolverManager(ctx: KContext) = KZ3ForkingSolverManager(ctx) + } - private fun testUnsatCore(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + private fun testCheckSat(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>) = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkSolver(ctx).use { parentSolver -> - with(ctx) { - val a by boolSort - val b by boolSort - val f = a and b - val neg = !a - - // * check that unsat core is empty (non-tracked assertions) - parentSolver.push() - parentSolver.assert(f) + mkForkingSolverManager(ctx).use { man -> + man.mkForkingSolver().use { parentSolver -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and b + val neg = !a + + parentSolver.push() + + // * check children's assertions do not change parent's state + parentSolver.assert(f) + require(parentSolver.check() == KSolverStatus.SAT) + require(parentSolver.checkWithAssumptions(listOf(neg)) == KSolverStatus.UNSAT) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork.check()) + } - parentSolver.fork().also { fork -> - assertEquals(KSolverStatus.SAT, fork.check()) - fork.assert(neg) - assertEquals(KSolverStatus.UNSAT, fork.check()) - assertTrue { fork.unsatCore().isEmpty() } - assertEquals(KSolverStatus.SAT, parentSolver.check()) // parent's state hasn't changed - } - parentSolver.pop() - // * + assertEquals(KSolverStatus.SAT, parentSolver.check()) + // * - // check tracked exprs are in unsat core - parentSolver.push() - parentSolver.assertAndTrack(f) + // * check parent's assertions translated into child solver + parentSolver.push() + assertEquals(KSolverStatus.UNSAT, parentSolver.fork().checkWithAssumptions(listOf(neg))) + parentSolver.assert(neg) + require(parentSolver.check() == KSolverStatus.UNSAT) - parentSolver.fork().also { fork -> - fork.assertAndTrack(neg) - assertEquals(KSolverStatus.UNSAT, fork.check()) - assertContains(fork.unsatCore(), neg) - assertContains(fork.unsatCore(), f) - assertEquals(KSolverStatus.SAT, parentSolver.check()) // parent's state hasn't changed + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.UNSAT, fork.check()) + } + parentSolver.pop() + // * + + // * check children independence + assertEquals(KSolverStatus.SAT, parentSolver.check()) + parentSolver.fork().also { fork1 -> + val fork2 = parentSolver.fork() + fork2.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork2.check()) + assertEquals(KSolverStatus.SAT, fork1.check()) + + fork1.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork1.check()) + assertEquals(KSolverStatus.SAT, parentSolver.fork().check()) + } + assertEquals(KSolverStatus.SAT, parentSolver.check()) + // * } - // * - - // * check unsat core saves from parent to child - parentSolver.assert(neg) - require(parentSolver.check() == KSolverStatus.UNSAT) - require(neg !in parentSolver.unsatCore()) - require(f in parentSolver.unsatCore()) // only tracked f is in unsat core + } + } + } - parentSolver.fork().also { fork -> - assertEquals(KSolverStatus.UNSAT, fork.check()) - assertContains(fork.unsatCore(), f) - assertTrue { neg !in fork.unsatCore() } + private fun testUnsatCore(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + mkForkingSolverManager(ctx).use { man -> + man.mkForkingSolver().use { parentSolver -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and b + val neg = !a + + // * check that unsat core is empty (non-tracked assertions) + parentSolver.push() + parentSolver.assert(f) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertTrue { fork.unsatCore().isEmpty() } + assertEquals(KSolverStatus.SAT, parentSolver.check()) // parent's state hasn't changed + } + parentSolver.pop() + // * + + // check tracked exprs are in unsat core + parentSolver.push() + parentSolver.assertAndTrack(f) + + parentSolver.fork().also { fork -> + fork.assertAndTrack(neg) + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertContains(fork.unsatCore(), neg) + assertContains(fork.unsatCore(), f) + assertEquals(KSolverStatus.SAT, parentSolver.check()) // parent's state hasn't changed + } + // * + + // * check unsat core saves from parent to child + parentSolver.assert(neg) + require(parentSolver.check() == KSolverStatus.UNSAT) + require(neg !in parentSolver.unsatCore()) + require(f in parentSolver.unsatCore()) // only tracked f is in unsat core + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertContains(fork.unsatCore(), f) + assertTrue { neg !in fork.unsatCore() } + } } } } } - private fun testModel(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + private fun testModel(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkSolver(ctx).use { parentSolver -> - with(ctx) { - val a by boolSort - val b by boolSort - val f = a and !b - - parentSolver.assert(f) - - require(parentSolver.check() == KSolverStatus.SAT) - require(parentSolver.model().eval(a) == true.expr) - require(parentSolver.model().eval(b) == false.expr) - - parentSolver.fork().also { fork -> - assertEquals(KSolverStatus.SAT, fork.check()) - assertEquals(true.expr, fork.model().eval(a)) - assertEquals(false.expr, fork.model().eval(b)) + mkForkingSolverManager(ctx).use { man -> + man.mkForkingSolver().use { parentSolver -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and !b + + parentSolver.assert(f) + + require(parentSolver.check() == KSolverStatus.SAT) + require(parentSolver.model().eval(a) == true.expr) + require(parentSolver.model().eval(b) == false.expr) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(true.expr, fork.model().eval(a)) + assertEquals(false.expr, fork.model().eval(b)) + } } } } } - private fun testScopedAssertions(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + private fun testScopedAssertions(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkSolver(ctx).use { parent -> - with(ctx) { - val a by boolSort - val b by boolSort - val f = a and b - val neg = !a - - parent.push() - - parent.assertAndTrack(f) - require(parent.check() == KSolverStatus.SAT) - parent.push() - parent.assertAndTrack(neg) + mkForkingSolverManager(ctx).use { man -> + man.mkForkingSolver().use { parent -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and b + val neg = !a + + parent.push() + + parent.assertAndTrack(f) + require(parent.check() == KSolverStatus.SAT) + parent.push() + parent.assertAndTrack(neg) + + require(parent.check() == KSolverStatus.UNSAT) + + parent.fork().also { fork -> + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertContains(fork.unsatCore(), f) + assertContains(fork.unsatCore(), neg) + + fork.pop() + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(true.expr, fork.model().eval(a)) + assertEquals(true.expr, fork.model().eval(b)) + assertEquals(KSolverStatus.UNSAT, fork.checkWithAssumptions(listOf(neg))) + assertEquals(KSolverStatus.UNSAT, parent.check()) // check parent's state hasn't changed - require(parent.check() == KSolverStatus.UNSAT) + fork.fork().also { ffork -> + assertEquals(KSolverStatus.SAT, ffork.check()) + assertEquals(KSolverStatus.UNSAT, ffork.checkWithAssumptions(listOf(neg))) - parent.fork().also { fork -> - assertEquals(KSolverStatus.UNSAT, fork.check()) - assertContains(fork.unsatCore(), f) - assertContains(fork.unsatCore(), neg) + ffork.push() + ffork.assertAndTrack(neg) + assertEquals(KSolverStatus.UNSAT, ffork.check()) + assertContains(ffork.unsatCore(), f) + assertContains(ffork.unsatCore(), neg) - fork.pop() - assertEquals(KSolverStatus.SAT, fork.check()) - assertEquals(true.expr, fork.model().eval(a)) - assertEquals(true.expr, fork.model().eval(b)) - assertEquals(KSolverStatus.UNSAT, fork.checkWithAssumptions(listOf(neg))) - assertEquals(KSolverStatus.UNSAT, parent.check()) // check parent's state hasn't changed - - fork.fork().also { ffork -> - assertEquals(KSolverStatus.SAT, ffork.check()) - assertEquals(KSolverStatus.UNSAT, ffork.checkWithAssumptions(listOf(neg))) - - ffork.push() - ffork.assertAndTrack(neg) - assertEquals(KSolverStatus.UNSAT, ffork.check()) - assertContains(ffork.unsatCore(), f) - assertContains(ffork.unsatCore(), neg) - - assertEquals(KSolverStatus.SAT, fork.check()) // check parent's state hasn't changed - assertEquals(KSolverStatus.UNSAT, parent.check()) // check parent's state hasn't changed + assertEquals(KSolverStatus.SAT, fork.check()) // check parent's state hasn't changed + assertEquals(KSolverStatus.UNSAT, parent.check()) // check parent's state hasn't changed - ffork.pop() - assertEquals(KSolverStatus.SAT, ffork.check()) - assertEquals(KSolverStatus.UNSAT, ffork.checkWithAssumptions(listOf(neg))) + ffork.pop() + assertEquals(KSolverStatus.SAT, ffork.check()) + assertEquals(KSolverStatus.UNSAT, ffork.checkWithAssumptions(listOf(neg))) + } } - } - // check child's state is detached - val fork = parent.fork() - assertEquals(KSolverStatus.UNSAT, fork.check()) - parent.pop() + // check child's state is detached + val fork = parent.fork() + assertEquals(KSolverStatus.UNSAT, fork.check()) + parent.pop() - assertEquals(KSolverStatus.SAT, parent.check()) - assertEquals(KSolverStatus.UNSAT, fork.check()) + assertEquals(KSolverStatus.SAT, parent.check()) + assertEquals(KSolverStatus.UNSAT, fork.check()) - parent.pop() + parent.pop() - fork.pop() - fork.pop() + fork.pop() + fork.pop() - fork.assert(neg) - assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(neg) + assertEquals(KSolverStatus.SAT, fork.check()) + } } } } - private fun testUninterpretedSort(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + @Suppress("LongMethod") + private fun testUninterpretedSort(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkSolver(ctx).use { parentSolver -> - with(ctx) { - val uSort = mkUninterpretedSort("u") - val u1 by uSort - val u2 by uSort + mkForkingSolverManager(ctx).use { man -> + man.mkForkingSolver().use { parentSolver -> + with(ctx) { + val uSort = mkUninterpretedSort("u") + val u1 by uSort + val u2 by uSort - val eq12 = u1 eq u2 + val eq12 = u1 eq u2 - parentSolver.push() - parentSolver.assert(eq12) + parentSolver.push() - require(parentSolver.check() == KSolverStatus.SAT) - val pu1v = parentSolver.model().eval(u1) + parentSolver.fork().also { fork -> + assertDoesNotThrow { fork.pop() } // check assertion levels saved + fork.assert(u1 neq u2) + assertEquals(KSolverStatus.SAT, fork.check()) + } - parentSolver.fork().also { fork -> - assertEquals(KSolverStatus.SAT, fork.check()) - fork.assert(u1 eq pu1v) - assertEquals(KSolverStatus.SAT, fork.check()) - assertEquals(pu1v, fork.model().eval(u1)) - } + parentSolver.assert(eq12) - parentSolver.fork().also { fork -> - assertEquals(KSolverStatus.SAT, fork.check()) - fork.assert(u1 eq pu1v) - assertEquals(KSolverStatus.SAT, fork.check()) - assertEquals(pu1v, fork.model().eval(u1)) + require(parentSolver.check() == KSolverStatus.SAT) + val pu1v = parentSolver.model().eval(u1) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(u1 eq pu1v) + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(pu1v, fork.model().eval(u1)) + } + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(u1 neq pu1v) + assertEquals(KSolverStatus.SAT, fork.check()) + assertNotEquals(pu1v, fork.model().eval(u1)) + } + + parentSolver.push().also { + val u5 by uSort + val pu5v = mkUninterpretedSortValue(uSort, 5) + parentSolver.assert(u5 eq pu5v) + + parentSolver.assert(u1 eq pu1v) - fork.fork().also { ff -> - assertEquals(KSolverStatus.SAT, ff.check()) - assertEquals(pu1v, ff.model().eval(u1)) - ff.model().uninterpretedSortUniverse(uSort)?.also { universe -> + parentSolver.check() + parentSolver.model().uninterpretedSortUniverse(uSort)?.also { universe -> assertContains(universe, pu1v) + assertContains(universe, pu5v) } + + parentSolver.pop() } - } - parentSolver.model().uninterpretedSortUniverse(uSort)?.also { universe -> - assertContains(universe, pu1v) - } + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(u1 eq pu1v) + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(pu1v, fork.model().eval(u1)) + + fork.fork().also { ff -> + assertEquals(KSolverStatus.SAT, ff.check()) + assertEquals(pu1v, ff.model().eval(u1)) + ff.model().uninterpretedSortUniverse(uSort)?.also { universe -> + assertContains(universe, pu1v) + } + } + } + assertEquals(KSolverStatus.SAT, parentSolver.check()) + parentSolver.model().uninterpretedSortUniverse(uSort)?.also { universe -> + assertContains(universe, pu1v) + } + + } } } } - fun testLifeTime(mkSolver: (KContext) -> KForkingSolver<*>): Unit = + fun testLifeTime(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - with(ctx) { - val parent = mkSolver(ctx) - val x by intSort - val f = x gt 100.expr + mkForkingSolverManager(ctx).use { man -> + with(ctx) { + val parent = man.mkForkingSolver() + val x by intSort + val f = x gt 100.expr - parent.assert(f) - parent.check().also { require(it == KSolverStatus.SAT) } + parent.assert(f) + parent.check().also { require(it == KSolverStatus.SAT) } - val xVal = parent.model().eval(x) + val xVal = parent.model().eval(x) - val fork = parent.fork().fork().fork() - parent.close() + val fork = parent.fork().fork().fork() + parent.close() - fork.assert(f and (x eq xVal)) - fork.check().also { assertEquals(KSolverStatus.SAT, it) } - assertEquals(xVal, fork.model().eval(x)) + fork.assert(f and (x eq xVal)) + fork.check().also { assertEquals(KSolverStatus.SAT, it) } + assertEquals(xVal, fork.model().eval(x)) + } } - } } diff --git a/ksmt-z3/src/main/kotlin/com/microsoft/z3/UnsafeApi.kt b/ksmt-z3/src/main/kotlin/com/microsoft/z3/UnsafeApi.kt index 3d745114e..87ea03ce5 100644 --- a/ksmt-z3/src/main/kotlin/com/microsoft/z3/UnsafeApi.kt +++ b/ksmt-z3/src/main/kotlin/com/microsoft/z3/UnsafeApi.kt @@ -1,5 +1,7 @@ package com.microsoft.z3 +import it.unimi.dsi.fastutil.longs.LongSet + fun incRefUnsafe(ctx: Long, ast: Long) { // Invoke incRef directly without status check Native.INTERNALincRef(ctx, ast) @@ -9,3 +11,7 @@ fun decRefUnsafe(ctx: Long, ast: Long) { // Invoke decRef directly without status check Native.INTERNALdecRef(ctx, ast) } + +fun LongSet.decRefUnsafeAll(ctx: Long) = longIterator().forEachRemaining { + decRefUnsafe(ctx, it) +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt index 8c8a74e2f..89f560420 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt @@ -19,7 +19,18 @@ import io.ksmt.sort.KUninterpretedSort * 2. Assert distinct constraints ([assertPendingUninterpretedValueConstraints]) * that may be introduced during internalization. * */ -class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Context) { +class ExpressionUninterpretedValuesTracker private constructor( + val ctx: KContext, + val z3Ctx: KZ3Context, + private val registeredUninterpretedSortValues: HashMap +) { + constructor(ctx: KContext, z3Ctx: KZ3Context) : this(ctx, z3Ctx, hashMapOf()) + constructor(ctx: KContext, z3Ctx: KZ3Context, forkingSolverManager: KZ3ForkingSolverManager) : this( + ctx, + z3Ctx, + with(forkingSolverManager) { z3Ctx.findRegisteredUninterpretedSortValues() } + ) + private val expressionLevels = Object2IntOpenHashMap>().apply { defaultReturnValue(Int.MAX_VALUE) // Level which is greater than any possible level } @@ -32,9 +43,6 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont private val valueTrackerFrames = arrayListOf(currentFrame) - private val registeredUninterpretedSortValues = - hashMapOf() - /** * Skip any value tracking related actions until * we have uninterpreted values registered. @@ -49,6 +57,11 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont body() } + fun fork(parent: ExpressionUninterpretedValuesTracker) = also { + expressionLevels += parent.expressionLevels + repeat(parent.valueTrackerFrames.size - 1) { pushAssertionLevel() } + } + fun expressionUse(expr: KExpr<*>) = ifTrackingEnabled { currentFrame.analyzeUsedExpression(expr) } @@ -121,7 +134,7 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont z3Ctx.releaseTemporaryAst(constraintLhs) } - private data class UninterpretedSortValueDescriptor( + internal data class UninterpretedSortValueDescriptor( val value: KUninterpretedSortValue, val nativeUniqueValueDescriptor: Long, val nativeValueExpr: Long diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt index 834c28b7c..4001591e6 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt @@ -2,48 +2,88 @@ package io.ksmt.solver.z3 import com.microsoft.z3.Context import com.microsoft.z3.Solver +import com.microsoft.z3.Z3Exception import com.microsoft.z3.decRefUnsafe +import com.microsoft.z3.decRefUnsafeAll import com.microsoft.z3.incRefUnsafe import io.ksmt.KContext import io.ksmt.decl.KDecl import io.ksmt.expr.KExpr import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KSolverException import io.ksmt.solver.util.KExprLongInternalizerBase.Companion.NOT_INTERNALIZED import io.ksmt.sort.KSort import io.ksmt.sort.KUninterpretedSort import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap import it.unimi.dsi.fastutil.longs.LongOpenHashSet -import it.unimi.dsi.fastutil.longs.LongSet import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap @Suppress("TooManyFunctions") -class KZ3Context( +class KZ3Context internal constructor( ksmtCtx: KContext, - private val ctx: Context + private val ctx: Context, + forkingSolverManager: KZ3ForkingSolverManager?, ) : AutoCloseable { - constructor(ksmtCtx: KContext) : this(ksmtCtx, Context()) + constructor(ksmtCtx: KContext, ctx: Context) : this(ksmtCtx, ctx, null) + constructor(ksmtCtx: KContext) : this(ksmtCtx, Context(), null) private var isClosed = false - - private val expressions = Object2LongOpenHashMap>().apply { - defaultReturnValue(NOT_INTERNALIZED) - } - - private val sorts = Object2LongOpenHashMap().apply { - defaultReturnValue(NOT_INTERNALIZED) - } - - private val decls = Object2LongOpenHashMap>().apply { - defaultReturnValue(NOT_INTERNALIZED) + private val isForking = forkingSolverManager != null + + // common for parent and child structures + private val expressions: Object2LongOpenHashMap> + private val sorts: Object2LongOpenHashMap + private val decls: Object2LongOpenHashMap> + + private val z3Expressions: Long2ObjectOpenHashMap> + private val z3Sorts: Long2ObjectOpenHashMap + private val z3Decls: Long2ObjectOpenHashMap> + private val tmpNativeObjects: LongOpenHashSet + private val converterNativeObjects: LongOpenHashSet + + private val uninterpretedSortValueInterpreter: HashMap + private val uninterpretedSortValueDecls: Long2ObjectOpenHashMap + private val uninterpretedSortValueInterpreters: LongOpenHashSet + + + val uninterpretedValuesTracker: ExpressionUninterpretedValuesTracker + + init { + if (forkingSolverManager != null) { + with(forkingSolverManager) { + expressions = findExpressionsCache() + sorts = findSortsCache() + decls = findDeclsCache() + + z3Expressions = findExpressionsReversedCache() + z3Sorts = findSortsReversedCache() + z3Decls = findDeclsReversedCache() + tmpNativeObjects = findTmpNativeObjectsCache() + converterNativeObjects = findConverterNativeObjectsCache() + uninterpretedSortValueInterpreter = findUninterpretedSortValueInterpreter() + uninterpretedSortValueDecls = findUninterpretedSortValueDecls() + uninterpretedSortValueInterpreters = findUninterpretedSortValueInterpreters() + } + uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this, forkingSolverManager) + } else { + expressions = Object2LongOpenHashMap>().apply { defaultReturnValue(NOT_INTERNALIZED) } + sorts = Object2LongOpenHashMap().apply { defaultReturnValue(NOT_INTERNALIZED) } + decls = Object2LongOpenHashMap>().apply { defaultReturnValue(NOT_INTERNALIZED) } + + z3Expressions = Long2ObjectOpenHashMap>() + z3Sorts = Long2ObjectOpenHashMap() + z3Decls = Long2ObjectOpenHashMap>() + tmpNativeObjects = LongOpenHashSet() + converterNativeObjects = LongOpenHashSet() + + uninterpretedSortValueInterpreter = hashMapOf() + uninterpretedSortValueDecls = Long2ObjectOpenHashMap() + uninterpretedSortValueInterpreters = LongOpenHashSet() + + uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this) + } } - private val z3Expressions = Long2ObjectOpenHashMap>() - private val z3Sorts = Long2ObjectOpenHashMap() - private val z3Decls = Long2ObjectOpenHashMap>() - private val tmpNativeObjects = LongOpenHashSet() - private val converterNativeObjects = LongOpenHashSet() - - val uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this) @JvmField val nCtx: Long = ctx.nCtx() @@ -54,17 +94,24 @@ class KZ3Context( val isActive: Boolean get() = !isClosed + internal fun fork(ksmtCtx: KContext, manager: KZ3ForkingSolverManager): KZ3Context { + require(isForking) { "Can't fork non-forking context" } + return KZ3Context(ksmtCtx, ctx, manager).also { + it.uninterpretedValuesTracker.fork(uninterpretedValuesTracker) + } + } + + internal fun findInternalizedExprWithoutAnalysis(expr: KExpr<*>): Long { + val result = expressions.getLong(expr) + return if (result == NOT_INTERNALIZED) NOT_INTERNALIZED else result + } + /** * Find internalized expr. * Returns [NOT_INTERNALIZED] if expression was not found. * */ - fun findInternalizedExpr(expr: KExpr<*>): Long { - val result = expressions.getLong(expr) - if (result == NOT_INTERNALIZED) return NOT_INTERNALIZED - - uninterpretedValuesTracker.expressionUse(expr) - - return result + fun findInternalizedExpr(expr: KExpr<*>): Long = findInternalizedExprWithoutAnalysis(expr).also { + if (it != NOT_INTERNALIZED) uninterpretedValuesTracker.expressionUse(expr) } fun saveInternalizedExpr(expr: KExpr<*>, internalized: Long) { @@ -148,11 +195,6 @@ class KZ3Context( return ast } - private val uninterpretedSortValueInterpreter = hashMapOf() - - private val uninterpretedSortValueDecls = Long2ObjectOpenHashMap() - private val uninterpretedSortValueInterpreters = LongOpenHashSet() - fun saveUninterpretedSortValueDecl(decl: Long, value: KUninterpretedSortValue): Long { if (uninterpretedSortValueDecls.putIfAbsent(decl, value) == null) { incRefUnsafe(nCtx, decl) @@ -264,37 +306,38 @@ class KZ3Context( if (isClosed) return isClosed = true + if (isForking) return + uninterpretedSortValueInterpreter.clear() - uninterpretedSortValueDecls.keys.decRefAll() + uninterpretedSortValueDecls.keys.decRefUnsafeAll(nCtx) uninterpretedSortValueDecls.clear() - uninterpretedSortValueInterpreters.decRefAll() + uninterpretedSortValueInterpreters.decRefUnsafeAll(nCtx) uninterpretedSortValueInterpreters.clear() - converterNativeObjects.decRefAll() + converterNativeObjects.decRefUnsafeAll(nCtx) converterNativeObjects.clear() - z3Expressions.keys.decRefAll() + z3Expressions.keys.decRefUnsafeAll(nCtx) expressions.clear() z3Expressions.clear() - tmpNativeObjects.decRefAll() + tmpNativeObjects.decRefUnsafeAll(nCtx) tmpNativeObjects.clear() - z3Decls.keys.decRefAll() + z3Decls.keys.decRefUnsafeAll(nCtx) decls.clear() z3Decls.clear() - z3Sorts.keys.decRefAll() + z3Sorts.keys.decRefUnsafeAll(nCtx) sorts.clear() z3Sorts.clear() - ctx.close() - } - - private fun LongSet.decRefAll() = - longIterator().forEachRemaining { - decRefUnsafe(nCtx, it) + try { + ctx.close() + } catch (e: Z3Exception) { + throw KSolverException(e) } + } } diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt new file mode 100644 index 000000000..cb93d09e2 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt @@ -0,0 +1,131 @@ +package io.ksmt.solver.z3 + +import com.microsoft.z3.Context +import com.microsoft.z3.solverAssert +import com.microsoft.z3.solverAssertAndTrack +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap +import it.unimi.dsi.fastutil.longs.LongOpenHashSet +import kotlin.time.Duration + +open class KZ3ForkingSolver internal constructor( + ctx: KContext, + private val manager: KZ3ForkingSolverManager, + parent: KZ3ForkingSolver? +) : KZ3SolverBase(ctx), KForkingSolver { + final override val z3Ctx: KZ3Context + + private val trackedAssertions = ScopedLinkedFrame>>( + ::Long2ObjectOpenHashMap, ::Long2ObjectOpenHashMap + ) + private val z3Assertions = ScopedLinkedFrame(::LongOpenHashSet, ::LongOpenHashSet) + + private val isChild = parent != null + private var assertionsInitiated = !isChild + + init { + if (parent != null) { + z3Ctx = parent.z3Ctx.fork(ctx, manager) + trackedAssertions.fork(parent.trackedAssertions) + z3Assertions.fork(parent.z3Assertions) + } else { + val context = Context() + with(manager) { registerContext(context) } + z3Ctx = KZ3Context(ctx, context, manager) + } + } + + private val config: KZ3ForkingSolverConfigurationImpl by lazy { + z3Try { + z3Ctx.nativeContext.mkParams().let { + parent?.config?.fork(it)?.apply { setParameters(solver) } ?: KZ3ForkingSolverConfigurationImpl(it) + } + } + } + + init { + if (isChild) config // initialize child config + } + + override fun configure(configurator: KZ3SolverConfiguration.() -> Unit) { + config.configurator() + config.setParameters(solver) + } + + override fun fork(): KForkingSolver = manager.mkForkingSolver(this) + + override fun saveTrackedAssertion(track: Long, trackedExpr: KExpr) { + trackedAssertions.currentFrame[track] = trackedExpr + } + + override fun findTrackedExprByTrack(track: Long): KExpr? = trackedAssertions.find { it[track] } + + /** + * Asserts parental (in case of child) assertions if not + */ + private fun ensureAssertionsInitiated() { + if (assertionsInitiated) return + + z3Assertions.stacked() + .zip(trackedAssertions.stacked()) + .asReversed() + .forEachIndexed { scope, (z3AssertionFrame, trackedFrame) -> + if (scope > 0) { + solver.push() + currentScope++ + } + + z3AssertionFrame.forEach(solver::solverAssert) + trackedFrame.forEach { (track, expr) -> + /** tracked [expr] was previously internalized by parent */ + solver.solverAssertAndTrack(track, z3Ctx.findInternalizedExprWithoutAnalysis(expr)) + } + } + + assertionsInitiated = true + } + + override fun push() { + z3Try { ensureAssertionsInitiated() } + super.push() + trackedAssertions.push() + z3Assertions.push() + } + + override fun pop(n: UInt) { + z3Try { ensureAssertionsInitiated() } + super.pop(n) + trackedAssertions.pop(n) + z3Assertions.pop(n) + } + + override fun assert(expr: KExpr) = z3Try { + ensureAssertionsInitiated() + ctx.ensureContextMatch(expr) + + val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } + solver.solverAssert(z3Expr) + + z3Ctx.assertPendingAxioms(solver) + z3Assertions.currentFrame += z3Expr + } + + override fun assertAndTrack(expr: KExpr) { + z3Try { ensureAssertionsInitiated() } + super.assertAndTrack(expr) + } + + override fun check(timeout: Duration): KSolverStatus { + z3Try { ensureAssertionsInitiated() } + return super.check(timeout) + } + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { + z3Try { ensureAssertionsInitiated() } + return super.checkWithAssumptions(assumptions, timeout) + } +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt new file mode 100644 index 000000000..cddbd90b4 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt @@ -0,0 +1,172 @@ +package io.ksmt.solver.z3 + +import com.microsoft.z3.Context +import com.microsoft.z3.Z3Exception +import com.microsoft.z3.decRefUnsafeAll +import io.ksmt.KAst +import io.ksmt.KContext +import io.ksmt.decl.KDecl +import io.ksmt.expr.KExpr +import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KForkingSolverManager +import io.ksmt.solver.KSolverException +import io.ksmt.solver.util.KExprLongInternalizerBase +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap +import it.unimi.dsi.fastutil.longs.LongOpenHashSet +import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap +import java.util.IdentityHashMap +import java.util.concurrent.ConcurrentHashMap + +class KZ3ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { + private val solvers = ConcurrentHashMap.newKeySet() + + /** + * for each parent-to-child hierarchy created only one Context. + * Each Context user is registered to control solver is alive + */ + private val forkingSolverToContext = IdentityHashMap() + private val contextReferences = IdentityHashMap() + + // shared cache + private val expressionsCache = IdentityHashMap() + private val expressionsReversedCache = IdentityHashMap() + private val sortsCache = IdentityHashMap() + private val sortsReversedCache = IdentityHashMap() + private val declsCache = IdentityHashMap() + private val declsReversedCache = IdentityHashMap() + + private val tmpNativeObjectsCache = IdentityHashMap() + private val converterNativeObjectsCache = IdentityHashMap() + + private val uninterpretedSortValueInterpreter = IdentityHashMap() + private val uninterpretedSortValueDecls = IdentityHashMap() + private val uninterpretedSortValueInterpreters = IdentityHashMap() + private val registeredUninterpretedSortValues = IdentityHashMap() + + internal fun KZ3Context.findExpressionsCache() = expressionsCache.getValue(nativeContext) + internal fun KZ3Context.findExpressionsReversedCache() = expressionsReversedCache.getValue(nativeContext) + internal fun KZ3Context.findSortsCache() = sortsCache.getValue(nativeContext) + internal fun KZ3Context.findSortsReversedCache() = sortsReversedCache.getValue(nativeContext) + internal fun KZ3Context.findDeclsCache() = declsCache.getValue(nativeContext) + internal fun KZ3Context.findDeclsReversedCache() = declsReversedCache.getValue(nativeContext) + internal fun KZ3Context.findTmpNativeObjectsCache() = tmpNativeObjectsCache.getValue(nativeContext) + internal fun KZ3Context.findConverterNativeObjectsCache() = converterNativeObjectsCache.getValue(nativeContext) + internal fun KZ3Context.findUninterpretedSortValueInterpreter() = + uninterpretedSortValueInterpreter.getValue(nativeContext) + + internal fun KZ3Context.findUninterpretedSortValueDecls() = + uninterpretedSortValueDecls.getValue(nativeContext) + + internal fun KZ3Context.findUninterpretedSortValueInterpreters() = + uninterpretedSortValueInterpreters.getValue(nativeContext) + + internal fun KZ3Context.findRegisteredUninterpretedSortValues() = + registeredUninterpretedSortValues.getValue(nativeContext) + + internal fun KZ3ForkingSolver.registerContext(sharedContext: Context) { + if (forkingSolverToContext.putIfAbsent(this, sharedContext) == null) { + incRef(sharedContext) + + expressionsCache[sharedContext] = ExpressionsCache().withNotInternalizedAsDefaultValue() + expressionsReversedCache[sharedContext] = ExpressionsReversedCache() + sortsCache[sharedContext] = SortsCache().withNotInternalizedAsDefaultValue() + sortsReversedCache[sharedContext] = SortsReversedCache() + declsCache[sharedContext] = DeclsCache().withNotInternalizedAsDefaultValue() + declsReversedCache[sharedContext] = DeclsReversedCache() + tmpNativeObjectsCache[sharedContext] = TmpNativeObjectsCache() + converterNativeObjectsCache[sharedContext] = ConverterNativeObjectsCache() + uninterpretedSortValueInterpreter[sharedContext] = UninterpretedSortValueInterpreterCache() + uninterpretedSortValueDecls[sharedContext] = UninterpretedSortValueDecls() + uninterpretedSortValueInterpreters[sharedContext] = UninterpretedSortValueInterpretersCache() + registeredUninterpretedSortValues[sharedContext] = RegisteredUninterpretedSortValues() + } + } + + private fun incRef(context: Context) { + contextReferences[context] = contextReferences.getOrDefault(context, 0) + 1 + } + + private fun decRef(context: Context) { + val referencesAfterDec = contextReferences.getValue(context) - 1 + if (referencesAfterDec == 0) { + val nCtx = context.nCtx() + contextReferences -= context + + expressionsReversedCache.remove(context)!!.keys.decRefUnsafeAll(nCtx) + expressionsCache -= context + + sortsReversedCache.remove(context)!!.keys.decRefUnsafeAll(nCtx) + sortsCache -= context + + declsReversedCache.remove(context)!!.keys.decRefUnsafeAll(nCtx) + declsCache -= context + + uninterpretedSortValueInterpreters.remove(context)!!.decRefUnsafeAll(nCtx) + uninterpretedSortValueInterpreter -= context + uninterpretedSortValueDecls -= context + registeredUninterpretedSortValues -= context + + converterNativeObjectsCache.remove(context)!!.decRefUnsafeAll(nCtx) + tmpNativeObjectsCache.remove(context)!!.decRefUnsafeAll(nCtx) + + try { + ctx.close() + } catch (e: Z3Exception) { + throw KSolverException(e) + } + } else { + contextReferences[context] = referencesAfterDec + } + } + + override fun mkForkingSolver(): KForkingSolver { + return KZ3ForkingSolver(ctx, this, null).also { solvers += it } + } + + internal fun mkForkingSolver(parent: KZ3ForkingSolver): KForkingSolver { + return KZ3ForkingSolver(ctx, this, parent).also { + solvers += it + forkingSolverToContext[it] = forkingSolverToContext[parent] + } + } + + /** + * unregister [solver] for this manager + */ + internal fun close(solver: KZ3ForkingSolver) { + solvers -= solver + val sharedContext = forkingSolverToContext.getValue(solver) + forkingSolverToContext -= solver + decRef(sharedContext) + } + + override fun close() { + solvers.forEach(KZ3ForkingSolver::close) + } + + private fun Object2LongOpenHashMap.withNotInternalizedAsDefaultValue() = apply { + defaultReturnValue(KExprLongInternalizerBase.NOT_INTERNALIZED) + } + +} + +private typealias ExpressionsCache = Object2LongOpenHashMap> +private typealias ExpressionsReversedCache = Long2ObjectOpenHashMap> + +private typealias SortsCache = Object2LongOpenHashMap +private typealias SortsReversedCache = Long2ObjectOpenHashMap + +private typealias DeclsCache = Object2LongOpenHashMap> +private typealias DeclsReversedCache = Long2ObjectOpenHashMap> + +private typealias TmpNativeObjectsCache = LongOpenHashSet +private typealias ConverterNativeObjectsCache = LongOpenHashSet + +private typealias UninterpretedSortValueInterpreterCache = HashMap +private typealias UninterpretedSortValueDecls = Long2ObjectOpenHashMap +private typealias UninterpretedSortValueInterpretersCache = LongOpenHashSet +@Suppress("MaxLineLength") +private typealias RegisteredUninterpretedSortValues = HashMap diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt index f8bdfa43f..114082f01 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt @@ -1,252 +1,28 @@ package io.ksmt.solver.z3 -import com.microsoft.z3.Solver -import com.microsoft.z3.Status -import com.microsoft.z3.Z3Exception -import com.microsoft.z3.solverAssert -import com.microsoft.z3.solverAssertAndTrack -import com.microsoft.z3.solverCheckAssumptions -import com.microsoft.z3.solverGetUnsatCore -import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap import io.ksmt.KContext import io.ksmt.expr.KExpr -import io.ksmt.solver.KModel import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverException -import io.ksmt.solver.KSolverStatus import io.ksmt.sort.KBoolSort -import io.ksmt.utils.NativeLibraryLoader -import java.lang.ref.PhantomReference -import java.lang.ref.ReferenceQueue -import java.util.IdentityHashMap -import kotlin.time.Duration -import kotlin.time.DurationUnit - -open class KZ3Solver(private val ctx: KContext) : KSolver { - private val z3Ctx = KZ3Context(ctx) - private val solver = createSolver() - - private var lastCheckStatus = KSolverStatus.UNKNOWN - private var lastReasonOfUnknown: String? = null - private var lastModel: KZ3Model? = null - private var lastUnsatCore: List>? = null - - private var currentScope: UInt = 0u +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap - @Suppress("LeakingThis") - private val contextCleanupActionHandler = registerContextForCleanup(this, z3Ctx) +open class KZ3Solver(ctx: KContext) : KZ3SolverBase(ctx), KSolver { + override val z3Ctx: KZ3Context = KZ3Context(ctx) + private val trackedAssertions = ScopedArrayFrame { Long2ObjectOpenHashMap>() } - private val exprInternalizer by lazy { - createExprInternalizer(z3Ctx) - } - private val exprConverter by lazy { - createExprConverter(z3Ctx) + override fun saveTrackedAssertion(track: Long, trackedExpr: KExpr) { + trackedAssertions.currentFrame[track] = trackedExpr } - open fun createExprInternalizer(z3Ctx: KZ3Context): KZ3ExprInternalizer = KZ3ExprInternalizer(ctx, z3Ctx) - - open fun createExprConverter(z3Ctx: KZ3Context) = KZ3ExprConverter(ctx, z3Ctx) - - private fun createSolver(): Solver = z3Ctx.nativeContext.mkSolver() - - override fun configure(configurator: KZ3SolverConfiguration.() -> Unit) { - val params = z3Ctx.nativeContext.mkParams() - KZ3SolverConfigurationImpl(params).configurator() - solver.setParameters(params) - } + override fun findTrackedExprByTrack(track: Long): KExpr? = trackedAssertions.find { it[track] } override fun push() { - solver.push() - z3Ctx.pushAssertionLevel() - currentScope++ + super.push() + trackedAssertions.push() } override fun pop(n: UInt) { - require(n <= currentScope) { - "Can not pop $n scope levels because current scope level is $currentScope" - } - if (n == 0u) return - - solver.pop(n.toInt()) - repeat(n.toInt()) { z3Ctx.popAssertionLevel() } - - currentScope -= n - } - - override fun assert(expr: KExpr) = z3Try { - ctx.ensureContextMatch(expr) - - val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } - solver.solverAssert(z3Expr) - - z3Ctx.assertPendingAxioms(solver) - } - - private val trackedAssertions = Long2ObjectOpenHashMap>() - - override fun assertAndTrack(expr: KExpr) = z3Try { - ctx.ensureContextMatch(expr) - - val trackExpr = ctx.mkFreshConst("track", ctx.boolSort) - val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } - val z3TrackVar = with(exprInternalizer) { trackExpr.internalizeExpr() } - - trackedAssertions.put(z3TrackVar, expr) - - solver.solverAssertAndTrack(z3Expr, z3TrackVar) - } - - override fun check(timeout: Duration): KSolverStatus = z3TryCheck { - solver.updateTimeout(timeout) - solver.check().processCheckResult() - } - - override fun checkWithAssumptions( - assumptions: List>, - timeout: Duration - ): KSolverStatus = z3TryCheck { - ctx.ensureContextMatch(assumptions) - - val z3Assumptions = with(exprInternalizer) { - LongArray(assumptions.size) { - val assumption = assumptions[it] - - /** - * Assumptions are trivially unsat and no check-sat is required. - * */ - if (assumption == ctx.falseExpr) { - lastUnsatCore = listOf(ctx.falseExpr) - lastCheckStatus = KSolverStatus.UNSAT - return KSolverStatus.UNSAT - } - - assumption.internalizeExpr() - } - } - - solver.updateTimeout(timeout) - - solver.solverCheckAssumptions(z3Assumptions).processCheckResult() - } - - override fun model(): KModel = z3Try { - require(lastCheckStatus == KSolverStatus.SAT) { - "Model are only available after SAT checks, current solver status: $lastCheckStatus" - } - - val model = lastModel ?: KZ3Model( - model = solver.model, - ctx = ctx, - z3Ctx = z3Ctx, - internalizer = exprInternalizer - ) - lastModel = model - - model - } - - override fun unsatCore(): List> = z3Try { - require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } - - val unsatCore = lastUnsatCore ?: with(exprConverter) { - val solverUnsatCore = solver.solverGetUnsatCore() - solverUnsatCore.map { trackedAssertions.get(it) ?: it.convertExpr() } - } - lastUnsatCore = unsatCore - - unsatCore - } - - override fun reasonOfUnknown(): String = z3Try { - require(lastCheckStatus == KSolverStatus.UNKNOWN) { "Unknown reason is only available after UNKNOWN checks" } - lastReasonOfUnknown ?: solver.reasonUnknown - } - - override fun interrupt() = z3Try { - solver.interrupt() - } - - override fun close() { - unregisterContextCleanup(contextCleanupActionHandler) - z3Ctx.close() - } - - private fun Status?.processCheckResult() = when (this) { - Status.SATISFIABLE -> KSolverStatus.SAT - Status.UNSATISFIABLE -> KSolverStatus.UNSAT - Status.UNKNOWN -> KSolverStatus.UNKNOWN - null -> KSolverStatus.UNKNOWN - }.also { lastCheckStatus = it } - - private fun Solver.updateTimeout(timeout: Duration) { - val z3Timeout = if (timeout == Duration.INFINITE) { - UInt.MAX_VALUE.toInt() - } else { - timeout.toInt(DurationUnit.MILLISECONDS) - } - val params = z3Ctx.nativeContext.mkParams().apply { - add("timeout", z3Timeout) - } - setParameters(params) - } - - private inline fun z3Try(body: () -> T): T = try { - body() - } catch (ex: Z3Exception) { - throw KSolverException(ex) - } - - private fun invalidateSolverState() { - lastReasonOfUnknown = null - lastCheckStatus = KSolverStatus.UNKNOWN - lastModel = null - lastUnsatCore = null - } - - private inline fun z3TryCheck(body: () -> KSolverStatus): KSolverStatus = try { - invalidateSolverState() - body() - } catch (ex: Z3Exception) { - lastReasonOfUnknown = ex.message - KSolverStatus.UNKNOWN.also { lastCheckStatus = it } - } - - companion object { - init { - System.setProperty("z3.skipLibraryLoad", "true") - NativeLibraryLoader.load { os -> - when (os) { - NativeLibraryLoader.OS.LINUX -> listOf("libz3", "libz3java") - NativeLibraryLoader.OS.MACOS -> listOf("libz3", "libz3java") - NativeLibraryLoader.OS.WINDOWS -> listOf("vcruntime140", "vcruntime140_1", "libz3", "libz3java") - } - } - } - - private val cleanupHandlers = ReferenceQueue() - private val contextForCleanup = IdentityHashMap, KZ3Context>() - - /** Ensure Z3 native context is closed and all native memory is released. - * */ - private fun registerContextForCleanup(solver: KZ3Solver, context: KZ3Context): PhantomReference { - cleanupStaleContexts() - val cleanupHandler = PhantomReference(solver, cleanupHandlers) - contextForCleanup[cleanupHandler] = context - - return cleanupHandler - } - - private fun unregisterContextCleanup(handler: PhantomReference) { - contextForCleanup.remove(handler) - handler.clear() - cleanupStaleContexts() - } - - private fun cleanupStaleContexts() { - while (true) { - val handler = cleanupHandlers.poll() ?: break - contextForCleanup.remove(handler)?.close() - } - } + super.pop(n) + trackedAssertions.pop() } } diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt new file mode 100644 index 000000000..16a287a0d --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt @@ -0,0 +1,256 @@ +package io.ksmt.solver.z3 + +import com.microsoft.z3.Solver +import com.microsoft.z3.Status +import com.microsoft.z3.Z3Exception +import com.microsoft.z3.solverAssert +import com.microsoft.z3.solverAssertAndTrack +import com.microsoft.z3.solverCheckAssumptions +import com.microsoft.z3.solverGetUnsatCore +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverException +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import io.ksmt.utils.NativeLibraryLoader +import java.lang.ref.PhantomReference +import java.lang.ref.ReferenceQueue +import java.util.IdentityHashMap +import kotlin.time.Duration +import kotlin.time.DurationUnit + +abstract class KZ3SolverBase(protected val ctx: KContext) : KSolver { + protected abstract val z3Ctx: KZ3Context + protected val solver by lazy { createSolver() } + + protected var lastCheckStatus = KSolverStatus.UNKNOWN + protected var lastReasonOfUnknown: String? = null + protected var lastModel: KZ3Model? = null + protected var lastUnsatCore: List>? = null + + protected open var currentScope: UInt = 0u + + @Suppress("LeakingThis") + private val contextCleanupActionHandler = registerContextForCleanup(this, z3Ctx) + + protected val exprInternalizer by lazy { + createExprInternalizer(z3Ctx) + } + protected val exprConverter by lazy { + createExprConverter(z3Ctx) + } + + open fun createExprInternalizer(z3Ctx: KZ3Context): KZ3ExprInternalizer = KZ3ExprInternalizer(ctx, z3Ctx) + + open fun createExprConverter(z3Ctx: KZ3Context) = KZ3ExprConverter(ctx, z3Ctx) + + private fun createSolver(): Solver = z3Try { z3Ctx.nativeContext.mkSolver() } + + override fun configure(configurator: KZ3SolverConfiguration.() -> Unit) = z3Try { + val params = z3Ctx.nativeContext.mkParams() + KZ3SolverConfigurationImpl(params).configurator() + solver.setParameters(params) + } + + override fun push(): Unit = z3Try { + solver.push() + z3Ctx.pushAssertionLevel() + currentScope++ + } + + override fun pop(n: UInt) = z3Try { + require(n <= currentScope) { + "Can not pop $n scope levels because current scope level is $currentScope" + } + if (n == 0u) return + + solver.pop(n.toInt()) + repeat(n.toInt()) { z3Ctx.popAssertionLevel() } + + currentScope -= n + } + + override fun assert(expr: KExpr) = z3Try { + ctx.ensureContextMatch(expr) + + val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } + solver.solverAssert(z3Expr) + + z3Ctx.assertPendingAxioms(solver) + } + + protected abstract fun saveTrackedAssertion(track: Long, trackedExpr: KExpr) + protected abstract fun findTrackedExprByTrack(track: Long): KExpr? + + override fun assertAndTrack(expr: KExpr) = z3Try { + ctx.ensureContextMatch(expr) + + val trackExpr = ctx.mkFreshConst("track", ctx.boolSort) + val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } + val z3TrackVar = with(exprInternalizer) { trackExpr.internalizeExpr() } + + solver.solverAssertAndTrack(z3Expr, z3TrackVar) + saveTrackedAssertion(z3TrackVar, expr) + } + + override fun check(timeout: Duration): KSolverStatus = z3TryCheck { + solver.updateTimeout(timeout) + solver.check().processCheckResult() + } + + override fun checkWithAssumptions( + assumptions: List>, + timeout: Duration + ): KSolverStatus = z3TryCheck { + ctx.ensureContextMatch(assumptions) + + val z3Assumptions = with(exprInternalizer) { + LongArray(assumptions.size) { + val assumption = assumptions[it] + + /** + * Assumptions are trivially unsat and no check-sat is required. + * */ + if (assumption == ctx.falseExpr) { + lastUnsatCore = listOf(ctx.falseExpr) + lastCheckStatus = KSolverStatus.UNSAT + return KSolverStatus.UNSAT + } + + assumption.internalizeExpr() + } + } + + solver.updateTimeout(timeout) + + solver.solverCheckAssumptions(z3Assumptions).processCheckResult() + } + + override fun model(): KModel = z3Try { + require(lastCheckStatus == KSolverStatus.SAT) { + "Model are only available after SAT checks, current solver status: $lastCheckStatus" + } + + val model = lastModel ?: KZ3Model( + model = solver.model, + ctx = ctx, + z3Ctx = z3Ctx, + internalizer = exprInternalizer + ) + lastModel = model + + model + } + + override fun unsatCore(): List> = z3Try { + require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } + + val unsatCore = lastUnsatCore ?: with(exprConverter) { + val solverUnsatCore = solver.solverGetUnsatCore() + solverUnsatCore.map { solverUnsatCoreExpr -> + findTrackedExprByTrack(solverUnsatCoreExpr) ?: solverUnsatCoreExpr.convertExpr() + } + } + lastUnsatCore = unsatCore + + unsatCore + } + + override fun reasonOfUnknown(): String = z3Try { + require(lastCheckStatus == KSolverStatus.UNKNOWN) { "Unknown reason is only available after UNKNOWN checks" } + lastReasonOfUnknown ?: solver.reasonUnknown + } + + override fun interrupt() = z3Try { + solver.interrupt() + } + + override fun close() { + unregisterContextCleanup(contextCleanupActionHandler) + z3Ctx.close() + } + + protected fun Status?.processCheckResult() = when (this) { + Status.SATISFIABLE -> KSolverStatus.SAT + Status.UNSATISFIABLE -> KSolverStatus.UNSAT + Status.UNKNOWN -> KSolverStatus.UNKNOWN + null -> KSolverStatus.UNKNOWN + }.also { lastCheckStatus = it } + + protected fun Solver.updateTimeout(timeout: Duration) { + val z3Timeout = if (timeout == Duration.INFINITE) { + UInt.MAX_VALUE.toInt() + } else { + timeout.toInt(DurationUnit.MILLISECONDS) + } + val params = z3Ctx.nativeContext.mkParams().apply { + add("timeout", z3Timeout) + } + setParameters(params) + } + + protected inline fun z3Try(body: () -> T): T = try { + body() + } catch (ex: Z3Exception) { + throw KSolverException(ex) + } + + protected fun invalidateSolverState() { + lastReasonOfUnknown = null + lastCheckStatus = KSolverStatus.UNKNOWN + lastModel = null + lastUnsatCore = null + } + + protected inline fun z3TryCheck(body: () -> KSolverStatus): KSolverStatus = try { + invalidateSolverState() + body() + } catch (ex: Z3Exception) { + lastReasonOfUnknown = ex.message + KSolverStatus.UNKNOWN.also { lastCheckStatus = it } + } + + companion object { + init { + System.setProperty("z3.skipLibraryLoad", "true") + NativeLibraryLoader.load { os -> + when (os) { + NativeLibraryLoader.OS.LINUX -> listOf("libz3", "libz3java") + NativeLibraryLoader.OS.MACOS -> listOf("libz3", "libz3java") + NativeLibraryLoader.OS.WINDOWS -> listOf("vcruntime140", "vcruntime140_1", "libz3", "libz3java") + } + } + } + + private val cleanupHandlers = ReferenceQueue>() + private val contextForCleanup = IdentityHashMap>, KZ3Context>() + + /** Ensure Z3 native context is closed and all native memory is released. + * */ + private fun registerContextForCleanup( + solver: KSolver, + context: KZ3Context + ): PhantomReference> { + cleanupStaleContexts() + val cleanupHandler = PhantomReference(solver, cleanupHandlers) + contextForCleanup[cleanupHandler] = context + + return cleanupHandler + } + + private fun unregisterContextCleanup(handler: PhantomReference>) { + contextForCleanup.remove(handler) + handler.clear() + cleanupStaleContexts() + } + + private fun cleanupStaleContexts() { + while (true) { + val handler = cleanupHandlers.poll() ?: break + contextForCleanup.remove(handler)?.close() + } + } + } +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverConfiguration.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverConfiguration.kt index 92c2c559a..a3cc6dcc6 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverConfiguration.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverConfiguration.kt @@ -1,6 +1,7 @@ package io.ksmt.solver.z3 import com.microsoft.z3.Params +import com.microsoft.z3.Solver import io.ksmt.solver.KSolverConfiguration import io.ksmt.solver.KSolverUniversalConfigurationBuilder @@ -45,6 +46,44 @@ class KZ3SolverConfigurationImpl(private val params: Params) : KZ3SolverConfigur } } +class KZ3ForkingSolverConfigurationImpl(private val params: Params) : KZ3SolverConfiguration { + private val booleanOptions = hashMapOf() + private val intOptions = hashMapOf() + private val doubleOptions = hashMapOf() + private val stringOptions = hashMapOf() + + override fun setZ3Option(option: String, value: Boolean) { + params.add(option, value) + booleanOptions[option] = value + } + + override fun setZ3Option(option: String, value: Int) { + params.add(option, value) + intOptions[option] = value + } + + override fun setZ3Option(option: String, value: Double) { + params.add(option, value) + doubleOptions[option] = value + } + + override fun setZ3Option(option: String, value: String) { + params.add(option, value) + stringOptions[option] = value + } + + fun fork(params: Params): KZ3ForkingSolverConfigurationImpl = KZ3ForkingSolverConfigurationImpl(params).also { + booleanOptions.forEach { (option, value) -> it.setZ3Option(option, value) } + intOptions.forEach { (option, value) -> it.setZ3Option(option, value) } + doubleOptions.forEach { (option, value) -> it.setZ3Option(option, value) } + stringOptions.forEach { (option, value) -> it.setZ3Option(option, value) } + } + + fun setParameters(solver: Solver) { + solver.setParameters(params) + } +} + class KZ3SolverUniversalConfiguration( private val builder: KSolverUniversalConfigurationBuilder ) : KZ3SolverConfiguration { diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ScopedFrame.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ScopedFrame.kt new file mode 100644 index 000000000..e7ccf7043 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ScopedFrame.kt @@ -0,0 +1,129 @@ +package io.ksmt.solver.z3 + +internal interface ScopedFrame { + val currentScope: UInt + val currentFrame: T + + fun flatten(collect: T.(T) -> Unit): T + + /** + * find value [V] in frame [T], and return it or null + */ + fun find(predicate: (T) -> V?): V? + + fun push() + fun pop(n: UInt = 1u) +} + +internal class ScopedArrayFrame( + currentFrame: T, + private val createNewFrame: () -> T +) : ScopedFrame { + constructor(createNewFrame: () -> T) : this(createNewFrame(), createNewFrame) + + private val frames = arrayListOf(currentFrame) + + override var currentFrame = currentFrame + private set + + override val currentScope: UInt + get() = frames.size.toUInt() + + override fun flatten(collect: T.(T) -> Unit) = createNewFrame().also { newFrame -> + frames.forEach { newFrame.collect(it) } + } + + override fun find(predicate: (T) -> V?): V? { + frames.forEach { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun push() { + currentFrame = createNewFrame() + frames += currentFrame + } + + override fun pop(n: UInt) { + repeat(n.toInt()) { frames.removeLast() } + currentFrame = frames.last() + } +} + +internal class ScopedLinkedFrame private constructor( + private var current: LinkedFrame, + private val createNewFrame: () -> T, + private val copyFrame: (T) -> T +) : ScopedFrame { + constructor( + currentFrame: T, + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(LinkedFrame(currentFrame), createNewFrame, copyFrame) + + constructor( + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(createNewFrame(), createNewFrame, copyFrame) + + override val currentFrame: T + get() = current.value + + override val currentScope: UInt + get() = current.scope + + override fun flatten(collect: T.(T) -> Unit): T = createNewFrame().also { newFrame -> + forEachReversed { frame -> + newFrame.collect(frame) + } + } + + fun stacked(): ArrayDeque = ArrayDeque().also { stack -> + forEachReversed { frame -> + stack.addLast(frame) + } + } + + override fun find(predicate: (T) -> V?): V? { + forEachReversed { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun push() { + current = LinkedFrame(createNewFrame(), current) + } + + override fun pop(n: UInt) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + recreateTopFrame() + } + + private fun recreateTopFrame() { + val newTopFrame = copyFrame(currentFrame) + current = LinkedFrame(newTopFrame, current.previous) + } + + fun fork(parent: ScopedLinkedFrame) { + current = parent.current + recreateTopFrame() + } + + private inline fun forEachReversed(action: (T) -> Unit) { + var cur: LinkedFrame? = current + while (cur != null) { + action(cur.value) + cur = cur.previous + } + } + + private class LinkedFrame( + val value: E, + val previous: LinkedFrame? = null + ) { + val scope: UInt = previous?.scope?.plus(1u) ?: 0u + } + +} From b5ed3abe3981644827ecd0aaed2b1f0d191b1dba Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Mon, 14 Aug 2023 13:22:02 +0300 Subject: [PATCH 05/12] Yices forking solver, uninterpreted values in universe fix --- .../kotlin/io/ksmt/test/TestWorkerProcess.kt | 2 +- .../kotlin/io/ksmt/test/KForkingSolverTest.kt | 25 ++ .../io/ksmt/test/MultiIndexedArrayTest.kt | 4 +- .../test/UninterpretedSortUniverseTest.kt | 76 ++++++ .../io/ksmt/solver/yices/KYicesContext.kt | 91 ++++--- .../solver/yices/KYicesExprInternalizer.kt | 2 +- .../ksmt/solver/yices/KYicesForkingContext.kt | 47 ++++ .../ksmt/solver/yices/KYicesForkingSolver.kt | 112 +++++++++ .../yices/KYicesForkingSolverManager.kt | 160 ++++++++++++ .../io/ksmt/solver/yices/KYicesModel.kt | 16 +- .../io/ksmt/solver/yices/KYicesSolver.kt | 233 ++---------------- .../io/ksmt/solver/yices/KYicesSolverBase.kt | 232 +++++++++++++++++ .../solver/yices/KYicesSolverConfiguration.kt | 13 + .../io/ksmt/solver/yices/ScopedFrame.kt | 147 +++++++++++ .../yices/UninterpretedValuesTracker.kt | 90 +++++++ 15 files changed, 989 insertions(+), 261 deletions(-) create mode 100644 ksmt-test/src/test/kotlin/io/ksmt/test/UninterpretedSortUniverseTest.kt create mode 100644 ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt create mode 100644 ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt create mode 100644 ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt create mode 100644 ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverBase.kt create mode 100644 ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt create mode 100644 ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/UninterpretedValuesTracker.kt diff --git a/ksmt-test/src/main/kotlin/io/ksmt/test/TestWorkerProcess.kt b/ksmt-test/src/main/kotlin/io/ksmt/test/TestWorkerProcess.kt index d6308d7de..ff56fd4c9 100644 --- a/ksmt-test/src/main/kotlin/io/ksmt/test/TestWorkerProcess.kt +++ b/ksmt-test/src/main/kotlin/io/ksmt/test/TestWorkerProcess.kt @@ -108,7 +108,7 @@ class TestWorkerProcess : ChildProcessBase() { private fun internalizeAndConvertYices(assertions: List>): List> { // Yices doesn't reverse cache internalized expressions (only interpreted values) - KYicesContext().use { internContext -> + KYicesContext(ctx).use { internContext -> val internalizer = KYicesExprInternalizer(internContext) val yicesAssertions = with(internalizer) { diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt index 3a7e46595..c019cbf94 100644 --- a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt @@ -4,6 +4,7 @@ import io.ksmt.KContext import io.ksmt.solver.KForkingSolverManager import io.ksmt.solver.KSolverStatus import io.ksmt.solver.cvc5.KCvc5ForkingSolverManager +import io.ksmt.solver.yices.KYicesForkingSolverManager import io.ksmt.solver.z3.KZ3ForkingSolverManager import io.ksmt.utils.getValue import org.junit.jupiter.api.Nested @@ -38,6 +39,29 @@ class KForkingSolverTest { private fun mkCvc5ForkingSolverManager(ctx: KContext) = KCvc5ForkingSolverManager(ctx) } + @Nested + inner class KForkingSolverTestYices { + @Test + fun testCheckSat() = testCheckSat(::mkYicesForkingSolverManager) + + @Test + fun testModel() = testModel(::mkYicesForkingSolverManager) + + @Test + fun testUnsatCore() = testUnsatCore(::mkYicesForkingSolverManager) + + @Test + fun testUninterpretedSort() = testUninterpretedSort(::mkYicesForkingSolverManager) + + @Test + fun testScopedAssertions() = testScopedAssertions(::mkYicesForkingSolverManager) + + @Test + fun testLifeTime() = testLifeTime(::mkYicesForkingSolverManager) + + private fun mkYicesForkingSolverManager(ctx: KContext) = KYicesForkingSolverManager(ctx) + } + @Nested inner class KForkingSolverTestZ3 { @Test @@ -336,6 +360,7 @@ class KForkingSolverTest { } } + parentSolver.assert(u1 neq pu1v) assertEquals(KSolverStatus.SAT, parentSolver.check()) parentSolver.model().uninterpretedSortUniverse(uSort)?.also { universe -> assertContains(universe, pu1v) diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/MultiIndexedArrayTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/MultiIndexedArrayTest.kt index 923850844..2e5420616 100644 --- a/ksmt-test/src/test/kotlin/io/ksmt/test/MultiIndexedArrayTest.kt +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/MultiIndexedArrayTest.kt @@ -95,7 +95,7 @@ class MultiIndexedArrayTest { @Test fun testMultiIndexedArraysYicesWithZ3Oracle(): Unit = with(KContext(simplificationMode = NO_SIMPLIFY)) { oracleManager.createSolver(this, KZ3Solver::class).use { oracleSolver -> - KYicesContext().use { yicesNativeCtx -> + KYicesContext(this).use { yicesNativeCtx -> runMultiIndexedArraySamples(oracleSolver) { expr -> internalizeAndConvertYices(yicesNativeCtx, expr) } @@ -117,7 +117,7 @@ class MultiIndexedArrayTest { @Test fun testMultiIndexedArraysYicesWithYicesOracle(): Unit = with(KContext(simplificationMode = NO_SIMPLIFY)) { oracleManager.createSolver(this, KYicesSolver::class).use { oracleSolver -> - KYicesContext().use { yicesNativeCtx -> + KYicesContext(this).use { yicesNativeCtx -> runMultiIndexedArraySamples(oracleSolver) { expr -> internalizeAndConvertYices(yicesNativeCtx, expr) } diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/UninterpretedSortUniverseTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/UninterpretedSortUniverseTest.kt new file mode 100644 index 000000000..71808ae4b --- /dev/null +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/UninterpretedSortUniverseTest.kt @@ -0,0 +1,76 @@ +package io.ksmt.test + +import io.ksmt.KContext +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.solver.bitwuzla.KBitwuzlaSolver +import io.ksmt.solver.cvc5.KCvc5Solver +import io.ksmt.solver.yices.KYicesSolver +import io.ksmt.solver.z3.KZ3Solver +import io.ksmt.utils.getValue +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals +import kotlin.test.assertNotNull + +class UninterpretedSortUniverseTest { + + @Nested + inner class UninterpretedSortUniverseTestBitwuzla { + @Test + fun testUniverseContainsValue() = testUniverseContainsValue(::mkBitwuzlaSolver) + + private fun mkBitwuzlaSolver(ctx: KContext) = KBitwuzlaSolver(ctx) + } + + @Nested + inner class UninterpretedSortUniverseTestCvc5 { + @Test + fun testUniverseContainsValue() = testUniverseContainsValue(::mkCvc5Solver) + + private fun mkCvc5Solver(ctx: KContext) = KCvc5Solver(ctx) + } + + @Nested + inner class UninterpretedSortUniverseTestYices { + @Test + fun testUniverseContainsValue() = testUniverseContainsValue(::mkYicesSolver) + + private fun mkYicesSolver(ctx: KContext) = KYicesSolver(ctx) + } + + @Nested + inner class UninterpretedSortUniverseTestZ3 { + @Test + fun testUniverseContainsValue() = testUniverseContainsValue(::mkZ3Solver) + + private fun mkZ3Solver(ctx: KContext) = KZ3Solver(ctx) + } + + fun testUniverseContainsValue(mkSolver: (KContext) -> KSolver<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + mkSolver(ctx).use { s -> + with(ctx) { + val u = mkUninterpretedSort("u") + val u1 by u + val uval5 = mkUninterpretedSortValue(u, 5) + + s.assert(u1 neq uval5) + assertEquals(KSolverStatus.SAT, s.check()) + val u1v = s.model().eval(u1) + assertNotEquals(uval5, u1v) + + s.assert(u1 neq u1v) + assertEquals(KSolverStatus.SAT, s.check()) + + val universe = s.model().uninterpretedSortUniverse(u) + assertNotNull(universe) + + assertContains(universe, u1v) + assertContains(universe, uval5) + } + } + } +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesContext.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesContext.kt index 3fc8cda18..054bd9aca 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesContext.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesContext.kt @@ -3,13 +3,12 @@ package io.ksmt.solver.yices import com.sri.yices.Terms import com.sri.yices.Types import com.sri.yices.Yices -import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap -import it.unimi.dsi.fastutil.ints.IntOpenHashSet -import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap +import io.ksmt.KContext import io.ksmt.decl.KDecl import io.ksmt.expr.KConst import io.ksmt.expr.KExpr import io.ksmt.expr.KInterpretedValue +import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.solver.KSolverUnsupportedFeatureException import io.ksmt.solver.util.KExprIntInternalizerBase.Companion.NOT_INTERNALIZED import io.ksmt.solver.yices.TermUtils.addTerm @@ -19,37 +18,45 @@ import io.ksmt.solver.yices.TermUtils.funApplicationTerm import io.ksmt.solver.yices.TermUtils.mulTerm import io.ksmt.solver.yices.TermUtils.orTerm import io.ksmt.sort.KSort -import io.ksmt.utils.NativeLibraryLoader +import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap import java.math.BigInteger import java.util.concurrent.atomic.AtomicInteger -open class KYicesContext : AutoCloseable { - private var isClosed = false +open class KYicesContext(ctx: KContext) : AutoCloseable { + protected var isClosed = false - private val expressions = mkTermCache>() - private val yicesExpressions = mkTermReverseCache>() + protected open val expressions = mkTermCache>() + protected open val yicesExpressions = mkTermReverseCache>() - private val sorts = mkSortCache() - private val yicesSorts = mkSortReverseCache() + protected open val sorts = mkSortCache() + protected open val yicesSorts = mkSortReverseCache() - private val decls = mkTermCache>() - private val yicesDecls = mkTermReverseCache>() + protected open val decls = mkTermCache>() + protected open val yicesDecls = mkTermReverseCache>() - private val vars = mkTermCache>() - private val yicesVars = mkTermReverseCache>() + protected open val vars = mkTermCache>() + protected open val yicesVars = mkTermReverseCache>() - private val yicesTypes = mkSortSet() - private val yicesTerms = mkTermSet() + protected open val yicesTypes = mkSortSet() + protected open val yicesTerms = mkTermSet() val isActive: Boolean get() = !isClosed - fun findInternalizedExpr(expr: KExpr<*>): YicesTerm = expressions.getInt(expr) + fun findInternalizedExpr(expr: KExpr<*>): YicesTerm = expressions.getInt(expr).also { + if (it != NOT_INTERNALIZED) + uninterpretedSortValuesTracker.expressionUse(expr) + } + fun saveInternalizedExpr(expr: KExpr<*>, internalized: YicesTerm) { if (expressions.putIfAbsent(expr, internalized) == NOT_INTERNALIZED) { if (expr is KInterpretedValue<*> || expr is KConst<*>) { yicesExpressions.put(internalized, expr) } + uninterpretedSortValuesTracker.expressionSave(expr) } } @@ -175,9 +182,9 @@ open class KYicesContext : AutoCloseable { fun functionType(domain: YicesSortArray, range: YicesSort) = mkType { Types.functionType(domain, range) } fun newUninterpretedType(name: String) = mkType { Types.newUninterpretedType(name) } - val zero = mkTerm { Terms.intConst(0L) } - val one = mkTerm { Terms.intConst(1L) } - val minusOne = mkTerm { Terms.intConst(-1L) } + val zero by lazy { mkTerm { Terms.intConst(0L) } } + val one by lazy { mkTerm { Terms.intConst(1L) } } + val minusOne by lazy { mkTerm { Terms.intConst(-1L) } } private inline fun mkTerm(mk: () -> YicesTerm): YicesTerm = withGcGuard { val term = mk() @@ -294,7 +301,32 @@ open class KYicesContext : AutoCloseable { fun uninterpretedSortConst(sort: YicesSort, idx: Int) = mkTerm { Terms.mkConst(sort, idx) } - private var maxValueIndex = 0 + protected open var maxValueIndex = 0 + + /** + * Collects uninterpreted sort values usage for [KYicesModel.uninterpretedSortUniverse] + */ + protected open val uninterpretedSortValuesTracker = UninterpretedValuesTracker( + ctx, + ScopedArrayFrame(::HashSet), + ScopedArrayFrame(::HashMap), + Object2IntOpenHashMap>() + ) + + fun pushAssertionLevel() { + uninterpretedSortValuesTracker.push() + } + + fun popAssertionLevel(n: UInt) { + uninterpretedSortValuesTracker.pop(n) + } + + fun registerUninterpretedSortValue(value: KUninterpretedSortValue) { + uninterpretedSortValuesTracker.addToCurrentLevel(value) + } + + fun uninterpretedSortValues(sort: KUninterpretedSort) = + uninterpretedSortValuesTracker.getUninterpretedSortValues(sort) /** * Yices can produce different values with the same index. @@ -339,20 +371,6 @@ open class KYicesContext : AutoCloseable { } companion object { - init { - if (!Yices.isReady()) { - NativeLibraryLoader.load { os -> - when (os) { - NativeLibraryLoader.OS.LINUX -> listOf("libyices", "libyices2java") - NativeLibraryLoader.OS.WINDOWS -> listOf("libyices", "libyices2java") - NativeLibraryLoader.OS.MACOS -> listOf("libyices", "libyices2java") - } - } - Yices.init() - Yices.setReadyFlag(true) - } - } - private const val UNINTERPRETED_SORT_VALUE_SHIFT = 1 shl 30 private const val UNINTERPRETED_SORT_MAX_ALLOWED_VALUE = UNINTERPRETED_SORT_VALUE_SHIFT / 2 private const val UNINTERPRETED_SORT_MIN_ALLOWED_VALUE = -UNINTERPRETED_SORT_MAX_ALLOWED_VALUE @@ -420,7 +438,8 @@ open class KYicesContext : AutoCloseable { } } - private fun performGc() { + @JvmStatic + protected fun performGc() { // spin wait until [gcGuard] == [FREE] while (true) { if (gcGuard.compareAndSet(FREE, ON_GC)) { diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesExprInternalizer.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesExprInternalizer.kt index fd9dde704..4ef22f8be 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesExprInternalizer.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesExprInternalizer.kt @@ -960,7 +960,7 @@ open class KYicesExprInternalizer( yicesCtx.uninterpretedSortConst( sort.internalizeSort(), yicesCtx.uninterpretedSortValueIndex(valueIdx) - ) + ).also { yicesCtx.registerUninterpretedSortValue(expr) } } } diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt new file mode 100644 index 000000000..750ae1831 --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt @@ -0,0 +1,47 @@ +package io.ksmt.solver.yices + +import io.ksmt.KContext + +/** + * Yices Context that allows forking and resources sharing via [KYicesForkingSolverManager]. + * To track resources, we have to use a unique solver specific ID, related to them. + * @param [solver] is used as "specific ID" for resource tracking, + * because we can't use initialized [com.sri.yices.Context] here. + */ +class KYicesForkingContext( + ctx: KContext, + manager: KYicesForkingSolverManager, + solver: KYicesForkingSolver +) : KYicesContext(ctx) { + override val expressions = manager.findExpressionsCache(solver) + override val yicesExpressions = manager.findExpressionsReversedCache(solver) + + override val sorts = manager.findSortsCache(solver) + override val yicesSorts = manager.findSortsReversedCache(solver) + + override val decls = manager.findDeclsCache(solver) + override val yicesDecls = manager.findDeclsReversedCache(solver) + + override val vars = manager.findVarsCache(solver) + override val yicesVars = manager.findVarsReversedCache(solver) + + override val yicesTypes = manager.findTypesCache(solver) + override val yicesTerms = manager.findTermsCache(solver) + + private val maxValueIndexAtomic = manager.findMaxUninterpretedSortValueIdx(solver) + + override var maxValueIndex: Int + get() = maxValueIndexAtomic.get() + set(value) { + maxValueIndexAtomic.set(value) + } + + override val uninterpretedSortValuesTracker = manager.createUninterpretedValuesTracker(solver) + + override fun close() { + if (isClosed) return + isClosed = true + + performGc() + } +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt new file mode 100644 index 000000000..d96be5c1d --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt @@ -0,0 +1,112 @@ +package io.ksmt.solver.yices + +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import kotlin.time.Duration + +class KYicesForkingSolver( + ctx: KContext, + private val manager: KYicesForkingSolverManager, + parent: KYicesForkingSolver?, +) : KForkingSolver, KYicesSolverBase(ctx) { + + override val yicesCtx: KYicesForkingContext by lazy { KYicesForkingContext(ctx, manager, this) } + + private val trackedAssertions = + ScopedLinkedFrame, YicesTerm>>>(::ArrayList, ::ArrayList) + private val yicesAssertions = ScopedLinkedFrame>(::HashSet, ::HashSet) + + override val currentScope: UInt + get() = trackedAssertions.currentScope + + private val ksmtConfig: KYicesForkingSolverConfigurationImpl by lazy { + parent?.ksmtConfig?.fork(config) ?: KYicesForkingSolverConfigurationImpl(config) + } + + private var assertionsInitiated = parent == null + + init { + if (parent != null) { + trackedAssertions.fork(parent.trackedAssertions) + yicesAssertions.fork(parent.yicesAssertions) + + ksmtConfig // force initialization + } + } + + private fun ensureAssertionsInitiated() { + if (assertionsInitiated) return + + yicesAssertions.stacked() + .zip(trackedAssertions.stacked()) + .asReversed() + .forEachIndexed { scope, (yicesAssertionFrame, _) -> + if (scope > 0) nativeContext.push() + + yicesAssertionFrame.forEach(nativeContext::assertFormula) + } + + assertionsInitiated = true + } + + override fun configure(configurator: KYicesSolverConfiguration.() -> Unit) { + requireActiveConfig() + ksmtConfig.configurator() + } + + override fun fork(): KForkingSolver = manager.mkForkingSolver(this) + + override fun saveTrackedAssertion(track: YicesTerm, trackedExpr: KExpr) { + trackedAssertions.currentFrame += trackedExpr to track + } + + override fun collectTrackedAssertions(collector: (Pair, YicesTerm>) -> Unit) { + trackedAssertions.forEach { frame -> + frame.forEach(collector) + } + } + + override val hasTrackedAssertions: Boolean + get() = trackedAssertions.any { it.isNotEmpty() } + + override fun assert(expr: KExpr) = yicesTry { + ensureAssertionsInitiated() + ctx.ensureContextMatch(expr) + + val yicesExpr = with(exprInternalizer) { expr.internalize() } + nativeContext.assertFormula(yicesExpr) + yicesAssertions.currentFrame += yicesExpr + } + + override fun assertAndTrack(expr: KExpr) { + ensureAssertionsInitiated() + super.assertAndTrack(expr) + } + + override fun push() { + ensureAssertionsInitiated() + super.push() + trackedAssertions.push() + yicesAssertions.push() + } + + override fun pop(n: UInt) { + ensureAssertionsInitiated() + super.pop(n) + trackedAssertions.pop(n) + yicesAssertions.pop(n) + } + + override fun check(timeout: Duration): KSolverStatus { + ensureAssertionsInitiated() + return super.check(timeout) + } + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { + ensureAssertionsInitiated() + return super.checkWithAssumptions(assumptions, timeout) + } +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt new file mode 100644 index 000000000..04966037f --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt @@ -0,0 +1,160 @@ +package io.ksmt.solver.yices + +import com.sri.yices.Yices +import io.ksmt.KAst +import io.ksmt.KContext +import io.ksmt.decl.KDecl +import io.ksmt.expr.KExpr +import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KForkingSolverManager +import io.ksmt.solver.util.KExprIntInternalizerBase +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap +import java.util.Collections.newSetFromMap +import java.util.Collections.synchronizedSet +import java.util.IdentityHashMap +import java.util.concurrent.atomic.AtomicInteger + +class KYicesForkingSolverManager( + private val ctx: KContext +) : KForkingSolverManager { + + private val solvers = synchronizedSet(newSetFromMap(IdentityHashMap())) + private val sharedCacheReferences = IdentityHashMap() + + private val expressionsCache = IdentityHashMap() + private val expressionsReversedCache = IdentityHashMap() + private val sortsCache = IdentityHashMap() + private val sortsReversedCache = IdentityHashMap() + private val declsCache = IdentityHashMap() + private val declsReversedCache = IdentityHashMap() + private val varsCache = IdentityHashMap() + private val varsReversedCache = IdentityHashMap() + private val typesCache = IdentityHashMap() + private val termsCache = IdentityHashMap() + private val maxUninterpretedSortValueIndex = IdentityHashMap() + + private val scopedExpressions = IdentityHashMap() + private val scopedUninterpretedValues = IdentityHashMap() + private val expressionLevels = IdentityHashMap() + + internal fun findExpressionsCache(s: KYicesForkingSolver): ExpressionsCache = expressionsCache.getValue(s) + internal fun findExpressionsReversedCache(s: KYicesForkingSolver): ExpressionsReversedCache = + expressionsReversedCache.getValue(s) + + internal fun findSortsCache(s: KYicesForkingSolver): SortsCache = sortsCache.getValue(s) + internal fun findSortsReversedCache(s: KYicesForkingSolver): SortsReversedCache = sortsReversedCache.getValue(s) + internal fun findDeclsCache(s: KYicesForkingSolver): DeclsCache = declsCache.getValue(s) + internal fun findDeclsReversedCache(s: KYicesForkingSolver): DeclsReversedCache = declsReversedCache.getValue(s) + internal fun findVarsCache(s: KYicesForkingSolver): VarsCache = varsCache.getValue(s) + internal fun findVarsReversedCache(s: KYicesForkingSolver): VarsReversedCache = varsReversedCache.getValue(s) + internal fun findTypesCache(s: KYicesForkingSolver): TypesCache = typesCache.getValue(s) + internal fun findTermsCache(s: KYicesForkingSolver): TermsCache = termsCache.getValue(s) + internal fun findMaxUninterpretedSortValueIdx(s: KYicesForkingSolver) = maxUninterpretedSortValueIndex.getValue(s) + + override fun mkForkingSolver(): KForkingSolver = + KYicesForkingSolver(ctx, this, null).also { + solvers += it + sharedCacheReferences[it] = AtomicInteger(1) + expressionsCache[it] = ExpressionsCache().withNotInternalizedDefaultValue() + expressionsReversedCache[it] = ExpressionsReversedCache() + sortsCache[it] = SortsCache().withNotInternalizedDefaultValue() + sortsReversedCache[it] = SortsReversedCache() + declsCache[it] = DeclsCache().withNotInternalizedDefaultValue() + declsReversedCache[it] = DeclsReversedCache() + varsCache[it] = VarsCache().withNotInternalizedDefaultValue() + varsReversedCache[it] = VarsReversedCache() + typesCache[it] = TypesCache() + termsCache[it] = TermsCache() + maxUninterpretedSortValueIndex[it] = AtomicInteger(0) + scopedExpressions[it] = ScopedExpressions(::HashSet, ::HashSet) + scopedUninterpretedValues[it] = ScopedUninterpretedSortValues(::HashMap, ::HashMap) + expressionLevels[it] = ExpressionLevels() + } + + internal fun mkForkingSolver(parent: KYicesForkingSolver) = KYicesForkingSolver(ctx, this, parent).also { + solvers += it + sharedCacheReferences[it] = sharedCacheReferences.getValue(parent).apply { incrementAndGet() } + expressionsCache[it] = expressionsCache[parent] + expressionsReversedCache[it] = expressionsReversedCache[parent] + sortsCache[it] = sortsCache[parent] + sortsReversedCache[it] = sortsReversedCache[parent] + declsCache[it] = declsCache[parent] + declsReversedCache[it] = declsReversedCache[parent] + varsCache[it] = varsCache[parent] + varsReversedCache[it] = varsReversedCache[parent] + typesCache[it] = typesCache[parent] + termsCache[it] = termsCache[parent] + scopedExpressions[it] = ScopedExpressions(::HashSet, ::HashSet) + .apply { fork(scopedExpressions.getValue(parent)) } + scopedUninterpretedValues[it] = ScopedUninterpretedSortValues(::HashMap, ::HashMap) + .apply { fork(scopedUninterpretedValues.getValue(parent)) } + expressionLevels[it] = ExpressionLevels(expressionLevels.getValue(parent)) + + val parentMaxUninterpretedSortValueIdx = maxUninterpretedSortValueIndex.getValue(parent).get() + maxUninterpretedSortValueIndex[it] = AtomicInteger(parentMaxUninterpretedSortValueIdx) + } + + internal fun createUninterpretedValuesTracker(solver: KYicesForkingSolver) = UninterpretedValuesTracker( + ctx, + scopedExpressions.getValue(solver), + scopedUninterpretedValues.getValue(solver), + expressionLevels.getValue(solver) + ) + + /** + * Unregisters [solver] for this manager + */ + internal fun close(solver: KYicesForkingSolver) { + solvers -= solver + decRef(solver) + } + + override fun close() { + solvers.forEach(KYicesForkingSolver::close) + } + + private fun decRef(solver: KYicesForkingSolver) { + val referencesAfterDec = sharedCacheReferences.getValue(solver).decrementAndGet() + if (referencesAfterDec == 0) { + sharedCacheReferences -= solver + expressionsCache -= solver + expressionsReversedCache -= solver + sortsCache -= solver + sortsReversedCache -= solver + declsCache -= solver + declsReversedCache -= solver + varsCache -= solver + varsReversedCache -= solver + typesCache.remove(solver)?.forEach(Yices::yicesDecrefType) + termsCache.remove(solver)?.forEach(Yices::yicesDecrefTerm) + maxUninterpretedSortValueIndex -= solver + scopedExpressions -= solver + scopedUninterpretedValues -= solver + expressionLevels -= solver + } + } + + private fun Object2IntOpenHashMap.withNotInternalizedDefaultValue() = apply { + defaultReturnValue(KExprIntInternalizerBase.NOT_INTERNALIZED) + } +} + +private typealias ExpressionsCache = Object2IntOpenHashMap> +private typealias ExpressionsReversedCache = Int2ObjectOpenHashMap> +private typealias SortsCache = Object2IntOpenHashMap +private typealias SortsReversedCache = Int2ObjectOpenHashMap +private typealias DeclsCache = Object2IntOpenHashMap> +private typealias DeclsReversedCache = Int2ObjectOpenHashMap> +private typealias VarsCache = Object2IntOpenHashMap> +private typealias VarsReversedCache = Int2ObjectOpenHashMap> +private typealias TypesCache = IntOpenHashSet +private typealias TermsCache = IntOpenHashSet +private typealias ScopedExpressions = ScopedLinkedFrame>> +@Suppress("MaxLineLength") +private typealias ScopedUninterpretedSortValues = ScopedLinkedFrame>> +private typealias ExpressionLevels = Object2IntOpenHashMap> diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesModel.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesModel.kt index b643189f9..8e461e797 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesModel.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesModel.kt @@ -9,10 +9,10 @@ import io.ksmt.decl.KDecl import io.ksmt.decl.KFuncDecl import io.ksmt.expr.KExpr import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KModel import io.ksmt.solver.model.KFuncInterp import io.ksmt.solver.model.KFuncInterpEntryVarsFree import io.ksmt.solver.model.KFuncInterpVarsFree -import io.ksmt.solver.KModel import io.ksmt.solver.model.KModelEvaluator import io.ksmt.solver.model.KModelImpl import io.ksmt.sort.KArray2Sort @@ -62,14 +62,14 @@ class KYicesModel( private val interpretations = hashMapOf, KFuncInterp<*>>() private val funcInterpretationsToDo = arrayListOf>>() - override fun uninterpretedSortUniverse( - sort: KUninterpretedSort - ): Set? = uninterpretedSortUniverse.getOrPut(sort) { - val sortDependencies = uninterpretedSortDependencies[sort] ?: return null + override fun uninterpretedSortUniverse(sort: KUninterpretedSort) = uninterpretedSortUniverse.getOrPut(sort) { + val knownTrackedValues = yicesCtx.uninterpretedSortValues(sort) + + val sortDependencies = uninterpretedSortDependencies[sort] ?: return knownTrackedValues sortDependencies.forEach { interpretation(it) } - knownUninterpretedSortValues[sort]?.values?.toHashSet() ?: hashSetOf() + knownTrackedValues + (knownUninterpretedSortValues[sort]?.values?.toHashSet() ?: hashSetOf()) } private val evaluatorWithModelCompletion by lazy { KModelEvaluator(ctx, this, isComplete = true) } @@ -107,7 +107,7 @@ class KYicesModel( } } - private fun functionInterpretation(yval: YVal, decl: KFuncDecl): KFuncInterp { + private fun functionInterpretation(yval: YVal, decl: KFuncDecl): KFuncInterp { val functionChildren = model.expandFunction(yval) val default = if (yval.tag != YValTag.UNKNOWN) { getValue(functionChildren.value, decl.sort).uncheckedCast<_, KExpr>() @@ -163,7 +163,7 @@ class KYicesModel( declarations.forEach { interpretation(it) } val uninterpretedSortsUniverses = uninterpretedSorts.associateWith { - uninterpretedSortUniverse(it) ?: error("missed sort universe for $it") + uninterpretedSortUniverse(it) } return KModelImpl(ctx, interpretations.toMap(), uninterpretedSortsUniverses) diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolver.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolver.kt index 5d54a73ad..a1220a44b 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolver.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolver.kt @@ -1,234 +1,41 @@ package io.ksmt.solver.yices -import com.sri.yices.Config -import com.sri.yices.Context -import com.sri.yices.Status -import com.sri.yices.YicesException -import it.unimi.dsi.fastutil.ints.IntArrayList -import it.unimi.dsi.fastutil.ints.IntOpenHashSet import io.ksmt.KContext import io.ksmt.expr.KExpr -import io.ksmt.solver.KModel -import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverException -import io.ksmt.solver.KSolverStatus import io.ksmt.sort.KBoolSort -import java.util.Timer -import java.util.TimerTask -import kotlin.time.Duration -class KYicesSolver(private val ctx: KContext) : KSolver { - private val yicesCtx = KYicesContext() +class KYicesSolver(ctx: KContext) : KYicesSolverBase(ctx) { + override val yicesCtx = KYicesContext(ctx) - private val config = Config() - private val nativeContext by lazy { - Context(config).also { - config.close() - } - } - - private val exprInternalizer: KYicesExprInternalizer by lazy { - KYicesExprInternalizer(yicesCtx) - } - private val exprConverter: KYicesExprConverter by lazy { - KYicesExprConverter(ctx, yicesCtx) - } - - private var lastAssumptions: TrackedAssumptions? = null - private var lastCheckStatus = KSolverStatus.UNKNOWN - private var lastReasonOfUnknown: String? = null - - private var currentLevelTrackedAssertions = mutableListOf, YicesTerm>>() - private val trackedAssertions = mutableListOf(currentLevelTrackedAssertions) - - private val timer = Timer() - - override fun configure(configurator: KYicesSolverConfiguration.() -> Unit) { - require(config.isActive) { - "Solver instance has already been created" - } - - KYicesSolverConfigurationImpl(config).configurator() - } - - override fun assert(expr: KExpr) = yicesTry { - ctx.ensureContextMatch(expr) - - val yicesExpr = with(exprInternalizer) { expr.internalize() } - nativeContext.assertFormula(yicesExpr) - } - - override fun assertAndTrack(expr: KExpr) = yicesTry { - ctx.ensureContextMatch(expr) - - val trackVarExpr = ctx.mkFreshConst("track", ctx.boolSort) - val trackedExpr = with(ctx) { !trackVarExpr or expr } + private val trackedAssertions = ScopedArrayFrame, YicesTerm>>>(::ArrayList) + override val currentScope: UInt + get() = trackedAssertions.currentScope - assert(trackedExpr) + override val hasTrackedAssertions: Boolean + get() = trackedAssertions.any { it.isNotEmpty() } - val yicesTrackVar = with(exprInternalizer) { trackVarExpr.internalize() } - currentLevelTrackedAssertions += expr to yicesTrackVar + override fun saveTrackedAssertion(track: YicesTerm, trackedExpr: KExpr) { + trackedAssertions.currentFrame += trackedExpr to track } - override fun push(): Unit = yicesTry { - nativeContext.push() - - currentLevelTrackedAssertions = mutableListOf() - trackedAssertions.add(currentLevelTrackedAssertions) - } - - override fun pop(n: UInt) = yicesTry { - val currentScope = trackedAssertions.lastIndex.toUInt() - require(n <= currentScope) { - "Can not pop $n scope levels because current scope level is $currentScope" - } - - if (n == 0u) return - - repeat(n.toInt()) { - nativeContext.pop() - trackedAssertions.removeLast() - } - currentLevelTrackedAssertions = trackedAssertions.last() - } - - override fun check(timeout: Duration): KSolverStatus = yicesTryCheck { - if (trackedAssertions.any { it.isNotEmpty() }) { - return checkWithAssumptions(emptyList(), timeout) - } - - checkWithTimer(timeout) { - nativeContext.check() - }.processCheckResult() - } - - override fun checkWithAssumptions( - assumptions: List>, - timeout: Duration - ): KSolverStatus = yicesTryCheck { - ctx.ensureContextMatch(assumptions) - - val yicesAssumptions = TrackedAssumptions().also { lastAssumptions = it } - + override fun collectTrackedAssertions(collector: (Pair, YicesTerm>) -> Unit) { trackedAssertions.forEach { frame -> - frame.forEach { assertion -> - yicesAssumptions.assumeTrackedAssertion(assertion) - } - } - - with(exprInternalizer) { - assumptions.forEach { assumedExpr -> - yicesAssumptions.assumeAssumption(assumedExpr, assumedExpr.internalize()) - } - } - - checkWithTimer(timeout) { - nativeContext.checkWithAssumptions(yicesAssumptions.assumedTerms()) - }.processCheckResult() - } - - override fun model(): KModel = yicesTry { - require(lastCheckStatus == KSolverStatus.SAT) { - "Model are only available after SAT checks, current solver status: $lastCheckStatus" - } - val model = nativeContext.model - - return KYicesModel(model, ctx, yicesCtx, exprInternalizer, exprConverter) - } - - override fun unsatCore(): List> = yicesTry { - require(lastCheckStatus == KSolverStatus.UNSAT) { - "Unsat cores are only available after UNSAT checks" + frame.forEach(collector) } - - lastAssumptions?.resolveUnsatCore(nativeContext.unsatCore) ?: emptyList() } - override fun reasonOfUnknown(): String { - require(lastCheckStatus == KSolverStatus.UNKNOWN) { - "Unknown reason is only available after UNKNOWN checks" - } - - // There is no way to retrieve reason of unknown from Yices in general case. - return lastReasonOfUnknown ?: "unknown" - } - - override fun interrupt() = yicesTry { - nativeContext.stopSearch() - } - - private inline fun checkWithTimer(timeout: Duration, body: () -> T): T { - val task = StopSearchTask() - - if (timeout.isFinite()) { - timer.schedule(task, timeout.inWholeMilliseconds) - } - - return try { - body() - } finally { - task.cancel() - } - } - - private inline fun yicesTry(body: () -> T): T = try { - body() - } catch (ex: YicesException) { - throw KSolverException(ex) - } - - private inline fun yicesTryCheck(body: () -> KSolverStatus): KSolverStatus = try { - invalidateSolverState() - body() - } catch (ex: YicesException) { - lastReasonOfUnknown = ex.message - KSolverStatus.UNKNOWN.also { lastCheckStatus = it } - } - - private fun invalidateSolverState() { - lastCheckStatus = KSolverStatus.UNKNOWN - lastReasonOfUnknown = null - lastAssumptions = null - } - - private fun Status.processCheckResult() = when (this) { - Status.SAT -> KSolverStatus.SAT - Status.UNSAT -> KSolverStatus.UNSAT - else -> KSolverStatus.UNKNOWN - }.also { lastCheckStatus = it } - - override fun close() { - nativeContext.close() - yicesCtx.close() - timer.cancel() + override fun configure(configurator: KYicesSolverConfiguration.() -> Unit) { + requireActiveConfig() + KYicesSolverConfigurationImpl(config).configurator() } - private inner class StopSearchTask : TimerTask() { - override fun run() { - nativeContext.stopSearch() - } + override fun push() { + super.push() + trackedAssertions.push() } - private class TrackedAssumptions { - private val assumedExprs = arrayListOf, YicesTerm>>() - private val assumedTerms = IntArrayList() - - fun assumeTrackedAssertion(trackedAssertion: Pair, YicesTerm>) { - assumedExprs.add(trackedAssertion) - assumedTerms.add(trackedAssertion.second) - } - - fun assumeAssumption(expr: KExpr, term: YicesTerm) = - assumeTrackedAssertion(expr to term) - - fun assumedTerms(): YicesTermArray { - assumedTerms.trim() // Elements length now equal to size - return assumedTerms.elements() - } - - fun resolveUnsatCore(yicesUnsatCore: YicesTermArray): List> { - val unsatCoreTerms = IntOpenHashSet(yicesUnsatCore) - return assumedExprs.mapNotNull { (expr, term) -> expr.takeIf { unsatCoreTerms.contains(term) } } - } + override fun pop(n: UInt) { + super.pop(n) + trackedAssertions.pop(n) } } diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverBase.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverBase.kt new file mode 100644 index 000000000..d8ad8eb00 --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverBase.kt @@ -0,0 +1,232 @@ +package io.ksmt.solver.yices + +import com.sri.yices.Config +import com.sri.yices.Context +import com.sri.yices.Status +import com.sri.yices.Yices +import com.sri.yices.YicesException +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverException +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import io.ksmt.utils.NativeLibraryLoader +import it.unimi.dsi.fastutil.ints.IntArrayList +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import java.util.Timer +import java.util.TimerTask +import kotlin.time.Duration + +abstract class KYicesSolverBase(protected val ctx: KContext) : KSolver { + protected abstract val yicesCtx: KYicesContext + protected val nativeContext by lazy { Context(config).also { config.close() } } + protected val config by lazy { Config() } + + protected val exprInternalizer: KYicesExprInternalizer by lazy { KYicesExprInternalizer(yicesCtx) } + protected val exprConverter: KYicesExprConverter by lazy { KYicesExprConverter(ctx, yicesCtx) } + + private var lastAssumptions: TrackedAssumptions? = null + private var lastCheckStatus = KSolverStatus.UNKNOWN + private var lastReasonOfUnknown: String? = null + + protected abstract val currentScope: UInt + + protected abstract fun saveTrackedAssertion(track: YicesTerm, trackedExpr: KExpr) + protected abstract fun collectTrackedAssertions(collector: (Pair, YicesTerm>) -> Unit) + protected abstract val hasTrackedAssertions: Boolean + + private val timer = Timer() + + protected fun requireActiveConfig() = require(config.isActive) { + "Solver instance has already been created" + } + + override fun assert(expr: KExpr) = yicesTry { + ctx.ensureContextMatch(expr) + + val yicesExpr = with(exprInternalizer) { expr.internalize() } + nativeContext.assertFormula(yicesExpr) + } + + override fun assertAndTrack(expr: KExpr) = yicesTry { + ctx.ensureContextMatch(expr) + + val trackVarExpr = ctx.mkFreshConst("track", ctx.boolSort) + val trackedExpr = with(ctx) { !trackVarExpr or expr } + + assert(trackedExpr) + + val yicesTrackVar = with(exprInternalizer) { trackVarExpr.internalize() } + saveTrackedAssertion(yicesTrackVar, expr) + } + + override fun push(): Unit = yicesTry { + nativeContext.push() + yicesCtx.pushAssertionLevel() + } + + override fun pop(n: UInt) = yicesTry { + require(n <= currentScope) { + "Can not pop $n scope levels because current scope level is $currentScope" + } + + if (n == 0u) return + + repeat(n.toInt()) { + nativeContext.pop() + } + yicesCtx.popAssertionLevel(n) + } + + override fun check(timeout: Duration): KSolverStatus = if (hasTrackedAssertions) { + checkWithAssumptions(emptyList(), timeout) + } else yicesTryCheck { + checkWithTimer(timeout) { + nativeContext.check() + }.processCheckResult() + } + + override fun checkWithAssumptions( + assumptions: List>, + timeout: Duration + ): KSolverStatus = yicesTryCheck { + ctx.ensureContextMatch(assumptions) + + val yicesAssumptions = TrackedAssumptions().also { lastAssumptions = it } + + collectTrackedAssertions(yicesAssumptions::assumeTrackedAssertion) + + with(exprInternalizer) { + assumptions.forEach { assumedExpr -> + yicesAssumptions.assumeAssumption(assumedExpr, assumedExpr.internalize()) + } + } + + checkWithTimer(timeout) { + nativeContext.checkWithAssumptions(yicesAssumptions.assumedTerms()) + }.processCheckResult() + } + + override fun model(): KModel = yicesTry { + require(lastCheckStatus == KSolverStatus.SAT) { + "Model are only available after SAT checks, current solver status: $lastCheckStatus" + } + val model = nativeContext.model + + return KYicesModel(model, ctx, yicesCtx, exprInternalizer, exprConverter) + } + + override fun unsatCore(): List> = yicesTry { + require(lastCheckStatus == KSolverStatus.UNSAT) { + "Unsat cores are only available after UNSAT checks" + } + + lastAssumptions?.resolveUnsatCore(nativeContext.unsatCore) ?: emptyList() + } + + override fun reasonOfUnknown(): String { + require(lastCheckStatus == KSolverStatus.UNKNOWN) { + "Unknown reason is only available after UNKNOWN checks" + } + + // There is no way to retrieve reason of unknown from Yices in general case. + return lastReasonOfUnknown ?: "unknown" + } + + override fun interrupt() = yicesTry { + nativeContext.stopSearch() + } + + private inline fun checkWithTimer(timeout: Duration, body: () -> T): T { + val task = StopSearchTask() + + if (timeout.isFinite()) { + timer.schedule(task, timeout.inWholeMilliseconds) + } + + return try { + body() + } finally { + task.cancel() + } + } + + protected inline fun yicesTry(body: () -> T): T = try { + body() + } catch (ex: YicesException) { + throw KSolverException(ex) + } + + private inline fun yicesTryCheck(body: () -> KSolverStatus): KSolverStatus = try { + invalidateSolverState() + body() + } catch (ex: YicesException) { + lastReasonOfUnknown = ex.message + KSolverStatus.UNKNOWN.also { lastCheckStatus = it } + } + + private fun invalidateSolverState() { + lastCheckStatus = KSolverStatus.UNKNOWN + lastReasonOfUnknown = null + lastAssumptions = null + } + + private fun Status.processCheckResult() = when (this) { + Status.SAT -> KSolverStatus.SAT + Status.UNSAT -> KSolverStatus.UNSAT + else -> KSolverStatus.UNKNOWN + }.also { lastCheckStatus = it } + + override fun close() { + nativeContext.close() + yicesCtx.close() + timer.cancel() + } + + private inner class StopSearchTask : TimerTask() { + override fun run() { + nativeContext.stopSearch() + } + } + + private class TrackedAssumptions { + private val assumedExprs = arrayListOf, YicesTerm>>() + private val assumedTerms = IntArrayList() + + fun assumeTrackedAssertion(trackedAssertion: Pair, YicesTerm>) { + assumedExprs.add(trackedAssertion) + assumedTerms.add(trackedAssertion.second) + } + + fun assumeAssumption(expr: KExpr, term: YicesTerm) = + assumeTrackedAssertion(expr to term) + + fun assumedTerms(): YicesTermArray { + assumedTerms.trim() // Elements length now equal to size + return assumedTerms.elements() + } + + fun resolveUnsatCore(yicesUnsatCore: YicesTermArray): List> { + val unsatCoreTerms = IntOpenHashSet(yicesUnsatCore) + return assumedExprs.mapNotNull { (expr, term) -> expr.takeIf { unsatCoreTerms.contains(term) } } + } + } + + companion object { + init { + if (!Yices.isReady()) { + NativeLibraryLoader.load { os -> + when (os) { + NativeLibraryLoader.OS.LINUX -> listOf("libyices", "libyices2java") + NativeLibraryLoader.OS.WINDOWS -> listOf("libyices", "libyices2java") + NativeLibraryLoader.OS.MACOS -> listOf("libyices", "libyices2java") + } + } + Yices.init() + Yices.setReadyFlag(true) + } + } + } +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverConfiguration.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverConfiguration.kt index f97b72673..0d477b2a9 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverConfiguration.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverConfiguration.kt @@ -31,6 +31,19 @@ class KYicesSolverConfigurationImpl(private val config: Config) : KYicesSolverCo } } +class KYicesForkingSolverConfigurationImpl(private val config: Config) : KYicesSolverConfiguration { + private val options = hashMapOf() + + override fun setYicesOption(option: String, value: String) { + config.set(option, value) + options[option] = value + } + + fun fork(config: Config) = KYicesForkingSolverConfigurationImpl(config).also { + options.forEach { (option, value) -> it.setYicesOption(option, value) } + } +} + class KYicesSolverUniversalConfiguration( private val builder: KSolverUniversalConfigurationBuilder ) : KYicesSolverConfiguration { diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt new file mode 100644 index 000000000..b5ed5f43c --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt @@ -0,0 +1,147 @@ +package io.ksmt.solver.yices + +internal interface ScopedFrame { + val currentScope: UInt + val currentFrame: T + + fun getFrame(level: Int): T + + fun any(predicate: (T) -> Boolean): Boolean + + /** + * find value [V] in frame [T], and return it or null + */ + fun find(predicate: (T) -> V?): V? + fun forEach(action: (T) -> Unit) + + fun push() + fun pop(n: UInt = 1u) +} + +internal class ScopedArrayFrame( + currentFrame: T, + private inline val createNewFrame: () -> T +) : ScopedFrame { + constructor(createNewFrame: () -> T) : this(createNewFrame(), createNewFrame) + + private val frames = arrayListOf(currentFrame) + + override var currentFrame = currentFrame + private set + + override val currentScope: UInt + get() = frames.size.toUInt() + + override fun getFrame(level: Int): T = frames[level] + + override fun any(predicate: (T) -> Boolean): Boolean = frames.any(predicate) + + override fun find(predicate: (T) -> V?): V? { + frames.forEach { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun forEach(action: (T) -> Unit) = frames.forEach(action) + + override fun push() { + currentFrame = createNewFrame() + frames += currentFrame + } + + override fun pop(n: UInt) { + repeat(n.toInt()) { frames.removeLast() } + currentFrame = frames.last() + } +} + +internal class ScopedLinkedFrame private constructor( + private var current: LinkedFrame, + private inline val createNewFrame: () -> T, + private inline val copyFrame: (T) -> T +) : ScopedFrame { + constructor( + currentFrame: T, + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(LinkedFrame(currentFrame), createNewFrame, copyFrame) + + constructor( + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(createNewFrame(), createNewFrame, copyFrame) + + override val currentFrame: T + get() = current.value + + override val currentScope: UInt + get() = current.scope + + override fun getFrame(level: Int): T { + if (level > current.scope.toInt() || level < 0) throw IllegalArgumentException("Level $level is out of scope") + var cur: LinkedFrame? = current + + while (cur != null && level < cur.scope.toInt()) { + cur = cur.previous + } + return cur!!.value + } + + override fun any(predicate: (T) -> Boolean): Boolean { + forEachReversed { frame -> + if (predicate(frame)) return true + } + return false + } + + fun stacked(): ArrayDeque = ArrayDeque().also { stack -> + forEachReversed { frame -> + stack.addLast(frame) + } + } + + override fun find(predicate: (T) -> V?): V? { + forEachReversed { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun forEach(action: (T) -> Unit) = forEachReversed(action) + + override fun push() { + current = LinkedFrame(createNewFrame(), current) + } + + override fun pop(n: UInt) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + recreateTopFrame() + } + + private fun recreateTopFrame() { + val newTopFrame = copyFrame(currentFrame) + current = LinkedFrame(newTopFrame, current.previous) + } + + fun fork(parent: ScopedLinkedFrame) { + current = parent.current + recreateTopFrame() + } + + private inline fun forEachReversed(action: (T) -> Unit) { + var cur: LinkedFrame? = current + while (cur != null) { + action(cur.value) + cur = cur.previous + } + } + + private class LinkedFrame( + val value: E, + val previous: LinkedFrame? = null + ) { + val scope: UInt = previous?.scope?.plus(1u) ?: 0u + } + +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/UninterpretedValuesTracker.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/UninterpretedValuesTracker.kt new file mode 100644 index 000000000..e8cc27a09 --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/UninterpretedValuesTracker.kt @@ -0,0 +1,90 @@ +package io.ksmt.solver.yices + +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.expr.transformer.KNonRecursiveTransformer +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap + +class UninterpretedValuesTracker internal constructor( + private val ctx: KContext, + private val scopedExpressions: ScopedFrame>>, + private val uninterpretedValues: ScopedFrame>>, + private val expressionLevels: Object2IntOpenHashMap> +) { + private var analyzer: ExprUninterpretedValuesAnalyzer = createNewAnalyzer() + + private fun createNewAnalyzer() = ExprUninterpretedValuesAnalyzer( + ctx, + scopedExpressions, + uninterpretedValues, + expressionLevels + ) + + fun expressionUse(expr: KExpr<*>) { + if (expr in scopedExpressions.currentFrame) return + analyzer.apply(expr) + } + + fun expressionSave(expr: KExpr<*>) { + if (scopedExpressions.currentFrame.add(expr)) { + expressionLevels.put(expr, scopedExpressions.currentScope.toInt()) + } + } + + fun addToCurrentLevel(value: KUninterpretedSortValue) { + analyzer.addToCurrentLevel(value) + } + + fun getUninterpretedSortValues(sort: KUninterpretedSort) = hashSetOf().apply { + uninterpretedValues.forEach { frame -> + frame[sort]?.also { this += it } + } + } + + fun push() { + scopedExpressions.push() + uninterpretedValues.push() + } + + fun pop(n: UInt) { + scopedExpressions.pop(n) + uninterpretedValues.pop(n) + + analyzer = createNewAnalyzer() + } + + private class ExprUninterpretedValuesAnalyzer( + ctx: KContext, + val scopedExpressions: ScopedFrame>>, + val uninterpretedValues: ScopedFrame>>, + val expressionLevels: Object2IntOpenHashMap> + ) : KNonRecursiveTransformer(ctx) { + + fun addToCurrentLevel(value: KUninterpretedSortValue) { + uninterpretedValues.currentFrame.getOrPut(value.sort) { hashSetOf() } += value + } + + override fun transformExpr(expr: KExpr): KExpr { + if (scopedExpressions.currentFrame.add(expr)) + expressionLevels[expr] = scopedExpressions.currentScope.toInt() + return super.transformExpr(expr) + } + + override fun transform(expr: KUninterpretedSortValue): KExpr { + addToCurrentLevel(expr) + return super.transform(expr) + } + + override fun exprTransformationRequired(expr: KExpr): Boolean { + val frameLevel = expressionLevels.getInt(expr) + if (frameLevel < scopedExpressions.currentScope.toInt()) { + // If expr is valid on its level we don't need to move it + return expr !in scopedExpressions.getFrame(frameLevel) + } + return super.exprTransformationRequired(expr) + } + } +} From 60bda4e20f58e5ba2b03738d86ad1e8420c61eae Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Mon, 14 Aug 2023 15:19:01 +0300 Subject: [PATCH 06/12] Bitwuzla uninterpreted sort values universe fix --- .../bitwuzla/KBitwuzlaExprInternalizer.kt | 36 ++++++++++--------- .../io/ksmt/solver/bitwuzla/KBitwuzlaModel.kt | 20 +++++++---- .../KBitwuzlaUninterpretedSortValueContext.kt | 10 +++--- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt index 5938f1e19..2aff66301 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt @@ -151,17 +151,11 @@ import io.ksmt.expr.KUnaryMinusArithExpr import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.expr.KUniversalQuantifier import io.ksmt.expr.KXorExpr -import io.ksmt.expr.rewrite.simplify.rewriteBvAddNoUnderflowExpr -import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoUnderflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvNegNoOverflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvSubNoUnderflowExpr import io.ksmt.solver.KSolverUnsupportedFeatureException -import org.ksmt.solver.bitwuzla.bindings.Bitwuzla -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaKind -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaRoundingMode -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm -import org.ksmt.solver.bitwuzla.bindings.Native +import io.ksmt.solver.bitwuzla.KBitwuzlaExprInternalizer.BvOverflowCheckMode.OVERFLOW +import io.ksmt.solver.bitwuzla.KBitwuzlaExprInternalizer.BvOverflowCheckMode.UNDERFLOW import io.ksmt.solver.util.KExprLongInternalizerBase import io.ksmt.sort.KArithSort import io.ksmt.sort.KArray2Sort @@ -186,7 +180,13 @@ import io.ksmt.sort.KRealSort import io.ksmt.sort.KSort import io.ksmt.sort.KSortVisitor import io.ksmt.sort.KUninterpretedSort +import org.ksmt.solver.bitwuzla.bindings.Bitwuzla +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaKind +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaRoundingMode +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTermArray +import org.ksmt.solver.bitwuzla.bindings.Native import java.math.BigInteger @Suppress("LargeClass") @@ -726,7 +726,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL override fun transform(expr: KBvAddNoOverflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> if (isSigned) { - mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW) + mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW) } else { val overflowCheck = Native.bitwuzlaMkTerm2( bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_UADD_OVERFLOW, a0, a1 @@ -738,20 +738,20 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL override fun transform(expr: KBvAddNoUnderflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> - mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW) + mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW) } } override fun transform(expr: KBvSubNoOverflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> - mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW) + mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW) } } override fun transform(expr: KBvSubNoUnderflowExpr) = with(expr) { if (isSigned) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> - mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW) + mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW) } } else { transform { @@ -776,7 +776,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL override fun transform(expr: KBvMulNoOverflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> if (isSigned) { - mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW) + mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW) } else { val overflowCheck = Native.bitwuzlaMkTerm2( bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_UMUL_OVERFLOW, a0, a1 @@ -788,7 +788,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL override fun transform(expr: KBvMulNoUnderflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> - mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW) + mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW) } } @@ -813,7 +813,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL a1, BitwuzlaKind.BITWUZLA_KIND_BV_SADD_OVERFLOW ) { a0Sign, a1Sign -> - if (mode == BvOverflowCheckMode.OVERFLOW) { + if (mode == OVERFLOW) { // Both positive mkAndTerm(longArrayOf(mkNotTerm(a0Sign), mkNotTerm(a1Sign))) } else { @@ -833,7 +833,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL a1, BitwuzlaKind.BITWUZLA_KIND_BV_SSUB_OVERFLOW ) { a0Sign, a1Sign -> - if (mode == BvOverflowCheckMode.OVERFLOW) { + if (mode == OVERFLOW) { // Positive sub negative mkAndTerm(longArrayOf(mkNotTerm(a0Sign), a1Sign)) } else { @@ -853,7 +853,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL a1, BitwuzlaKind.BITWUZLA_KIND_BV_SMUL_OVERFLOW ) { a0Sign, a1Sign -> - if (mode == BvOverflowCheckMode.OVERFLOW) { + if (mode == OVERFLOW) { // Overflow is possible when sign bits are equal mkEqTerm(bitwuzlaCtx.ctx.boolSort, a0Sign, a1Sign) } else { @@ -1401,6 +1401,8 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL } override fun transform(expr: KUninterpretedSortValue): KExpr = expr.transform { + // register it for uninterpreted sort universe + bitwuzlaCtx.registerDeclaration(expr.decl) Native.bitwuzlaMkBvValueUint32( bitwuzla, expr.sort.internalizeSort(), diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaModel.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaModel.kt index f36e43c99..f4520736b 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaModel.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaModel.kt @@ -2,18 +2,15 @@ package io.ksmt.solver.bitwuzla import io.ksmt.KContext import io.ksmt.decl.KDecl +import io.ksmt.decl.KUninterpretedSortValueDecl import io.ksmt.expr.KExpr import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolverUnsupportedFeatureException import io.ksmt.solver.model.KFuncInterp import io.ksmt.solver.model.KFuncInterpEntryVarsFree import io.ksmt.solver.model.KFuncInterpEntryVarsFreeOneAry import io.ksmt.solver.model.KFuncInterpVarsFree -import io.ksmt.solver.KModel -import io.ksmt.solver.KSolverUnsupportedFeatureException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm -import org.ksmt.solver.bitwuzla.bindings.FunValue -import org.ksmt.solver.bitwuzla.bindings.Native import io.ksmt.solver.model.KFuncInterpWithVars import io.ksmt.solver.model.KModelEvaluator import io.ksmt.solver.model.KModelImpl @@ -23,6 +20,10 @@ import io.ksmt.sort.KSort import io.ksmt.sort.KUninterpretedSort import io.ksmt.utils.mkFreshConstDecl import io.ksmt.utils.uncheckedCast +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm +import org.ksmt.solver.bitwuzla.bindings.FunValue +import org.ksmt.solver.bitwuzla.bindings.Native open class KBitwuzlaModel( private val ctx: KContext, @@ -77,7 +78,12 @@ open class KBitwuzlaModel( * to ensure that [uninterpretedSortValueContext] contains * all possible values for the given sort. * */ - sortDependency.forEach { interpretation(it) } + sortDependency.forEach { + if (it is KUninterpretedSortValueDecl) { + val value = ctx.mkUninterpretedSortValue(it.sort, it.valueIdx) + uninterpretedSortValueContext.registerValue(value) + } else interpretation(it) + } uninterpretedSortValueContext.currentSortUniverse(sort) } diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaUninterpretedSortValueContext.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaUninterpretedSortValueContext.kt index cd7e184ba..d0c4d8d9f 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaUninterpretedSortValueContext.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaUninterpretedSortValueContext.kt @@ -10,10 +10,12 @@ class KBitwuzlaUninterpretedSortValueContext(private val ctx: KContext) { private val sortsUniverses = hashMapOf, KUninterpretedSortValue>>() fun mkValue(sort: KUninterpretedSort, value: KBitVec32Value): KUninterpretedSortValue { - val sortUniverse = sortsUniverses.getOrPut(sort) { hashMapOf() } - return sortUniverse.getOrPut(value) { - ctx.mkUninterpretedSortValue(sort, value.intValue) - } + return registerValue(ctx.mkUninterpretedSortValue(sort, value.intValue)) + } + + fun registerValue(value: KUninterpretedSortValue): KUninterpretedSortValue { + val sortsUniverse = sortsUniverses.getOrPut(value.sort) { hashMapOf() } + return sortsUniverse.getOrPut(value) { value } } fun currentSortUniverse(sort: KUninterpretedSort): Set = From 6c626c4469967f178c653596f3374f873bf604bb Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Mon, 14 Aug 2023 20:49:12 +0300 Subject: [PATCH 07/12] bitwuzla - forking solver, scoped uninterpreted values tracking fix, yices - forking solver init exceptions wrapper, forking solver close fix, all solvers - ScopedLinkedFrame.pop fix --- .../ksmt/solver/bitwuzla/KBitwuzlaContext.kt | 22 +- .../solver/bitwuzla/KBitwuzlaForkingSolver.kt | 106 +++++++++ .../bitwuzla/KBitwuzlaForkingSolverManager.kt | 28 +++ .../ksmt/solver/bitwuzla/KBitwuzlaSolver.kt | 191 +---------------- .../solver/bitwuzla/KBitwuzlaSolverBase.kt | 201 ++++++++++++++++++ .../bitwuzla/KBitwuzlaSolverConfiguration.kt | 20 ++ .../ksmt/solver/bitwuzla/ScopedLinkedFrame.kt | 67 ++++++ .../kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt | 4 +- .../kotlin/io/ksmt/test/KForkingSolverTest.kt | 28 ++- .../ksmt/solver/yices/KYicesForkingSolver.kt | 17 +- .../yices/KYicesForkingSolverManager.kt | 5 +- .../io/ksmt/solver/yices/ScopedFrame.kt | 4 +- 12 files changed, 482 insertions(+), 211 deletions(-) create mode 100644 ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt create mode 100644 ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt create mode 100644 ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverBase.kt create mode 100644 ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/ScopedLinkedFrame.kt diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaContext.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaContext.kt index 5fb618e35..5b845795c 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaContext.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaContext.kt @@ -1,8 +1,5 @@ package io.ksmt.solver.bitwuzla -import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap -import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap -import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap import io.ksmt.KContext import io.ksmt.decl.KDecl import io.ksmt.decl.KFuncDecl @@ -15,13 +12,10 @@ import io.ksmt.expr.KExistentialQuantifier import io.ksmt.expr.KExpr import io.ksmt.expr.KFunctionApp import io.ksmt.expr.KFunctionAsArray +import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.expr.KUniversalQuantifier import io.ksmt.expr.transformer.KNonRecursiveTransformer import io.ksmt.solver.KSolverException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm -import org.ksmt.solver.bitwuzla.bindings.Native import io.ksmt.solver.util.KExprLongInternalizerBase.Companion.NOT_INTERNALIZED import io.ksmt.sort.KArray2Sort import io.ksmt.sort.KArray3Sort @@ -37,6 +31,13 @@ import io.ksmt.sort.KRealSort import io.ksmt.sort.KSort import io.ksmt.sort.KSortVisitor import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap +import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm +import org.ksmt.solver.bitwuzla.bindings.Native open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable { private var isClosed = false @@ -433,6 +434,11 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable { return super.transform(expr) } + override fun transform(expr: KUninterpretedSortValue): KExpr { + registerDeclIfNotIgnored(expr.decl) + return super.transform(expr) + } + private val quantifiedVarsScopeOwner = arrayListOf>() private val quantifiedVarsScope = arrayListOf>?>() @@ -474,7 +480,7 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable { override fun transform(expr: KExistentialQuantifier): KExpr = expr.transformQuantifier(expr.bounds, expr.body) - override fun transform(expr: KUniversalQuantifier): KExpr = + override fun transform(expr: KUniversalQuantifier): KExpr = expr.transformQuantifier(expr.bounds, expr.body) } diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt new file mode 100644 index 000000000..845296cb3 --- /dev/null +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt @@ -0,0 +1,106 @@ +package io.ksmt.solver.bitwuzla + +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import kotlin.time.Duration + +class KBitwuzlaForkingSolver( + private val ctx: KContext, + private val manager: KBitwuzlaForkingSolverManager, + parent: KBitwuzlaForkingSolver? +) : KBitwuzlaSolverBase(ctx), + KForkingSolver { + + private val assertions = ScopedLinkedFrame>>(::ArrayList, ::ArrayList) + private val trackToExprFrames = + ScopedLinkedFrame, KExpr>>>(::ArrayList, ::ArrayList) + + private val config: KBitwuzlaForkingSolverConfigurationImpl + + init { + if (parent != null) { + config = parent.config.fork(bitwuzlaCtx.bitwuzla) + assertions.fork(parent.assertions) + trackToExprFrames.fork(parent.trackToExprFrames) + } else { + config = KBitwuzlaForkingSolverConfigurationImpl(bitwuzlaCtx.bitwuzla) + } + } + + override fun configure(configurator: KBitwuzlaSolverConfiguration.() -> Unit) { + config.configurator() + } + + override fun fork(): KForkingSolver = manager.mkForkingSolver(this) + + private var assertionsInitiated = parent == null + + private fun ensureAssertionsInitiated() { + if (assertionsInitiated) return + + assertions.stacked().zip(trackToExprFrames.stacked()) + .asReversed() + .forEachIndexed { scope, (assertionsFrame, trackedExprsFrame) -> + if (scope > 0) super.push() + + assertionsFrame.forEach { assertion -> + internalizeAndAssertWithAxioms(assertion) + } + + trackedExprsFrame.forEach { (track, trackedExpr) -> + super.registerTrackForExpr(trackedExpr, track) + } + } + assertionsInitiated = true + } + + override fun assert(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { + ctx.ensureContextMatch(expr) + ensureAssertionsInitiated() + + internalizeAndAssertWithAxioms(expr) + assertions.currentFrame += expr + } + + override fun assertAndTrack(expr: KExpr) { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + super.assertAndTrack(expr) + } + + override fun registerTrackForExpr(expr: KExpr, track: KExpr) { + super.registerTrackForExpr(expr, track) + trackToExprFrames.currentFrame += track to expr + } + + override fun push() { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + super.push() + assertions.push() + trackToExprFrames.push() + } + + override fun pop(n: UInt) { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + super.pop(n) + assertions.pop(n) + trackToExprFrames.pop(n) + } + + override fun check(timeout: Duration): KSolverStatus { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + return super.check(timeout) + } + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + return super.checkWithAssumptions(assumptions, timeout) + } + + override fun close() { + super.close() + manager.close(this) + } +} diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt new file mode 100644 index 000000000..b941e94ec --- /dev/null +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt @@ -0,0 +1,28 @@ +package io.ksmt.solver.bitwuzla + +import io.ksmt.KContext +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KForkingSolverManager +import java.util.concurrent.ConcurrentHashMap + +class KBitwuzlaForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { + private val solvers = ConcurrentHashMap.newKeySet() + + override fun mkForkingSolver(): KForkingSolver { + return KBitwuzlaForkingSolver(ctx, this, null).also { + solvers += it + } + } + + internal fun mkForkingSolver(parent: KBitwuzlaForkingSolver) = KBitwuzlaForkingSolver(ctx, this, parent).also { + solvers += it + } + + internal fun close(solver: KBitwuzlaForkingSolver) { + solvers -= solver + } + + override fun close() { + solvers.forEach(KBitwuzlaForkingSolver::close) + } +} diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt index 5d86b25b9..30c0b0f71 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt @@ -1,199 +1,10 @@ package io.ksmt.solver.bitwuzla -import it.unimi.dsi.fastutil.longs.LongOpenHashSet import io.ksmt.KContext -import io.ksmt.expr.KExpr -import io.ksmt.solver.KModel -import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverStatus -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaOption -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaResult -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTermArray -import org.ksmt.solver.bitwuzla.bindings.Native -import io.ksmt.sort.KBoolSort -import kotlin.time.Duration -open class KBitwuzlaSolver(private val ctx: KContext) : KSolver { - open val bitwuzlaCtx = KBitwuzlaContext(ctx) - open val exprInternalizer: KBitwuzlaExprInternalizer by lazy { - KBitwuzlaExprInternalizer(bitwuzlaCtx) - } - open val exprConverter: KBitwuzlaExprConverter by lazy { - KBitwuzlaExprConverter(ctx, bitwuzlaCtx) - } - - private var lastCheckStatus = KSolverStatus.UNKNOWN - private var lastReasonOfUnknown: String? = null - private var lastAssumptions: TrackedAssumptions? = null - private var lastModel: KBitwuzlaModel? = null - - init { - Native.bitwuzlaSetOption(bitwuzlaCtx.bitwuzla, BitwuzlaOption.BITWUZLA_OPT_INCREMENTAL, value = 1) - Native.bitwuzlaSetOption(bitwuzlaCtx.bitwuzla, BitwuzlaOption.BITWUZLA_OPT_PRODUCE_MODELS, value = 1) - } - - private var trackedAssertions = mutableListOf, BitwuzlaTerm>>() - private val trackVarsAssertionFrames = arrayListOf(trackedAssertions) +open class KBitwuzlaSolver(ctx: KContext) : KBitwuzlaSolverBase(ctx) { override fun configure(configurator: KBitwuzlaSolverConfiguration.() -> Unit) { KBitwuzlaSolverConfigurationImpl(bitwuzlaCtx.bitwuzla).configurator() } - - override fun assert(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { - ctx.ensureContextMatch(expr) - - val assertionWithAxioms = with(exprInternalizer) { expr.internalizeAssertion() } - - assertionWithAxioms.axioms.forEach { - Native.bitwuzlaAssert(bitwuzlaCtx.bitwuzla, it) - } - Native.bitwuzlaAssert(bitwuzlaCtx.bitwuzla, assertionWithAxioms.assertion) - } - - override fun assertAndTrack(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { - ctx.ensureContextMatch(expr) - - val trackVarExpr = ctx.mkFreshConst("track", ctx.boolSort) - val trackedExpr = with(ctx) { !trackVarExpr or expr } - - assert(trackedExpr) - - val trackVarTerm = with(exprInternalizer) { trackVarExpr.internalize() } - trackedAssertions += expr to trackVarTerm - } - - override fun push(): Unit = bitwuzlaCtx.bitwuzlaTry { - Native.bitwuzlaPush(bitwuzlaCtx.bitwuzla, nlevels = 1) - - trackedAssertions = trackedAssertions.toMutableList() - trackVarsAssertionFrames.add(trackedAssertions) - - bitwuzlaCtx.createNestedDeclarationScope() - } - - override fun pop(n: UInt): Unit = bitwuzlaCtx.bitwuzlaTry { - val currentLevel = trackVarsAssertionFrames.lastIndex.toUInt() - require(n <= currentLevel) { - "Cannot pop $n scope levels because current scope level is $currentLevel" - } - - if (n == 0u) return - - repeat(n.toInt()) { - trackVarsAssertionFrames.removeLast() - bitwuzlaCtx.popDeclarationScope() - } - - trackedAssertions = trackVarsAssertionFrames.last() - - Native.bitwuzlaPop(bitwuzlaCtx.bitwuzla, n.toInt()) - } - - override fun check(timeout: Duration): KSolverStatus = - checkWithAssumptions(emptyList(), timeout) - - override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus = - bitwuzlaTryCheck { - ctx.ensureContextMatch(assumptions) - - val currentAssumptions = TrackedAssumptions().also { lastAssumptions = it } - - trackedAssertions.forEach { - currentAssumptions.assumeTrackedAssertion(it) - } - - with(exprInternalizer) { - assumptions.forEach { - currentAssumptions.assumeAssumption(it, it.internalize()) - } - } - - checkWithTimeout(timeout).processCheckResult() - } - - private fun checkWithTimeout(timeout: Duration): BitwuzlaResult = if (timeout.isInfinite()) { - Native.bitwuzlaCheckSatResult(bitwuzlaCtx.bitwuzla) - } else { - Native.bitwuzlaCheckSatTimeoutResult(bitwuzlaCtx.bitwuzla, timeout.inWholeMilliseconds) - } - - override fun model(): KModel = bitwuzlaCtx.bitwuzlaTry { - require(lastCheckStatus == KSolverStatus.SAT) { "Model are only available after SAT checks" } - val model = lastModel ?: KBitwuzlaModel( - ctx, bitwuzlaCtx, exprConverter, - bitwuzlaCtx.declarations(), - bitwuzlaCtx.uninterpretedSortsWithRelevantDecls() - ) - lastModel = model - model - } - - override fun unsatCore(): List> = bitwuzlaCtx.bitwuzlaTry { - require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } - val unsatAssumptions = Native.bitwuzlaGetUnsatAssumptions(bitwuzlaCtx.bitwuzla) - lastAssumptions?.resolveUnsatCore(unsatAssumptions) ?: emptyList() - } - - override fun reasonOfUnknown(): String = bitwuzlaCtx.bitwuzlaTry { - require(lastCheckStatus == KSolverStatus.UNKNOWN) { - "Unknown reason is only available after UNKNOWN checks" - } - - // There is no way to retrieve reason of unknown from Bitwuzla in general case. - return lastReasonOfUnknown ?: "unknown" - } - - override fun interrupt() = bitwuzlaCtx.bitwuzlaTry { - Native.bitwuzlaForceTerminate(bitwuzlaCtx.bitwuzla) - } - - override fun close() = bitwuzlaCtx.bitwuzlaTry { - bitwuzlaCtx.close() - } - - private fun BitwuzlaResult.processCheckResult() = when (this) { - BitwuzlaResult.BITWUZLA_SAT -> KSolverStatus.SAT - BitwuzlaResult.BITWUZLA_UNSAT -> KSolverStatus.UNSAT - BitwuzlaResult.BITWUZLA_UNKNOWN -> KSolverStatus.UNKNOWN - }.also { lastCheckStatus = it } - - private fun invalidateSolverState() { - /** - * Bitwuzla model is only valid until the next check-sat call. - * */ - lastModel?.markInvalid() - lastModel = null - - lastCheckStatus = KSolverStatus.UNKNOWN - lastReasonOfUnknown = null - - lastAssumptions = null - } - - private inline fun bitwuzlaTryCheck(body: () -> KSolverStatus): KSolverStatus = try { - invalidateSolverState() - body() - } catch (ex: BitwuzlaNativeException) { - lastReasonOfUnknown = ex.message - KSolverStatus.UNKNOWN.also { lastCheckStatus = it } - } - - private inner class TrackedAssumptions { - private val assumedExprs = arrayListOf, BitwuzlaTerm>>() - - fun assumeTrackedAssertion(trackedAssertion: Pair, BitwuzlaTerm>) { - assumedExprs.add(trackedAssertion) - Native.bitwuzlaAssume(bitwuzlaCtx.bitwuzla, trackedAssertion.second) - } - - fun assumeAssumption(expr: KExpr, term: BitwuzlaTerm) = - assumeTrackedAssertion(expr to term) - - fun resolveUnsatCore(unsatAssumptions: BitwuzlaTermArray): List> { - val unsatCoreTerms = LongOpenHashSet(unsatAssumptions) - return assumedExprs.mapNotNull { (expr, term) -> expr.takeIf { unsatCoreTerms.contains(term) } } - } - } } diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverBase.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverBase.kt new file mode 100644 index 000000000..fa8ae269b --- /dev/null +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverBase.kt @@ -0,0 +1,201 @@ +package io.ksmt.solver.bitwuzla + +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import it.unimi.dsi.fastutil.longs.LongOpenHashSet +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaOption +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaResult +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTermArray +import org.ksmt.solver.bitwuzla.bindings.Native +import kotlin.time.Duration + +abstract class KBitwuzlaSolverBase(private val ctx: KContext) : KSolver { + open val bitwuzlaCtx = KBitwuzlaContext(ctx) + open val exprInternalizer: KBitwuzlaExprInternalizer by lazy { + KBitwuzlaExprInternalizer(bitwuzlaCtx) + } + open val exprConverter: KBitwuzlaExprConverter by lazy { + KBitwuzlaExprConverter(ctx, bitwuzlaCtx) + } + + private var lastCheckStatus = KSolverStatus.UNKNOWN + private var lastReasonOfUnknown: String? = null + private var lastAssumptions: TrackedAssumptions? = null + private var lastModel: KBitwuzlaModel? = null + + init { + Native.bitwuzlaSetOption(bitwuzlaCtx.bitwuzla, BitwuzlaOption.BITWUZLA_OPT_INCREMENTAL, value = 1) + Native.bitwuzlaSetOption(bitwuzlaCtx.bitwuzla, BitwuzlaOption.BITWUZLA_OPT_PRODUCE_MODELS, value = 1) + } + + private var trackedAssertions = mutableListOf, BitwuzlaTerm>>() + private val trackVarsAssertionFrames = arrayListOf(trackedAssertions) + + protected fun internalizeAndAssertWithAxioms(expr: KExpr) { + val assertionWithAxioms = with(exprInternalizer) { expr.internalizeAssertion() } + + assertionWithAxioms.axioms.forEach { + Native.bitwuzlaAssert(bitwuzlaCtx.bitwuzla, it) + } + Native.bitwuzlaAssert(bitwuzlaCtx.bitwuzla, assertionWithAxioms.assertion) + } + + override fun assert(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { + ctx.ensureContextMatch(expr) + internalizeAndAssertWithAxioms(expr) + } + + protected open fun registerTrackForExpr(expr: KExpr, track: KExpr) { + val trackVarTerm = with(exprInternalizer) { track.internalize() } + trackedAssertions += expr to trackVarTerm + } + + override fun assertAndTrack(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { + ctx.ensureContextMatch(expr) + + val trackVarExpr = ctx.mkFreshConst("track", ctx.boolSort) + val trackedExpr = with(ctx) { !trackVarExpr or expr } + + assert(trackedExpr) + registerTrackForExpr(expr, trackVarExpr) + } + + override fun push(): Unit = bitwuzlaCtx.bitwuzlaTry { + Native.bitwuzlaPush(bitwuzlaCtx.bitwuzla, nlevels = 1) + + trackedAssertions = trackedAssertions.toMutableList() + trackVarsAssertionFrames.add(trackedAssertions) + + bitwuzlaCtx.createNestedDeclarationScope() + } + + override fun pop(n: UInt): Unit = bitwuzlaCtx.bitwuzlaTry { + val currentLevel = trackVarsAssertionFrames.lastIndex.toUInt() + require(n <= currentLevel) { + "Cannot pop $n scope levels because current scope level is $currentLevel" + } + + if (n == 0u) return + + repeat(n.toInt()) { + trackVarsAssertionFrames.removeLast() + bitwuzlaCtx.popDeclarationScope() + } + + trackedAssertions = trackVarsAssertionFrames.last() + + Native.bitwuzlaPop(bitwuzlaCtx.bitwuzla, n.toInt()) + } + + override fun check(timeout: Duration): KSolverStatus = + checkWithAssumptions(emptyList(), timeout) + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus = + bitwuzlaTryCheck { + ctx.ensureContextMatch(assumptions) + + val currentAssumptions = TrackedAssumptions().also { lastAssumptions = it } + + trackedAssertions.forEach { + currentAssumptions.assumeTrackedAssertion(it) + } + + with(exprInternalizer) { + assumptions.forEach { + currentAssumptions.assumeAssumption(it, it.internalize()) + } + } + + checkWithTimeout(timeout).processCheckResult() + } + + protected fun checkWithTimeout(timeout: Duration): BitwuzlaResult = if (timeout.isInfinite()) { + Native.bitwuzlaCheckSatResult(bitwuzlaCtx.bitwuzla) + } else { + Native.bitwuzlaCheckSatTimeoutResult(bitwuzlaCtx.bitwuzla, timeout.inWholeMilliseconds) + } + + override fun model(): KModel = bitwuzlaCtx.bitwuzlaTry { + require(lastCheckStatus == KSolverStatus.SAT) { "Model are only available after SAT checks" } + val model = lastModel ?: KBitwuzlaModel( + ctx, bitwuzlaCtx, exprConverter, + bitwuzlaCtx.declarations(), + bitwuzlaCtx.uninterpretedSortsWithRelevantDecls() + ) + lastModel = model + model + } + + override fun unsatCore(): List> = bitwuzlaCtx.bitwuzlaTry { + require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } + val unsatAssumptions = Native.bitwuzlaGetUnsatAssumptions(bitwuzlaCtx.bitwuzla) + lastAssumptions?.resolveUnsatCore(unsatAssumptions) ?: emptyList() + } + + override fun reasonOfUnknown(): String = bitwuzlaCtx.bitwuzlaTry { + require(lastCheckStatus == KSolverStatus.UNKNOWN) { + "Unknown reason is only available after UNKNOWN checks" + } + + // There is no way to retrieve reason of unknown from Bitwuzla in general case. + return lastReasonOfUnknown ?: "unknown" + } + + override fun interrupt() = bitwuzlaCtx.bitwuzlaTry { + Native.bitwuzlaForceTerminate(bitwuzlaCtx.bitwuzla) + } + + override fun close() = bitwuzlaCtx.bitwuzlaTry { + bitwuzlaCtx.close() + } + + protected fun BitwuzlaResult.processCheckResult() = when (this) { + BitwuzlaResult.BITWUZLA_SAT -> KSolverStatus.SAT + BitwuzlaResult.BITWUZLA_UNSAT -> KSolverStatus.UNSAT + BitwuzlaResult.BITWUZLA_UNKNOWN -> KSolverStatus.UNKNOWN + }.also { lastCheckStatus = it } + + private fun invalidateSolverState() { + /** + * Bitwuzla model is only valid until the next check-sat call. + * */ + lastModel?.markInvalid() + lastModel = null + + lastCheckStatus = KSolverStatus.UNKNOWN + lastReasonOfUnknown = null + + lastAssumptions = null + } + + private inline fun bitwuzlaTryCheck(body: () -> KSolverStatus): KSolverStatus = try { + invalidateSolverState() + body() + } catch (ex: BitwuzlaNativeException) { + lastReasonOfUnknown = ex.message + KSolverStatus.UNKNOWN.also { lastCheckStatus = it } + } + + private inner class TrackedAssumptions { + private val assumedExprs = arrayListOf, BitwuzlaTerm>>() + + fun assumeTrackedAssertion(trackedAssertion: Pair, BitwuzlaTerm>) { + assumedExprs.add(trackedAssertion) + Native.bitwuzlaAssume(bitwuzlaCtx.bitwuzla, trackedAssertion.second) + } + + fun assumeAssumption(expr: KExpr, term: BitwuzlaTerm) = + assumeTrackedAssertion(expr to term) + + fun resolveUnsatCore(unsatAssumptions: BitwuzlaTermArray): List> { + val unsatCoreTerms = LongOpenHashSet(unsatAssumptions) + return assumedExprs.mapNotNull { (expr, term) -> expr.takeIf { unsatCoreTerms.contains(term) } } + } + } +} diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverConfiguration.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverConfiguration.kt index a40ee8a60..5ee777c8f 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverConfiguration.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverConfiguration.kt @@ -6,6 +6,7 @@ import io.ksmt.solver.KSolverUnsupportedParameterException import org.ksmt.solver.bitwuzla.bindings.Bitwuzla import org.ksmt.solver.bitwuzla.bindings.BitwuzlaOption import org.ksmt.solver.bitwuzla.bindings.Native +import java.util.EnumMap interface KBitwuzlaSolverConfiguration : KSolverConfiguration { fun setBitwuzlaOption(option: BitwuzlaOption, value: Int) @@ -44,6 +45,25 @@ class KBitwuzlaSolverConfigurationImpl(private val bitwuzla: Bitwuzla) : KBitwuz } } +class KBitwuzlaForkingSolverConfigurationImpl(private val bitwuzla: Bitwuzla) : KBitwuzlaSolverConfiguration { + private val intOptions = EnumMap<_, Int>(BitwuzlaOption::class.java) + private val stringOptions = EnumMap<_, String>(BitwuzlaOption::class.java) + override fun setBitwuzlaOption(option: BitwuzlaOption, value: Int) { + Native.bitwuzlaSetOption(bitwuzla, option, value) + intOptions[option] = value + } + + override fun setBitwuzlaOption(option: BitwuzlaOption, value: String) { + Native.bitwuzlaSetOptionStr(bitwuzla, option, value) + stringOptions[option] = value + } + + fun fork(childBitwuzla: Bitwuzla) = KBitwuzlaForkingSolverConfigurationImpl(childBitwuzla).also { + intOptions.forEach { (option, value) -> it.setBitwuzlaOption(option, value) } + stringOptions.forEach { (option, value) -> it.setBitwuzlaOption(option, value) } + } +} + class KBitwuzlaSolverUniversalConfiguration( private val builder: KSolverUniversalConfigurationBuilder ) : KBitwuzlaSolverConfiguration { diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/ScopedLinkedFrame.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/ScopedLinkedFrame.kt new file mode 100644 index 000000000..d3e4e5931 --- /dev/null +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/ScopedLinkedFrame.kt @@ -0,0 +1,67 @@ +package io.ksmt.solver.bitwuzla + +internal class ScopedLinkedFrame private constructor( + private var current: LinkedFrame, + private inline val createNewFrame: () -> T, + private inline val copyFrame: (T) -> T +) { + constructor( + currentFrame: T, + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(LinkedFrame(currentFrame), createNewFrame, copyFrame) + + constructor( + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(createNewFrame(), createNewFrame, copyFrame) + + val currentFrame: T + get() = current.value + + val currentScope: UInt + get() = current.scope + + fun stacked(): ArrayDeque = ArrayDeque().also { stack -> + forEachReversed { frame -> + stack.addLast(frame) + } + } + + fun push() { + current = LinkedFrame(createNewFrame(), current) + } + + fun pop(n: UInt) { + repeat(n.toInt()) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + } + recreateTopFrame() + } + + private fun recreateTopFrame() { + val newTopFrame = copyFrame(currentFrame) + current = LinkedFrame(newTopFrame, current.previous) + } + + fun fork(parent: ScopedLinkedFrame) { + current = parent.current + recreateTopFrame() + } + + private inline fun forEachReversed(action: (T) -> Unit) { + var cur: LinkedFrame? = current + while (cur != null) { + action(cur.value) + cur = cur.previous + } + } + + private class LinkedFrame( + val value: E, + val previous: LinkedFrame? = null + ) { + val scope: UInt = previous?.scope?.plus(1u) ?: 0u + } + +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt index e8fdcd351..1ee768826 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt @@ -97,7 +97,9 @@ internal class ScopedLinkedFrame private constructor( } override fun pop(n: UInt) { - current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + repeat(n.toInt()) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + } recreateTopFrame() } diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt index c019cbf94..edf9c9361 100644 --- a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt @@ -3,6 +3,7 @@ package io.ksmt.test import io.ksmt.KContext import io.ksmt.solver.KForkingSolverManager import io.ksmt.solver.KSolverStatus +import io.ksmt.solver.bitwuzla.KBitwuzlaForkingSolverManager import io.ksmt.solver.cvc5.KCvc5ForkingSolverManager import io.ksmt.solver.yices.KYicesForkingSolverManager import io.ksmt.solver.z3.KZ3ForkingSolverManager @@ -16,6 +17,29 @@ import kotlin.test.assertNotEquals import kotlin.test.assertTrue class KForkingSolverTest { + @Nested + inner class KForkingSolverTestBitwuzla { + @Test + fun testCheckSat() = testCheckSat(::mkBitwuzlaForkingSolverManager) + + @Test + fun testModel() = testModel(::mkBitwuzlaForkingSolverManager) + + @Test + fun testUnsatCore() = testUnsatCore(::mkBitwuzlaForkingSolverManager) + + @Test + fun testUninterpretedSort() = testUninterpretedSort(::mkBitwuzlaForkingSolverManager) + + @Test + fun testScopedAssertions() = testScopedAssertions(::mkBitwuzlaForkingSolverManager) + + @Test + fun testLifeTime() = testLifeTime(::mkBitwuzlaForkingSolverManager) + + private fun mkBitwuzlaForkingSolverManager(ctx: KContext) = KBitwuzlaForkingSolverManager(ctx) + } + @Nested inner class KForkingSolverTestCvc5 { @Test @@ -376,8 +400,8 @@ class KForkingSolverTest { mkForkingSolverManager(ctx).use { man -> with(ctx) { val parent = man.mkForkingSolver() - val x by intSort - val f = x gt 100.expr + val x by bv8Sort + val f = mkBvSignedGreaterExpr(x, mkBv(100, bv8Sort)) parent.assert(f) parent.check().also { require(it == KSolverStatus.SAT) } diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt index d96be5c1d..0d75af7e1 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt @@ -73,7 +73,7 @@ class KYicesForkingSolver( get() = trackedAssertions.any { it.isNotEmpty() } override fun assert(expr: KExpr) = yicesTry { - ensureAssertionsInitiated() + yicesTry { ensureAssertionsInitiated() } ctx.ensureContextMatch(expr) val yicesExpr = with(exprInternalizer) { expr.internalize() } @@ -82,31 +82,36 @@ class KYicesForkingSolver( } override fun assertAndTrack(expr: KExpr) { - ensureAssertionsInitiated() + yicesTry { ensureAssertionsInitiated() } super.assertAndTrack(expr) } override fun push() { - ensureAssertionsInitiated() + yicesTry { ensureAssertionsInitiated() } super.push() trackedAssertions.push() yicesAssertions.push() } override fun pop(n: UInt) { - ensureAssertionsInitiated() + yicesTry { ensureAssertionsInitiated() } super.pop(n) trackedAssertions.pop(n) yicesAssertions.pop(n) } override fun check(timeout: Duration): KSolverStatus { - ensureAssertionsInitiated() + yicesTry { ensureAssertionsInitiated() } return super.check(timeout) } override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { - ensureAssertionsInitiated() + yicesTry { ensureAssertionsInitiated() } return super.checkWithAssumptions(assumptions, timeout) } + + override fun close() { + super.close() + manager.close(this) + } } diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt index 04966037f..239b124ba 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt @@ -14,16 +14,15 @@ import io.ksmt.sort.KUninterpretedSort import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap import it.unimi.dsi.fastutil.ints.IntOpenHashSet import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap -import java.util.Collections.newSetFromMap -import java.util.Collections.synchronizedSet import java.util.IdentityHashMap +import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger class KYicesForkingSolverManager( private val ctx: KContext ) : KForkingSolverManager { - private val solvers = synchronizedSet(newSetFromMap(IdentityHashMap())) + private val solvers = ConcurrentHashMap.newKeySet() private val sharedCacheReferences = IdentityHashMap() private val expressionsCache = IdentityHashMap() diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt index b5ed5f43c..93f757e6f 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt @@ -115,7 +115,9 @@ internal class ScopedLinkedFrame private constructor( } override fun pop(n: UInt) { - current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + repeat(n.toInt()) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + } recreateTopFrame() } From 68a0696304914133d03c322812f51a434b643a75 Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Mon, 28 Aug 2023 17:07:33 +0300 Subject: [PATCH 08/12] z3 pop fix, post-review refactoring --- .../main/kotlin/com/microsoft/z3/UnsafeApi.kt | 6 ---- .../kotlin/io/ksmt/solver/z3/KZ3Context.kt | 21 +++++++----- .../io/ksmt/solver/z3/KZ3ForkingSolver.kt | 4 ++- .../ksmt/solver/z3/KZ3ForkingSolverManager.kt | 20 +++++++---- .../kotlin/io/ksmt/solver/z3/KZ3Solver.kt | 9 ++--- .../kotlin/io/ksmt/solver/z3/ScopedFrame.kt | 34 +++++-------------- 6 files changed, 43 insertions(+), 51 deletions(-) diff --git a/ksmt-z3/src/main/kotlin/com/microsoft/z3/UnsafeApi.kt b/ksmt-z3/src/main/kotlin/com/microsoft/z3/UnsafeApi.kt index 87ea03ce5..3d745114e 100644 --- a/ksmt-z3/src/main/kotlin/com/microsoft/z3/UnsafeApi.kt +++ b/ksmt-z3/src/main/kotlin/com/microsoft/z3/UnsafeApi.kt @@ -1,7 +1,5 @@ package com.microsoft.z3 -import it.unimi.dsi.fastutil.longs.LongSet - fun incRefUnsafe(ctx: Long, ast: Long) { // Invoke incRef directly without status check Native.INTERNALincRef(ctx, ast) @@ -11,7 +9,3 @@ fun decRefUnsafe(ctx: Long, ast: Long) { // Invoke decRef directly without status check Native.INTERNALdecRef(ctx, ast) } - -fun LongSet.decRefUnsafeAll(ctx: Long) = longIterator().forEachRemaining { - decRefUnsafe(ctx, it) -} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt index 4001591e6..18af6c350 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt @@ -4,7 +4,6 @@ import com.microsoft.z3.Context import com.microsoft.z3.Solver import com.microsoft.z3.Z3Exception import com.microsoft.z3.decRefUnsafe -import com.microsoft.z3.decRefUnsafeAll import com.microsoft.z3.incRefUnsafe import io.ksmt.KContext import io.ksmt.decl.KDecl @@ -16,6 +15,7 @@ import io.ksmt.sort.KSort import io.ksmt.sort.KUninterpretedSort import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap import it.unimi.dsi.fastutil.longs.LongOpenHashSet +import it.unimi.dsi.fastutil.longs.LongSet import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap @Suppress("TooManyFunctions") @@ -310,27 +310,27 @@ class KZ3Context internal constructor( uninterpretedSortValueInterpreter.clear() - uninterpretedSortValueDecls.keys.decRefUnsafeAll(nCtx) + uninterpretedSortValueDecls.keys.decRefAll() uninterpretedSortValueDecls.clear() - uninterpretedSortValueInterpreters.decRefUnsafeAll(nCtx) + uninterpretedSortValueInterpreters.decRefAll() uninterpretedSortValueInterpreters.clear() - converterNativeObjects.decRefUnsafeAll(nCtx) + converterNativeObjects.decRefAll() converterNativeObjects.clear() - z3Expressions.keys.decRefUnsafeAll(nCtx) + z3Expressions.keys.decRefAll() expressions.clear() z3Expressions.clear() - tmpNativeObjects.decRefUnsafeAll(nCtx) + tmpNativeObjects.decRefAll() tmpNativeObjects.clear() - z3Decls.keys.decRefUnsafeAll(nCtx) + z3Decls.keys.decRefAll() decls.clear() z3Decls.clear() - z3Sorts.keys.decRefUnsafeAll(nCtx) + z3Sorts.keys.decRefAll() sorts.clear() z3Sorts.clear() @@ -340,4 +340,9 @@ class KZ3Context internal constructor( throw KSolverException(e) } } + + private fun LongSet.decRefAll() = + longIterator().forEachRemaining { + decRefUnsafe(nCtx, it) + } } diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt index cb93d09e2..ffe29bfed 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt @@ -62,7 +62,9 @@ open class KZ3ForkingSolver internal constructor( trackedAssertions.currentFrame[track] = trackedExpr } - override fun findTrackedExprByTrack(track: Long): KExpr? = trackedAssertions.find { it[track] } + override fun findTrackedExprByTrack(track: Long): KExpr? = trackedAssertions.findNonNullValue { + it[track] + } /** * Asserts parental (in case of child) assertions if not diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt index cddbd90b4..30b75c8ff 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt @@ -2,7 +2,7 @@ package io.ksmt.solver.z3 import com.microsoft.z3.Context import com.microsoft.z3.Z3Exception -import com.microsoft.z3.decRefUnsafeAll +import com.microsoft.z3.decRefUnsafe import io.ksmt.KAst import io.ksmt.KContext import io.ksmt.decl.KDecl @@ -16,6 +16,7 @@ import io.ksmt.sort.KSort import io.ksmt.sort.KUninterpretedSort import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap import it.unimi.dsi.fastutil.longs.LongOpenHashSet +import it.unimi.dsi.fastutil.longs.LongSet import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap import java.util.IdentityHashMap import java.util.concurrent.ConcurrentHashMap @@ -95,22 +96,22 @@ class KZ3ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager val nCtx = context.nCtx() contextReferences -= context - expressionsReversedCache.remove(context)!!.keys.decRefUnsafeAll(nCtx) + expressionsReversedCache.remove(context)!!.keys.decRefAll(nCtx) expressionsCache -= context - sortsReversedCache.remove(context)!!.keys.decRefUnsafeAll(nCtx) + sortsReversedCache.remove(context)!!.keys.decRefAll(nCtx) sortsCache -= context - declsReversedCache.remove(context)!!.keys.decRefUnsafeAll(nCtx) + declsReversedCache.remove(context)!!.keys.decRefAll(nCtx) declsCache -= context - uninterpretedSortValueInterpreters.remove(context)!!.decRefUnsafeAll(nCtx) + uninterpretedSortValueInterpreters.remove(context)!!.decRefAll(nCtx) uninterpretedSortValueInterpreter -= context uninterpretedSortValueDecls -= context registeredUninterpretedSortValues -= context - converterNativeObjectsCache.remove(context)!!.decRefUnsafeAll(nCtx) - tmpNativeObjectsCache.remove(context)!!.decRefUnsafeAll(nCtx) + converterNativeObjectsCache.remove(context)!!.decRefAll(nCtx) + tmpNativeObjectsCache.remove(context)!!.decRefAll(nCtx) try { ctx.close() @@ -151,6 +152,11 @@ class KZ3ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager defaultReturnValue(KExprLongInternalizerBase.NOT_INTERNALIZED) } + private fun LongSet.decRefAll(nCtx: Long) = + longIterator().forEachRemaining { + decRefUnsafe(nCtx, it) + } + } private typealias ExpressionsCache = Object2LongOpenHashMap> diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt index 114082f01..f76330cf8 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt @@ -4,17 +4,18 @@ import io.ksmt.KContext import io.ksmt.expr.KExpr import io.ksmt.solver.KSolver import io.ksmt.sort.KBoolSort -import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap open class KZ3Solver(ctx: KContext) : KZ3SolverBase(ctx), KSolver { override val z3Ctx: KZ3Context = KZ3Context(ctx) - private val trackedAssertions = ScopedArrayFrame { Long2ObjectOpenHashMap>() } + private val trackedAssertions = ScopedArrayFrameOfLong2ObjectOpenHashMap>() override fun saveTrackedAssertion(track: Long, trackedExpr: KExpr) { trackedAssertions.currentFrame[track] = trackedExpr } - override fun findTrackedExprByTrack(track: Long): KExpr? = trackedAssertions.find { it[track] } + override fun findTrackedExprByTrack(track: Long): KExpr? = trackedAssertions.findNonNullValue { + it[track] + } override fun push() { super.push() @@ -23,6 +24,6 @@ open class KZ3Solver(ctx: KContext) : KZ3SolverBase(ctx), KSolver { val currentScope: UInt val currentFrame: T - fun flatten(collect: T.(T) -> Unit): T - - /** - * find value [V] in frame [T], and return it or null - */ - fun find(predicate: (T) -> V?): V? - fun push() fun pop(n: UInt = 1u) } -internal class ScopedArrayFrame( - currentFrame: T, - private val createNewFrame: () -> T -) : ScopedFrame { - constructor(createNewFrame: () -> T) : this(createNewFrame(), createNewFrame) +internal class ScopedArrayFrameOfLong2ObjectOpenHashMap( + currentFrame: Long2ObjectOpenHashMap +) : ScopedFrame> { + constructor() : this(Long2ObjectOpenHashMap()) private val frames = arrayListOf(currentFrame) @@ -29,11 +23,7 @@ internal class ScopedArrayFrame( override val currentScope: UInt get() = frames.size.toUInt() - override fun flatten(collect: T.(T) -> Unit) = createNewFrame().also { newFrame -> - frames.forEach { newFrame.collect(it) } - } - - override fun find(predicate: (T) -> V?): V? { + inline fun findNonNullValue(predicate: (Long2ObjectOpenHashMap) -> V?): V? { frames.forEach { frame -> predicate(frame)?.let { return it } } @@ -41,7 +31,7 @@ internal class ScopedArrayFrame( } override fun push() { - currentFrame = createNewFrame() + currentFrame = Long2ObjectOpenHashMap() frames += currentFrame } @@ -73,19 +63,13 @@ internal class ScopedLinkedFrame private constructor( override val currentScope: UInt get() = current.scope - override fun flatten(collect: T.(T) -> Unit): T = createNewFrame().also { newFrame -> - forEachReversed { frame -> - newFrame.collect(frame) - } - } - fun stacked(): ArrayDeque = ArrayDeque().also { stack -> forEachReversed { frame -> stack.addLast(frame) } } - override fun find(predicate: (T) -> V?): V? { + inline fun findNonNullValue(predicate: (T) -> V?): V? { forEachReversed { frame -> predicate(frame)?.let { return it } } From 4438aff370be37bfd91591062a1cff116d12e290 Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Wed, 30 Aug 2023 15:56:32 +0300 Subject: [PATCH 09/12] z3 forking solver refactoring by uninterpreted sort values tracker extraction --- ...essionUninterpretedValuesForkingTracker.kt | 42 ++++ .../ExpressionUninterpretedValuesTracker.kt | 46 ++--- .../kotlin/io/ksmt/solver/z3/KZ3Context.kt | 114 ++++------- .../io/ksmt/solver/z3/KZ3ForkingContext.kt | 49 +++++ .../io/ksmt/solver/z3/KZ3ForkingSolver.kt | 46 ++--- .../ksmt/solver/z3/KZ3ForkingSolverManager.kt | 186 ++++++++---------- .../kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt | 27 +-- 7 files changed, 262 insertions(+), 248 deletions(-) create mode 100644 ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesForkingTracker.kt create mode 100644 ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingContext.kt diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesForkingTracker.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesForkingTracker.kt new file mode 100644 index 000000000..e67fae0b4 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesForkingTracker.kt @@ -0,0 +1,42 @@ +package io.ksmt.solver.z3 + +import io.ksmt.KContext +import io.ksmt.expr.KUninterpretedSortValue + +/** + * Uninterpreted sort values tracker with ability to fork. + * On child tracker creation ([fork]), it will have the same [registeredUninterpretedSortValues] as its parent + * to prevent descriptors loss. [expressionLevels] and [valueTrackerFrames] are copied + * to restore parental state of caching. + * Also, all axioms are asserted lazily on [assertPendingUninterpretedValueConstraints] + */ +class ExpressionUninterpretedValuesForkingTracker : ExpressionUninterpretedValuesTracker { + private constructor( + ctx: KContext, + z3Ctx: KZ3Context, + registeredUninterpretedSortValues: HashMap + ) : super(ctx, z3Ctx, registeredUninterpretedSortValues) + + constructor(ctx: KContext, z3Ctx: KZ3Context) : super(ctx, z3Ctx) + + fun fork(z3Ctx: KZ3Context) = ExpressionUninterpretedValuesForkingTracker( + ctx, z3Ctx, registeredUninterpretedSortValues + ).also { child -> + child.expressionLevels += expressionLevels + + var isFirstFrame = true + valueTrackerFrames.forEach { frame -> + if (!isFirstFrame) { + child.pushAssertionLevel() + } + + if (frame.initialized) { + child.currentFrame.ensureInitialized() + child.currentFrame.currentLevelUninterpretedValues += frame.currentLevelUninterpretedValues + child.currentFrame.currentLevelExpressions += frame.currentLevelExpressions + } + + isFirstFrame = false + } + } +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt index 89f560420..6915acd4e 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt @@ -3,13 +3,13 @@ package io.ksmt.solver.z3 import com.microsoft.z3.Native import com.microsoft.z3.Solver import com.microsoft.z3.solverAssert -import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap import io.ksmt.KContext import io.ksmt.expr.KExpr import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.expr.transformer.KNonRecursiveTransformer import io.ksmt.sort.KSort import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap /** * Uninterpreted sort values distinct constraints management. @@ -19,29 +19,25 @@ import io.ksmt.sort.KUninterpretedSort * 2. Assert distinct constraints ([assertPendingUninterpretedValueConstraints]) * that may be introduced during internalization. * */ -class ExpressionUninterpretedValuesTracker private constructor( +open class ExpressionUninterpretedValuesTracker protected constructor( val ctx: KContext, val z3Ctx: KZ3Context, - private val registeredUninterpretedSortValues: HashMap + protected val registeredUninterpretedSortValues: HashMap ) { constructor(ctx: KContext, z3Ctx: KZ3Context) : this(ctx, z3Ctx, hashMapOf()) - constructor(ctx: KContext, z3Ctx: KZ3Context, forkingSolverManager: KZ3ForkingSolverManager) : this( - ctx, - z3Ctx, - with(forkingSolverManager) { z3Ctx.findRegisteredUninterpretedSortValues() } - ) - private val expressionLevels = Object2IntOpenHashMap>().apply { + protected val expressionLevels = Object2IntOpenHashMap>().apply { defaultReturnValue(Int.MAX_VALUE) // Level which is greater than any possible level } - private var currentFrame = ValueTrackerAssertionFrame( - ctx, this, expressionLevels, + protected var currentFrame = ValueTrackerAssertionFrame( + ctx, expressionLevels, level = 0, notAssertedConstraintsFromPreviousLevels = 0 ) + private set - private val valueTrackerFrames = arrayListOf(currentFrame) + protected val valueTrackerFrames = arrayListOf(currentFrame) /** * Skip any value tracking related actions until @@ -57,11 +53,6 @@ class ExpressionUninterpretedValuesTracker private constructor( body() } - fun fork(parent: ExpressionUninterpretedValuesTracker) = also { - expressionLevels += parent.expressionLevels - repeat(parent.valueTrackerFrames.size - 1) { pushAssertionLevel() } - } - fun expressionUse(expr: KExpr<*>) = ifTrackingEnabled { currentFrame.analyzeUsedExpression(expr) } @@ -134,21 +125,22 @@ class ExpressionUninterpretedValuesTracker private constructor( z3Ctx.releaseTemporaryAst(constraintLhs) } - internal data class UninterpretedSortValueDescriptor( + protected data class UninterpretedSortValueDescriptor( val value: KUninterpretedSortValue, val nativeUniqueValueDescriptor: Long, val nativeValueExpr: Long ) - private class ValueTrackerAssertionFrame( + protected inner class ValueTrackerAssertionFrame( val ctx: KContext, - val tracker: ExpressionUninterpretedValuesTracker, val expressionLevels: Object2IntOpenHashMap>, val level: Int, val notAssertedConstraintsFromPreviousLevels: Int ) { - private var initialized = false - private var lastAssertedConstraint = 0 + var initialized = false + private set + + var lastAssertedConstraint = 0 lateinit var currentLevelExpressions: MutableSet> lateinit var currentLevelUninterpretedValues: MutableList @@ -159,7 +151,7 @@ class ExpressionUninterpretedValuesTracker private constructor( * since we might not have any uninterpreted values on * a current assertion level. * */ - private fun ensureInitialized() { + fun ensureInitialized() { if (initialized) return currentLevelExpressions = hashSetOf() @@ -176,7 +168,7 @@ class ExpressionUninterpretedValuesTracker private constructor( val notAssertedConstraints = numberOfConstraints - lastAssertedConstraint val nextLevelRemainingConstraints = notAssertedConstraintsFromPreviousLevels + notAssertedConstraints return ValueTrackerAssertionFrame( - ctx, tracker, expressionLevels, + ctx, expressionLevels, level = level + 1, notAssertedConstraintsFromPreviousLevels = nextLevelRemainingConstraints ) @@ -228,7 +220,7 @@ class ExpressionUninterpretedValuesTracker private constructor( } fun addRegisteredValueToCurrentLevel(value: KUninterpretedSortValue) { - val descriptor = tracker.registeredUninterpretedSortValues[value] + val descriptor = registeredUninterpretedSortValues[value] ?: error("Value $value was not registered") addRegisteredValueToCurrentLevel(descriptor) } @@ -239,10 +231,10 @@ class ExpressionUninterpretedValuesTracker private constructor( currentLevelUninterpretedValues.add(descriptor) } - fun getFrame(level: Int) = tracker.valueTrackerFrames[level] + fun getFrame(level: Int) = valueTrackerFrames[level] } - private class ExprUninterpretedValuesAnalyzer( + protected class ExprUninterpretedValuesAnalyzer( ctx: KContext, val frame: ValueTrackerAssertionFrame ) : KNonRecursiveTransformer(ctx) { diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt index 18af6c350..714f2b178 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt @@ -19,70 +19,29 @@ import it.unimi.dsi.fastutil.longs.LongSet import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap @Suppress("TooManyFunctions") -class KZ3Context internal constructor( +open class KZ3Context( ksmtCtx: KContext, private val ctx: Context, - forkingSolverManager: KZ3ForkingSolverManager?, ) : AutoCloseable { - constructor(ksmtCtx: KContext, ctx: Context) : this(ksmtCtx, ctx, null) - constructor(ksmtCtx: KContext) : this(ksmtCtx, Context(), null) - - private var isClosed = false - private val isForking = forkingSolverManager != null - - // common for parent and child structures - private val expressions: Object2LongOpenHashMap> - private val sorts: Object2LongOpenHashMap - private val decls: Object2LongOpenHashMap> - - private val z3Expressions: Long2ObjectOpenHashMap> - private val z3Sorts: Long2ObjectOpenHashMap - private val z3Decls: Long2ObjectOpenHashMap> - private val tmpNativeObjects: LongOpenHashSet - private val converterNativeObjects: LongOpenHashSet - - private val uninterpretedSortValueInterpreter: HashMap - private val uninterpretedSortValueDecls: Long2ObjectOpenHashMap - private val uninterpretedSortValueInterpreters: LongOpenHashSet - - - val uninterpretedValuesTracker: ExpressionUninterpretedValuesTracker - - init { - if (forkingSolverManager != null) { - with(forkingSolverManager) { - expressions = findExpressionsCache() - sorts = findSortsCache() - decls = findDeclsCache() - - z3Expressions = findExpressionsReversedCache() - z3Sorts = findSortsReversedCache() - z3Decls = findDeclsReversedCache() - tmpNativeObjects = findTmpNativeObjectsCache() - converterNativeObjects = findConverterNativeObjectsCache() - uninterpretedSortValueInterpreter = findUninterpretedSortValueInterpreter() - uninterpretedSortValueDecls = findUninterpretedSortValueDecls() - uninterpretedSortValueInterpreters = findUninterpretedSortValueInterpreters() - } - uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this, forkingSolverManager) - } else { - expressions = Object2LongOpenHashMap>().apply { defaultReturnValue(NOT_INTERNALIZED) } - sorts = Object2LongOpenHashMap().apply { defaultReturnValue(NOT_INTERNALIZED) } - decls = Object2LongOpenHashMap>().apply { defaultReturnValue(NOT_INTERNALIZED) } - - z3Expressions = Long2ObjectOpenHashMap>() - z3Sorts = Long2ObjectOpenHashMap() - z3Decls = Long2ObjectOpenHashMap>() - tmpNativeObjects = LongOpenHashSet() - converterNativeObjects = LongOpenHashSet() - - uninterpretedSortValueInterpreter = hashMapOf() - uninterpretedSortValueDecls = Long2ObjectOpenHashMap() - uninterpretedSortValueInterpreters = LongOpenHashSet() - - uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this) - } - } + constructor(ksmtCtx: KContext) : this(ksmtCtx, Context()) + + protected var isClosed = false + + protected open val expressions = Object2LongOpenHashMap>().apply { defaultReturnValue(NOT_INTERNALIZED) } + protected open val sorts = Object2LongOpenHashMap().apply { defaultReturnValue(NOT_INTERNALIZED) } + protected open val decls = Object2LongOpenHashMap>().apply { defaultReturnValue(NOT_INTERNALIZED) } + + protected open val z3Expressions = Long2ObjectOpenHashMap>() + protected open val z3Sorts = Long2ObjectOpenHashMap() + protected open val z3Decls = Long2ObjectOpenHashMap>() + protected open val tmpNativeObjects = LongOpenHashSet() + protected open val converterNativeObjects = LongOpenHashSet() + + protected open val uninterpretedSortValueInterpreter = hashMapOf() + protected open val uninterpretedSortValueDecls = Long2ObjectOpenHashMap() + protected open val uninterpretedSortValueInterpreters = LongOpenHashSet() + + open val uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this) @JvmField @@ -94,13 +53,6 @@ class KZ3Context internal constructor( val isActive: Boolean get() = !isClosed - internal fun fork(ksmtCtx: KContext, manager: KZ3ForkingSolverManager): KZ3Context { - require(isForking) { "Can't fork non-forking context" } - return KZ3Context(ksmtCtx, ctx, manager).also { - it.uninterpretedValuesTracker.fork(uninterpretedValuesTracker) - } - } - internal fun findInternalizedExprWithoutAnalysis(expr: KExpr<*>): Long { val result = expressions.getLong(expr) return if (result == NOT_INTERNALIZED) NOT_INTERNALIZED else result @@ -304,9 +256,6 @@ class KZ3Context internal constructor( override fun close() { if (isClosed) return - isClosed = true - - if (isForking) return uninterpretedSortValueInterpreter.clear() @@ -334,15 +283,24 @@ class KZ3Context internal constructor( sorts.clear() z3Sorts.clear() - try { + z3Try { + isClosed = true ctx.close() - } catch (e: Z3Exception) { - throw KSolverException(e) } } - private fun LongSet.decRefAll() = - longIterator().forEachRemaining { - decRefUnsafe(nCtx, it) - } + private fun LongSet.decRefAll() = longIterator().forEachRemaining { + decRefUnsafe(nCtx, it) + } + + fun ensureActive() { + check(!isClosed) { "The context is already closed." } + } + + inline fun z3Try(body: () -> T): T = try { + ensureActive() + body() + } catch (ex: Z3Exception) { + throw KSolverException(ex) + } } diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingContext.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingContext.kt new file mode 100644 index 000000000..cabd6bac9 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingContext.kt @@ -0,0 +1,49 @@ +package io.ksmt.solver.z3 + +import com.microsoft.z3.Context +import io.ksmt.KContext + +class KZ3ForkingContext private constructor( + ksmtCtx: KContext, + private val ctx: Context, + manager: KZ3ForkingSolverManager, + parent: KZ3ForkingContext?, +) : KZ3Context(ksmtCtx, ctx) { + + constructor( + ksmtCtx: KContext, + ctx: Context, + manager: KZ3ForkingSolverManager + ) : this(ksmtCtx, ctx, manager, null) + + constructor(ksmtCtx: KContext, manager: KZ3ForkingSolverManager) : this(ksmtCtx, Context(), manager) + + // common for parent and child structures + override val expressions = with(manager) { getExpressionsCache() } + override val sorts = with(manager) { getSortsCache() } + override val decls = with(manager) { getDeclsCache() } + + override val z3Expressions = with(manager) { getExpressionsReversedCache() } + override val z3Sorts = with(manager) { getSortsReversedCache() } + override val z3Decls = with(manager) { getDeclsReversedCache() } + override val tmpNativeObjects = with(manager) { getTmpNativeObjectsCache() } + override val converterNativeObjects = with(manager) { getConverterNativeObjectsCache() } + + override val uninterpretedSortValueInterpreter = with(manager) { getUninterpretedSortValueInterpreter() } + override val uninterpretedSortValueDecls = with(manager) { getUninterpretedSortValueDecls() } + override val uninterpretedSortValueInterpreters = with(manager) { getUninterpretedSortValueInterpreters() } + + override val uninterpretedValuesTracker: ExpressionUninterpretedValuesForkingTracker = parent + ?.uninterpretedValuesTracker?.fork(this) + ?: ExpressionUninterpretedValuesForkingTracker(ksmtCtx, this) + + internal fun fork(ksmtCtx: KContext, manager: KZ3ForkingSolverManager): KZ3ForkingContext { + ensureActive() + return KZ3ForkingContext(ksmtCtx, ctx, manager, this) + } + + override fun close() { + if (isClosed) return + isClosed = true + } +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt index ffe29bfed..e4dd2560d 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt @@ -1,6 +1,5 @@ package io.ksmt.solver.z3 -import com.microsoft.z3.Context import com.microsoft.z3.solverAssert import com.microsoft.z3.solverAssertAndTrack import io.ksmt.KContext @@ -17,7 +16,7 @@ open class KZ3ForkingSolver internal constructor( private val manager: KZ3ForkingSolverManager, parent: KZ3ForkingSolver? ) : KZ3SolverBase(ctx), KForkingSolver { - final override val z3Ctx: KZ3Context + final override val z3Ctx: KZ3ForkingContext = manager.createZ3ForkingContext(parent?.z3Ctx) private val trackedAssertions = ScopedLinkedFrame>>( ::Long2ObjectOpenHashMap, ::Long2ObjectOpenHashMap @@ -27,20 +26,8 @@ open class KZ3ForkingSolver internal constructor( private val isChild = parent != null private var assertionsInitiated = !isChild - init { - if (parent != null) { - z3Ctx = parent.z3Ctx.fork(ctx, manager) - trackedAssertions.fork(parent.trackedAssertions) - z3Assertions.fork(parent.z3Assertions) - } else { - val context = Context() - with(manager) { registerContext(context) } - z3Ctx = KZ3Context(ctx, context, manager) - } - } - private val config: KZ3ForkingSolverConfigurationImpl by lazy { - z3Try { + z3Ctx.z3Try { z3Ctx.nativeContext.mkParams().let { parent?.config?.fork(it)?.apply { setParameters(solver) } ?: KZ3ForkingSolverConfigurationImpl(it) } @@ -48,7 +35,12 @@ open class KZ3ForkingSolver internal constructor( } init { - if (isChild) config // initialize child config + if (parent != null) { + // lightweight copying via copying of the linked list node + trackedAssertions.fork(parent.trackedAssertions) + z3Assertions.fork(parent.z3Assertions) + config // initialize child config + } } override fun configure(configurator: KZ3SolverConfiguration.() -> Unit) { @@ -56,6 +48,9 @@ open class KZ3ForkingSolver internal constructor( config.setParameters(solver) } + /** + * Creates lazily initiated forked solver with shared cache, preserving parental assertions and configuration. + */ override fun fork(): KForkingSolver = manager.mkForkingSolver(this) override fun saveTrackedAssertion(track: Long, trackedExpr: KExpr) { @@ -67,7 +62,7 @@ open class KZ3ForkingSolver internal constructor( } /** - * Asserts parental (in case of child) assertions if not + * Asserts parental (in case of child) assertions if already not */ private fun ensureAssertionsInitiated() { if (assertionsInitiated) return @@ -92,20 +87,20 @@ open class KZ3ForkingSolver internal constructor( } override fun push() { - z3Try { ensureAssertionsInitiated() } + z3Ctx.z3Try { ensureAssertionsInitiated() } super.push() trackedAssertions.push() z3Assertions.push() } override fun pop(n: UInt) { - z3Try { ensureAssertionsInitiated() } + z3Ctx.z3Try { ensureAssertionsInitiated() } super.pop(n) trackedAssertions.pop(n) z3Assertions.pop(n) } - override fun assert(expr: KExpr) = z3Try { + override fun assert(expr: KExpr) = z3Ctx.z3Try { ensureAssertionsInitiated() ctx.ensureContextMatch(expr) @@ -117,17 +112,22 @@ open class KZ3ForkingSolver internal constructor( } override fun assertAndTrack(expr: KExpr) { - z3Try { ensureAssertionsInitiated() } + z3Ctx.z3Try { ensureAssertionsInitiated() } super.assertAndTrack(expr) } override fun check(timeout: Duration): KSolverStatus { - z3Try { ensureAssertionsInitiated() } + z3Ctx.z3Try { ensureAssertionsInitiated() } return super.check(timeout) } override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { - z3Try { ensureAssertionsInitiated() } + z3Ctx.z3Try { ensureAssertionsInitiated() } return super.checkWithAssumptions(assumptions, timeout) } + + override fun close() { + super.close() + manager.close(this) + } } diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt index 30b75c8ff..0711873d3 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt @@ -18,145 +18,127 @@ import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap import it.unimi.dsi.fastutil.longs.LongOpenHashSet import it.unimi.dsi.fastutil.longs.LongSet import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap -import java.util.IdentityHashMap import java.util.concurrent.ConcurrentHashMap +/** + * Responsible for creation and managing of [KZ3ForkingSolver]. + * + * It's cheaper to create multiple copies of solvers with [KZ3ForkingSolver.fork] + * instead of assertions transferring in [KZ3Solver] instances. + * + * All created solvers with one manager (via both [KZ3ForkingSolver.fork] and [mkForkingSolver]) + * use the same [Context], cache, and registered uninterpreted sort values. + */ class KZ3ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { + private val z3Context by lazy { Context() } private val solvers = ConcurrentHashMap.newKeySet() - /** - * for each parent-to-child hierarchy created only one Context. - * Each Context user is registered to control solver is alive - */ - private val forkingSolverToContext = IdentityHashMap() - private val contextReferences = IdentityHashMap() - // shared cache - private val expressionsCache = IdentityHashMap() - private val expressionsReversedCache = IdentityHashMap() - private val sortsCache = IdentityHashMap() - private val sortsReversedCache = IdentityHashMap() - private val declsCache = IdentityHashMap() - private val declsReversedCache = IdentityHashMap() - - private val tmpNativeObjectsCache = IdentityHashMap() - private val converterNativeObjectsCache = IdentityHashMap() - - private val uninterpretedSortValueInterpreter = IdentityHashMap() - private val uninterpretedSortValueDecls = IdentityHashMap() - private val uninterpretedSortValueInterpreters = IdentityHashMap() - private val registeredUninterpretedSortValues = IdentityHashMap() - - internal fun KZ3Context.findExpressionsCache() = expressionsCache.getValue(nativeContext) - internal fun KZ3Context.findExpressionsReversedCache() = expressionsReversedCache.getValue(nativeContext) - internal fun KZ3Context.findSortsCache() = sortsCache.getValue(nativeContext) - internal fun KZ3Context.findSortsReversedCache() = sortsReversedCache.getValue(nativeContext) - internal fun KZ3Context.findDeclsCache() = declsCache.getValue(nativeContext) - internal fun KZ3Context.findDeclsReversedCache() = declsReversedCache.getValue(nativeContext) - internal fun KZ3Context.findTmpNativeObjectsCache() = tmpNativeObjectsCache.getValue(nativeContext) - internal fun KZ3Context.findConverterNativeObjectsCache() = converterNativeObjectsCache.getValue(nativeContext) - internal fun KZ3Context.findUninterpretedSortValueInterpreter() = - uninterpretedSortValueInterpreter.getValue(nativeContext) - - internal fun KZ3Context.findUninterpretedSortValueDecls() = - uninterpretedSortValueDecls.getValue(nativeContext) - - internal fun KZ3Context.findUninterpretedSortValueInterpreters() = - uninterpretedSortValueInterpreters.getValue(nativeContext) - - internal fun KZ3Context.findRegisteredUninterpretedSortValues() = - registeredUninterpretedSortValues.getValue(nativeContext) - - internal fun KZ3ForkingSolver.registerContext(sharedContext: Context) { - if (forkingSolverToContext.putIfAbsent(this, sharedContext) == null) { - incRef(sharedContext) - - expressionsCache[sharedContext] = ExpressionsCache().withNotInternalizedAsDefaultValue() - expressionsReversedCache[sharedContext] = ExpressionsReversedCache() - sortsCache[sharedContext] = SortsCache().withNotInternalizedAsDefaultValue() - sortsReversedCache[sharedContext] = SortsReversedCache() - declsCache[sharedContext] = DeclsCache().withNotInternalizedAsDefaultValue() - declsReversedCache[sharedContext] = DeclsReversedCache() - tmpNativeObjectsCache[sharedContext] = TmpNativeObjectsCache() - converterNativeObjectsCache[sharedContext] = ConverterNativeObjectsCache() - uninterpretedSortValueInterpreter[sharedContext] = UninterpretedSortValueInterpreterCache() - uninterpretedSortValueDecls[sharedContext] = UninterpretedSortValueDecls() - uninterpretedSortValueInterpreters[sharedContext] = UninterpretedSortValueInterpretersCache() - registeredUninterpretedSortValues[sharedContext] = RegisteredUninterpretedSortValues() - } - } + private val expressionsCache = ExpressionsCache().withNotInternalizedAsDefaultValue() + private val expressionsReversedCache = ExpressionsReversedCache() + private val sortsCache = SortsCache().withNotInternalizedAsDefaultValue() + private val sortsReversedCache = SortsReversedCache() + private val declsCache = DeclsCache().withNotInternalizedAsDefaultValue() + private val declsReversedCache = DeclsReversedCache() - private fun incRef(context: Context) { - contextReferences[context] = contextReferences.getOrDefault(context, 0) + 1 - } + private val tmpNativeObjectsCache = TmpNativeObjectsCache() + private val converterNativeObjectsCache = ConverterNativeObjectsCache() - private fun decRef(context: Context) { - val referencesAfterDec = contextReferences.getValue(context) - 1 - if (referencesAfterDec == 0) { - val nCtx = context.nCtx() - contextReferences -= context + private val uninterpretedSortValueInterpreter = UninterpretedSortValueInterpreterCache() + private val uninterpretedSortValueDecls = UninterpretedSortValueDecls() + private val uninterpretedSortValueInterpreters = UninterpretedSortValueInterpretersCache() - expressionsReversedCache.remove(context)!!.keys.decRefAll(nCtx) - expressionsCache -= context + internal fun KZ3Context.getExpressionsCache() = ensureContextMatches(nativeContext).let { expressionsCache } + internal fun KZ3Context.getExpressionsReversedCache() = ensureContextMatches(nativeContext) + .let { expressionsReversedCache } - sortsReversedCache.remove(context)!!.keys.decRefAll(nCtx) - sortsCache -= context + internal fun KZ3Context.getSortsCache() = ensureContextMatches(nativeContext).let { sortsCache } + internal fun KZ3Context.getSortsReversedCache() = ensureContextMatches(nativeContext).let { sortsReversedCache } + internal fun KZ3Context.getDeclsCache() = ensureContextMatches(nativeContext).let { declsCache } + internal fun KZ3Context.getDeclsReversedCache() = ensureContextMatches(nativeContext).let { declsReversedCache } + internal fun KZ3Context.getTmpNativeObjectsCache() = ensureContextMatches(nativeContext) + .let { tmpNativeObjectsCache } - declsReversedCache.remove(context)!!.keys.decRefAll(nCtx) - declsCache -= context + internal fun KZ3Context.getConverterNativeObjectsCache() = ensureContextMatches(nativeContext) + .let { converterNativeObjectsCache } - uninterpretedSortValueInterpreters.remove(context)!!.decRefAll(nCtx) - uninterpretedSortValueInterpreter -= context - uninterpretedSortValueDecls -= context - registeredUninterpretedSortValues -= context + internal fun KZ3Context.getUninterpretedSortValueInterpreter() = ensureContextMatches(nativeContext) + .let { uninterpretedSortValueInterpreter } - converterNativeObjectsCache.remove(context)!!.decRefAll(nCtx) - tmpNativeObjectsCache.remove(context)!!.decRefAll(nCtx) + internal fun KZ3Context.getUninterpretedSortValueDecls() = ensureContextMatches(nativeContext) + .let { uninterpretedSortValueDecls } - try { - ctx.close() - } catch (e: Z3Exception) { - throw KSolverException(e) - } - } else { - contextReferences[context] = referencesAfterDec - } - } + internal fun KZ3Context.getUninterpretedSortValueInterpreters() = ensureContextMatches(nativeContext) + .let { uninterpretedSortValueInterpreters } override fun mkForkingSolver(): KForkingSolver { return KZ3ForkingSolver(ctx, this, null).also { solvers += it } } internal fun mkForkingSolver(parent: KZ3ForkingSolver): KForkingSolver { - return KZ3ForkingSolver(ctx, this, parent).also { - solvers += it - forkingSolverToContext[it] = forkingSolverToContext[parent] - } + return KZ3ForkingSolver(ctx, this, parent).also { solvers += it } } + internal fun createZ3ForkingContext(parentCtx: KZ3ForkingContext? = null) = parentCtx?.fork(ctx, this) + ?: KZ3ForkingContext(ctx, z3Context, this) + /** * unregister [solver] for this manager */ internal fun close(solver: KZ3ForkingSolver) { solvers -= solver - val sharedContext = forkingSolverToContext.getValue(solver) - forkingSolverToContext -= solver - decRef(sharedContext) + closeContextIfStale() } override fun close() { solvers.forEach(KZ3ForkingSolver::close) } + private fun closeContextIfStale() { + if (solvers.isEmpty()) { + val nCtx = z3Context.nCtx() + + expressionsReversedCache.keys.decRefAll(nCtx) + expressionsReversedCache.clear() + expressionsCache.clear() + + sortsReversedCache.keys.decRefAll(nCtx) + sortsReversedCache.clear() + sortsCache.clear() + + declsReversedCache.keys.decRefAll(nCtx) + declsReversedCache.clear() + declsCache.clear() + + uninterpretedSortValueInterpreters.decRefAll(nCtx) + uninterpretedSortValueInterpreters.clear() + uninterpretedSortValueInterpreter.clear() + uninterpretedSortValueDecls.clear() + + converterNativeObjectsCache.decRefAll(nCtx) + converterNativeObjectsCache.clear() + tmpNativeObjectsCache.decRefAll(nCtx) + tmpNativeObjectsCache.clear() + + try { + ctx.close() + } catch (e: Z3Exception) { + throw KSolverException(e) + } + } + } + private fun Object2LongOpenHashMap.withNotInternalizedAsDefaultValue() = apply { defaultReturnValue(KExprLongInternalizerBase.NOT_INTERNALIZED) } - private fun LongSet.decRefAll(nCtx: Long) = - longIterator().forEachRemaining { - decRefUnsafe(nCtx, it) - } + private fun LongSet.decRefAll(nCtx: Long) = longIterator().forEachRemaining { + decRefUnsafe(nCtx, it) + } + private fun ensureContextMatches(ctx: Context) { + require(ctx == z3Context) { "Context is not registered by manager." } + } } private typealias ExpressionsCache = Object2LongOpenHashMap> @@ -174,5 +156,3 @@ private typealias ConverterNativeObjectsCache = LongOpenHashSet private typealias UninterpretedSortValueInterpreterCache = HashMap private typealias UninterpretedSortValueDecls = Long2ObjectOpenHashMap private typealias UninterpretedSortValueInterpretersCache = LongOpenHashSet -@Suppress("MaxLineLength") -private typealias RegisteredUninterpretedSortValues = HashMap diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt index 16a287a0d..c95d0f636 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt @@ -11,7 +11,6 @@ import io.ksmt.KContext import io.ksmt.expr.KExpr import io.ksmt.solver.KModel import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverException import io.ksmt.solver.KSolverStatus import io.ksmt.sort.KBoolSort import io.ksmt.utils.NativeLibraryLoader @@ -46,21 +45,21 @@ abstract class KZ3SolverBase(protected val ctx: KContext) : KSolver Unit) = z3Try { + override fun configure(configurator: KZ3SolverConfiguration.() -> Unit) = z3Ctx.z3Try { val params = z3Ctx.nativeContext.mkParams() KZ3SolverConfigurationImpl(params).configurator() solver.setParameters(params) } - override fun push(): Unit = z3Try { + override fun push(): Unit = z3Ctx.z3Try { solver.push() z3Ctx.pushAssertionLevel() currentScope++ } - override fun pop(n: UInt) = z3Try { + override fun pop(n: UInt) = z3Ctx.z3Try { require(n <= currentScope) { "Can not pop $n scope levels because current scope level is $currentScope" } @@ -72,7 +71,7 @@ abstract class KZ3SolverBase(protected val ctx: KContext) : KSolver) = z3Try { + override fun assert(expr: KExpr) = z3Ctx.z3Try { ctx.ensureContextMatch(expr) val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } @@ -84,7 +83,7 @@ abstract class KZ3SolverBase(protected val ctx: KContext) : KSolver) protected abstract fun findTrackedExprByTrack(track: Long): KExpr? - override fun assertAndTrack(expr: KExpr) = z3Try { + override fun assertAndTrack(expr: KExpr) = z3Ctx.z3Try { ctx.ensureContextMatch(expr) val trackExpr = ctx.mkFreshConst("track", ctx.boolSort) @@ -128,7 +127,7 @@ abstract class KZ3SolverBase(protected val ctx: KContext) : KSolver> = z3Try { + override fun unsatCore(): List> = z3Ctx.z3Try { require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } val unsatCore = lastUnsatCore ?: with(exprConverter) { @@ -158,12 +157,12 @@ abstract class KZ3SolverBase(protected val ctx: KContext) : KSolver z3Try(body: () -> T): T = try { - body() - } catch (ex: Z3Exception) { - throw KSolverException(ex) - } - protected fun invalidateSolverState() { lastReasonOfUnknown = null lastCheckStatus = KSolverStatus.UNKNOWN From 6f88204d2e6a1b6f112e909d3b86171d786d2fda Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Wed, 30 Aug 2023 22:16:19 +0300 Subject: [PATCH 10/12] cvc5 forking solver refactoring by uninterpreted sort values forking tracker extraction --- ...essionUninterpretedValuesForkingTracker.kt | 22 ++ .../ExpressionUninterpretedValuesTracker.kt | 135 ++++++++ .../io/ksmt/solver/cvc5/KCvc5Context.kt | 288 ++++-------------- .../ksmt/solver/cvc5/KCvc5ForkingContext.kt | 75 +++++ .../io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt | 41 +-- .../solver/cvc5/KCvc5ForkingSolverManager.kt | 137 ++++----- .../kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt | 17 +- .../kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt | 2 +- .../io/ksmt/solver/cvc5/KCvc5SolverBase.kt | 21 +- .../kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt | 11 +- 10 files changed, 383 insertions(+), 366 deletions(-) create mode 100644 ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesForkingTracker.kt create mode 100644 ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesTracker.kt create mode 100644 ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingContext.kt diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesForkingTracker.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesForkingTracker.kt new file mode 100644 index 000000000..5e3679d14 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesForkingTracker.kt @@ -0,0 +1,22 @@ +package io.ksmt.solver.cvc5 + +/** + * An uninterpreted sort values tracker with ability to fork, + * preserving all registered descriptors ([uninterpretedSortValueDescriptors]). + * In a newly-forked tracker all known axioms will be asserted at + * the nearest call of [assertPendingUninterpretedValueConstraints] + */ +class ExpressionUninterpretedValuesForkingTracker : ExpressionUninterpretedValuesTracker { + constructor(cvc5Ctx: KCvc5Context) : super(cvc5Ctx) + private constructor( + cvc5Ctx: KCvc5Context, + uninterpretedSortValueDescriptors: ArrayList + ) : super(cvc5Ctx, uninterpretedSortValueDescriptors) + + fun fork(childCvc5Ctx: KCvc5Context) = + ExpressionUninterpretedValuesForkingTracker(childCvc5Ctx, uninterpretedSortValueDescriptors).also { child -> + repeat(assertedConstraintLevels.size) { + child.pushAssertionLevel() + } + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesTracker.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesTracker.kt new file mode 100644 index 000000000..b1f9acef7 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesTracker.kt @@ -0,0 +1,135 @@ +package io.ksmt.solver.cvc5 + +import io.github.cvc5.Kind +import io.github.cvc5.Solver +import io.github.cvc5.Term +import io.ksmt.decl.KDecl +import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.sort.KArray2Sort +import io.ksmt.sort.KArray3Sort +import io.ksmt.sort.KArrayNSort +import io.ksmt.sort.KArraySort +import io.ksmt.sort.KBoolSort +import io.ksmt.sort.KBvSort +import io.ksmt.sort.KFpRoundingModeSort +import io.ksmt.sort.KFpSort +import io.ksmt.sort.KIntSort +import io.ksmt.sort.KRealSort +import io.ksmt.sort.KSort +import io.ksmt.sort.KSortVisitor +import io.ksmt.sort.KUninterpretedSort + +open class ExpressionUninterpretedValuesTracker protected constructor( + private val cvc5Ctx: KCvc5Context, + protected val uninterpretedSortValueDescriptors: ArrayList +) { + constructor(cvc5Ctx: KCvc5Context) : this(cvc5Ctx, arrayListOf()) + + private var currentValueConstraintsLevel = 0 + protected val assertedConstraintLevels = arrayListOf() + + private val uninterpretedSortCollector = KUninterpretedSortCollector(cvc5Ctx) + + fun collectUninterpretedSorts(decl: KDecl<*>) { + uninterpretedSortCollector.collect(decl) + } + + fun pushAssertionLevel() { + assertedConstraintLevels += currentValueConstraintsLevel + } + + fun popAssertionLevel() { + currentValueConstraintsLevel = assertedConstraintLevels.removeLast() + } + + fun registerUninterpretedSortValue( + value: KUninterpretedSortValue, + uniqueValueDescriptorTerm: Term, + uninterpretedValueTerm: Term + ) { + uninterpretedSortValueDescriptors += UninterpretedSortValueDescriptor( + value = value, + nativeUniqueValueDescriptor = uniqueValueDescriptorTerm, + nativeValueTerm = uninterpretedValueTerm + ) + } + + fun assertPendingUninterpretedValueConstraints(solver: Solver) { + while (currentValueConstraintsLevel < uninterpretedSortValueDescriptors.size) { + assertUninterpretedSortValueConstraint( + solver, + uninterpretedSortValueDescriptors[currentValueConstraintsLevel] + ) + currentValueConstraintsLevel++ + } + } + + private fun assertUninterpretedSortValueConstraint(solver: Solver, value: UninterpretedSortValueDescriptor) { + val interpreter = cvc5Ctx.getUninterpretedSortValueInterpreter(value.value.sort) + ?: error("Interpreter was not registered for sort: ${value.value.sort}") + + val constraintLhs = solver.mkTerm(Kind.APPLY_UF, arrayOf(interpreter, value.nativeValueTerm)) + val constraint = constraintLhs.eqTerm(value.nativeUniqueValueDescriptor) + + solver.assertFormula(constraint) + } + + @Suppress("ForbiddenComment") + /** + * Uninterpreted sort values distinct constraints management. + * + * 1. save/register uninterpreted value. + * See [KUninterpretedSortValue] internalization for the details. + * 2. Assert distinct constraints ([assertPendingUninterpretedValueConstraints]) that may be introduced + * during internalization. + * Currently, we assert constraints for all the values we have ever internalized. + * + * todo: precise uninterpreted sort values tracking + * */ + protected data class UninterpretedSortValueDescriptor( + val value: KUninterpretedSortValue, + val nativeUniqueValueDescriptor: Term, + val nativeValueTerm: Term + ) + + class KUninterpretedSortCollector(private val cvc5Ctx: KCvc5Context) : KSortVisitor { + override fun visit(sort: KBoolSort) = Unit + + override fun visit(sort: KIntSort) = Unit + + override fun visit(sort: KRealSort) = Unit + + override fun visit(sort: S) = Unit + + override fun visit(sort: S) = Unit + + override fun visit(sort: KArraySort) { + sort.domain.accept(this) + sort.range.accept(this) + } + + override fun visit(sort: KArray3Sort) { + sort.domainSorts.forEach { it.accept(this) } + sort.range.accept(this) + } + + override fun visit(sort: KArray2Sort) { + sort.domainSorts.forEach { it.accept(this) } + sort.range.accept(this) + } + + override fun visit(sort: KArrayNSort) { + sort.domainSorts.forEach { it.accept(this) } + sort.range.accept(this) + } + + override fun visit(sort: KFpRoundingModeSort) = Unit + + override fun visit(sort: KUninterpretedSort) = cvc5Ctx.addUninterpretedSort(sort) + + fun collect(decl: KDecl<*>) { + decl.argSorts.map { it.accept(this) } + decl.sort.accept(this) + } + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt index 8f72a81e7..84550a933 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt @@ -1,6 +1,6 @@ package io.ksmt.solver.cvc5 -import io.github.cvc5.Kind +import io.github.cvc5.CVC5ApiException import io.github.cvc5.Solver import io.github.cvc5.Sort import io.github.cvc5.Term @@ -16,46 +16,34 @@ import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.expr.KUniversalQuantifier import io.ksmt.expr.rewrite.KExprUninterpretedDeclCollector import io.ksmt.expr.transformer.KNonRecursiveTransformer -import io.ksmt.sort.KArray2Sort -import io.ksmt.sort.KArray3Sort -import io.ksmt.sort.KArrayNSort +import io.ksmt.solver.KSolverException import io.ksmt.sort.KArraySort import io.ksmt.sort.KArraySortBase import io.ksmt.sort.KBoolSort -import io.ksmt.sort.KBvSort -import io.ksmt.sort.KFpRoundingModeSort -import io.ksmt.sort.KFpSort -import io.ksmt.sort.KIntSort -import io.ksmt.sort.KRealSort import io.ksmt.sort.KSort -import io.ksmt.sort.KSortVisitor import io.ksmt.sort.KUninterpretedSort import java.util.TreeMap -class KCvc5Context internal constructor( - private val solver: Solver, - /** - * Used as context for expressions lifetime separation. - * Exprs which stored in [KCvc5Context], created with [mkExprSolver] - */ +/** + * @param mkExprSolver used as "context" for creation of all native expressions, which are stored in [KCvc5Context]. + */ +open class KCvc5Context( + protected val solver: Solver, val mkExprSolver: Solver, - private val ctx: KContext, - forkingSolverManager: KCvc5ForkingSolverManager? = null + protected val ctx: KContext ) : AutoCloseable { - constructor(solver: Solver, mkExprSolver: Solver, ctx: KContext) - : this(solver, mkExprSolver, ctx, null) - constructor(solver: Solver, ctx: KContext) - : this(solver, solver, ctx, null) + /** + * Creates [KCvc5Context]. All native expressions will be created via [solver]. + */ + constructor(solver: Solver, ctx: KContext) : this(solver, solver, ctx) - private var isClosed = false - val isForking = forkingSolverManager != null + protected var isClosed = false - private val uninterpretedSortCollector = KUninterpretedSortCollector(this) - private var exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(uninterpretedSortCollector, ctx) + private var exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(ctx) - private val uninterpretedSorts: ScopedFrame> - private val declarations: ScopedFrame>> + protected open val uninterpretedSorts: ScopedFrame> = ScopedArrayFrame(::HashSet) + protected open val declarations: ScopedFrame>> = ScopedArrayFrame(::HashSet) /** @@ -76,64 +64,23 @@ class KCvc5Context internal constructor( * that is in global cache, but whose sorts / decls have been erased after pop() * (and put this expr to the cache of current accumulated scope) */ - private val currentAccumulatedScopeExpressions = HashMap, Term>() - private val expressions: HashMap, Term> + protected val currentAccumulatedScopeExpressions = HashMap, Term>() + protected open val expressions = HashMap, Term>() /** * We can't use HashMap with Term and Sort (hashcode is not implemented) */ - private val cvc5Expressions: TreeMap> - private val sorts: HashMap - private val cvc5Sorts: TreeMap - private val decls: HashMap, Term> - private val cvc5Decls: TreeMap> - - private val uninterpretedSortValueDescriptors: ArrayList - private val uninterpretedSortValueInterpreter: HashMap - - /** - * Uninterpreted sort values and universe are shared for whole forking hierarchy (from parent to children) - * due to shared expressions cache, - * that's why once [registerUninterpretedSortValue] and [saveUninterpretedSortValue] are called, - * each solver in hierarchy should assert newly internalized uninterpreted sort values via [assertPendingAxioms] - * - * @see KCvc5Model.uninterpretedSortUniverse - */ - private val uninterpretedSortValues: HashMap>> - - init { - if (isForking) { - uninterpretedSorts = ScopedLinkedFrame(::HashSet, ::HashSet) - declarations = ScopedLinkedFrame(::HashSet, ::HashSet) - } else { - uninterpretedSorts = ScopedArrayFrame(::HashSet) - declarations = ScopedArrayFrame(::HashSet) - } - - if (forkingSolverManager != null) { - with(forkingSolverManager) { - expressions = findExpressionsCache() - cvc5Expressions = findExpressionsReversedCache() - sorts = findSortsCache() - cvc5Sorts = findSortsReversedCache() - decls = findDeclsCache() - cvc5Decls = findDeclsReversedCache() - uninterpretedSortValueDescriptors = findUninterpretedSortsValueDescriptors() - uninterpretedSortValueInterpreter = findUninterpretedSortsValueInterpretersCache() - uninterpretedSortValues = findUninterpretedSortValues() - } - } else { - expressions = HashMap() - cvc5Expressions = TreeMap() - sorts = HashMap() - cvc5Sorts = TreeMap() - decls = HashMap() - cvc5Decls = TreeMap() - uninterpretedSortValueDescriptors = arrayListOf() - uninterpretedSortValueInterpreter = hashMapOf() - uninterpretedSortValues = hashMapOf() - } - } + protected open val cvc5Expressions = TreeMap>() + protected open val sorts = HashMap() + protected open val cvc5Sorts = TreeMap() + protected open val decls = HashMap, Term>() + protected open val cvc5Decls = TreeMap>() + + @Suppress("LeakingThis") + open val uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(this) + protected open val uninterpretedSortValueInterpreter = HashMap() + protected open val uninterpretedSortValues = + HashMap>>() fun addUninterpretedSort(sort: KUninterpretedSort) { uninterpretedSorts.currentFrame += sort @@ -143,7 +90,7 @@ class KCvc5Context internal constructor( fun addDeclaration(decl: KDecl<*>) { declarations.currentFrame += decl - uninterpretedSortCollector.collect(decl) + uninterpretedValuesTracker.collectUninterpretedSorts(decl) } fun declarations(): Set> = declarations.flatten { this += it } @@ -154,35 +101,23 @@ class KCvc5Context internal constructor( val isActive: Boolean get() = !isClosed - fun fork(solver: Solver, forkingSolverManager: KCvc5ForkingSolverManager): KCvc5Context { - require(isForking) { "Can't fork non-forking context" } - return KCvc5Context(solver, mkExprSolver, ctx, forkingSolverManager).also { forkCtx -> - forkCtx.currentAccumulatedScopeExpressions += currentAccumulatedScopeExpressions - (forkCtx.uninterpretedSorts as ScopedLinkedFrame).fork(uninterpretedSorts as ScopedLinkedFrame) - (forkCtx.declarations as ScopedLinkedFrame).fork(declarations as ScopedLinkedFrame) - - repeat(assertedConstraintLevels.size) { - forkCtx.pushAssertionLevel() - } - } - } + fun ensureActive() = check(isActive) { "The context is already closed." } fun push() { declarations.push() uninterpretedSorts.push() - - pushAssertionLevel() + uninterpretedValuesTracker.pushAssertionLevel() } fun pop(n: UInt) { declarations.pop(n) uninterpretedSorts.pop(n) - repeat(n.toInt()) { popAssertionLevel() } + repeat(n.toInt()) { uninterpretedValuesTracker.popAssertionLevel() } currentAccumulatedScopeExpressions.clear() // recreate cache restorer to avoid KNonRecursiveTransformer cache - exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(uninterpretedSortCollector, ctx) + exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(ctx) } fun findInternalizedExpr(expr: KExpr<*>): Term? = currentAccumulatedScopeExpressions[expr] @@ -196,6 +131,8 @@ class KCvc5Context internal constructor( exprCurrentLevelCacheRestorer.apply(expr) } + open fun Term.convert(converter: KCvc5ExprConverter) = with(converter) { convertExpr() } + fun findConvertedExpr(expr: Term): KExpr<*>? = cvc5Expressions[expr] fun saveInternalizedExpr(expr: KExpr<*>, internalized: Term): Term = @@ -255,35 +192,6 @@ class KCvc5Context internal constructor( return save(key, computeValue()) } - - @Suppress("ForbiddenComment") - /** - * Uninterpreted sort values distinct constraints management. - * - * 1. save/register uninterpreted value. - * See [KUninterpretedSortValue] internalization for the details. - * 2. Assert distinct constraints ([assertPendingAxioms]) that may be introduced during internalization. - * Currently, we assert constraints for all the values we have ever internalized. - * - * todo: precise uninterpreted sort values tracking - * */ - internal data class UninterpretedSortValueDescriptor( - val value: KUninterpretedSortValue, - val nativeUniqueValueDescriptor: Term, - val nativeValueTerm: Term - ) - - /** - * Uninterpreted sort value axioms will not be lost for [KCvc5ForkingSolver] on [fork]. - * - * On child initialization, "[currentValueConstraintsLevel] = 0" - * will be pushed to [assertedConstraintLevels] for each push-level ([currentValueConstraintsLevel] times). - * At the first call of [assertPendingAxioms] each descriptor from [uninterpretedSortValueDescriptors] - * will be asserted to the child [KCvc5ForkingSolver] - */ - private var currentValueConstraintsLevel = 0 - private val assertedConstraintLevels = arrayListOf() - fun saveUninterpretedSortValue(nativeValue: Term, value: KUninterpretedSortValue): Term { val sortValues = uninterpretedSortValues.getOrPut(value.sort) { arrayListOf() } sortValues += nativeValue to value @@ -301,11 +209,15 @@ class KCvc5Context internal constructor( registerUninterpretedSortValueInterpreter(value.sort, mkInterpreter()) } - registerUninterpretedSortValue(value, uniqueValueDescriptorTerm, uninterpretedValueTerm) + uninterpretedValuesTracker.registerUninterpretedSortValue( + value, + uniqueValueDescriptorTerm, + uninterpretedValueTerm + ) } fun assertPendingAxioms(solver: Solver) { - assertPendingUninterpretedValueConstraints(solver) + uninterpretedValuesTracker.assertPendingUninterpretedValueConstraints(solver) } fun getUninterpretedSortValueInterpreter(sort: KUninterpretedSort): Term? = @@ -315,49 +227,9 @@ class KCvc5Context internal constructor( uninterpretedSortValueInterpreter[sort] = interpreter } - fun registerUninterpretedSortValue( - value: KUninterpretedSortValue, - uniqueValueDescriptorTerm: Term, - uninterpretedValueTerm: Term - ) { - uninterpretedSortValueDescriptors += UninterpretedSortValueDescriptor( - value = value, - nativeUniqueValueDescriptor = uniqueValueDescriptorTerm, - nativeValueTerm = uninterpretedValueTerm - ) - } - fun getRegisteredSortValues(sort: KUninterpretedSort): List> = uninterpretedSortValues[sort] ?: emptyList() - private fun pushAssertionLevel() { - assertedConstraintLevels += currentValueConstraintsLevel - } - - private fun popAssertionLevel() { - currentValueConstraintsLevel = assertedConstraintLevels.removeLast() - } - - private fun assertPendingUninterpretedValueConstraints(solver: Solver) { - while (currentValueConstraintsLevel < uninterpretedSortValueDescriptors.size) { - assertUninterpretedSortValueConstraint( - solver, - uninterpretedSortValueDescriptors[currentValueConstraintsLevel] - ) - currentValueConstraintsLevel++ - } - } - - private fun assertUninterpretedSortValueConstraint(solver: Solver, value: UninterpretedSortValueDescriptor) { - val interpreter = uninterpretedSortValueInterpreter[value.value.sort] - ?: error("Interpreter was not registered for sort: ${value.value.sort}") - - val constraintLhs = solver.mkTerm(Kind.APPLY_UF, arrayOf(interpreter, value.nativeValueTerm)) - val constraint = constraintLhs.eqTerm(value.nativeUniqueValueDescriptor) - - solver.assertFormula(constraint) - } - private inline fun internalizeAst( cache: MutableMap, reverseCache: MutableMap, @@ -419,84 +291,42 @@ class KCvc5Context internal constructor( isClosed = true currentAccumulatedScopeExpressions.clear() - - if (!isForking) { - expressions.clear() - cvc5Expressions.clear() - sorts.clear() - cvc5Sorts.clear() - decls.clear() - cvc5Decls.clear() - uninterpretedSortValueDescriptors.clear() - uninterpretedSortValueInterpreter.clear() - uninterpretedSortValues.clear() - } + expressions.clear() + cvc5Expressions.clear() + sorts.clear() + cvc5Sorts.clear() + decls.clear() + cvc5Decls.clear() + uninterpretedSortValueInterpreter.clear() + uninterpretedSortValues.clear() } - class KUninterpretedSortCollector(private val cvc5Ctx: KCvc5Context) : KSortVisitor { - override fun visit(sort: KBoolSort) = Unit - - override fun visit(sort: KIntSort) = Unit - - override fun visit(sort: KRealSort) = Unit - - override fun visit(sort: S) = Unit - - override fun visit(sort: S) = Unit - - override fun visit(sort: KArraySort) { - sort.domain.accept(this) - sort.range.accept(this) - } - - override fun visit(sort: KArray3Sort) { - sort.domainSorts.forEach { it.accept(this) } - sort.range.accept(this) - } - - override fun visit(sort: KArray2Sort) { - sort.domainSorts.forEach { it.accept(this) } - sort.range.accept(this) - } - - override fun visit(sort: KArrayNSort) { - sort.domainSorts.forEach { it.accept(this) } - sort.range.accept(this) - } - - override fun visit(sort: KFpRoundingModeSort) = Unit - - override fun visit(sort: KUninterpretedSort) = cvc5Ctx.addUninterpretedSort(sort) - - fun collect(decl: KDecl<*>) { - decl.argSorts.map { it.accept(this) } - decl.sort.accept(this) - } + inline fun cvc5Try(body: () -> T): T = try { + ensureActive() + body() + } catch (ex: CVC5ApiException) { + throw KSolverException(ex) } - inner class KCurrentScopeExprCacheRestorer( - private val uninterpretedSortCollector: KUninterpretedSortCollector, - ctx: KContext - ) : KNonRecursiveTransformer(ctx) { - + inner class KCurrentScopeExprCacheRestorer(ctx: KContext) : KNonRecursiveTransformer(ctx) { override fun exprTransformationRequired(expr: KExpr): Boolean = expr !in currentAccumulatedScopeExpressions override fun transform(expr: KFunctionApp): KExpr = cacheIfNeed(expr) { this@KCvc5Context.addDeclaration(expr.decl) - uninterpretedSortCollector.collect(expr.decl) + uninterpretedValuesTracker.collectUninterpretedSorts(expr.decl) } override fun transform(expr: KConst): KExpr = cacheIfNeed(expr) { this@KCvc5Context.addDeclaration(expr.decl) - uninterpretedSortCollector.collect(expr.decl) + uninterpretedValuesTracker.collectUninterpretedSorts(expr.decl) saveInternalizedExprToCurrentAccumulatedScope(expr) } override fun , R : KSort> transform(expr: KFunctionAsArray): KExpr = cacheIfNeed(expr) { this@KCvc5Context.addDeclaration(expr.function) - uninterpretedSortCollector.collect(expr.function) + uninterpretedValuesTracker.collectUninterpretedSorts(expr.function) } override fun transform(expr: KArrayLambda): KExpr> = diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingContext.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingContext.kt new file mode 100644 index 000000000..dc1e8d5b7 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingContext.kt @@ -0,0 +1,75 @@ +package io.ksmt.solver.cvc5 + +import io.github.cvc5.Solver +import io.github.cvc5.Term +import io.ksmt.KContext +import io.ksmt.decl.KDecl +import io.ksmt.expr.KExpr +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort + +class KCvc5ForkingContext private constructor( + solver: Solver, + mkExprSolver: Solver, + ctx: KContext, + manager: KCvc5ForkingSolverManager, + parent: KCvc5ForkingContext? +) : KCvc5Context(solver, mkExprSolver, ctx) { + constructor(solver: Solver, mkExprSolver: Solver, ctx: KContext, manager: KCvc5ForkingSolverManager) : this( + solver, mkExprSolver, ctx, manager, null + ) + + private val uninterpretedSortsLinkedFrame = ScopedLinkedFrame>(::HashSet, ::HashSet) + private val declarationsLinkedFrame = ScopedLinkedFrame>>(::HashSet, ::HashSet) + + override val uninterpretedSorts: ScopedFrame> + get() = uninterpretedSortsLinkedFrame + override val declarations: ScopedFrame>> + get() = declarationsLinkedFrame + + override val expressions = with(manager) { getExpressionsCache() } + override val cvc5Expressions = with(manager) { getExpressionsReversedCache() } + override val sorts = with(manager) { getSortsCache() } + override val cvc5Sorts = with(manager) { getSortsReversedCache() } + override val decls = with(manager) { getDeclsCache() } + override val cvc5Decls = with(manager) { getDeclsReversedCache() } + + override val uninterpretedSortValueInterpreter = with(manager) { getUninterpretedSortsValueInterpretersCache() } + + /** + * Uninterpreted sort values and universe are shared for whole forking hierarchy (from parent to children) + * due to shared expressions cache, + * that's why once [registerUninterpretedSortValue] and [saveUninterpretedSortValue] are called, + * each solver in hierarchy should assert newly internalized uninterpreted sort values via [assertPendingAxioms] + * + * @see KCvc5Model.uninterpretedSortUniverse + */ + override val uninterpretedSortValues = with(manager) { getUninterpretedSortValues() } + + override val uninterpretedValuesTracker: ExpressionUninterpretedValuesForkingTracker = parent + ?.uninterpretedValuesTracker?.fork(this) + ?: ExpressionUninterpretedValuesForkingTracker(this) + + + init { + if (parent != null) { + currentAccumulatedScopeExpressions += parent.currentAccumulatedScopeExpressions + uninterpretedSortsLinkedFrame.fork(parent.uninterpretedSortsLinkedFrame) + declarationsLinkedFrame.fork(parent.declarationsLinkedFrame) + } + } + + fun fork(solver: Solver, forkingSolverManager: KCvc5ForkingSolverManager): KCvc5ForkingContext { + ensureActive() + return KCvc5ForkingContext(solver, mkExprSolver, ctx, forkingSolverManager, this) + } + + override fun close() { + if (isClosed) return + isClosed = true + } + + override fun Term.convert(converter: KCvc5ExprConverter): KExpr = with(converter) { + convertExprWithMkExprSolver() + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt index d49f52187..9dcecf4ab 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt @@ -1,11 +1,9 @@ package io.ksmt.solver.cvc5 -import io.github.cvc5.Solver import io.github.cvc5.Term import io.ksmt.KContext import io.ksmt.expr.KExpr import io.ksmt.solver.KForkingSolver -import io.ksmt.solver.KSolver import io.ksmt.solver.KSolverStatus import io.ksmt.sort.KBoolSort import java.util.TreeMap @@ -15,14 +13,11 @@ import kotlin.time.Duration open class KCvc5ForkingSolver internal constructor( ctx: KContext, private val manager: KCvc5ForkingSolverManager, - /** store reference on Solver to separate lifetime of native expressions */ - mkExprSolver: Solver, parent: KCvc5ForkingSolver? = null -) : KCvc5SolverBase(ctx), KForkingSolver, KSolver { - - final override val cvc5Ctx: KCvc5Context +) : KCvc5SolverBase(ctx), KForkingSolver { + final override val cvc5Ctx: KCvc5ForkingContext = manager.createCvc5ForkingContext(solver, parent?.cvc5Ctx) private val isChild = parent != null - private var assertionsInitiated = !isChild + private var assertionsInitiated = !isChild // don't need to initiate assertions for root solver private val trackedAssertions = ScopedLinkedFrame>>(::TreeMap, ::TreeMap) private val cvc5Assertions = ScopedLinkedFrame>(::TreeSet, ::TreeSet) @@ -30,22 +25,16 @@ open class KCvc5ForkingSolver internal constructor( override val currentScope: UInt get() = trackedAssertions.currentScope - init { - if (parent != null) { - cvc5Ctx = parent.cvc5Ctx.fork(solver, manager) - trackedAssertions.fork(parent.trackedAssertions) - cvc5Assertions.fork(parent.cvc5Assertions) - } else { - cvc5Ctx = KCvc5Context(solver, mkExprSolver, ctx, manager) - } - } - private val config: KCvc5ForkingSolverConfigurationImpl by lazy { parent?.config?.fork(solver) ?: KCvc5ForkingSolverConfigurationImpl(solver) } init { - if (isChild) config // initialize child config + if (parent != null) { + trackedAssertions.fork(parent.trackedAssertions) + cvc5Assertions.fork(parent.cvc5Assertions) + config // initialize child config + } } override fun configure(configurator: KCvc5SolverConfiguration.() -> Unit) { @@ -74,9 +63,9 @@ open class KCvc5ForkingSolver internal constructor( trackedAssertions.currentFrame[track] = trackedExpr } - override fun findTrackedExprByTrack(track: Term): KExpr? = trackedAssertions.find { it[track] } + override fun findTrackedExprByTrack(track: Term) = trackedAssertions.findNonNullValue { it[track] } - override fun assert(expr: KExpr): Unit = cvc5Try { + override fun assert(expr: KExpr): Unit = cvc5Ctx.cvc5Try { ctx.ensureContextMatch(expr) ensureAssertionsInitiated() @@ -87,32 +76,32 @@ open class KCvc5ForkingSolver internal constructor( } override fun assertAndTrack(expr: KExpr) { - cvc5Try { ensureAssertionsInitiated() } + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } super.assertAndTrack(expr) } override fun push() { - cvc5Try { ensureAssertionsInitiated() } + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } super.push() trackedAssertions.push() cvc5Assertions.push() } override fun pop(n: UInt) { - cvc5Try { ensureAssertionsInitiated() } + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } super.pop(n) trackedAssertions.pop(n) cvc5Assertions.pop(n) } override fun check(timeout: Duration): KSolverStatus { - cvc5Try { ensureAssertionsInitiated() } + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } cvc5Ctx.assertPendingAxioms(solver) return super.check(timeout) } override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { - cvc5Try { ensureAssertionsInitiated() } + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } cvc5Ctx.assertPendingAxioms(solver) return super.checkWithAssumptions(assumptions, timeout) } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt index 303083f2c..c0f7b300d 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt @@ -11,91 +11,80 @@ import io.ksmt.solver.KForkingSolver import io.ksmt.solver.KForkingSolverManager import io.ksmt.sort.KSort import io.ksmt.sort.KUninterpretedSort -import java.util.IdentityHashMap import java.util.TreeMap import java.util.concurrent.ConcurrentHashMap +/** + * Responsible for creation and managing of [KCvc5ForkingSolver]. + * + * It's cheaper to create multiple copies of solvers with [KCvc5ForkingSolver.fork] + * instead of assertions transferring in [KCvc5Solver] instances manually. + * + * All solvers created with one manager (via both [KCvc5ForkingSolver.fork] and [mkForkingSolver]) + * use the same [mkExprContext]*, cache, and registered uninterpreted sort values. + * + * (*) [mkExprContext] is responsible for native expressions creation for each [KCvc5ForkingSolver] + * in one [KCvc5ForkingSolverManager]. Therefore, life scope of native expressions is the same with + * life scope of [KCvc5ForkingSolverManager] + */ open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { - + private val mkExprContext by lazy { Solver() } private val solvers: MutableSet = ConcurrentHashMap.newKeySet() - /** - * for each parent-to-child hierarchy created only one mkExprSolver, - * which is responsible for native expressions lifetime - */ - private val forkingSolverToMkExprSolver = IdentityHashMap() - private val mkExprSolverReferences = IdentityHashMap() - // shared cache - private val expressionsCache = IdentityHashMap() - private val expressionsReversedCache = IdentityHashMap() - private val sortsCache = IdentityHashMap() - private val sortsReversedCache = IdentityHashMap() - private val declsCache = IdentityHashMap() - private val declsReversedCache = IdentityHashMap() - - private val uninterpretedSortValueDescriptors = IdentityHashMap() - private val uninterpretedSortValueInterpretersCache = - IdentityHashMap() - private val uninterpretedSortValues = IdentityHashMap() - - private fun Solver.ensureRegisteredAsMkExprSolver() = require(this in mkExprSolverReferences) { + private val expressionsCache = ExpressionsCache() + private val expressionsReversedCache = ExpressionsReversedCache() + private val sortsCache = SortsCache() + private val sortsReversedCache = SortsReversedCache() + private val declsCache = DeclsCache() + private val declsReversedCache = DeclsReversedCache() + private val uninterpretedSortValueInterpretersCache = UninterpretedSortValueInterpretersCache() + private val uninterpretedSortValues = UninterpretedSortValues() + + private fun Solver.ensureMkExprContextMatches() = require(this == mkExprContext) { "Solver is not registered by this manager" } - internal fun KCvc5Context.findExpressionsCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { - expressionsCache.getOrPut(mkExprSolver) { ExpressionsCache() } + internal fun KCvc5Context.getExpressionsCache() = mkExprSolver.ensureMkExprContextMatches().let { + expressionsCache } - internal fun KCvc5Context.findExpressionsReversedCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { - expressionsReversedCache.getOrPut(mkExprSolver) { ExpressionsReversedCache() } + internal fun KCvc5Context.getExpressionsReversedCache() = mkExprSolver.ensureMkExprContextMatches().let { + expressionsReversedCache } - internal fun KCvc5Context.findSortsCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { - sortsCache.getOrPut(mkExprSolver) { SortsCache() } + internal fun KCvc5Context.getSortsCache() = mkExprSolver.ensureMkExprContextMatches().let { + sortsCache } - internal fun KCvc5Context.findSortsReversedCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { - sortsReversedCache.getOrPut(mkExprSolver) { SortsReversedCache() } + internal fun KCvc5Context.getSortsReversedCache() = mkExprSolver.ensureMkExprContextMatches().let { + sortsReversedCache } - internal fun KCvc5Context.findDeclsCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { - declsCache.getOrPut(mkExprSolver) { DeclsCache() } + internal fun KCvc5Context.getDeclsCache() = mkExprSolver.ensureMkExprContextMatches().let { + declsCache } - internal fun KCvc5Context.findDeclsReversedCache() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { - declsReversedCache.getOrPut(mkExprSolver) { DeclsReversedCache() } + internal fun KCvc5Context.getDeclsReversedCache() = mkExprSolver.ensureMkExprContextMatches().let { + declsReversedCache } - internal fun KCvc5Context.findUninterpretedSortsValueDescriptors() = mkExprSolver.ensureRegisteredAsMkExprSolver() - .let { - uninterpretedSortValueDescriptors.getOrPut(mkExprSolver) { UninterpretedSortValueDescriptors() } - } + internal fun KCvc5Context.getUninterpretedSortsValueInterpretersCache() = mkExprSolver + .ensureMkExprContextMatches().let { uninterpretedSortValueInterpretersCache } - internal fun KCvc5Context.findUninterpretedSortsValueInterpretersCache() = mkExprSolver - .ensureRegisteredAsMkExprSolver().let { - uninterpretedSortValueInterpretersCache.getOrPut(mkExprSolver) { UninterpretedSortValueInterpretersCache() } - } - - internal fun KCvc5Context.findUninterpretedSortValues() = mkExprSolver.ensureRegisteredAsMkExprSolver().let { - uninterpretedSortValues.getOrPut(mkExprSolver) { UninterpretedSortValues() } + internal fun KCvc5Context.getUninterpretedSortValues() = mkExprSolver.ensureMkExprContextMatches().let { + uninterpretedSortValues } override fun mkForkingSolver(): KForkingSolver { - val mkExprSolver = Solver() - incRef(mkExprSolver) - return KCvc5ForkingSolver(ctx, this, mkExprSolver, null).also { + return KCvc5ForkingSolver(ctx, this, null).also { solvers += it - forkingSolverToMkExprSolver[it] = mkExprSolver } } internal fun mkForkingSolver(parent: KCvc5ForkingSolver): KForkingSolver { - val mkExprSolver = forkingSolverToMkExprSolver.getValue(parent) - incRef(mkExprSolver) - return KCvc5ForkingSolver(ctx, this, mkExprSolver, parent).also { + return KCvc5ForkingSolver(ctx, this, parent).also { solvers += it - forkingSolverToMkExprSolver[it] = mkExprSolver } } @@ -104,37 +93,30 @@ open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolver */ internal fun close(solver: KCvc5ForkingSolver) { solvers -= solver - val mkExprSolver = forkingSolverToMkExprSolver.getValue(solver) - forkingSolverToMkExprSolver -= solver - decRef(mkExprSolver) + closeContextIfStale() } override fun close() { solvers.forEach(KCvc5ForkingSolver::close) } - private fun incRef(mkExprSolver: Solver) { - mkExprSolverReferences[mkExprSolver] = mkExprSolverReferences.getOrDefault(mkExprSolver, 0) + 1 - } + internal fun createCvc5ForkingContext(solver: Solver, parent: KCvc5ForkingContext? = null) = parent + ?.fork(solver, this) + ?: KCvc5ForkingContext(solver, mkExprContext, ctx, this) - private fun decRef(mkExprSolver: Solver) { - val referencesAfterDec = mkExprSolverReferences.getValue(mkExprSolver) - 1 - if (referencesAfterDec == 0) { - mkExprSolverReferences -= mkExprSolver - expressionsCache -= mkExprSolver - expressionsReversedCache -= mkExprSolver - sortsCache -= mkExprSolver - sortsReversedCache -= mkExprSolver - declsCache -= mkExprSolver - declsReversedCache -= mkExprSolver - uninterpretedSortValueDescriptors -= mkExprSolver - uninterpretedSortValueInterpretersCache -= mkExprSolver - uninterpretedSortValues -= mkExprSolver - - mkExprSolver.close() - } else { - mkExprSolverReferences[mkExprSolver] = referencesAfterDec - } + private fun closeContextIfStale() { + if (solvers.isNotEmpty()) return + + expressionsCache.clear() + expressionsReversedCache.clear() + sortsCache.clear() + sortsReversedCache.clear() + declsCache.clear() + declsReversedCache.clear() + uninterpretedSortValueInterpretersCache.clear() + uninterpretedSortValues.clear() + + mkExprContext.close() } companion object { @@ -150,7 +132,6 @@ private typealias SortsCache = HashMap private typealias SortsReversedCache = TreeMap private typealias DeclsCache = HashMap, Term> private typealias DeclsReversedCache = TreeMap> -private typealias UninterpretedSortValueDescriptors = ArrayList private typealias UninterpretedSortValueInterpretersCache = HashMap @Suppress("MaxLineLength") private typealias UninterpretedSortValues = HashMap>> diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt index 2c2c28d5c..66c00d1c1 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt @@ -72,29 +72,26 @@ open class KCvc5Model( private fun funcInterp( decl: KDecl, internalizedDecl: Term - ): KFuncInterp = with(converter) { + ): KFuncInterp { val cvc5Interp = cvc5Ctx.nativeSolver.getValue(internalizedDecl) val vars = decl.argSorts.map { it.mkFreshConst("x") } val cvc5Vars = vars.map { with(internalizer) { it.internalizeExpr() } }.toTypedArray() - val cvc5InterpArgs = cvc5Interp.getChild(0).getChildren() + val cvc5InterpArgs = with(converter) { cvc5Interp.getChild(0).getChildren() } val cvc5FreshVarsInterp = cvc5Interp.substitute(cvc5InterpArgs, cvc5Vars) - // in case of forking solver, save in cache mkExprSolver's terms val defaultBody = cvc5FreshVarsInterp.getChild(1).let { - if (cvc5Ctx.isForking) it.convertExprWithMkExprSolver() else it.convertExpr() + with(cvc5Ctx) { it.convert(converter) } } - KFuncInterpWithVars(decl, vars.map { it.decl }, emptyList(), defaultBody) + return KFuncInterpWithVars(decl, vars.map { it.decl }, emptyList(), defaultBody) } - private fun constInterp(decl: KDecl, const: Term): KFuncInterp = with(converter) { + private fun constInterp(decl: KDecl, const: Term): KFuncInterp { val cvc5Interp = cvc5Ctx.nativeSolver.getValue(const) - // in case of forking solver, save in cache mkExprSolver's terms - val interp = if (cvc5Ctx.isForking) cvc5Interp.convertExprWithMkExprSolver() else cvc5Interp.convertExpr() - - KFuncInterpVarsFree(decl = decl, entries = emptyList(), default = interp) + val interp = with(cvc5Ctx) { cvc5Interp.convert(converter) } + return KFuncInterpVarsFree(decl = decl, entries = emptyList(), default = interp) } /** diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt index c2d64c3e9..69c31ad0d 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt @@ -18,7 +18,7 @@ open class KCvc5Solver(ctx: KContext) : KCvc5SolverBase(ctx), KSolver? = trackedAssertions.find { it[track] } + override fun findTrackedExprByTrack(track: Term) = trackedAssertions.findNonNullValue { it[track] } override fun push() { super.push() diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt index 7a8ce5639..07134ece9 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt @@ -9,7 +9,6 @@ import io.ksmt.expr.KApp import io.ksmt.expr.KExpr import io.ksmt.solver.KModel import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverException import io.ksmt.solver.KSolverStatus import io.ksmt.sort.KBoolSort import io.ksmt.utils.NativeLibraryLoader @@ -50,7 +49,7 @@ abstract class KCvc5SolverBase internal constructor( KCvc5SolverConfigurationImpl(solver).configurator() } - override fun assert(expr: KExpr) = cvc5Try { + override fun assert(expr: KExpr) = cvc5Ctx.cvc5Try { ctx.ensureContextMatch(expr) val cvc5Expr = with(exprInternalizer) { expr.internalizeExpr() } @@ -61,7 +60,7 @@ abstract class KCvc5SolverBase internal constructor( protected abstract fun saveTrackedAssertion(track: Term, trackedExpr: KExpr) protected abstract fun findTrackedExprByTrack(track: Term): KExpr? - override fun assertAndTrack(expr: KExpr) = cvc5Try { + override fun assertAndTrack(expr: KExpr) = cvc5Ctx.cvc5Try { ctx.ensureContextMatch(expr) val trackVarApp = createTrackVarApp() @@ -72,12 +71,12 @@ abstract class KCvc5SolverBase internal constructor( saveTrackedAssertion(cvc5TrackVar, expr) } - override fun push() = cvc5Try { + override fun push() = cvc5Ctx.cvc5Try { solver.push() cvc5Ctx.push() } - override fun pop(n: UInt) = cvc5Try { + override fun pop(n: UInt) = cvc5Ctx.cvc5Try { require(n <= currentScope) { "Can not pop $n scope levels because current scope level is $currentScope" } @@ -120,13 +119,13 @@ abstract class KCvc5SolverBase internal constructor( cvc5Ctx.uninterpretedSorts(), ) - override fun model(): KModel = cvc5Try { + override fun model(): KModel = cvc5Ctx.cvc5Try { require(lastCheckStatus == KSolverStatus.SAT) { "Models are only available after SAT checks" } val model = lastModel ?: freshModel() model.also { lastModel = it } } - override fun reasonOfUnknown(): String = cvc5Try { + override fun reasonOfUnknown(): String = cvc5Ctx.cvc5Try { require(lastCheckStatus == KSolverStatus.UNKNOWN) { "Unknown reason is only available after UNKNOWN checks" } @@ -145,7 +144,7 @@ abstract class KCvc5SolverBase internal constructor( return unsatCore } - protected fun cvc5UnsatCore(): Array = cvc5Try { + protected fun cvc5UnsatCore(): Array = cvc5Ctx.cvc5Try { require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } solver.unsatCore } @@ -179,12 +178,6 @@ abstract class KCvc5SolverBase internal constructor( setOption("tlimit-per", cvc5Timeout.toString()) } - protected inline fun cvc5Try(body: () -> T): T = try { - body() - } catch (ex: CVC5ApiException) { - throw KSolverException(ex) - } - protected inline fun cvc5TryCheck(body: () -> KSolverStatus): KSolverStatus = try { invalidateSolverState() body() diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt index 1ee768826..769327387 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt @@ -1,16 +1,11 @@ package io.ksmt.solver.cvc5 -internal interface ScopedFrame { +interface ScopedFrame { val currentScope: UInt val currentFrame: T fun flatten(collect: T.(T) -> Unit): T - /** - * find value [V] in frame [T], and return it or null - */ - fun find(predicate: (T) -> V?): V? - fun push() fun pop(n: UInt = 1u) } @@ -33,7 +28,7 @@ internal class ScopedArrayFrame( frames.forEach { newFrame.collect(it) } } - override fun find(predicate: (T) -> V?): V? { + inline fun findNonNullValue(predicate: (T) -> V?): V? { frames.forEach { frame -> predicate(frame)?.let { return it } } @@ -85,7 +80,7 @@ internal class ScopedLinkedFrame private constructor( } } - override fun find(predicate: (T) -> V?): V? { + inline fun findNonNullValue(predicate: (T) -> V?): V? { forEachReversed { frame -> predicate(frame)?.let { return it } } From c6bbd315b4eef008530d5ecf51061909a90083eb Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Wed, 30 Aug 2023 23:53:57 +0300 Subject: [PATCH 11/12] yices cache sharing in forking solver manager, added description on fork in bitwuzla --- .../solver/bitwuzla/KBitwuzlaForkingSolver.kt | 3 + .../bitwuzla/KBitwuzlaForkingSolverManager.kt | 6 + .../ksmt/solver/yices/KYicesForkingContext.kt | 22 +-- .../ksmt/solver/yices/KYicesForkingSolver.kt | 3 + .../yices/KYicesForkingSolverManager.kt | 128 +++++++++--------- 5 files changed, 87 insertions(+), 75 deletions(-) diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt index 845296cb3..21a74a5d8 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt @@ -34,6 +34,9 @@ class KBitwuzlaForkingSolver( config.configurator() } + /** + * Creates lazily initiated forked solver (without cache sharing), preserving parental assertions and configuration. + */ override fun fork(): KForkingSolver = manager.mkForkingSolver(this) private var assertionsInitiated = parent == null diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt index b941e94ec..8c92674e9 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt @@ -5,6 +5,12 @@ import io.ksmt.solver.KForkingSolver import io.ksmt.solver.KForkingSolverManager import java.util.concurrent.ConcurrentHashMap +/** + * Responsible for creation and managing of [KBitwuzlaForkingSolver]. + * + * Neither native cache is shared between [KBitwuzlaForkingSolver]s + * because cache sharing is not supported in bitwuzla. + */ class KBitwuzlaForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { private val solvers = ConcurrentHashMap.newKeySet() diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt index 750ae1831..f3b65ba0e 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt @@ -13,22 +13,22 @@ class KYicesForkingContext( manager: KYicesForkingSolverManager, solver: KYicesForkingSolver ) : KYicesContext(ctx) { - override val expressions = manager.findExpressionsCache(solver) - override val yicesExpressions = manager.findExpressionsReversedCache(solver) + override val expressions = manager.getExpressionsCache(solver) + override val yicesExpressions = manager.getExpressionsReversedCache(solver) - override val sorts = manager.findSortsCache(solver) - override val yicesSorts = manager.findSortsReversedCache(solver) + override val sorts = manager.getSortsCache(solver) + override val yicesSorts = manager.getSortsReversedCache(solver) - override val decls = manager.findDeclsCache(solver) - override val yicesDecls = manager.findDeclsReversedCache(solver) + override val decls = manager.getDeclsCache(solver) + override val yicesDecls = manager.getDeclsReversedCache(solver) - override val vars = manager.findVarsCache(solver) - override val yicesVars = manager.findVarsReversedCache(solver) + override val vars = manager.getVarsCache(solver) + override val yicesVars = manager.getVarsReversedCache(solver) - override val yicesTypes = manager.findTypesCache(solver) - override val yicesTerms = manager.findTermsCache(solver) + override val yicesTypes = manager.getTypesCache(solver) + override val yicesTerms = manager.getTermsCache(solver) - private val maxValueIndexAtomic = manager.findMaxUninterpretedSortValueIdx(solver) + private val maxValueIndexAtomic = manager.getMaxUninterpretedSortValueIdx(solver) override var maxValueIndex: Int get() = maxValueIndexAtomic.get() diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt index 0d75af7e1..14022027a 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt @@ -57,6 +57,9 @@ class KYicesForkingSolver( ksmtConfig.configurator() } + /** + * Creates lazily initiated forked solver with shared cache, preserving parental assertions and configuration. + */ override fun fork(): KForkingSolver = manager.mkForkingSolver(this) override fun saveTrackedAssertion(track: YicesTerm, trackedExpr: KExpr) { diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt index 239b124ba..9950961ee 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt @@ -18,57 +18,67 @@ import java.util.IdentityHashMap import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger +/** + * Responsible for creation and managing of [KYicesForkingSolver]. + * + * It's cheaper to create multiple copies of solvers with [KYicesForkingSolver.fork] + * instead of assertions transferring in [KYicesSolver] instances manually. + * + * All created solvers with one manager (via both [KYicesForkingSolver.fork] and [mkForkingSolver]) use the same cache. + */ class KYicesForkingSolverManager( private val ctx: KContext ) : KForkingSolverManager { private val solvers = ConcurrentHashMap.newKeySet() - private val sharedCacheReferences = IdentityHashMap() - - private val expressionsCache = IdentityHashMap() - private val expressionsReversedCache = IdentityHashMap() - private val sortsCache = IdentityHashMap() - private val sortsReversedCache = IdentityHashMap() - private val declsCache = IdentityHashMap() - private val declsReversedCache = IdentityHashMap() - private val varsCache = IdentityHashMap() - private val varsReversedCache = IdentityHashMap() - private val typesCache = IdentityHashMap() - private val termsCache = IdentityHashMap() + + private fun ensureSolverRegistered(s: KYicesForkingSolver) = check(s in solvers) { + "Solver is not registered by the manager." + } + + private val expressionsCache = ExpressionsCache().withNotInternalizedDefaultValue() + private val expressionsReversedCache = ExpressionsReversedCache() + private val sortsCache = SortsCache().withNotInternalizedDefaultValue() + private val sortsReversedCache = SortsReversedCache() + private val declsCache = DeclsCache().withNotInternalizedDefaultValue() + private val declsReversedCache = DeclsReversedCache() + private val varsCache = VarsCache().withNotInternalizedDefaultValue() + private val varsReversedCache = VarsReversedCache() + private val typesCache = TypesCache() + private val termsCache = TermsCache() private val maxUninterpretedSortValueIndex = IdentityHashMap() private val scopedExpressions = IdentityHashMap() private val scopedUninterpretedValues = IdentityHashMap() private val expressionLevels = IdentityHashMap() - internal fun findExpressionsCache(s: KYicesForkingSolver): ExpressionsCache = expressionsCache.getValue(s) - internal fun findExpressionsReversedCache(s: KYicesForkingSolver): ExpressionsReversedCache = - expressionsReversedCache.getValue(s) - - internal fun findSortsCache(s: KYicesForkingSolver): SortsCache = sortsCache.getValue(s) - internal fun findSortsReversedCache(s: KYicesForkingSolver): SortsReversedCache = sortsReversedCache.getValue(s) - internal fun findDeclsCache(s: KYicesForkingSolver): DeclsCache = declsCache.getValue(s) - internal fun findDeclsReversedCache(s: KYicesForkingSolver): DeclsReversedCache = declsReversedCache.getValue(s) - internal fun findVarsCache(s: KYicesForkingSolver): VarsCache = varsCache.getValue(s) - internal fun findVarsReversedCache(s: KYicesForkingSolver): VarsReversedCache = varsReversedCache.getValue(s) - internal fun findTypesCache(s: KYicesForkingSolver): TypesCache = typesCache.getValue(s) - internal fun findTermsCache(s: KYicesForkingSolver): TermsCache = termsCache.getValue(s) - internal fun findMaxUninterpretedSortValueIdx(s: KYicesForkingSolver) = maxUninterpretedSortValueIndex.getValue(s) + internal fun getExpressionsCache(s: KYicesForkingSolver): ExpressionsCache = ensureSolverRegistered(s).let { + expressionsCache + } + internal fun getExpressionsReversedCache(s: KYicesForkingSolver) = ensureSolverRegistered(s).let { + expressionsReversedCache + } + internal fun getSortsCache(s: KYicesForkingSolver): SortsCache = ensureSolverRegistered(s).let { sortsCache } + internal fun getSortsReversedCache(s: KYicesForkingSolver): SortsReversedCache = ensureSolverRegistered(s).let { + sortsReversedCache + } + internal fun getDeclsCache(s: KYicesForkingSolver): DeclsCache = ensureSolverRegistered(s).let { declsCache } + internal fun getDeclsReversedCache(s: KYicesForkingSolver): DeclsReversedCache = ensureSolverRegistered(s).let { + declsReversedCache + } + internal fun getVarsCache(s: KYicesForkingSolver): VarsCache = ensureSolverRegistered(s).let { varsCache } + internal fun getVarsReversedCache(s: KYicesForkingSolver): VarsReversedCache = ensureSolverRegistered(s).let { + varsReversedCache + } + internal fun getTypesCache(s: KYicesForkingSolver): TypesCache = ensureSolverRegistered(s).let { typesCache } + internal fun getTermsCache(s: KYicesForkingSolver): TermsCache = ensureSolverRegistered(s).let { termsCache } + internal fun getMaxUninterpretedSortValueIdx(s: KYicesForkingSolver) = ensureSolverRegistered(s).let { + maxUninterpretedSortValueIndex.getValue(s) + } override fun mkForkingSolver(): KForkingSolver = KYicesForkingSolver(ctx, this, null).also { solvers += it - sharedCacheReferences[it] = AtomicInteger(1) - expressionsCache[it] = ExpressionsCache().withNotInternalizedDefaultValue() - expressionsReversedCache[it] = ExpressionsReversedCache() - sortsCache[it] = SortsCache().withNotInternalizedDefaultValue() - sortsReversedCache[it] = SortsReversedCache() - declsCache[it] = DeclsCache().withNotInternalizedDefaultValue() - declsReversedCache[it] = DeclsReversedCache() - varsCache[it] = VarsCache().withNotInternalizedDefaultValue() - varsReversedCache[it] = VarsReversedCache() - typesCache[it] = TypesCache() - termsCache[it] = TermsCache() maxUninterpretedSortValueIndex[it] = AtomicInteger(0) scopedExpressions[it] = ScopedExpressions(::HashSet, ::HashSet) scopedUninterpretedValues[it] = ScopedUninterpretedSortValues(::HashMap, ::HashMap) @@ -77,17 +87,6 @@ class KYicesForkingSolverManager( internal fun mkForkingSolver(parent: KYicesForkingSolver) = KYicesForkingSolver(ctx, this, parent).also { solvers += it - sharedCacheReferences[it] = sharedCacheReferences.getValue(parent).apply { incrementAndGet() } - expressionsCache[it] = expressionsCache[parent] - expressionsReversedCache[it] = expressionsReversedCache[parent] - sortsCache[it] = sortsCache[parent] - sortsReversedCache[it] = sortsReversedCache[parent] - declsCache[it] = declsCache[parent] - declsReversedCache[it] = declsReversedCache[parent] - varsCache[it] = varsCache[parent] - varsReversedCache[it] = varsReversedCache[parent] - typesCache[it] = typesCache[parent] - termsCache[it] = termsCache[parent] scopedExpressions[it] = ScopedExpressions(::HashSet, ::HashSet) .apply { fork(scopedExpressions.getValue(parent)) } scopedUninterpretedValues[it] = ScopedUninterpretedSortValues(::HashMap, ::HashMap) @@ -118,23 +117,24 @@ class KYicesForkingSolverManager( } private fun decRef(solver: KYicesForkingSolver) { - val referencesAfterDec = sharedCacheReferences.getValue(solver).decrementAndGet() - if (referencesAfterDec == 0) { - sharedCacheReferences -= solver - expressionsCache -= solver - expressionsReversedCache -= solver - sortsCache -= solver - sortsReversedCache -= solver - declsCache -= solver - declsReversedCache -= solver - varsCache -= solver - varsReversedCache -= solver - typesCache.remove(solver)?.forEach(Yices::yicesDecrefType) - termsCache.remove(solver)?.forEach(Yices::yicesDecrefTerm) - maxUninterpretedSortValueIndex -= solver - scopedExpressions -= solver - scopedUninterpretedValues -= solver - expressionLevels -= solver + scopedExpressions -= solver + scopedUninterpretedValues -= solver + maxUninterpretedSortValueIndex -= solver + expressionLevels -= solver + + if (solvers.isEmpty()) { + expressionsCache.clear() + expressionsReversedCache.clear() + sortsCache.clear() + sortsReversedCache.clear() + declsCache.clear() + declsReversedCache.clear() + varsCache.clear() + varsReversedCache.clear() + typesCache.forEach(Yices::yicesDecrefType) + termsCache.forEach(Yices::yicesDecrefTerm) + typesCache.clear() + termsCache.clear() } } From 6284a01302c0c1cfc5b5bbc544c9c2e121974a5e Mon Sep 17 00:00:00 2001 From: Dmitriy Sokolov Date: Thu, 31 Aug 2023 00:05:03 +0300 Subject: [PATCH 12/12] rename: mkForkingSolver -> createForkingSolver --- .../solver/bitwuzla/KBitwuzlaForkingSolver.kt | 2 +- .../bitwuzla/KBitwuzlaForkingSolverManager.kt | 7 ++-- .../io/ksmt/solver/KForkingSolverManager.kt | 2 +- .../io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt | 2 +- .../solver/cvc5/KCvc5ForkingSolverManager.kt | 6 ++-- .../kotlin/io/ksmt/test/KForkingSolverTest.kt | 36 +++++++++---------- .../ksmt/solver/yices/KYicesForkingSolver.kt | 2 +- .../yices/KYicesForkingSolverManager.kt | 7 ++-- .../io/ksmt/solver/z3/KZ3ForkingSolver.kt | 2 +- .../ksmt/solver/z3/KZ3ForkingSolverManager.kt | 6 ++-- 10 files changed, 36 insertions(+), 36 deletions(-) diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt index 21a74a5d8..7be9fe378 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt @@ -37,7 +37,7 @@ class KBitwuzlaForkingSolver( /** * Creates lazily initiated forked solver (without cache sharing), preserving parental assertions and configuration. */ - override fun fork(): KForkingSolver = manager.mkForkingSolver(this) + override fun fork(): KForkingSolver = manager.createForkingSolver(this) private var assertionsInitiated = parent == null diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt index 8c92674e9..64dd3145a 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt @@ -14,15 +14,14 @@ import java.util.concurrent.ConcurrentHashMap class KBitwuzlaForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { private val solvers = ConcurrentHashMap.newKeySet() - override fun mkForkingSolver(): KForkingSolver { + override fun createForkingSolver(): KForkingSolver { return KBitwuzlaForkingSolver(ctx, this, null).also { solvers += it } } - internal fun mkForkingSolver(parent: KBitwuzlaForkingSolver) = KBitwuzlaForkingSolver(ctx, this, parent).also { - solvers += it - } + internal fun createForkingSolver(parent: KBitwuzlaForkingSolver) = KBitwuzlaForkingSolver(ctx, this, parent) + .also { solvers += it } internal fun close(solver: KBitwuzlaForkingSolver) { solvers -= solver diff --git a/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt index 7df85068a..34c63ff11 100644 --- a/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt +++ b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt @@ -7,7 +7,7 @@ package io.ksmt.solver */ interface KForkingSolverManager : AutoCloseable { - fun mkForkingSolver(): KForkingSolver + fun createForkingSolver(): KForkingSolver /** * Closes the manager and all opened solvers ([KForkingSolver]) managed by this diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt index 9dcecf4ab..d0bb588fb 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt @@ -41,7 +41,7 @@ open class KCvc5ForkingSolver internal constructor( config.configurator() } - override fun fork(): KForkingSolver = manager.mkForkingSolver(this) + override fun fork(): KForkingSolver = manager.createForkingSolver(this) private fun ensureAssertionsInitiated() { if (assertionsInitiated) return diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt index c0f7b300d..d6d9f6406 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt @@ -20,7 +20,7 @@ import java.util.concurrent.ConcurrentHashMap * It's cheaper to create multiple copies of solvers with [KCvc5ForkingSolver.fork] * instead of assertions transferring in [KCvc5Solver] instances manually. * - * All solvers created with one manager (via both [KCvc5ForkingSolver.fork] and [mkForkingSolver]) + * All solvers created with one manager (via both [KCvc5ForkingSolver.fork] and [createForkingSolver]) * use the same [mkExprContext]*, cache, and registered uninterpreted sort values. * * (*) [mkExprContext] is responsible for native expressions creation for each [KCvc5ForkingSolver] @@ -76,13 +76,13 @@ open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolver uninterpretedSortValues } - override fun mkForkingSolver(): KForkingSolver { + override fun createForkingSolver(): KForkingSolver { return KCvc5ForkingSolver(ctx, this, null).also { solvers += it } } - internal fun mkForkingSolver(parent: KCvc5ForkingSolver): KForkingSolver { + internal fun createForkingSolver(parent: KCvc5ForkingSolver): KForkingSolver { return KCvc5ForkingSolver(ctx, this, parent).also { solvers += it } diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt index edf9c9361..189bf7245 100644 --- a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt @@ -109,10 +109,10 @@ class KForkingSolverTest { private fun mkZ3ForkingSolverManager(ctx: KContext) = KZ3ForkingSolverManager(ctx) } - private fun testCheckSat(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>) = + private fun testCheckSat(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>) = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkForkingSolverManager(ctx).use { man -> - man.mkForkingSolver().use { parentSolver -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parentSolver -> with(ctx) { val a by boolSort val b by boolSort @@ -166,10 +166,10 @@ class KForkingSolverTest { } } - private fun testUnsatCore(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + private fun testUnsatCore(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkForkingSolverManager(ctx).use { man -> - man.mkForkingSolver().use { parentSolver -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parentSolver -> with(ctx) { val a by boolSort val b by boolSort @@ -219,10 +219,10 @@ class KForkingSolverTest { } } - private fun testModel(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + private fun testModel(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkForkingSolverManager(ctx).use { man -> - man.mkForkingSolver().use { parentSolver -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parentSolver -> with(ctx) { val a by boolSort val b by boolSort @@ -244,10 +244,10 @@ class KForkingSolverTest { } } - private fun testScopedAssertions(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + private fun testScopedAssertions(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkForkingSolverManager(ctx).use { man -> - man.mkForkingSolver().use { parent -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parent -> with(ctx) { val a by boolSort val b by boolSort @@ -315,10 +315,10 @@ class KForkingSolverTest { } @Suppress("LongMethod") - private fun testUninterpretedSort(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + private fun testUninterpretedSort(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkForkingSolverManager(ctx).use { man -> - man.mkForkingSolver().use { parentSolver -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parentSolver -> with(ctx) { val uSort = mkUninterpretedSort("u") val u1 by uSort @@ -395,11 +395,11 @@ class KForkingSolverTest { } } - fun testLifeTime(mkForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + fun testLifeTime(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> - mkForkingSolverManager(ctx).use { man -> + createForkingSolverManager(ctx).use { man -> with(ctx) { - val parent = man.mkForkingSolver() + val parent = man.createForkingSolver() val x by bv8Sort val f = mkBvSignedGreaterExpr(x, mkBv(100, bv8Sort)) diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt index 14022027a..fd94e4799 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt @@ -60,7 +60,7 @@ class KYicesForkingSolver( /** * Creates lazily initiated forked solver with shared cache, preserving parental assertions and configuration. */ - override fun fork(): KForkingSolver = manager.mkForkingSolver(this) + override fun fork(): KForkingSolver = manager.createForkingSolver(this) override fun saveTrackedAssertion(track: YicesTerm, trackedExpr: KExpr) { trackedAssertions.currentFrame += trackedExpr to track diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt index 9950961ee..098cc6601 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt @@ -24,7 +24,8 @@ import java.util.concurrent.atomic.AtomicInteger * It's cheaper to create multiple copies of solvers with [KYicesForkingSolver.fork] * instead of assertions transferring in [KYicesSolver] instances manually. * - * All created solvers with one manager (via both [KYicesForkingSolver.fork] and [mkForkingSolver]) use the same cache. + * All solvers created with one manager (via both [KYicesForkingSolver.fork] and [createForkingSolver]) + * use the same cache. */ class KYicesForkingSolverManager( private val ctx: KContext @@ -76,7 +77,7 @@ class KYicesForkingSolverManager( maxUninterpretedSortValueIndex.getValue(s) } - override fun mkForkingSolver(): KForkingSolver = + override fun createForkingSolver(): KForkingSolver = KYicesForkingSolver(ctx, this, null).also { solvers += it maxUninterpretedSortValueIndex[it] = AtomicInteger(0) @@ -85,7 +86,7 @@ class KYicesForkingSolverManager( expressionLevels[it] = ExpressionLevels() } - internal fun mkForkingSolver(parent: KYicesForkingSolver) = KYicesForkingSolver(ctx, this, parent).also { + internal fun createForkingSolver(parent: KYicesForkingSolver) = KYicesForkingSolver(ctx, this, parent).also { solvers += it scopedExpressions[it] = ScopedExpressions(::HashSet, ::HashSet) .apply { fork(scopedExpressions.getValue(parent)) } diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt index e4dd2560d..82592bb72 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt @@ -51,7 +51,7 @@ open class KZ3ForkingSolver internal constructor( /** * Creates lazily initiated forked solver with shared cache, preserving parental assertions and configuration. */ - override fun fork(): KForkingSolver = manager.mkForkingSolver(this) + override fun fork(): KForkingSolver = manager.createForkingSolver(this) override fun saveTrackedAssertion(track: Long, trackedExpr: KExpr) { trackedAssertions.currentFrame[track] = trackedExpr diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt index 0711873d3..a3c2a6406 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt @@ -26,7 +26,7 @@ import java.util.concurrent.ConcurrentHashMap * It's cheaper to create multiple copies of solvers with [KZ3ForkingSolver.fork] * instead of assertions transferring in [KZ3Solver] instances. * - * All created solvers with one manager (via both [KZ3ForkingSolver.fork] and [mkForkingSolver]) + * All created solvers with one manager (via both [KZ3ForkingSolver.fork] and [createForkingSolver]) * use the same [Context], cache, and registered uninterpreted sort values. */ class KZ3ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { @@ -71,11 +71,11 @@ class KZ3ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager internal fun KZ3Context.getUninterpretedSortValueInterpreters() = ensureContextMatches(nativeContext) .let { uninterpretedSortValueInterpreters } - override fun mkForkingSolver(): KForkingSolver { + override fun createForkingSolver(): KForkingSolver { return KZ3ForkingSolver(ctx, this, null).also { solvers += it } } - internal fun mkForkingSolver(parent: KZ3ForkingSolver): KForkingSolver { + internal fun createForkingSolver(parent: KZ3ForkingSolver): KForkingSolver { return KZ3ForkingSolver(ctx, this, parent).also { solvers += it } }