diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaContext.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaContext.kt index 5fb618e35..5b845795c 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaContext.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaContext.kt @@ -1,8 +1,5 @@ package io.ksmt.solver.bitwuzla -import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap -import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap -import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap import io.ksmt.KContext import io.ksmt.decl.KDecl import io.ksmt.decl.KFuncDecl @@ -15,13 +12,10 @@ import io.ksmt.expr.KExistentialQuantifier import io.ksmt.expr.KExpr import io.ksmt.expr.KFunctionApp import io.ksmt.expr.KFunctionAsArray +import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.expr.KUniversalQuantifier import io.ksmt.expr.transformer.KNonRecursiveTransformer import io.ksmt.solver.KSolverException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm -import org.ksmt.solver.bitwuzla.bindings.Native import io.ksmt.solver.util.KExprLongInternalizerBase.Companion.NOT_INTERNALIZED import io.ksmt.sort.KArray2Sort import io.ksmt.sort.KArray3Sort @@ -37,6 +31,13 @@ import io.ksmt.sort.KRealSort import io.ksmt.sort.KSort import io.ksmt.sort.KSortVisitor import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap +import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm +import org.ksmt.solver.bitwuzla.bindings.Native open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable { private var isClosed = false @@ -433,6 +434,11 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable { return super.transform(expr) } + override fun transform(expr: KUninterpretedSortValue): KExpr { + registerDeclIfNotIgnored(expr.decl) + return super.transform(expr) + } + private val quantifiedVarsScopeOwner = arrayListOf>() private val quantifiedVarsScope = arrayListOf>?>() @@ -474,7 +480,7 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable { override fun transform(expr: KExistentialQuantifier): KExpr = expr.transformQuantifier(expr.bounds, expr.body) - override fun transform(expr: KUniversalQuantifier): KExpr = + override fun transform(expr: KUniversalQuantifier): KExpr = expr.transformQuantifier(expr.bounds, expr.body) } diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt index 5938f1e19..2aff66301 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaExprInternalizer.kt @@ -151,17 +151,11 @@ import io.ksmt.expr.KUnaryMinusArithExpr import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.expr.KUniversalQuantifier import io.ksmt.expr.KXorExpr -import io.ksmt.expr.rewrite.simplify.rewriteBvAddNoUnderflowExpr -import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoUnderflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvNegNoOverflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvSubNoUnderflowExpr import io.ksmt.solver.KSolverUnsupportedFeatureException -import org.ksmt.solver.bitwuzla.bindings.Bitwuzla -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaKind -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaRoundingMode -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm -import org.ksmt.solver.bitwuzla.bindings.Native +import io.ksmt.solver.bitwuzla.KBitwuzlaExprInternalizer.BvOverflowCheckMode.OVERFLOW +import io.ksmt.solver.bitwuzla.KBitwuzlaExprInternalizer.BvOverflowCheckMode.UNDERFLOW import io.ksmt.solver.util.KExprLongInternalizerBase import io.ksmt.sort.KArithSort import io.ksmt.sort.KArray2Sort @@ -186,7 +180,13 @@ import io.ksmt.sort.KRealSort import io.ksmt.sort.KSort import io.ksmt.sort.KSortVisitor import io.ksmt.sort.KUninterpretedSort +import org.ksmt.solver.bitwuzla.bindings.Bitwuzla +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaKind +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaRoundingMode +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTermArray +import org.ksmt.solver.bitwuzla.bindings.Native import java.math.BigInteger @Suppress("LargeClass") @@ -726,7 +726,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL override fun transform(expr: KBvAddNoOverflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> if (isSigned) { - mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW) + mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW) } else { val overflowCheck = Native.bitwuzlaMkTerm2( bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_UADD_OVERFLOW, a0, a1 @@ -738,20 +738,20 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL override fun transform(expr: KBvAddNoUnderflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> - mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW) + mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW) } } override fun transform(expr: KBvSubNoOverflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> - mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW) + mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW) } } override fun transform(expr: KBvSubNoUnderflowExpr) = with(expr) { if (isSigned) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> - mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW) + mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW) } } else { transform { @@ -776,7 +776,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL override fun transform(expr: KBvMulNoOverflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> if (isSigned) { - mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW) + mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW) } else { val overflowCheck = Native.bitwuzlaMkTerm2( bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_UMUL_OVERFLOW, a0, a1 @@ -788,7 +788,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL override fun transform(expr: KBvMulNoUnderflowExpr) = with(expr) { transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm -> - mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW) + mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW) } } @@ -813,7 +813,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL a1, BitwuzlaKind.BITWUZLA_KIND_BV_SADD_OVERFLOW ) { a0Sign, a1Sign -> - if (mode == BvOverflowCheckMode.OVERFLOW) { + if (mode == OVERFLOW) { // Both positive mkAndTerm(longArrayOf(mkNotTerm(a0Sign), mkNotTerm(a1Sign))) } else { @@ -833,7 +833,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL a1, BitwuzlaKind.BITWUZLA_KIND_BV_SSUB_OVERFLOW ) { a0Sign, a1Sign -> - if (mode == BvOverflowCheckMode.OVERFLOW) { + if (mode == OVERFLOW) { // Positive sub negative mkAndTerm(longArrayOf(mkNotTerm(a0Sign), a1Sign)) } else { @@ -853,7 +853,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL a1, BitwuzlaKind.BITWUZLA_KIND_BV_SMUL_OVERFLOW ) { a0Sign, a1Sign -> - if (mode == BvOverflowCheckMode.OVERFLOW) { + if (mode == OVERFLOW) { // Overflow is possible when sign bits are equal mkEqTerm(bitwuzlaCtx.ctx.boolSort, a0Sign, a1Sign) } else { @@ -1401,6 +1401,8 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL } override fun transform(expr: KUninterpretedSortValue): KExpr = expr.transform { + // register it for uninterpreted sort universe + bitwuzlaCtx.registerDeclaration(expr.decl) Native.bitwuzlaMkBvValueUint32( bitwuzla, expr.sort.internalizeSort(), diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt new file mode 100644 index 000000000..7be9fe378 --- /dev/null +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolver.kt @@ -0,0 +1,109 @@ +package io.ksmt.solver.bitwuzla + +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import kotlin.time.Duration + +class KBitwuzlaForkingSolver( + private val ctx: KContext, + private val manager: KBitwuzlaForkingSolverManager, + parent: KBitwuzlaForkingSolver? +) : KBitwuzlaSolverBase(ctx), + KForkingSolver { + + private val assertions = ScopedLinkedFrame>>(::ArrayList, ::ArrayList) + private val trackToExprFrames = + ScopedLinkedFrame, KExpr>>>(::ArrayList, ::ArrayList) + + private val config: KBitwuzlaForkingSolverConfigurationImpl + + init { + if (parent != null) { + config = parent.config.fork(bitwuzlaCtx.bitwuzla) + assertions.fork(parent.assertions) + trackToExprFrames.fork(parent.trackToExprFrames) + } else { + config = KBitwuzlaForkingSolverConfigurationImpl(bitwuzlaCtx.bitwuzla) + } + } + + override fun configure(configurator: KBitwuzlaSolverConfiguration.() -> Unit) { + config.configurator() + } + + /** + * Creates lazily initiated forked solver (without cache sharing), preserving parental assertions and configuration. + */ + override fun fork(): KForkingSolver = manager.createForkingSolver(this) + + private var assertionsInitiated = parent == null + + private fun ensureAssertionsInitiated() { + if (assertionsInitiated) return + + assertions.stacked().zip(trackToExprFrames.stacked()) + .asReversed() + .forEachIndexed { scope, (assertionsFrame, trackedExprsFrame) -> + if (scope > 0) super.push() + + assertionsFrame.forEach { assertion -> + internalizeAndAssertWithAxioms(assertion) + } + + trackedExprsFrame.forEach { (track, trackedExpr) -> + super.registerTrackForExpr(trackedExpr, track) + } + } + assertionsInitiated = true + } + + override fun assert(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { + ctx.ensureContextMatch(expr) + ensureAssertionsInitiated() + + internalizeAndAssertWithAxioms(expr) + assertions.currentFrame += expr + } + + override fun assertAndTrack(expr: KExpr) { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + super.assertAndTrack(expr) + } + + override fun registerTrackForExpr(expr: KExpr, track: KExpr) { + super.registerTrackForExpr(expr, track) + trackToExprFrames.currentFrame += track to expr + } + + override fun push() { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + super.push() + assertions.push() + trackToExprFrames.push() + } + + override fun pop(n: UInt) { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + super.pop(n) + assertions.pop(n) + trackToExprFrames.pop(n) + } + + override fun check(timeout: Duration): KSolverStatus { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + return super.check(timeout) + } + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { + bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() } + return super.checkWithAssumptions(assumptions, timeout) + } + + override fun close() { + super.close() + manager.close(this) + } +} diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt new file mode 100644 index 000000000..64dd3145a --- /dev/null +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaForkingSolverManager.kt @@ -0,0 +1,33 @@ +package io.ksmt.solver.bitwuzla + +import io.ksmt.KContext +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KForkingSolverManager +import java.util.concurrent.ConcurrentHashMap + +/** + * Responsible for creation and managing of [KBitwuzlaForkingSolver]. + * + * Neither native cache is shared between [KBitwuzlaForkingSolver]s + * because cache sharing is not supported in bitwuzla. + */ +class KBitwuzlaForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { + private val solvers = ConcurrentHashMap.newKeySet() + + override fun createForkingSolver(): KForkingSolver { + return KBitwuzlaForkingSolver(ctx, this, null).also { + solvers += it + } + } + + internal fun createForkingSolver(parent: KBitwuzlaForkingSolver) = KBitwuzlaForkingSolver(ctx, this, parent) + .also { solvers += it } + + internal fun close(solver: KBitwuzlaForkingSolver) { + solvers -= solver + } + + override fun close() { + solvers.forEach(KBitwuzlaForkingSolver::close) + } +} diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaModel.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaModel.kt index f36e43c99..f4520736b 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaModel.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaModel.kt @@ -2,18 +2,15 @@ package io.ksmt.solver.bitwuzla import io.ksmt.KContext import io.ksmt.decl.KDecl +import io.ksmt.decl.KUninterpretedSortValueDecl import io.ksmt.expr.KExpr import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolverUnsupportedFeatureException import io.ksmt.solver.model.KFuncInterp import io.ksmt.solver.model.KFuncInterpEntryVarsFree import io.ksmt.solver.model.KFuncInterpEntryVarsFreeOneAry import io.ksmt.solver.model.KFuncInterpVarsFree -import io.ksmt.solver.KModel -import io.ksmt.solver.KSolverUnsupportedFeatureException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm -import org.ksmt.solver.bitwuzla.bindings.FunValue -import org.ksmt.solver.bitwuzla.bindings.Native import io.ksmt.solver.model.KFuncInterpWithVars import io.ksmt.solver.model.KModelEvaluator import io.ksmt.solver.model.KModelImpl @@ -23,6 +20,10 @@ import io.ksmt.sort.KSort import io.ksmt.sort.KUninterpretedSort import io.ksmt.utils.mkFreshConstDecl import io.ksmt.utils.uncheckedCast +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm +import org.ksmt.solver.bitwuzla.bindings.FunValue +import org.ksmt.solver.bitwuzla.bindings.Native open class KBitwuzlaModel( private val ctx: KContext, @@ -77,7 +78,12 @@ open class KBitwuzlaModel( * to ensure that [uninterpretedSortValueContext] contains * all possible values for the given sort. * */ - sortDependency.forEach { interpretation(it) } + sortDependency.forEach { + if (it is KUninterpretedSortValueDecl) { + val value = ctx.mkUninterpretedSortValue(it.sort, it.valueIdx) + uninterpretedSortValueContext.registerValue(value) + } else interpretation(it) + } uninterpretedSortValueContext.currentSortUniverse(sort) } diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt index 5d86b25b9..30c0b0f71 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolver.kt @@ -1,199 +1,10 @@ package io.ksmt.solver.bitwuzla -import it.unimi.dsi.fastutil.longs.LongOpenHashSet import io.ksmt.KContext -import io.ksmt.expr.KExpr -import io.ksmt.solver.KModel -import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverStatus -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaOption -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaResult -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm -import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTermArray -import org.ksmt.solver.bitwuzla.bindings.Native -import io.ksmt.sort.KBoolSort -import kotlin.time.Duration -open class KBitwuzlaSolver(private val ctx: KContext) : KSolver { - open val bitwuzlaCtx = KBitwuzlaContext(ctx) - open val exprInternalizer: KBitwuzlaExprInternalizer by lazy { - KBitwuzlaExprInternalizer(bitwuzlaCtx) - } - open val exprConverter: KBitwuzlaExprConverter by lazy { - KBitwuzlaExprConverter(ctx, bitwuzlaCtx) - } - - private var lastCheckStatus = KSolverStatus.UNKNOWN - private var lastReasonOfUnknown: String? = null - private var lastAssumptions: TrackedAssumptions? = null - private var lastModel: KBitwuzlaModel? = null - - init { - Native.bitwuzlaSetOption(bitwuzlaCtx.bitwuzla, BitwuzlaOption.BITWUZLA_OPT_INCREMENTAL, value = 1) - Native.bitwuzlaSetOption(bitwuzlaCtx.bitwuzla, BitwuzlaOption.BITWUZLA_OPT_PRODUCE_MODELS, value = 1) - } - - private var trackedAssertions = mutableListOf, BitwuzlaTerm>>() - private val trackVarsAssertionFrames = arrayListOf(trackedAssertions) +open class KBitwuzlaSolver(ctx: KContext) : KBitwuzlaSolverBase(ctx) { override fun configure(configurator: KBitwuzlaSolverConfiguration.() -> Unit) { KBitwuzlaSolverConfigurationImpl(bitwuzlaCtx.bitwuzla).configurator() } - - override fun assert(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { - ctx.ensureContextMatch(expr) - - val assertionWithAxioms = with(exprInternalizer) { expr.internalizeAssertion() } - - assertionWithAxioms.axioms.forEach { - Native.bitwuzlaAssert(bitwuzlaCtx.bitwuzla, it) - } - Native.bitwuzlaAssert(bitwuzlaCtx.bitwuzla, assertionWithAxioms.assertion) - } - - override fun assertAndTrack(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { - ctx.ensureContextMatch(expr) - - val trackVarExpr = ctx.mkFreshConst("track", ctx.boolSort) - val trackedExpr = with(ctx) { !trackVarExpr or expr } - - assert(trackedExpr) - - val trackVarTerm = with(exprInternalizer) { trackVarExpr.internalize() } - trackedAssertions += expr to trackVarTerm - } - - override fun push(): Unit = bitwuzlaCtx.bitwuzlaTry { - Native.bitwuzlaPush(bitwuzlaCtx.bitwuzla, nlevels = 1) - - trackedAssertions = trackedAssertions.toMutableList() - trackVarsAssertionFrames.add(trackedAssertions) - - bitwuzlaCtx.createNestedDeclarationScope() - } - - override fun pop(n: UInt): Unit = bitwuzlaCtx.bitwuzlaTry { - val currentLevel = trackVarsAssertionFrames.lastIndex.toUInt() - require(n <= currentLevel) { - "Cannot pop $n scope levels because current scope level is $currentLevel" - } - - if (n == 0u) return - - repeat(n.toInt()) { - trackVarsAssertionFrames.removeLast() - bitwuzlaCtx.popDeclarationScope() - } - - trackedAssertions = trackVarsAssertionFrames.last() - - Native.bitwuzlaPop(bitwuzlaCtx.bitwuzla, n.toInt()) - } - - override fun check(timeout: Duration): KSolverStatus = - checkWithAssumptions(emptyList(), timeout) - - override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus = - bitwuzlaTryCheck { - ctx.ensureContextMatch(assumptions) - - val currentAssumptions = TrackedAssumptions().also { lastAssumptions = it } - - trackedAssertions.forEach { - currentAssumptions.assumeTrackedAssertion(it) - } - - with(exprInternalizer) { - assumptions.forEach { - currentAssumptions.assumeAssumption(it, it.internalize()) - } - } - - checkWithTimeout(timeout).processCheckResult() - } - - private fun checkWithTimeout(timeout: Duration): BitwuzlaResult = if (timeout.isInfinite()) { - Native.bitwuzlaCheckSatResult(bitwuzlaCtx.bitwuzla) - } else { - Native.bitwuzlaCheckSatTimeoutResult(bitwuzlaCtx.bitwuzla, timeout.inWholeMilliseconds) - } - - override fun model(): KModel = bitwuzlaCtx.bitwuzlaTry { - require(lastCheckStatus == KSolverStatus.SAT) { "Model are only available after SAT checks" } - val model = lastModel ?: KBitwuzlaModel( - ctx, bitwuzlaCtx, exprConverter, - bitwuzlaCtx.declarations(), - bitwuzlaCtx.uninterpretedSortsWithRelevantDecls() - ) - lastModel = model - model - } - - override fun unsatCore(): List> = bitwuzlaCtx.bitwuzlaTry { - require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } - val unsatAssumptions = Native.bitwuzlaGetUnsatAssumptions(bitwuzlaCtx.bitwuzla) - lastAssumptions?.resolveUnsatCore(unsatAssumptions) ?: emptyList() - } - - override fun reasonOfUnknown(): String = bitwuzlaCtx.bitwuzlaTry { - require(lastCheckStatus == KSolverStatus.UNKNOWN) { - "Unknown reason is only available after UNKNOWN checks" - } - - // There is no way to retrieve reason of unknown from Bitwuzla in general case. - return lastReasonOfUnknown ?: "unknown" - } - - override fun interrupt() = bitwuzlaCtx.bitwuzlaTry { - Native.bitwuzlaForceTerminate(bitwuzlaCtx.bitwuzla) - } - - override fun close() = bitwuzlaCtx.bitwuzlaTry { - bitwuzlaCtx.close() - } - - private fun BitwuzlaResult.processCheckResult() = when (this) { - BitwuzlaResult.BITWUZLA_SAT -> KSolverStatus.SAT - BitwuzlaResult.BITWUZLA_UNSAT -> KSolverStatus.UNSAT - BitwuzlaResult.BITWUZLA_UNKNOWN -> KSolverStatus.UNKNOWN - }.also { lastCheckStatus = it } - - private fun invalidateSolverState() { - /** - * Bitwuzla model is only valid until the next check-sat call. - * */ - lastModel?.markInvalid() - lastModel = null - - lastCheckStatus = KSolverStatus.UNKNOWN - lastReasonOfUnknown = null - - lastAssumptions = null - } - - private inline fun bitwuzlaTryCheck(body: () -> KSolverStatus): KSolverStatus = try { - invalidateSolverState() - body() - } catch (ex: BitwuzlaNativeException) { - lastReasonOfUnknown = ex.message - KSolverStatus.UNKNOWN.also { lastCheckStatus = it } - } - - private inner class TrackedAssumptions { - private val assumedExprs = arrayListOf, BitwuzlaTerm>>() - - fun assumeTrackedAssertion(trackedAssertion: Pair, BitwuzlaTerm>) { - assumedExprs.add(trackedAssertion) - Native.bitwuzlaAssume(bitwuzlaCtx.bitwuzla, trackedAssertion.second) - } - - fun assumeAssumption(expr: KExpr, term: BitwuzlaTerm) = - assumeTrackedAssertion(expr to term) - - fun resolveUnsatCore(unsatAssumptions: BitwuzlaTermArray): List> { - val unsatCoreTerms = LongOpenHashSet(unsatAssumptions) - return assumedExprs.mapNotNull { (expr, term) -> expr.takeIf { unsatCoreTerms.contains(term) } } - } - } } diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverBase.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverBase.kt new file mode 100644 index 000000000..fa8ae269b --- /dev/null +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverBase.kt @@ -0,0 +1,201 @@ +package io.ksmt.solver.bitwuzla + +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import it.unimi.dsi.fastutil.longs.LongOpenHashSet +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaOption +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaResult +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm +import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTermArray +import org.ksmt.solver.bitwuzla.bindings.Native +import kotlin.time.Duration + +abstract class KBitwuzlaSolverBase(private val ctx: KContext) : KSolver { + open val bitwuzlaCtx = KBitwuzlaContext(ctx) + open val exprInternalizer: KBitwuzlaExprInternalizer by lazy { + KBitwuzlaExprInternalizer(bitwuzlaCtx) + } + open val exprConverter: KBitwuzlaExprConverter by lazy { + KBitwuzlaExprConverter(ctx, bitwuzlaCtx) + } + + private var lastCheckStatus = KSolverStatus.UNKNOWN + private var lastReasonOfUnknown: String? = null + private var lastAssumptions: TrackedAssumptions? = null + private var lastModel: KBitwuzlaModel? = null + + init { + Native.bitwuzlaSetOption(bitwuzlaCtx.bitwuzla, BitwuzlaOption.BITWUZLA_OPT_INCREMENTAL, value = 1) + Native.bitwuzlaSetOption(bitwuzlaCtx.bitwuzla, BitwuzlaOption.BITWUZLA_OPT_PRODUCE_MODELS, value = 1) + } + + private var trackedAssertions = mutableListOf, BitwuzlaTerm>>() + private val trackVarsAssertionFrames = arrayListOf(trackedAssertions) + + protected fun internalizeAndAssertWithAxioms(expr: KExpr) { + val assertionWithAxioms = with(exprInternalizer) { expr.internalizeAssertion() } + + assertionWithAxioms.axioms.forEach { + Native.bitwuzlaAssert(bitwuzlaCtx.bitwuzla, it) + } + Native.bitwuzlaAssert(bitwuzlaCtx.bitwuzla, assertionWithAxioms.assertion) + } + + override fun assert(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { + ctx.ensureContextMatch(expr) + internalizeAndAssertWithAxioms(expr) + } + + protected open fun registerTrackForExpr(expr: KExpr, track: KExpr) { + val trackVarTerm = with(exprInternalizer) { track.internalize() } + trackedAssertions += expr to trackVarTerm + } + + override fun assertAndTrack(expr: KExpr) = bitwuzlaCtx.bitwuzlaTry { + ctx.ensureContextMatch(expr) + + val trackVarExpr = ctx.mkFreshConst("track", ctx.boolSort) + val trackedExpr = with(ctx) { !trackVarExpr or expr } + + assert(trackedExpr) + registerTrackForExpr(expr, trackVarExpr) + } + + override fun push(): Unit = bitwuzlaCtx.bitwuzlaTry { + Native.bitwuzlaPush(bitwuzlaCtx.bitwuzla, nlevels = 1) + + trackedAssertions = trackedAssertions.toMutableList() + trackVarsAssertionFrames.add(trackedAssertions) + + bitwuzlaCtx.createNestedDeclarationScope() + } + + override fun pop(n: UInt): Unit = bitwuzlaCtx.bitwuzlaTry { + val currentLevel = trackVarsAssertionFrames.lastIndex.toUInt() + require(n <= currentLevel) { + "Cannot pop $n scope levels because current scope level is $currentLevel" + } + + if (n == 0u) return + + repeat(n.toInt()) { + trackVarsAssertionFrames.removeLast() + bitwuzlaCtx.popDeclarationScope() + } + + trackedAssertions = trackVarsAssertionFrames.last() + + Native.bitwuzlaPop(bitwuzlaCtx.bitwuzla, n.toInt()) + } + + override fun check(timeout: Duration): KSolverStatus = + checkWithAssumptions(emptyList(), timeout) + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus = + bitwuzlaTryCheck { + ctx.ensureContextMatch(assumptions) + + val currentAssumptions = TrackedAssumptions().also { lastAssumptions = it } + + trackedAssertions.forEach { + currentAssumptions.assumeTrackedAssertion(it) + } + + with(exprInternalizer) { + assumptions.forEach { + currentAssumptions.assumeAssumption(it, it.internalize()) + } + } + + checkWithTimeout(timeout).processCheckResult() + } + + protected fun checkWithTimeout(timeout: Duration): BitwuzlaResult = if (timeout.isInfinite()) { + Native.bitwuzlaCheckSatResult(bitwuzlaCtx.bitwuzla) + } else { + Native.bitwuzlaCheckSatTimeoutResult(bitwuzlaCtx.bitwuzla, timeout.inWholeMilliseconds) + } + + override fun model(): KModel = bitwuzlaCtx.bitwuzlaTry { + require(lastCheckStatus == KSolverStatus.SAT) { "Model are only available after SAT checks" } + val model = lastModel ?: KBitwuzlaModel( + ctx, bitwuzlaCtx, exprConverter, + bitwuzlaCtx.declarations(), + bitwuzlaCtx.uninterpretedSortsWithRelevantDecls() + ) + lastModel = model + model + } + + override fun unsatCore(): List> = bitwuzlaCtx.bitwuzlaTry { + require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } + val unsatAssumptions = Native.bitwuzlaGetUnsatAssumptions(bitwuzlaCtx.bitwuzla) + lastAssumptions?.resolveUnsatCore(unsatAssumptions) ?: emptyList() + } + + override fun reasonOfUnknown(): String = bitwuzlaCtx.bitwuzlaTry { + require(lastCheckStatus == KSolverStatus.UNKNOWN) { + "Unknown reason is only available after UNKNOWN checks" + } + + // There is no way to retrieve reason of unknown from Bitwuzla in general case. + return lastReasonOfUnknown ?: "unknown" + } + + override fun interrupt() = bitwuzlaCtx.bitwuzlaTry { + Native.bitwuzlaForceTerminate(bitwuzlaCtx.bitwuzla) + } + + override fun close() = bitwuzlaCtx.bitwuzlaTry { + bitwuzlaCtx.close() + } + + protected fun BitwuzlaResult.processCheckResult() = when (this) { + BitwuzlaResult.BITWUZLA_SAT -> KSolverStatus.SAT + BitwuzlaResult.BITWUZLA_UNSAT -> KSolverStatus.UNSAT + BitwuzlaResult.BITWUZLA_UNKNOWN -> KSolverStatus.UNKNOWN + }.also { lastCheckStatus = it } + + private fun invalidateSolverState() { + /** + * Bitwuzla model is only valid until the next check-sat call. + * */ + lastModel?.markInvalid() + lastModel = null + + lastCheckStatus = KSolverStatus.UNKNOWN + lastReasonOfUnknown = null + + lastAssumptions = null + } + + private inline fun bitwuzlaTryCheck(body: () -> KSolverStatus): KSolverStatus = try { + invalidateSolverState() + body() + } catch (ex: BitwuzlaNativeException) { + lastReasonOfUnknown = ex.message + KSolverStatus.UNKNOWN.also { lastCheckStatus = it } + } + + private inner class TrackedAssumptions { + private val assumedExprs = arrayListOf, BitwuzlaTerm>>() + + fun assumeTrackedAssertion(trackedAssertion: Pair, BitwuzlaTerm>) { + assumedExprs.add(trackedAssertion) + Native.bitwuzlaAssume(bitwuzlaCtx.bitwuzla, trackedAssertion.second) + } + + fun assumeAssumption(expr: KExpr, term: BitwuzlaTerm) = + assumeTrackedAssertion(expr to term) + + fun resolveUnsatCore(unsatAssumptions: BitwuzlaTermArray): List> { + val unsatCoreTerms = LongOpenHashSet(unsatAssumptions) + return assumedExprs.mapNotNull { (expr, term) -> expr.takeIf { unsatCoreTerms.contains(term) } } + } + } +} diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverConfiguration.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverConfiguration.kt index a40ee8a60..5ee777c8f 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverConfiguration.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaSolverConfiguration.kt @@ -6,6 +6,7 @@ import io.ksmt.solver.KSolverUnsupportedParameterException import org.ksmt.solver.bitwuzla.bindings.Bitwuzla import org.ksmt.solver.bitwuzla.bindings.BitwuzlaOption import org.ksmt.solver.bitwuzla.bindings.Native +import java.util.EnumMap interface KBitwuzlaSolverConfiguration : KSolverConfiguration { fun setBitwuzlaOption(option: BitwuzlaOption, value: Int) @@ -44,6 +45,25 @@ class KBitwuzlaSolverConfigurationImpl(private val bitwuzla: Bitwuzla) : KBitwuz } } +class KBitwuzlaForkingSolverConfigurationImpl(private val bitwuzla: Bitwuzla) : KBitwuzlaSolverConfiguration { + private val intOptions = EnumMap<_, Int>(BitwuzlaOption::class.java) + private val stringOptions = EnumMap<_, String>(BitwuzlaOption::class.java) + override fun setBitwuzlaOption(option: BitwuzlaOption, value: Int) { + Native.bitwuzlaSetOption(bitwuzla, option, value) + intOptions[option] = value + } + + override fun setBitwuzlaOption(option: BitwuzlaOption, value: String) { + Native.bitwuzlaSetOptionStr(bitwuzla, option, value) + stringOptions[option] = value + } + + fun fork(childBitwuzla: Bitwuzla) = KBitwuzlaForkingSolverConfigurationImpl(childBitwuzla).also { + intOptions.forEach { (option, value) -> it.setBitwuzlaOption(option, value) } + stringOptions.forEach { (option, value) -> it.setBitwuzlaOption(option, value) } + } +} + class KBitwuzlaSolverUniversalConfiguration( private val builder: KSolverUniversalConfigurationBuilder ) : KBitwuzlaSolverConfiguration { diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaUninterpretedSortValueContext.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaUninterpretedSortValueContext.kt index cd7e184ba..d0c4d8d9f 100644 --- a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaUninterpretedSortValueContext.kt +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/KBitwuzlaUninterpretedSortValueContext.kt @@ -10,10 +10,12 @@ class KBitwuzlaUninterpretedSortValueContext(private val ctx: KContext) { private val sortsUniverses = hashMapOf, KUninterpretedSortValue>>() fun mkValue(sort: KUninterpretedSort, value: KBitVec32Value): KUninterpretedSortValue { - val sortUniverse = sortsUniverses.getOrPut(sort) { hashMapOf() } - return sortUniverse.getOrPut(value) { - ctx.mkUninterpretedSortValue(sort, value.intValue) - } + return registerValue(ctx.mkUninterpretedSortValue(sort, value.intValue)) + } + + fun registerValue(value: KUninterpretedSortValue): KUninterpretedSortValue { + val sortsUniverse = sortsUniverses.getOrPut(value.sort) { hashMapOf() } + return sortsUniverse.getOrPut(value) { value } } fun currentSortUniverse(sort: KUninterpretedSort): Set = diff --git a/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/ScopedLinkedFrame.kt b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/ScopedLinkedFrame.kt new file mode 100644 index 000000000..d3e4e5931 --- /dev/null +++ b/ksmt-bitwuzla/src/main/kotlin/io/ksmt/solver/bitwuzla/ScopedLinkedFrame.kt @@ -0,0 +1,67 @@ +package io.ksmt.solver.bitwuzla + +internal class ScopedLinkedFrame private constructor( + private var current: LinkedFrame, + private inline val createNewFrame: () -> T, + private inline val copyFrame: (T) -> T +) { + constructor( + currentFrame: T, + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(LinkedFrame(currentFrame), createNewFrame, copyFrame) + + constructor( + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(createNewFrame(), createNewFrame, copyFrame) + + val currentFrame: T + get() = current.value + + val currentScope: UInt + get() = current.scope + + fun stacked(): ArrayDeque = ArrayDeque().also { stack -> + forEachReversed { frame -> + stack.addLast(frame) + } + } + + fun push() { + current = LinkedFrame(createNewFrame(), current) + } + + fun pop(n: UInt) { + repeat(n.toInt()) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + } + recreateTopFrame() + } + + private fun recreateTopFrame() { + val newTopFrame = copyFrame(currentFrame) + current = LinkedFrame(newTopFrame, current.previous) + } + + fun fork(parent: ScopedLinkedFrame) { + current = parent.current + recreateTopFrame() + } + + private inline fun forEachReversed(action: (T) -> Unit) { + var cur: LinkedFrame? = current + while (cur != null) { + action(cur.value) + cur = cur.previous + } + } + + private class LinkedFrame( + val value: E, + val previous: LinkedFrame? = null + ) { + val scope: UInt = previous?.scope?.plus(1u) ?: 0u + } + +} diff --git a/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolver.kt b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolver.kt new file mode 100644 index 000000000..86a7ba498 --- /dev/null +++ b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolver.kt @@ -0,0 +1,14 @@ +package io.ksmt.solver + +/** + * A solver capable of creating forks (copies) of itself, preserving assertions and assertion scopes + * + * @see KForkingSolverManager + */ +interface KForkingSolver : KSolver { + + /** + * Creates forked solver + */ + fun fork(): KForkingSolver +} diff --git a/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt new file mode 100644 index 000000000..34c63ff11 --- /dev/null +++ b/ksmt-core/src/main/kotlin/io/ksmt/solver/KForkingSolverManager.kt @@ -0,0 +1,16 @@ +package io.ksmt.solver + +/** + * Responsible for creation of [KForkingSolver] and managing its lifetime + * + * @see KForkingSolver + */ +interface KForkingSolverManager : AutoCloseable { + + fun createForkingSolver(): KForkingSolver + + /** + * Closes the manager and all opened solvers ([KForkingSolver]) managed by this + */ + override fun close() +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesForkingTracker.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesForkingTracker.kt new file mode 100644 index 000000000..5e3679d14 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesForkingTracker.kt @@ -0,0 +1,22 @@ +package io.ksmt.solver.cvc5 + +/** + * An uninterpreted sort values tracker with ability to fork, + * preserving all registered descriptors ([uninterpretedSortValueDescriptors]). + * In a newly-forked tracker all known axioms will be asserted at + * the nearest call of [assertPendingUninterpretedValueConstraints] + */ +class ExpressionUninterpretedValuesForkingTracker : ExpressionUninterpretedValuesTracker { + constructor(cvc5Ctx: KCvc5Context) : super(cvc5Ctx) + private constructor( + cvc5Ctx: KCvc5Context, + uninterpretedSortValueDescriptors: ArrayList + ) : super(cvc5Ctx, uninterpretedSortValueDescriptors) + + fun fork(childCvc5Ctx: KCvc5Context) = + ExpressionUninterpretedValuesForkingTracker(childCvc5Ctx, uninterpretedSortValueDescriptors).also { child -> + repeat(assertedConstraintLevels.size) { + child.pushAssertionLevel() + } + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesTracker.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesTracker.kt new file mode 100644 index 000000000..b1f9acef7 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ExpressionUninterpretedValuesTracker.kt @@ -0,0 +1,135 @@ +package io.ksmt.solver.cvc5 + +import io.github.cvc5.Kind +import io.github.cvc5.Solver +import io.github.cvc5.Term +import io.ksmt.decl.KDecl +import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.sort.KArray2Sort +import io.ksmt.sort.KArray3Sort +import io.ksmt.sort.KArrayNSort +import io.ksmt.sort.KArraySort +import io.ksmt.sort.KBoolSort +import io.ksmt.sort.KBvSort +import io.ksmt.sort.KFpRoundingModeSort +import io.ksmt.sort.KFpSort +import io.ksmt.sort.KIntSort +import io.ksmt.sort.KRealSort +import io.ksmt.sort.KSort +import io.ksmt.sort.KSortVisitor +import io.ksmt.sort.KUninterpretedSort + +open class ExpressionUninterpretedValuesTracker protected constructor( + private val cvc5Ctx: KCvc5Context, + protected val uninterpretedSortValueDescriptors: ArrayList +) { + constructor(cvc5Ctx: KCvc5Context) : this(cvc5Ctx, arrayListOf()) + + private var currentValueConstraintsLevel = 0 + protected val assertedConstraintLevels = arrayListOf() + + private val uninterpretedSortCollector = KUninterpretedSortCollector(cvc5Ctx) + + fun collectUninterpretedSorts(decl: KDecl<*>) { + uninterpretedSortCollector.collect(decl) + } + + fun pushAssertionLevel() { + assertedConstraintLevels += currentValueConstraintsLevel + } + + fun popAssertionLevel() { + currentValueConstraintsLevel = assertedConstraintLevels.removeLast() + } + + fun registerUninterpretedSortValue( + value: KUninterpretedSortValue, + uniqueValueDescriptorTerm: Term, + uninterpretedValueTerm: Term + ) { + uninterpretedSortValueDescriptors += UninterpretedSortValueDescriptor( + value = value, + nativeUniqueValueDescriptor = uniqueValueDescriptorTerm, + nativeValueTerm = uninterpretedValueTerm + ) + } + + fun assertPendingUninterpretedValueConstraints(solver: Solver) { + while (currentValueConstraintsLevel < uninterpretedSortValueDescriptors.size) { + assertUninterpretedSortValueConstraint( + solver, + uninterpretedSortValueDescriptors[currentValueConstraintsLevel] + ) + currentValueConstraintsLevel++ + } + } + + private fun assertUninterpretedSortValueConstraint(solver: Solver, value: UninterpretedSortValueDescriptor) { + val interpreter = cvc5Ctx.getUninterpretedSortValueInterpreter(value.value.sort) + ?: error("Interpreter was not registered for sort: ${value.value.sort}") + + val constraintLhs = solver.mkTerm(Kind.APPLY_UF, arrayOf(interpreter, value.nativeValueTerm)) + val constraint = constraintLhs.eqTerm(value.nativeUniqueValueDescriptor) + + solver.assertFormula(constraint) + } + + @Suppress("ForbiddenComment") + /** + * Uninterpreted sort values distinct constraints management. + * + * 1. save/register uninterpreted value. + * See [KUninterpretedSortValue] internalization for the details. + * 2. Assert distinct constraints ([assertPendingUninterpretedValueConstraints]) that may be introduced + * during internalization. + * Currently, we assert constraints for all the values we have ever internalized. + * + * todo: precise uninterpreted sort values tracking + * */ + protected data class UninterpretedSortValueDescriptor( + val value: KUninterpretedSortValue, + val nativeUniqueValueDescriptor: Term, + val nativeValueTerm: Term + ) + + class KUninterpretedSortCollector(private val cvc5Ctx: KCvc5Context) : KSortVisitor { + override fun visit(sort: KBoolSort) = Unit + + override fun visit(sort: KIntSort) = Unit + + override fun visit(sort: KRealSort) = Unit + + override fun visit(sort: S) = Unit + + override fun visit(sort: S) = Unit + + override fun visit(sort: KArraySort) { + sort.domain.accept(this) + sort.range.accept(this) + } + + override fun visit(sort: KArray3Sort) { + sort.domainSorts.forEach { it.accept(this) } + sort.range.accept(this) + } + + override fun visit(sort: KArray2Sort) { + sort.domainSorts.forEach { it.accept(this) } + sort.range.accept(this) + } + + override fun visit(sort: KArrayNSort) { + sort.domainSorts.forEach { it.accept(this) } + sort.range.accept(this) + } + + override fun visit(sort: KFpRoundingModeSort) = Unit + + override fun visit(sort: KUninterpretedSort) = cvc5Ctx.addUninterpretedSort(sort) + + fun collect(decl: KDecl<*>) { + decl.argSorts.map { it.accept(this) } + decl.sort.accept(this) + } + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt index 0368b6118..84550a933 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Context.kt @@ -1,6 +1,6 @@ package io.ksmt.solver.cvc5 -import io.github.cvc5.Kind +import io.github.cvc5.CVC5ApiException import io.github.cvc5.Solver import io.github.cvc5.Sort import io.github.cvc5.Term @@ -16,35 +16,40 @@ import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.expr.KUniversalQuantifier import io.ksmt.expr.rewrite.KExprUninterpretedDeclCollector import io.ksmt.expr.transformer.KNonRecursiveTransformer -import io.ksmt.sort.KArray2Sort -import io.ksmt.sort.KArray3Sort -import io.ksmt.sort.KArrayNSort +import io.ksmt.solver.KSolverException import io.ksmt.sort.KArraySort import io.ksmt.sort.KArraySortBase import io.ksmt.sort.KBoolSort -import io.ksmt.sort.KBvSort -import io.ksmt.sort.KFpRoundingModeSort -import io.ksmt.sort.KFpSort -import io.ksmt.sort.KIntSort -import io.ksmt.sort.KRealSort import io.ksmt.sort.KSort -import io.ksmt.sort.KSortVisitor import io.ksmt.sort.KUninterpretedSort import java.util.TreeMap -class KCvc5Context( - private val solver: Solver, - private val ctx: KContext +/** + * @param mkExprSolver used as "context" for creation of all native expressions, which are stored in [KCvc5Context]. + */ +open class KCvc5Context( + protected val solver: Solver, + val mkExprSolver: Solver, + protected val ctx: KContext ) : AutoCloseable { - private var isClosed = false - private val uninterpretedSortCollector = KUninterpretedSortCollector(this) - private var exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(uninterpretedSortCollector, ctx) + /** + * Creates [KCvc5Context]. All native expressions will be created via [solver]. + */ + constructor(solver: Solver, ctx: KContext) : this(solver, solver, ctx) + + protected var isClosed = false + + private var exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(ctx) + + protected open val uninterpretedSorts: ScopedFrame> = ScopedArrayFrame(::HashSet) + protected open val declarations: ScopedFrame>> = ScopedArrayFrame(::HashSet) + /** * We use double-scoped expression internalization cache: - * * current (before pop operation) - [currentScopeExpressions] - * * global (after pop operation) - [expressions] + * * current accumulated (before pop operation) - [currentAccumulatedScopeExpressions] + * * global (current + all previously met)- [expressions] * * Due to incremental collect of declarations and uninterpreted sorts for the model, * we collect them during internalizing. @@ -57,39 +62,38 @@ class KCvc5Context( * * **solution**: Recollect sorts / decls for each expression * that is in global cache, but whose sorts / decls have been erased after pop() - * (and put this expr to the current scope cache) + * (and put this expr to the cache of current accumulated scope) */ - private val currentScopeExpressions = HashMap, Term>() - private val expressions = HashMap, Term>() - // we can't use HashMap with Term and Sort (hashcode is not implemented) - private val cvc5Expressions = TreeMap>() - private val sorts = HashMap() - private val cvc5Sorts = TreeMap() - private val decls = HashMap, Term>() - private val cvc5Decls = TreeMap>() - - private var currentLevelUninterpretedSorts = hashSetOf() - private val uninterpretedSorts = mutableListOf(currentLevelUninterpretedSorts) - - private var currentLevelDeclarations = hashSetOf>() - private val declarations = mutableListOf(currentLevelDeclarations) - - fun addUninterpretedSort(sort: KUninterpretedSort) { currentLevelUninterpretedSorts += sort } + protected val currentAccumulatedScopeExpressions = HashMap, Term>() + protected open val expressions = HashMap, Term>() /** - * uninterpreted sorts of active push-levels + * We can't use HashMap with Term and Sort (hashcode is not implemented) */ - fun uninterpretedSorts(): List> = uninterpretedSorts + protected open val cvc5Expressions = TreeMap>() + protected open val sorts = HashMap() + protected open val cvc5Sorts = TreeMap() + protected open val decls = HashMap, Term>() + protected open val cvc5Decls = TreeMap>() + + @Suppress("LeakingThis") + open val uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(this) + protected open val uninterpretedSortValueInterpreter = HashMap() + protected open val uninterpretedSortValues = + HashMap>>() + + fun addUninterpretedSort(sort: KUninterpretedSort) { + uninterpretedSorts.currentFrame += sort + } + + fun uninterpretedSorts(): Set = uninterpretedSorts.flatten { this += it } fun addDeclaration(decl: KDecl<*>) { - currentLevelDeclarations += decl - uninterpretedSortCollector.collect(decl) + declarations.currentFrame += decl + uninterpretedValuesTracker.collectUninterpretedSorts(decl) } - /** - * declarations of active push-levels - */ - fun declarations(): List>> = declarations + fun declarations(): Set> = declarations.flatten { this += it } val nativeSolver: Solver get() = solver @@ -97,58 +101,54 @@ class KCvc5Context( val isActive: Boolean get() = !isClosed - fun push() { - currentLevelDeclarations = hashSetOf() - declarations.add(currentLevelDeclarations) - currentLevelUninterpretedSorts = hashSetOf() - uninterpretedSorts.add(currentLevelUninterpretedSorts) + fun ensureActive() = check(isActive) { "The context is already closed." } - pushAssertionLevel() + fun push() { + declarations.push() + uninterpretedSorts.push() + uninterpretedValuesTracker.pushAssertionLevel() } fun pop(n: UInt) { - repeat(n.toInt()) { - declarations.removeLast() - uninterpretedSorts.removeLast() + declarations.pop(n) + uninterpretedSorts.pop(n) - popAssertionLevel() - } + repeat(n.toInt()) { uninterpretedValuesTracker.popAssertionLevel() } - currentLevelDeclarations = declarations.last() - currentLevelUninterpretedSorts = uninterpretedSorts.last() - - expressions += currentScopeExpressions - currentScopeExpressions.clear() + currentAccumulatedScopeExpressions.clear() // recreate cache restorer to avoid KNonRecursiveTransformer cache - exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(uninterpretedSortCollector, ctx) + exprCurrentLevelCacheRestorer = KCurrentScopeExprCacheRestorer(ctx) } - // expr - fun findInternalizedExpr(expr: KExpr<*>): Term? = currentScopeExpressions[expr] + fun findInternalizedExpr(expr: KExpr<*>): Term? = currentAccumulatedScopeExpressions[expr] ?: expressions[expr]?.also { /* - * expr is not in current scope cache, but in global cache. + * expr is not in cache of current accumulated scope, but in global cache. * Recollect declarations and uninterpreted sorts - * and add entire expression tree to the current scope cache from the global - * to avoid re-internalizing with native calls + * and add entire expression tree to the current accumulated scope cache from the global + * to avoid re-internalization */ exprCurrentLevelCacheRestorer.apply(expr) } + open fun Term.convert(converter: KCvc5ExprConverter) = with(converter) { convertExpr() } + fun findConvertedExpr(expr: Term): KExpr<*>? = cvc5Expressions[expr] fun saveInternalizedExpr(expr: KExpr<*>, internalized: Term): Term = - internalizeAst(currentScopeExpressions, cvc5Expressions, expr) { internalized } + internalizeAst(currentAccumulatedScopeExpressions, cvc5Expressions, expr) { internalized } + .also { expressions[expr] = internalized } /** * save expr, which is in global cache, to the current scope cache */ - fun savePreviouslyInternalizedExpr(expr: KExpr<*>): Term = saveInternalizedExpr(expr, expressions[expr]!!) + fun saveInternalizedExprToCurrentAccumulatedScope(expr: KExpr<*>): Term = + currentAccumulatedScopeExpressions.getOrPut(expr) { expressions[expr]!! } fun saveConvertedExpr(expr: Term, converted: KExpr<*>): KExpr<*> = - convertAst(currentScopeExpressions, cvc5Expressions, expr) { converted } + convertAst(currentAccumulatedScopeExpressions, cvc5Expressions, expr) { converted } + .also { expressions[converted] = expr } - // sort fun findInternalizedSort(sort: KSort): Sort? = sorts[sort] fun findConvertedSort(sort: Sort): KSort? = cvc5Sorts[sort] @@ -165,7 +165,6 @@ class KCvc5Context( inline fun convertSort(sort: Sort, converter: () -> KSort): KSort = findOrSave(sort, converter, ::findConvertedSort, ::saveConvertedSort) - // decl fun findInternalizedDecl(decl: KDecl<*>): Term? = decls[decl] fun findConvertedDecl(decl: Term): KDecl<*>? = cvc5Decls[decl] @@ -193,35 +192,9 @@ class KCvc5Context( return save(key, computeValue()) } - - @Suppress("ForbiddenComment") - /** - * Uninterpreted sort values distinct constraints management. - * - * 1. save/register uninterpreted value. - * See [KUninterpretedSortValue] internalization for the details. - * 2. Assert distinct constraints ([assertPendingAxioms]) that may be introduced during internalization. - * Currently, we assert constraints for all the values we have ever internalized. - * - * todo: precise uninterpreted sort values tracking - * */ - private data class UninterpretedSortValueDescriptor( - val value: KUninterpretedSortValue, - val nativeUniqueValueDescriptor: Term, - val nativeValueTerm: Term - ) - - private var currentValueConstraintsLevel = 0 - private val assertedConstraintLevels = arrayListOf() - private val uninterpretedSortValueDescriptors = arrayListOf() - private val uninterpretedSortValueInterpreter = hashMapOf() - - private val uninterpretedSortValues = - hashMapOf>>() - fun saveUninterpretedSortValue(nativeValue: Term, value: KUninterpretedSortValue): Term { val sortValues = uninterpretedSortValues.getOrPut(value.sort) { arrayListOf() } - sortValues.add(nativeValue to value) + sortValues += nativeValue to value return nativeValue } @@ -236,11 +209,15 @@ class KCvc5Context( registerUninterpretedSortValueInterpreter(value.sort, mkInterpreter()) } - registerUninterpretedSortValue(value, uniqueValueDescriptorTerm, uninterpretedValueTerm) + uninterpretedValuesTracker.registerUninterpretedSortValue( + value, + uniqueValueDescriptorTerm, + uninterpretedValueTerm + ) } fun assertPendingAxioms(solver: Solver) { - assertPendingUninterpretedValueConstraints(solver) + uninterpretedValuesTracker.assertPendingUninterpretedValueConstraints(solver) } fun getUninterpretedSortValueInterpreter(sort: KUninterpretedSort): Term? = @@ -250,49 +227,9 @@ class KCvc5Context( uninterpretedSortValueInterpreter[sort] = interpreter } - fun registerUninterpretedSortValue( - value: KUninterpretedSortValue, - uniqueValueDescriptorTerm: Term, - uninterpretedValueTerm: Term - ) { - uninterpretedSortValueDescriptors += UninterpretedSortValueDescriptor( - value = value, - nativeUniqueValueDescriptor = uniqueValueDescriptorTerm, - nativeValueTerm = uninterpretedValueTerm - ) - } - fun getRegisteredSortValues(sort: KUninterpretedSort): List> = uninterpretedSortValues[sort] ?: emptyList() - private fun pushAssertionLevel() { - assertedConstraintLevels.add(currentValueConstraintsLevel) - } - - private fun popAssertionLevel() { - currentValueConstraintsLevel = assertedConstraintLevels.removeLast() - } - - private fun assertPendingUninterpretedValueConstraints(solver: Solver) { - while (currentValueConstraintsLevel < uninterpretedSortValueDescriptors.size) { - assertUninterpretedSortValueConstraint( - solver, - uninterpretedSortValueDescriptors[currentValueConstraintsLevel] - ) - currentValueConstraintsLevel++ - } - } - - private fun assertUninterpretedSortValueConstraint(solver: Solver, value: UninterpretedSortValueDescriptor) { - val interpreter = uninterpretedSortValueInterpreter[value.value.sort] - ?: error("Interpreter was not registered for sort: ${value.value.sort}") - - val constraintLhs = solver.mkTerm(Kind.APPLY_UF, arrayOf(interpreter, value.nativeValueTerm)) - val constraint = constraintLhs.eqTerm(value.nativeUniqueValueDescriptor) - - solver.assertFormula(constraint) - } - private inline fun internalizeAst( cache: MutableMap, reverseCache: MutableMap, @@ -349,90 +286,47 @@ class KCvc5Context( return converted } - override fun close() { if (isClosed) return isClosed = true - currentScopeExpressions.clear() + currentAccumulatedScopeExpressions.clear() expressions.clear() cvc5Expressions.clear() - - uninterpretedSorts.clear() - currentLevelUninterpretedSorts.clear() - - declarations.clear() - currentLevelDeclarations.clear() - sorts.clear() cvc5Sorts.clear() decls.clear() cvc5Decls.clear() + uninterpretedSortValueInterpreter.clear() + uninterpretedSortValues.clear() } - class KUninterpretedSortCollector(private val cvc5Ctx: KCvc5Context) : KSortVisitor { - override fun visit(sort: KBoolSort) = Unit - - override fun visit(sort: KIntSort) = Unit - - override fun visit(sort: KRealSort) = Unit - - override fun visit(sort: S) = Unit - - override fun visit(sort: S) = Unit - - override fun visit(sort: KArraySort) { - sort.domain.accept(this) - sort.range.accept(this) - } - - override fun visit(sort: KArray3Sort) { - sort.domainSorts.forEach { it.accept(this) } - sort.range.accept(this) - } - - override fun visit(sort: KArray2Sort) { - sort.domainSorts.forEach { it.accept(this) } - sort.range.accept(this) - } - - override fun visit(sort: KArrayNSort) { - sort.domainSorts.forEach { it.accept(this) } - sort.range.accept(this) - } - - override fun visit(sort: KFpRoundingModeSort) = Unit - - override fun visit(sort: KUninterpretedSort) = cvc5Ctx.addUninterpretedSort(sort) - - fun collect(decl: KDecl<*>) { - decl.argSorts.map { it.accept(this) } - decl.sort.accept(this) - } + inline fun cvc5Try(body: () -> T): T = try { + ensureActive() + body() + } catch (ex: CVC5ApiException) { + throw KSolverException(ex) } - inner class KCurrentScopeExprCacheRestorer( - private val uninterpretedSortCollector: KUninterpretedSortCollector, - ctx: KContext - ) : KNonRecursiveTransformer(ctx) { - - override fun exprTransformationRequired(expr: KExpr): Boolean = expr !in currentScopeExpressions + inner class KCurrentScopeExprCacheRestorer(ctx: KContext) : KNonRecursiveTransformer(ctx) { + override fun exprTransformationRequired(expr: KExpr): Boolean = + expr !in currentAccumulatedScopeExpressions override fun transform(expr: KFunctionApp): KExpr = cacheIfNeed(expr) { this@KCvc5Context.addDeclaration(expr.decl) - uninterpretedSortCollector.collect(expr.decl) + uninterpretedValuesTracker.collectUninterpretedSorts(expr.decl) } override fun transform(expr: KConst): KExpr = cacheIfNeed(expr) { this@KCvc5Context.addDeclaration(expr.decl) - uninterpretedSortCollector.collect(expr.decl) - this@KCvc5Context.savePreviouslyInternalizedExpr(expr) + uninterpretedValuesTracker.collectUninterpretedSorts(expr.decl) + saveInternalizedExprToCurrentAccumulatedScope(expr) } override fun , R : KSort> transform(expr: KFunctionAsArray): KExpr = cacheIfNeed(expr) { this@KCvc5Context.addDeclaration(expr.function) - uninterpretedSortCollector.collect(expr.function) + uninterpretedValuesTracker.collectUninterpretedSorts(expr.function) } override fun transform(expr: KArrayLambda): KExpr> = @@ -446,7 +340,7 @@ class KCvc5Context( override fun transform(expr: KUniversalQuantifier): KExpr = cacheIfNeed(expr) { transformQuantifier(expr.bounds.toSet(), expr.body) - this@KCvc5Context.savePreviouslyInternalizedExpr(expr) + saveInternalizedExprToCurrentAccumulatedScope(expr) } private fun transformQuantifier(bounds: Set>, body: KExpr<*>) { @@ -455,11 +349,11 @@ class KCvc5Context( } private fun > cacheIfNeed(expr: E, transform: E.() -> Unit): KExpr { - if (expr in currentScopeExpressions) + if (expr in currentAccumulatedScopeExpressions) return expr expr.transform() - this@KCvc5Context.savePreviouslyInternalizedExpr(expr) + saveInternalizedExprToCurrentAccumulatedScope(expr) return expr } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5DeclInternalizer.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5DeclInternalizer.kt index ba9c03ce2..2cb1fefdc 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5DeclInternalizer.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5DeclInternalizer.kt @@ -18,7 +18,7 @@ open class KCvc5DeclInternalizer( val domainSorts = decl.argSorts.map { it.accept(sortInternalizer) } val rangeSort = decl.sort.accept(sortInternalizer) - cvc5Ctx.nativeSolver.declareFun( + cvc5Ctx.mkExprSolver.declareFun( decl.name, domainSorts.toTypedArray(), rangeSort @@ -29,6 +29,6 @@ open class KCvc5DeclInternalizer( cvc5Ctx.addDeclaration(decl) val sort = decl.sort.accept(sortInternalizer) - cvc5Ctx.nativeSolver.mkConst(sort, decl.name) + cvc5Ctx.mkExprSolver.mkConst(sort, decl.name) } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprConverter.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprConverter.kt index e7f69e186..b79d970eb 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprConverter.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprConverter.kt @@ -24,9 +24,10 @@ import io.ksmt.expr.KExpr import io.ksmt.expr.KFpRoundingMode import io.ksmt.expr.KInterpretedValue import io.ksmt.solver.KSolverUnsupportedFeatureException -import io.ksmt.solver.util.KExprConverterBase import io.ksmt.solver.util.ExprConversionResult +import io.ksmt.solver.util.KExprConverterBase import io.ksmt.solver.util.KExprConverterUtils.argumentsConversionRequired +import io.ksmt.solver.util.conversionLoop import io.ksmt.sort.KArithSort import io.ksmt.sort.KArray2Sort import io.ksmt.sort.KArray3Sort @@ -42,7 +43,7 @@ import io.ksmt.sort.KRealSort import io.ksmt.sort.KSort import io.ksmt.utils.asExpr import io.ksmt.utils.uncheckedCast -import java.util.* +import java.util.TreeMap @Suppress("LargeClass") open class KCvc5ExprConverter( @@ -734,6 +735,26 @@ open class KCvc5ExprConverter( } } + fun Term.convertExprWithMkExprSolver(): KExpr = convertExprWithoutCacheSave().also { + with(internalizer) { + it.internalizeExpr() + } + } + + fun Term.convertExprWithoutCacheSave(): KExpr { + val wasteCache = TreeMap>() + return conversionLoop( + stack = exprStack, + native = this, + stackPush = { stack, element -> stack.add(element) }, + stackPop = { stack -> stack.removeLast() }, + stackIsNotEmpty = { stack -> stack.isNotEmpty() }, + convertNative = { expr -> convertNativeExpr(expr) }, + findConverted = { expr -> wasteCache[expr] ?: findConvertedNative(expr) }, + saveConverted = { expr, converted -> wasteCache[expr] = converted } + ) + } + fun Term.convertExpr(): KExpr = convertFromNative() @Suppress("UNCHECKED_CAST") diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprInternalizer.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprInternalizer.kt index 7a369f3dd..c5fa7837b 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprInternalizer.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ExprInternalizer.kt @@ -160,11 +160,11 @@ import io.ksmt.expr.KXorExpr import io.ksmt.expr.rewrite.simplify.rewriteBvAddNoOverflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvAddNoUnderflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvDivNoOverflowExpr +import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoOverflowExpr +import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoUnderflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvNegNoOverflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvSubNoOverflowExpr import io.ksmt.expr.rewrite.simplify.rewriteBvSubNoUnderflowExpr -import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoOverflowExpr -import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoUnderflowExpr import io.ksmt.expr.rewrite.simplify.simplifyBvRotateLeftExpr import io.ksmt.expr.rewrite.simplify.simplifyBvRotateRightExpr import io.ksmt.solver.KSolverUnsupportedFeatureException @@ -203,7 +203,7 @@ class KCvc5ExprInternalizer( } private val nsolver: Solver - get() = cvc5Ctx.nativeSolver + get() = cvc5Ctx.mkExprSolver private val zeroIntValueTerm: Term by lazy { nsolver.mkInteger(0L) } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingContext.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingContext.kt new file mode 100644 index 000000000..dc1e8d5b7 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingContext.kt @@ -0,0 +1,75 @@ +package io.ksmt.solver.cvc5 + +import io.github.cvc5.Solver +import io.github.cvc5.Term +import io.ksmt.KContext +import io.ksmt.decl.KDecl +import io.ksmt.expr.KExpr +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort + +class KCvc5ForkingContext private constructor( + solver: Solver, + mkExprSolver: Solver, + ctx: KContext, + manager: KCvc5ForkingSolverManager, + parent: KCvc5ForkingContext? +) : KCvc5Context(solver, mkExprSolver, ctx) { + constructor(solver: Solver, mkExprSolver: Solver, ctx: KContext, manager: KCvc5ForkingSolverManager) : this( + solver, mkExprSolver, ctx, manager, null + ) + + private val uninterpretedSortsLinkedFrame = ScopedLinkedFrame>(::HashSet, ::HashSet) + private val declarationsLinkedFrame = ScopedLinkedFrame>>(::HashSet, ::HashSet) + + override val uninterpretedSorts: ScopedFrame> + get() = uninterpretedSortsLinkedFrame + override val declarations: ScopedFrame>> + get() = declarationsLinkedFrame + + override val expressions = with(manager) { getExpressionsCache() } + override val cvc5Expressions = with(manager) { getExpressionsReversedCache() } + override val sorts = with(manager) { getSortsCache() } + override val cvc5Sorts = with(manager) { getSortsReversedCache() } + override val decls = with(manager) { getDeclsCache() } + override val cvc5Decls = with(manager) { getDeclsReversedCache() } + + override val uninterpretedSortValueInterpreter = with(manager) { getUninterpretedSortsValueInterpretersCache() } + + /** + * Uninterpreted sort values and universe are shared for whole forking hierarchy (from parent to children) + * due to shared expressions cache, + * that's why once [registerUninterpretedSortValue] and [saveUninterpretedSortValue] are called, + * each solver in hierarchy should assert newly internalized uninterpreted sort values via [assertPendingAxioms] + * + * @see KCvc5Model.uninterpretedSortUniverse + */ + override val uninterpretedSortValues = with(manager) { getUninterpretedSortValues() } + + override val uninterpretedValuesTracker: ExpressionUninterpretedValuesForkingTracker = parent + ?.uninterpretedValuesTracker?.fork(this) + ?: ExpressionUninterpretedValuesForkingTracker(this) + + + init { + if (parent != null) { + currentAccumulatedScopeExpressions += parent.currentAccumulatedScopeExpressions + uninterpretedSortsLinkedFrame.fork(parent.uninterpretedSortsLinkedFrame) + declarationsLinkedFrame.fork(parent.declarationsLinkedFrame) + } + } + + fun fork(solver: Solver, forkingSolverManager: KCvc5ForkingSolverManager): KCvc5ForkingContext { + ensureActive() + return KCvc5ForkingContext(solver, mkExprSolver, ctx, forkingSolverManager, this) + } + + override fun close() { + if (isClosed) return + isClosed = true + } + + override fun Term.convert(converter: KCvc5ExprConverter): KExpr = with(converter) { + convertExprWithMkExprSolver() + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt new file mode 100644 index 000000000..d0bb588fb --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolver.kt @@ -0,0 +1,114 @@ +package io.ksmt.solver.cvc5 + +import io.github.cvc5.Term +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import java.util.TreeMap +import java.util.TreeSet +import kotlin.time.Duration + +open class KCvc5ForkingSolver internal constructor( + ctx: KContext, + private val manager: KCvc5ForkingSolverManager, + parent: KCvc5ForkingSolver? = null +) : KCvc5SolverBase(ctx), KForkingSolver { + final override val cvc5Ctx: KCvc5ForkingContext = manager.createCvc5ForkingContext(solver, parent?.cvc5Ctx) + private val isChild = parent != null + private var assertionsInitiated = !isChild // don't need to initiate assertions for root solver + + private val trackedAssertions = ScopedLinkedFrame>>(::TreeMap, ::TreeMap) + private val cvc5Assertions = ScopedLinkedFrame>(::TreeSet, ::TreeSet) + + override val currentScope: UInt + get() = trackedAssertions.currentScope + + private val config: KCvc5ForkingSolverConfigurationImpl by lazy { + parent?.config?.fork(solver) ?: KCvc5ForkingSolverConfigurationImpl(solver) + } + + init { + if (parent != null) { + trackedAssertions.fork(parent.trackedAssertions) + cvc5Assertions.fork(parent.cvc5Assertions) + config // initialize child config + } + } + + override fun configure(configurator: KCvc5SolverConfiguration.() -> Unit) { + config.configurator() + } + + override fun fork(): KForkingSolver = manager.createForkingSolver(this) + + private fun ensureAssertionsInitiated() { + if (assertionsInitiated) return + + cvc5Assertions.stacked() + .zip(trackedAssertions.stacked()) + .asReversed() + .forEachIndexed { scope, (cvc5AssertionFrame, trackedFrame) -> + if (scope > 0) solver.push() + + cvc5AssertionFrame.forEach(solver::assertFormula) + trackedFrame.forEach { (track, _) -> solver.assertFormula(track) } + } + + assertionsInitiated = true + } + + override fun saveTrackedAssertion(track: Term, trackedExpr: KExpr) { + trackedAssertions.currentFrame[track] = trackedExpr + } + + override fun findTrackedExprByTrack(track: Term) = trackedAssertions.findNonNullValue { it[track] } + + override fun assert(expr: KExpr): Unit = cvc5Ctx.cvc5Try { + ctx.ensureContextMatch(expr) + ensureAssertionsInitiated() + + val cvc5Expr = with(exprInternalizer) { expr.internalizeExpr() } + solver.assertFormula(cvc5Expr) + cvc5Ctx.assertPendingAxioms(solver) + cvc5Assertions.currentFrame.add(cvc5Expr) + } + + override fun assertAndTrack(expr: KExpr) { + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } + super.assertAndTrack(expr) + } + + override fun push() { + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } + super.push() + trackedAssertions.push() + cvc5Assertions.push() + } + + override fun pop(n: UInt) { + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } + super.pop(n) + trackedAssertions.pop(n) + cvc5Assertions.pop(n) + } + + override fun check(timeout: Duration): KSolverStatus { + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } + cvc5Ctx.assertPendingAxioms(solver) + return super.check(timeout) + } + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { + cvc5Ctx.cvc5Try { ensureAssertionsInitiated() } + cvc5Ctx.assertPendingAxioms(solver) + return super.checkWithAssumptions(assumptions, timeout) + } + + override fun close() { + manager.close(this) + solver.close() + cvc5Ctx.close() + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt new file mode 100644 index 000000000..d6d9f6406 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5ForkingSolverManager.kt @@ -0,0 +1,137 @@ +package io.ksmt.solver.cvc5 + +import io.github.cvc5.Solver +import io.github.cvc5.Sort +import io.github.cvc5.Term +import io.ksmt.KContext +import io.ksmt.decl.KDecl +import io.ksmt.expr.KExpr +import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KForkingSolverManager +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort +import java.util.TreeMap +import java.util.concurrent.ConcurrentHashMap + +/** + * Responsible for creation and managing of [KCvc5ForkingSolver]. + * + * It's cheaper to create multiple copies of solvers with [KCvc5ForkingSolver.fork] + * instead of assertions transferring in [KCvc5Solver] instances manually. + * + * All solvers created with one manager (via both [KCvc5ForkingSolver.fork] and [createForkingSolver]) + * use the same [mkExprContext]*, cache, and registered uninterpreted sort values. + * + * (*) [mkExprContext] is responsible for native expressions creation for each [KCvc5ForkingSolver] + * in one [KCvc5ForkingSolverManager]. Therefore, life scope of native expressions is the same with + * life scope of [KCvc5ForkingSolverManager] + */ +open class KCvc5ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { + private val mkExprContext by lazy { Solver() } + private val solvers: MutableSet = ConcurrentHashMap.newKeySet() + + // shared cache + private val expressionsCache = ExpressionsCache() + private val expressionsReversedCache = ExpressionsReversedCache() + private val sortsCache = SortsCache() + private val sortsReversedCache = SortsReversedCache() + private val declsCache = DeclsCache() + private val declsReversedCache = DeclsReversedCache() + private val uninterpretedSortValueInterpretersCache = UninterpretedSortValueInterpretersCache() + private val uninterpretedSortValues = UninterpretedSortValues() + + private fun Solver.ensureMkExprContextMatches() = require(this == mkExprContext) { + "Solver is not registered by this manager" + } + + internal fun KCvc5Context.getExpressionsCache() = mkExprSolver.ensureMkExprContextMatches().let { + expressionsCache + } + + internal fun KCvc5Context.getExpressionsReversedCache() = mkExprSolver.ensureMkExprContextMatches().let { + expressionsReversedCache + } + + internal fun KCvc5Context.getSortsCache() = mkExprSolver.ensureMkExprContextMatches().let { + sortsCache + } + + internal fun KCvc5Context.getSortsReversedCache() = mkExprSolver.ensureMkExprContextMatches().let { + sortsReversedCache + } + + internal fun KCvc5Context.getDeclsCache() = mkExprSolver.ensureMkExprContextMatches().let { + declsCache + } + + internal fun KCvc5Context.getDeclsReversedCache() = mkExprSolver.ensureMkExprContextMatches().let { + declsReversedCache + } + + internal fun KCvc5Context.getUninterpretedSortsValueInterpretersCache() = mkExprSolver + .ensureMkExprContextMatches().let { uninterpretedSortValueInterpretersCache } + + internal fun KCvc5Context.getUninterpretedSortValues() = mkExprSolver.ensureMkExprContextMatches().let { + uninterpretedSortValues + } + + override fun createForkingSolver(): KForkingSolver { + return KCvc5ForkingSolver(ctx, this, null).also { + solvers += it + } + } + + internal fun createForkingSolver(parent: KCvc5ForkingSolver): KForkingSolver { + return KCvc5ForkingSolver(ctx, this, parent).also { + solvers += it + } + } + + /** + * unregister [solver] for this manager + */ + internal fun close(solver: KCvc5ForkingSolver) { + solvers -= solver + closeContextIfStale() + } + + override fun close() { + solvers.forEach(KCvc5ForkingSolver::close) + } + + internal fun createCvc5ForkingContext(solver: Solver, parent: KCvc5ForkingContext? = null) = parent + ?.fork(solver, this) + ?: KCvc5ForkingContext(solver, mkExprContext, ctx, this) + + private fun closeContextIfStale() { + if (solvers.isNotEmpty()) return + + expressionsCache.clear() + expressionsReversedCache.clear() + sortsCache.clear() + sortsReversedCache.clear() + declsCache.clear() + declsReversedCache.clear() + uninterpretedSortValueInterpretersCache.clear() + uninterpretedSortValues.clear() + + mkExprContext.close() + } + + companion object { + init { + KCvc5SolverBase.ensureCvc5LibLoaded() + } + } +} + +private typealias ExpressionsCache = HashMap, Term> +private typealias ExpressionsReversedCache = TreeMap> +private typealias SortsCache = HashMap +private typealias SortsReversedCache = TreeMap +private typealias DeclsCache = HashMap, Term> +private typealias DeclsReversedCache = TreeMap> +private typealias UninterpretedSortValueInterpretersCache = HashMap +@Suppress("MaxLineLength") +private typealias UninterpretedSortValues = HashMap>> diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt index 306706549..66c00d1c1 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Model.kt @@ -72,27 +72,32 @@ open class KCvc5Model( private fun funcInterp( decl: KDecl, internalizedDecl: Term - ): KFuncInterp = with(converter) { + ): KFuncInterp { val cvc5Interp = cvc5Ctx.nativeSolver.getValue(internalizedDecl) val vars = decl.argSorts.map { it.mkFreshConst("x") } val cvc5Vars = vars.map { with(internalizer) { it.internalizeExpr() } }.toTypedArray() - val cvc5InterpArgs = cvc5Interp.getChild(0).getChildren() + val cvc5InterpArgs = with(converter) { cvc5Interp.getChild(0).getChildren() } val cvc5FreshVarsInterp = cvc5Interp.substitute(cvc5InterpArgs, cvc5Vars) - val defaultBody = cvc5FreshVarsInterp.getChild(1).convertExpr() + val defaultBody = cvc5FreshVarsInterp.getChild(1).let { + with(cvc5Ctx) { it.convert(converter) } + } - KFuncInterpWithVars(decl, vars.map { it.decl }, emptyList(), defaultBody) + return KFuncInterpWithVars(decl, vars.map { it.decl }, emptyList(), defaultBody) } - private fun constInterp(decl: KDecl, const: Term): KFuncInterp = with(converter) { + private fun constInterp(decl: KDecl, const: Term): KFuncInterp { val cvc5Interp = cvc5Ctx.nativeSolver.getValue(const) - val interp = cvc5Interp.convertExpr() - - KFuncInterpVarsFree(decl = decl, entries = emptyList(), default = interp) + val interp = with(cvc5Ctx) { cvc5Interp.convert(converter) } + return KFuncInterpVarsFree(decl = decl, entries = emptyList(), default = interp) } + /** + * In case of [KCvc5ForkingSolver.model] call, uninterpreted sort universe extends values of whole forking hierarchy + * @see KCvc5Context.getRegisteredSortValues + */ override fun uninterpretedSortUniverse(sort: KUninterpretedSort): Set? = getUninterpretedSortContext(sort).getSortUniverse() diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt index 1e92797fe..69c31ad0d 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5Solver.kt @@ -1,230 +1,32 @@ package io.ksmt.solver.cvc5 -import io.github.cvc5.CVC5ApiException -import io.github.cvc5.Result -import io.github.cvc5.Solver import io.github.cvc5.Term import io.ksmt.KContext import io.ksmt.expr.KExpr -import io.ksmt.solver.KModel import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverException -import io.ksmt.solver.KSolverStatus import io.ksmt.sort.KBoolSort -import io.ksmt.utils.NativeLibraryLoader import java.util.TreeMap -import kotlin.time.Duration -import kotlin.time.DurationUnit -open class KCvc5Solver(private val ctx: KContext) : KSolver { - private val solver = Solver() - private val cvc5Ctx = KCvc5Context(solver, ctx) +open class KCvc5Solver(ctx: KContext) : KCvc5SolverBase(ctx), KSolver { + override val cvc5Ctx: KCvc5Context = KCvc5Context(solver, ctx) + private val trackedAssertions = ScopedArrayFrame>> { TreeMap() } - private val exprInternalizer by lazy { createExprInternalizer(cvc5Ctx) } + override val currentScope: UInt + get() = trackedAssertions.currentScope - private val currentScope: UInt - get() = cvc5TrackedAssertions.lastIndex.toUInt() - - private var lastCheckStatus = KSolverStatus.UNKNOWN - private var lastReasonOfUnknown: String? = null - private var lastModel: KCvc5Model? = null - // we need TreeMap here (hashcode not implemented in Term) - private var cvc5LastAssumptions: TreeMap>? = null - - private var cvc5CurrentLevelTrackedAssertions = TreeMap>() - private val cvc5TrackedAssertions = mutableListOf(cvc5CurrentLevelTrackedAssertions) - - init { - solver.setOption("produce-models", "true") - solver.setOption("produce-unsat-cores", "true") - /** - * Allow floating-point sorts of all sizes, rather than - * only Float32 (8/24) or Float64 (11/53) (experimental in cvc5 1.0.2) - */ - solver.setOption("fp-exp", "true") - } - - open fun createExprInternalizer(cvc5Ctx: KCvc5Context): KCvc5ExprInternalizer = KCvc5ExprInternalizer(cvc5Ctx) - - override fun configure(configurator: KCvc5SolverConfiguration.() -> Unit) { - KCvc5SolverConfigurationImpl(solver).configurator() - } - - override fun assert(expr: KExpr) = cvc5Try { - ctx.ensureContextMatch(expr) - - val cvc5Expr = with(exprInternalizer) { expr.internalizeExpr() } - solver.assertFormula(cvc5Expr) - - cvc5Ctx.assertPendingAxioms(solver) - } - - override fun assertAndTrack(expr: KExpr) { - ctx.ensureContextMatch(expr) - - val trackVarApp = ctx.mkFreshConst("track", ctx.boolSort) - val cvc5TrackVar = with(exprInternalizer) { trackVarApp.internalizeExpr() } - val trackedExpr = with(ctx) { trackVarApp implies expr } - - cvc5CurrentLevelTrackedAssertions[cvc5TrackVar] = expr - - assert(trackedExpr) - solver.assertFormula(cvc5TrackVar) + override fun saveTrackedAssertion(track: Term, trackedExpr: KExpr) { + trackedAssertions.currentFrame[track] = trackedExpr } - override fun push() = solver.push().also { - cvc5CurrentLevelTrackedAssertions = TreeMap() - cvc5TrackedAssertions.add(cvc5CurrentLevelTrackedAssertions) + override fun findTrackedExprByTrack(track: Term) = trackedAssertions.findNonNullValue { it[track] } - cvc5Ctx.push() + override fun push() { + super.push() + trackedAssertions.push() } override fun pop(n: UInt) { - require(n <= currentScope) { - "Can not pop $n scope levels because current scope level is $currentScope" - } - - if (n == 0u) return - - repeat(n.toInt()) { - cvc5TrackedAssertions.removeLast() - } - cvc5CurrentLevelTrackedAssertions = cvc5TrackedAssertions.last() - - cvc5Ctx.pop(n) - solver.pop(n.toInt()) - } - - override fun check(timeout: Duration): KSolverStatus = cvc5TryCheck { - solver.updateTimeout(timeout) - solver.checkSat().processCheckResult() - } - - override fun checkWithAssumptions( - assumptions: List>, - timeout: Duration - ): KSolverStatus = cvc5TryCheck { - ctx.ensureContextMatch(assumptions) - - val lastAssumptions = TreeMap>().also { cvc5LastAssumptions = it } - val cvc5Assumptions = Array(assumptions.size) { idx -> - val assumedExpr = assumptions[idx] - with(exprInternalizer) { - assumedExpr.internalizeExpr().also { - lastAssumptions[it] = assumedExpr - } - } - } - - solver.updateTimeout(timeout) - solver.checkSatAssuming(cvc5Assumptions).processCheckResult() - } - - override fun reasonOfUnknown(): String = cvc5Try { - require(lastCheckStatus == KSolverStatus.UNKNOWN) { - "Unknown reason is only available after UNKNOWN checks" - } - lastReasonOfUnknown ?: "no explanation" - } - - override fun model(): KModel = cvc5Try { - require(lastCheckStatus == KSolverStatus.SAT) { "Models are only available after SAT checks" } - val model = lastModel ?: KCvc5Model( - ctx, - cvc5Ctx, - exprInternalizer, - cvc5Ctx.declarations().flatMapTo(hashSetOf()) { it }, - cvc5Ctx.uninterpretedSorts().flatMapTo(hashSetOf()) { it }, - ) - lastModel = model - - model - } - - override fun unsatCore(): List> = cvc5Try { - require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } - - val cvc5FullCore = solver.unsatCore - - val trackedTerms = TreeMap>() - cvc5TrackedAssertions.forEach { frame -> - trackedTerms.putAll(frame) - } - cvc5LastAssumptions?.also { trackedTerms.putAll(it) } - - cvc5FullCore.mapNotNull { trackedTerms[it] } - } - - override fun close() { - cvc5CurrentLevelTrackedAssertions.clear() - cvc5TrackedAssertions.clear() - - cvc5Ctx.close() - solver.close() - } - - /* - there are no methods to interrupt cvc5, - but maybe CVC5ApiRecoverableException can be thrown in someway - */ - override fun interrupt() = Unit - - private fun Result.processCheckResult() = when { - isSat -> KSolverStatus.SAT - isUnsat -> KSolverStatus.UNSAT - isUnknown || isNull -> KSolverStatus.UNKNOWN - else -> KSolverStatus.UNKNOWN - }.also { - lastCheckStatus = it - if (it == KSolverStatus.UNKNOWN) { - lastReasonOfUnknown = this.unknownExplanation?.name - } - } - - private fun Solver.updateTimeout(timeout: Duration) { - val cvc5Timeout = if (timeout == Duration.INFINITE) 0 else timeout.toInt(DurationUnit.MILLISECONDS) - setOption("tlimit-per", cvc5Timeout.toString()) - } - - private inline fun cvc5Try(body: () -> T): T = try { - body() - } catch (ex: CVC5ApiException) { - throw KSolverException(ex) - } - - private inline fun cvc5TryCheck(body: () -> KSolverStatus): KSolverStatus = try { - invalidateSolverState() - body() - } catch (ex: CVC5ApiException) { - lastReasonOfUnknown = ex.message - KSolverStatus.UNKNOWN.also { lastCheckStatus = it } - } - - private fun invalidateSolverState() { - /** - * Cvc5 model is only valid until the next check-sat call. - * */ - lastModel?.markInvalid() - lastModel = null - - lastCheckStatus = KSolverStatus.UNKNOWN - lastReasonOfUnknown = null - - cvc5LastAssumptions = null - } - - companion object { - init { - if (System.getProperty("cvc5.skipLibraryLoad") != "true") { - NativeLibraryLoader.load { os -> - when (os) { - NativeLibraryLoader.OS.LINUX -> listOf("libcvc5", "libcvc5jni") - NativeLibraryLoader.OS.WINDOWS -> listOf("libcvc5", "libcvc5jni") - NativeLibraryLoader.OS.MACOS -> listOf("libcvc5", "libcvc5jni") - } - } - System.setProperty("cvc5.skipLibraryLoad", "true") - } - } + super.pop(n) + trackedAssertions.pop(n) } } diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt new file mode 100644 index 000000000..07134ece9 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverBase.kt @@ -0,0 +1,218 @@ +package io.ksmt.solver.cvc5 + +import io.github.cvc5.CVC5ApiException +import io.github.cvc5.Result +import io.github.cvc5.Solver +import io.github.cvc5.Term +import io.ksmt.KContext +import io.ksmt.expr.KApp +import io.ksmt.expr.KExpr +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import io.ksmt.utils.NativeLibraryLoader +import java.util.TreeMap +import kotlin.time.Duration +import kotlin.time.DurationUnit + +abstract class KCvc5SolverBase internal constructor( + protected val ctx: KContext +) : KSolver { + + protected abstract val currentScope: UInt + + protected val solver = Solver().apply { configureInitially() } + protected abstract val cvc5Ctx: KCvc5Context + protected val exprInternalizer by lazy { createExprInternalizer(cvc5Ctx) } + + protected var lastCheckStatus = KSolverStatus.UNKNOWN + protected var lastReasonOfUnknown: String? = null + protected var lastModel: KCvc5Model? = null + + // use TreeMap for cvc5 Term (hashcode not implemented) + protected var lastCvc5Assumptions: TreeMap>? = null + + open fun createExprInternalizer(cvc5Ctx: KCvc5Context): KCvc5ExprInternalizer = KCvc5ExprInternalizer(cvc5Ctx) + + private fun Solver.configureInitially() { + setOption("produce-models", "true") + setOption("produce-unsat-cores", "true") + /** + * Allow floating-point sorts of all sizes, rather than + * only Float32 (8/24) or Float64 (11/53) + */ + setOption("fp-exp", "true") + } + + override fun configure(configurator: KCvc5SolverConfiguration.() -> Unit) { + KCvc5SolverConfigurationImpl(solver).configurator() + } + + override fun assert(expr: KExpr) = cvc5Ctx.cvc5Try { + ctx.ensureContextMatch(expr) + + val cvc5Expr = with(exprInternalizer) { expr.internalizeExpr() } + solver.assertFormula(cvc5Expr) + cvc5Ctx.assertPendingAxioms(solver) + } + + protected abstract fun saveTrackedAssertion(track: Term, trackedExpr: KExpr) + protected abstract fun findTrackedExprByTrack(track: Term): KExpr? + + override fun assertAndTrack(expr: KExpr) = cvc5Ctx.cvc5Try { + ctx.ensureContextMatch(expr) + + val trackVarApp = createTrackVarApp() + val cvc5TrackVar = with(exprInternalizer) { trackVarApp.internalizeExpr() } + val trackedExpr = with(ctx) { trackVarApp implies expr } + assert(trackedExpr) + solver.assertFormula(cvc5TrackVar) + saveTrackedAssertion(cvc5TrackVar, expr) + } + + override fun push() = cvc5Ctx.cvc5Try { + solver.push() + cvc5Ctx.push() + } + + override fun pop(n: UInt) = cvc5Ctx.cvc5Try { + require(n <= currentScope) { + "Can not pop $n scope levels because current scope level is $currentScope" + } + + if (n == 0u) return + solver.pop(n.toInt()) + cvc5Ctx.pop(n) + } + + override fun check(timeout: Duration): KSolverStatus = cvc5TryCheck { + solver.updateTimeout(timeout) + solver.checkSat().processCheckResult() + } + + override fun checkWithAssumptions( + assumptions: List>, + timeout: Duration + ): KSolverStatus = cvc5TryCheck { + ctx.ensureContextMatch(assumptions) + + val lastAssumptions = TreeMap>().also { lastCvc5Assumptions = it } + val cvc5Assumptions = Array(assumptions.size) { idx -> + val assumedExpr = assumptions[idx] + with(exprInternalizer) { + assumedExpr.internalizeExpr().also { + lastAssumptions[it] = assumedExpr + } + } + } + + solver.updateTimeout(timeout) + solver.checkSatAssuming(cvc5Assumptions).processCheckResult() + } + + protected open fun freshModel(): KCvc5Model = KCvc5Model( + ctx, + cvc5Ctx, + exprInternalizer, + cvc5Ctx.declarations(), + cvc5Ctx.uninterpretedSorts(), + ) + + override fun model(): KModel = cvc5Ctx.cvc5Try { + require(lastCheckStatus == KSolverStatus.SAT) { "Models are only available after SAT checks" } + val model = lastModel ?: freshModel() + model.also { lastModel = it } + } + + override fun reasonOfUnknown(): String = cvc5Ctx.cvc5Try { + require(lastCheckStatus == KSolverStatus.UNKNOWN) { + "Unknown reason is only available after UNKNOWN checks" + } + lastReasonOfUnknown ?: "no explanation" + } + + override fun unsatCore(): List> { + val cvc5FullCore = cvc5UnsatCore() + + val unsatCore = mutableListOf>() + + cvc5FullCore.forEach { unsatCoreTerm -> + lastCvc5Assumptions?.get(unsatCoreTerm)?.also { unsatCore += it } + ?: findTrackedExprByTrack(unsatCoreTerm)?.also { unsatCore += it } + } + return unsatCore + } + + protected fun cvc5UnsatCore(): Array = cvc5Ctx.cvc5Try { + require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } + solver.unsatCore + } + + override fun close() { + cvc5Ctx.close() + solver.close() + } + + /* + there is no method to interrupt cvc5 + */ + override fun interrupt() = Unit + + protected fun createTrackVarApp(): KApp = ctx.mkFreshConst("track", ctx.boolSort) + + protected fun Result.processCheckResult() = when { + isSat -> KSolverStatus.SAT + isUnsat -> KSolverStatus.UNSAT + isUnknown || isNull -> KSolverStatus.UNKNOWN + else -> KSolverStatus.UNKNOWN + }.also { + lastCheckStatus = it + if (it == KSolverStatus.UNKNOWN) { + lastReasonOfUnknown = this.unknownExplanation?.name + } + } + + protected fun Solver.updateTimeout(timeout: Duration) { + val cvc5Timeout = if (timeout == Duration.INFINITE) 0 else timeout.toInt(DurationUnit.MILLISECONDS) + setOption("tlimit-per", cvc5Timeout.toString()) + } + + protected inline fun cvc5TryCheck(body: () -> KSolverStatus): KSolverStatus = try { + invalidateSolverState() + body() + } catch (ex: CVC5ApiException) { + lastReasonOfUnknown = ex.message + KSolverStatus.UNKNOWN.also { lastCheckStatus = it } + } + + protected fun invalidateSolverState() { + /** + * Cvc5 model is only valid until the next check-sat call. + * */ + lastModel?.markInvalid() + lastModel = null + lastCheckStatus = KSolverStatus.UNKNOWN + lastReasonOfUnknown = null + lastCvc5Assumptions = null + } + + companion object { + internal fun ensureCvc5LibLoaded() { + if (System.getProperty("cvc5.skipLibraryLoad") != "true") { + NativeLibraryLoader.load { os -> + when (os) { + NativeLibraryLoader.OS.LINUX -> listOf("libcvc5", "libcvc5jni") + NativeLibraryLoader.OS.WINDOWS -> listOf("libcvc5", "libcvc5jni") + NativeLibraryLoader.OS.MACOS -> listOf("libcvc5", "libcvc5jni") + } + } + System.setProperty("cvc5.skipLibraryLoad", "true") + } + } + + init { + ensureCvc5LibLoaded() + } + } +} diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverConfiguration.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverConfiguration.kt index 7ef58b0cd..ec743a154 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverConfiguration.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SolverConfiguration.kt @@ -45,6 +45,29 @@ class KCvc5SolverConfigurationImpl(val solver: Solver) : KCvc5SolverConfiguratio } } +class KCvc5ForkingSolverConfigurationImpl(val solver: Solver) : KCvc5SolverConfiguration { + private val options = hashMapOf() + private lateinit var logic: String + override fun setCvc5Option(option: String, value: String) { + solver.setOption(option, value) + options[option] = value + } + + override fun setCvc5Logic(value: String) { + solver.setLogic(value) + logic = value + } + + fun fork(solver: Solver): KCvc5ForkingSolverConfigurationImpl = KCvc5ForkingSolverConfigurationImpl(solver).also { + if (::logic.isInitialized) { + it.setCvc5Logic(logic) + } + options.forEach { (option, value) -> + it.setCvc5Option(option, value) + } + } +} + class KCvc5SolverUniversalConfiguration( private val builder: KSolverUniversalConfigurationBuilder ) : KCvc5SolverConfiguration { diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SortInternalizer.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SortInternalizer.kt index 17bfd918b..48153212b 100644 --- a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SortInternalizer.kt +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/KCvc5SortInternalizer.kt @@ -20,7 +20,7 @@ import io.ksmt.sort.KUninterpretedSort open class KCvc5SortInternalizer( private val cvc5Ctx: KCvc5Context ) : KSortVisitor { - private val nSolver: Solver = cvc5Ctx.nativeSolver + private val nSolver: Solver = cvc5Ctx.mkExprSolver override fun visit(sort: KBoolSort): Sort = cvc5Ctx.internalizeSort(sort) { nSolver.booleanSort diff --git a/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt new file mode 100644 index 000000000..769327387 --- /dev/null +++ b/ksmt-cvc5/src/main/kotlin/io/ksmt/solver/cvc5/ScopedFrame.kt @@ -0,0 +1,126 @@ +package io.ksmt.solver.cvc5 + +interface ScopedFrame { + val currentScope: UInt + val currentFrame: T + + fun flatten(collect: T.(T) -> Unit): T + + fun push() + fun pop(n: UInt = 1u) +} + +internal class ScopedArrayFrame( + currentFrame: T, + private val createNewFrame: () -> T +) : ScopedFrame { + constructor(createNewFrame: () -> T) : this(createNewFrame(), createNewFrame) + + private val frames = arrayListOf(currentFrame) + + override var currentFrame = currentFrame + private set + + override val currentScope: UInt + get() = frames.size.toUInt() + + override fun flatten(collect: T.(T) -> Unit) = createNewFrame().also { newFrame -> + frames.forEach { newFrame.collect(it) } + } + + inline fun findNonNullValue(predicate: (T) -> V?): V? { + frames.forEach { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun push() { + currentFrame = createNewFrame() + frames += currentFrame + } + + override fun pop(n: UInt) { + repeat(n.toInt()) { frames.removeLast() } + currentFrame = frames.last() + } +} + +internal class ScopedLinkedFrame private constructor( + private var current: LinkedFrame, + private val createNewFrame: () -> T, + private val copyFrame: (T) -> T +) : ScopedFrame { + constructor( + currentFrame: T, + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(LinkedFrame(currentFrame), createNewFrame, copyFrame) + + constructor( + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(createNewFrame(), createNewFrame, copyFrame) + + override val currentFrame: T + get() = current.value + + override val currentScope: UInt + get() = current.scope + + override fun flatten(collect: T.(T) -> Unit): T = createNewFrame().also { newFrame -> + forEachReversed { frame -> + newFrame.collect(frame) + } + } + + fun stacked(): ArrayDeque = ArrayDeque().also { stack -> + forEachReversed { frame -> + stack.addLast(frame) + } + } + + inline fun findNonNullValue(predicate: (T) -> V?): V? { + forEachReversed { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun push() { + current = LinkedFrame(createNewFrame(), current) + } + + override fun pop(n: UInt) { + repeat(n.toInt()) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + } + recreateTopFrame() + } + + private fun recreateTopFrame() { + val newTopFrame = copyFrame(currentFrame) + current = LinkedFrame(newTopFrame, current.previous) + } + + fun fork(parent: ScopedLinkedFrame) { + current = parent.current + recreateTopFrame() + } + + private inline fun forEachReversed(action: (T) -> Unit) { + var cur: LinkedFrame? = current + while (cur != null) { + action(cur.value) + cur = cur.previous + } + } + + private class LinkedFrame( + val value: E, + val previous: LinkedFrame? = null + ) { + val scope: UInt = previous?.scope?.plus(1u) ?: 0u + } + +} diff --git a/ksmt-test/src/main/kotlin/io/ksmt/test/TestWorkerProcess.kt b/ksmt-test/src/main/kotlin/io/ksmt/test/TestWorkerProcess.kt index d6308d7de..ff56fd4c9 100644 --- a/ksmt-test/src/main/kotlin/io/ksmt/test/TestWorkerProcess.kt +++ b/ksmt-test/src/main/kotlin/io/ksmt/test/TestWorkerProcess.kt @@ -108,7 +108,7 @@ class TestWorkerProcess : ChildProcessBase() { private fun internalizeAndConvertYices(assertions: List>): List> { // Yices doesn't reverse cache internalized expressions (only interpreted values) - KYicesContext().use { internContext -> + KYicesContext(ctx).use { internContext -> val internalizer = KYicesExprInternalizer(internContext) val yicesAssertions = with(internalizer) { diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt new file mode 100644 index 000000000..189bf7245 --- /dev/null +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/KForkingSolverTest.kt @@ -0,0 +1,420 @@ +package io.ksmt.test + +import io.ksmt.KContext +import io.ksmt.solver.KForkingSolverManager +import io.ksmt.solver.KSolverStatus +import io.ksmt.solver.bitwuzla.KBitwuzlaForkingSolverManager +import io.ksmt.solver.cvc5.KCvc5ForkingSolverManager +import io.ksmt.solver.yices.KYicesForkingSolverManager +import io.ksmt.solver.z3.KZ3ForkingSolverManager +import io.ksmt.utils.getValue +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertDoesNotThrow +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals +import kotlin.test.assertTrue + +class KForkingSolverTest { + @Nested + inner class KForkingSolverTestBitwuzla { + @Test + fun testCheckSat() = testCheckSat(::mkBitwuzlaForkingSolverManager) + + @Test + fun testModel() = testModel(::mkBitwuzlaForkingSolverManager) + + @Test + fun testUnsatCore() = testUnsatCore(::mkBitwuzlaForkingSolverManager) + + @Test + fun testUninterpretedSort() = testUninterpretedSort(::mkBitwuzlaForkingSolverManager) + + @Test + fun testScopedAssertions() = testScopedAssertions(::mkBitwuzlaForkingSolverManager) + + @Test + fun testLifeTime() = testLifeTime(::mkBitwuzlaForkingSolverManager) + + private fun mkBitwuzlaForkingSolverManager(ctx: KContext) = KBitwuzlaForkingSolverManager(ctx) + } + + @Nested + inner class KForkingSolverTestCvc5 { + @Test + fun testCheckSat() = testCheckSat(::mkCvc5ForkingSolverManager) + + @Test + fun testModel() = testModel(::mkCvc5ForkingSolverManager) + + @Test + fun testUnsatCore() = testUnsatCore(::mkCvc5ForkingSolverManager) + + @Test + fun testUninterpretedSort() = testUninterpretedSort(::mkCvc5ForkingSolverManager) + + @Test + fun testScopedAssertions() = testScopedAssertions(::mkCvc5ForkingSolverManager) + + @Test + fun testLifeTime() = testLifeTime(::mkCvc5ForkingSolverManager) + + private fun mkCvc5ForkingSolverManager(ctx: KContext) = KCvc5ForkingSolverManager(ctx) + } + + @Nested + inner class KForkingSolverTestYices { + @Test + fun testCheckSat() = testCheckSat(::mkYicesForkingSolverManager) + + @Test + fun testModel() = testModel(::mkYicesForkingSolverManager) + + @Test + fun testUnsatCore() = testUnsatCore(::mkYicesForkingSolverManager) + + @Test + fun testUninterpretedSort() = testUninterpretedSort(::mkYicesForkingSolverManager) + + @Test + fun testScopedAssertions() = testScopedAssertions(::mkYicesForkingSolverManager) + + @Test + fun testLifeTime() = testLifeTime(::mkYicesForkingSolverManager) + + private fun mkYicesForkingSolverManager(ctx: KContext) = KYicesForkingSolverManager(ctx) + } + + @Nested + inner class KForkingSolverTestZ3 { + @Test + fun testCheckSat() = testCheckSat(::mkZ3ForkingSolverManager) + + @Test + fun testModel() = testModel(::mkZ3ForkingSolverManager) + + @Test + fun testUnsatCore() = testUnsatCore(::mkZ3ForkingSolverManager) + + @Test + fun testUninterpretedSort() = testUninterpretedSort(::mkZ3ForkingSolverManager) + + @Test + fun testScopedAssertions() = testScopedAssertions(::mkZ3ForkingSolverManager) + + @Test + fun testLifeTime() = testLifeTime(::mkZ3ForkingSolverManager) + + private fun mkZ3ForkingSolverManager(ctx: KContext) = KZ3ForkingSolverManager(ctx) + } + + private fun testCheckSat(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>) = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parentSolver -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and b + val neg = !a + + parentSolver.push() + + // * check children's assertions do not change parent's state + parentSolver.assert(f) + require(parentSolver.check() == KSolverStatus.SAT) + require(parentSolver.checkWithAssumptions(listOf(neg)) == KSolverStatus.UNSAT) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork.check()) + } + + assertEquals(KSolverStatus.SAT, parentSolver.check()) + // * + + // * check parent's assertions translated into child solver + parentSolver.push() + assertEquals(KSolverStatus.UNSAT, parentSolver.fork().checkWithAssumptions(listOf(neg))) + parentSolver.assert(neg) + require(parentSolver.check() == KSolverStatus.UNSAT) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.UNSAT, fork.check()) + } + parentSolver.pop() + // * + + // * check children independence + assertEquals(KSolverStatus.SAT, parentSolver.check()) + parentSolver.fork().also { fork1 -> + val fork2 = parentSolver.fork() + fork2.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork2.check()) + assertEquals(KSolverStatus.SAT, fork1.check()) + + fork1.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork1.check()) + assertEquals(KSolverStatus.SAT, parentSolver.fork().check()) + } + assertEquals(KSolverStatus.SAT, parentSolver.check()) + // * + } + } + } + } + + private fun testUnsatCore(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parentSolver -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and b + val neg = !a + + // * check that unsat core is empty (non-tracked assertions) + parentSolver.push() + parentSolver.assert(f) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(neg) + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertTrue { fork.unsatCore().isEmpty() } + assertEquals(KSolverStatus.SAT, parentSolver.check()) // parent's state hasn't changed + } + parentSolver.pop() + // * + + // check tracked exprs are in unsat core + parentSolver.push() + parentSolver.assertAndTrack(f) + + parentSolver.fork().also { fork -> + fork.assertAndTrack(neg) + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertContains(fork.unsatCore(), neg) + assertContains(fork.unsatCore(), f) + assertEquals(KSolverStatus.SAT, parentSolver.check()) // parent's state hasn't changed + } + // * + + // * check unsat core saves from parent to child + parentSolver.assert(neg) + require(parentSolver.check() == KSolverStatus.UNSAT) + require(neg !in parentSolver.unsatCore()) + require(f in parentSolver.unsatCore()) // only tracked f is in unsat core + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertContains(fork.unsatCore(), f) + assertTrue { neg !in fork.unsatCore() } + } + } + } + } + } + + private fun testModel(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parentSolver -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and !b + + parentSolver.assert(f) + + require(parentSolver.check() == KSolverStatus.SAT) + require(parentSolver.model().eval(a) == true.expr) + require(parentSolver.model().eval(b) == false.expr) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(true.expr, fork.model().eval(a)) + assertEquals(false.expr, fork.model().eval(b)) + } + } + } + } + } + + private fun testScopedAssertions(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parent -> + with(ctx) { + val a by boolSort + val b by boolSort + val f = a and b + val neg = !a + + parent.push() + + parent.assertAndTrack(f) + require(parent.check() == KSolverStatus.SAT) + parent.push() + parent.assertAndTrack(neg) + + require(parent.check() == KSolverStatus.UNSAT) + + parent.fork().also { fork -> + assertEquals(KSolverStatus.UNSAT, fork.check()) + assertContains(fork.unsatCore(), f) + assertContains(fork.unsatCore(), neg) + + fork.pop() + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(true.expr, fork.model().eval(a)) + assertEquals(true.expr, fork.model().eval(b)) + assertEquals(KSolverStatus.UNSAT, fork.checkWithAssumptions(listOf(neg))) + assertEquals(KSolverStatus.UNSAT, parent.check()) // check parent's state hasn't changed + + fork.fork().also { ffork -> + assertEquals(KSolverStatus.SAT, ffork.check()) + assertEquals(KSolverStatus.UNSAT, ffork.checkWithAssumptions(listOf(neg))) + + ffork.push() + ffork.assertAndTrack(neg) + assertEquals(KSolverStatus.UNSAT, ffork.check()) + assertContains(ffork.unsatCore(), f) + assertContains(ffork.unsatCore(), neg) + + assertEquals(KSolverStatus.SAT, fork.check()) // check parent's state hasn't changed + assertEquals(KSolverStatus.UNSAT, parent.check()) // check parent's state hasn't changed + + ffork.pop() + assertEquals(KSolverStatus.SAT, ffork.check()) + assertEquals(KSolverStatus.UNSAT, ffork.checkWithAssumptions(listOf(neg))) + } + } + + // check child's state is detached + val fork = parent.fork() + assertEquals(KSolverStatus.UNSAT, fork.check()) + parent.pop() + + assertEquals(KSolverStatus.SAT, parent.check()) + assertEquals(KSolverStatus.UNSAT, fork.check()) + + parent.pop() + + fork.pop() + fork.pop() + + fork.assert(neg) + assertEquals(KSolverStatus.SAT, fork.check()) + } + } + } + } + + @Suppress("LongMethod") + private fun testUninterpretedSort(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + createForkingSolverManager(ctx).use { man -> + man.createForkingSolver().use { parentSolver -> + with(ctx) { + val uSort = mkUninterpretedSort("u") + val u1 by uSort + val u2 by uSort + + val eq12 = u1 eq u2 + + parentSolver.push() + + parentSolver.fork().also { fork -> + assertDoesNotThrow { fork.pop() } // check assertion levels saved + fork.assert(u1 neq u2) + assertEquals(KSolverStatus.SAT, fork.check()) + } + + parentSolver.assert(eq12) + + require(parentSolver.check() == KSolverStatus.SAT) + val pu1v = parentSolver.model().eval(u1) + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(u1 eq pu1v) + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(pu1v, fork.model().eval(u1)) + } + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(u1 neq pu1v) + assertEquals(KSolverStatus.SAT, fork.check()) + assertNotEquals(pu1v, fork.model().eval(u1)) + } + + parentSolver.push().also { + val u5 by uSort + val pu5v = mkUninterpretedSortValue(uSort, 5) + parentSolver.assert(u5 eq pu5v) + + parentSolver.assert(u1 eq pu1v) + + parentSolver.check() + parentSolver.model().uninterpretedSortUniverse(uSort)?.also { universe -> + assertContains(universe, pu1v) + assertContains(universe, pu5v) + } + + parentSolver.pop() + } + + parentSolver.fork().also { fork -> + assertEquals(KSolverStatus.SAT, fork.check()) + fork.assert(u1 eq pu1v) + assertEquals(KSolverStatus.SAT, fork.check()) + assertEquals(pu1v, fork.model().eval(u1)) + + fork.fork().also { ff -> + assertEquals(KSolverStatus.SAT, ff.check()) + assertEquals(pu1v, ff.model().eval(u1)) + ff.model().uninterpretedSortUniverse(uSort)?.also { universe -> + assertContains(universe, pu1v) + } + } + } + + parentSolver.assert(u1 neq pu1v) + assertEquals(KSolverStatus.SAT, parentSolver.check()) + parentSolver.model().uninterpretedSortUniverse(uSort)?.also { universe -> + assertContains(universe, pu1v) + } + + } + } + } + } + + fun testLifeTime(createForkingSolverManager: (KContext) -> KForkingSolverManager<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + createForkingSolverManager(ctx).use { man -> + with(ctx) { + val parent = man.createForkingSolver() + val x by bv8Sort + val f = mkBvSignedGreaterExpr(x, mkBv(100, bv8Sort)) + + parent.assert(f) + parent.check().also { require(it == KSolverStatus.SAT) } + + val xVal = parent.model().eval(x) + + val fork = parent.fork().fork().fork() + parent.close() + + fork.assert(f and (x eq xVal)) + fork.check().also { assertEquals(KSolverStatus.SAT, it) } + assertEquals(xVal, fork.model().eval(x)) + } + } + } +} diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/MultiIndexedArrayTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/MultiIndexedArrayTest.kt index 923850844..2e5420616 100644 --- a/ksmt-test/src/test/kotlin/io/ksmt/test/MultiIndexedArrayTest.kt +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/MultiIndexedArrayTest.kt @@ -95,7 +95,7 @@ class MultiIndexedArrayTest { @Test fun testMultiIndexedArraysYicesWithZ3Oracle(): Unit = with(KContext(simplificationMode = NO_SIMPLIFY)) { oracleManager.createSolver(this, KZ3Solver::class).use { oracleSolver -> - KYicesContext().use { yicesNativeCtx -> + KYicesContext(this).use { yicesNativeCtx -> runMultiIndexedArraySamples(oracleSolver) { expr -> internalizeAndConvertYices(yicesNativeCtx, expr) } @@ -117,7 +117,7 @@ class MultiIndexedArrayTest { @Test fun testMultiIndexedArraysYicesWithYicesOracle(): Unit = with(KContext(simplificationMode = NO_SIMPLIFY)) { oracleManager.createSolver(this, KYicesSolver::class).use { oracleSolver -> - KYicesContext().use { yicesNativeCtx -> + KYicesContext(this).use { yicesNativeCtx -> runMultiIndexedArraySamples(oracleSolver) { expr -> internalizeAndConvertYices(yicesNativeCtx, expr) } diff --git a/ksmt-test/src/test/kotlin/io/ksmt/test/UninterpretedSortUniverseTest.kt b/ksmt-test/src/test/kotlin/io/ksmt/test/UninterpretedSortUniverseTest.kt new file mode 100644 index 000000000..71808ae4b --- /dev/null +++ b/ksmt-test/src/test/kotlin/io/ksmt/test/UninterpretedSortUniverseTest.kt @@ -0,0 +1,76 @@ +package io.ksmt.test + +import io.ksmt.KContext +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.solver.bitwuzla.KBitwuzlaSolver +import io.ksmt.solver.cvc5.KCvc5Solver +import io.ksmt.solver.yices.KYicesSolver +import io.ksmt.solver.z3.KZ3Solver +import io.ksmt.utils.getValue +import org.junit.jupiter.api.Nested +import org.junit.jupiter.api.Test +import kotlin.test.assertContains +import kotlin.test.assertEquals +import kotlin.test.assertNotEquals +import kotlin.test.assertNotNull + +class UninterpretedSortUniverseTest { + + @Nested + inner class UninterpretedSortUniverseTestBitwuzla { + @Test + fun testUniverseContainsValue() = testUniverseContainsValue(::mkBitwuzlaSolver) + + private fun mkBitwuzlaSolver(ctx: KContext) = KBitwuzlaSolver(ctx) + } + + @Nested + inner class UninterpretedSortUniverseTestCvc5 { + @Test + fun testUniverseContainsValue() = testUniverseContainsValue(::mkCvc5Solver) + + private fun mkCvc5Solver(ctx: KContext) = KCvc5Solver(ctx) + } + + @Nested + inner class UninterpretedSortUniverseTestYices { + @Test + fun testUniverseContainsValue() = testUniverseContainsValue(::mkYicesSolver) + + private fun mkYicesSolver(ctx: KContext) = KYicesSolver(ctx) + } + + @Nested + inner class UninterpretedSortUniverseTestZ3 { + @Test + fun testUniverseContainsValue() = testUniverseContainsValue(::mkZ3Solver) + + private fun mkZ3Solver(ctx: KContext) = KZ3Solver(ctx) + } + + fun testUniverseContainsValue(mkSolver: (KContext) -> KSolver<*>): Unit = + KContext(simplificationMode = KContext.SimplificationMode.NO_SIMPLIFY).use { ctx -> + mkSolver(ctx).use { s -> + with(ctx) { + val u = mkUninterpretedSort("u") + val u1 by u + val uval5 = mkUninterpretedSortValue(u, 5) + + s.assert(u1 neq uval5) + assertEquals(KSolverStatus.SAT, s.check()) + val u1v = s.model().eval(u1) + assertNotEquals(uval5, u1v) + + s.assert(u1 neq u1v) + assertEquals(KSolverStatus.SAT, s.check()) + + val universe = s.model().uninterpretedSortUniverse(u) + assertNotNull(universe) + + assertContains(universe, u1v) + assertContains(universe, uval5) + } + } + } +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesContext.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesContext.kt index 3fc8cda18..054bd9aca 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesContext.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesContext.kt @@ -3,13 +3,12 @@ package io.ksmt.solver.yices import com.sri.yices.Terms import com.sri.yices.Types import com.sri.yices.Yices -import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap -import it.unimi.dsi.fastutil.ints.IntOpenHashSet -import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap +import io.ksmt.KContext import io.ksmt.decl.KDecl import io.ksmt.expr.KConst import io.ksmt.expr.KExpr import io.ksmt.expr.KInterpretedValue +import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.solver.KSolverUnsupportedFeatureException import io.ksmt.solver.util.KExprIntInternalizerBase.Companion.NOT_INTERNALIZED import io.ksmt.solver.yices.TermUtils.addTerm @@ -19,37 +18,45 @@ import io.ksmt.solver.yices.TermUtils.funApplicationTerm import io.ksmt.solver.yices.TermUtils.mulTerm import io.ksmt.solver.yices.TermUtils.orTerm import io.ksmt.sort.KSort -import io.ksmt.utils.NativeLibraryLoader +import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap import java.math.BigInteger import java.util.concurrent.atomic.AtomicInteger -open class KYicesContext : AutoCloseable { - private var isClosed = false +open class KYicesContext(ctx: KContext) : AutoCloseable { + protected var isClosed = false - private val expressions = mkTermCache>() - private val yicesExpressions = mkTermReverseCache>() + protected open val expressions = mkTermCache>() + protected open val yicesExpressions = mkTermReverseCache>() - private val sorts = mkSortCache() - private val yicesSorts = mkSortReverseCache() + protected open val sorts = mkSortCache() + protected open val yicesSorts = mkSortReverseCache() - private val decls = mkTermCache>() - private val yicesDecls = mkTermReverseCache>() + protected open val decls = mkTermCache>() + protected open val yicesDecls = mkTermReverseCache>() - private val vars = mkTermCache>() - private val yicesVars = mkTermReverseCache>() + protected open val vars = mkTermCache>() + protected open val yicesVars = mkTermReverseCache>() - private val yicesTypes = mkSortSet() - private val yicesTerms = mkTermSet() + protected open val yicesTypes = mkSortSet() + protected open val yicesTerms = mkTermSet() val isActive: Boolean get() = !isClosed - fun findInternalizedExpr(expr: KExpr<*>): YicesTerm = expressions.getInt(expr) + fun findInternalizedExpr(expr: KExpr<*>): YicesTerm = expressions.getInt(expr).also { + if (it != NOT_INTERNALIZED) + uninterpretedSortValuesTracker.expressionUse(expr) + } + fun saveInternalizedExpr(expr: KExpr<*>, internalized: YicesTerm) { if (expressions.putIfAbsent(expr, internalized) == NOT_INTERNALIZED) { if (expr is KInterpretedValue<*> || expr is KConst<*>) { yicesExpressions.put(internalized, expr) } + uninterpretedSortValuesTracker.expressionSave(expr) } } @@ -175,9 +182,9 @@ open class KYicesContext : AutoCloseable { fun functionType(domain: YicesSortArray, range: YicesSort) = mkType { Types.functionType(domain, range) } fun newUninterpretedType(name: String) = mkType { Types.newUninterpretedType(name) } - val zero = mkTerm { Terms.intConst(0L) } - val one = mkTerm { Terms.intConst(1L) } - val minusOne = mkTerm { Terms.intConst(-1L) } + val zero by lazy { mkTerm { Terms.intConst(0L) } } + val one by lazy { mkTerm { Terms.intConst(1L) } } + val minusOne by lazy { mkTerm { Terms.intConst(-1L) } } private inline fun mkTerm(mk: () -> YicesTerm): YicesTerm = withGcGuard { val term = mk() @@ -294,7 +301,32 @@ open class KYicesContext : AutoCloseable { fun uninterpretedSortConst(sort: YicesSort, idx: Int) = mkTerm { Terms.mkConst(sort, idx) } - private var maxValueIndex = 0 + protected open var maxValueIndex = 0 + + /** + * Collects uninterpreted sort values usage for [KYicesModel.uninterpretedSortUniverse] + */ + protected open val uninterpretedSortValuesTracker = UninterpretedValuesTracker( + ctx, + ScopedArrayFrame(::HashSet), + ScopedArrayFrame(::HashMap), + Object2IntOpenHashMap>() + ) + + fun pushAssertionLevel() { + uninterpretedSortValuesTracker.push() + } + + fun popAssertionLevel(n: UInt) { + uninterpretedSortValuesTracker.pop(n) + } + + fun registerUninterpretedSortValue(value: KUninterpretedSortValue) { + uninterpretedSortValuesTracker.addToCurrentLevel(value) + } + + fun uninterpretedSortValues(sort: KUninterpretedSort) = + uninterpretedSortValuesTracker.getUninterpretedSortValues(sort) /** * Yices can produce different values with the same index. @@ -339,20 +371,6 @@ open class KYicesContext : AutoCloseable { } companion object { - init { - if (!Yices.isReady()) { - NativeLibraryLoader.load { os -> - when (os) { - NativeLibraryLoader.OS.LINUX -> listOf("libyices", "libyices2java") - NativeLibraryLoader.OS.WINDOWS -> listOf("libyices", "libyices2java") - NativeLibraryLoader.OS.MACOS -> listOf("libyices", "libyices2java") - } - } - Yices.init() - Yices.setReadyFlag(true) - } - } - private const val UNINTERPRETED_SORT_VALUE_SHIFT = 1 shl 30 private const val UNINTERPRETED_SORT_MAX_ALLOWED_VALUE = UNINTERPRETED_SORT_VALUE_SHIFT / 2 private const val UNINTERPRETED_SORT_MIN_ALLOWED_VALUE = -UNINTERPRETED_SORT_MAX_ALLOWED_VALUE @@ -420,7 +438,8 @@ open class KYicesContext : AutoCloseable { } } - private fun performGc() { + @JvmStatic + protected fun performGc() { // spin wait until [gcGuard] == [FREE] while (true) { if (gcGuard.compareAndSet(FREE, ON_GC)) { diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesExprInternalizer.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesExprInternalizer.kt index fd9dde704..4ef22f8be 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesExprInternalizer.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesExprInternalizer.kt @@ -960,7 +960,7 @@ open class KYicesExprInternalizer( yicesCtx.uninterpretedSortConst( sort.internalizeSort(), yicesCtx.uninterpretedSortValueIndex(valueIdx) - ) + ).also { yicesCtx.registerUninterpretedSortValue(expr) } } } diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt new file mode 100644 index 000000000..f3b65ba0e --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingContext.kt @@ -0,0 +1,47 @@ +package io.ksmt.solver.yices + +import io.ksmt.KContext + +/** + * Yices Context that allows forking and resources sharing via [KYicesForkingSolverManager]. + * To track resources, we have to use a unique solver specific ID, related to them. + * @param [solver] is used as "specific ID" for resource tracking, + * because we can't use initialized [com.sri.yices.Context] here. + */ +class KYicesForkingContext( + ctx: KContext, + manager: KYicesForkingSolverManager, + solver: KYicesForkingSolver +) : KYicesContext(ctx) { + override val expressions = manager.getExpressionsCache(solver) + override val yicesExpressions = manager.getExpressionsReversedCache(solver) + + override val sorts = manager.getSortsCache(solver) + override val yicesSorts = manager.getSortsReversedCache(solver) + + override val decls = manager.getDeclsCache(solver) + override val yicesDecls = manager.getDeclsReversedCache(solver) + + override val vars = manager.getVarsCache(solver) + override val yicesVars = manager.getVarsReversedCache(solver) + + override val yicesTypes = manager.getTypesCache(solver) + override val yicesTerms = manager.getTermsCache(solver) + + private val maxValueIndexAtomic = manager.getMaxUninterpretedSortValueIdx(solver) + + override var maxValueIndex: Int + get() = maxValueIndexAtomic.get() + set(value) { + maxValueIndexAtomic.set(value) + } + + override val uninterpretedSortValuesTracker = manager.createUninterpretedValuesTracker(solver) + + override fun close() { + if (isClosed) return + isClosed = true + + performGc() + } +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt new file mode 100644 index 000000000..fd94e4799 --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolver.kt @@ -0,0 +1,120 @@ +package io.ksmt.solver.yices + +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import kotlin.time.Duration + +class KYicesForkingSolver( + ctx: KContext, + private val manager: KYicesForkingSolverManager, + parent: KYicesForkingSolver?, +) : KForkingSolver, KYicesSolverBase(ctx) { + + override val yicesCtx: KYicesForkingContext by lazy { KYicesForkingContext(ctx, manager, this) } + + private val trackedAssertions = + ScopedLinkedFrame, YicesTerm>>>(::ArrayList, ::ArrayList) + private val yicesAssertions = ScopedLinkedFrame>(::HashSet, ::HashSet) + + override val currentScope: UInt + get() = trackedAssertions.currentScope + + private val ksmtConfig: KYicesForkingSolverConfigurationImpl by lazy { + parent?.ksmtConfig?.fork(config) ?: KYicesForkingSolverConfigurationImpl(config) + } + + private var assertionsInitiated = parent == null + + init { + if (parent != null) { + trackedAssertions.fork(parent.trackedAssertions) + yicesAssertions.fork(parent.yicesAssertions) + + ksmtConfig // force initialization + } + } + + private fun ensureAssertionsInitiated() { + if (assertionsInitiated) return + + yicesAssertions.stacked() + .zip(trackedAssertions.stacked()) + .asReversed() + .forEachIndexed { scope, (yicesAssertionFrame, _) -> + if (scope > 0) nativeContext.push() + + yicesAssertionFrame.forEach(nativeContext::assertFormula) + } + + assertionsInitiated = true + } + + override fun configure(configurator: KYicesSolverConfiguration.() -> Unit) { + requireActiveConfig() + ksmtConfig.configurator() + } + + /** + * Creates lazily initiated forked solver with shared cache, preserving parental assertions and configuration. + */ + override fun fork(): KForkingSolver = manager.createForkingSolver(this) + + override fun saveTrackedAssertion(track: YicesTerm, trackedExpr: KExpr) { + trackedAssertions.currentFrame += trackedExpr to track + } + + override fun collectTrackedAssertions(collector: (Pair, YicesTerm>) -> Unit) { + trackedAssertions.forEach { frame -> + frame.forEach(collector) + } + } + + override val hasTrackedAssertions: Boolean + get() = trackedAssertions.any { it.isNotEmpty() } + + override fun assert(expr: KExpr) = yicesTry { + yicesTry { ensureAssertionsInitiated() } + ctx.ensureContextMatch(expr) + + val yicesExpr = with(exprInternalizer) { expr.internalize() } + nativeContext.assertFormula(yicesExpr) + yicesAssertions.currentFrame += yicesExpr + } + + override fun assertAndTrack(expr: KExpr) { + yicesTry { ensureAssertionsInitiated() } + super.assertAndTrack(expr) + } + + override fun push() { + yicesTry { ensureAssertionsInitiated() } + super.push() + trackedAssertions.push() + yicesAssertions.push() + } + + override fun pop(n: UInt) { + yicesTry { ensureAssertionsInitiated() } + super.pop(n) + trackedAssertions.pop(n) + yicesAssertions.pop(n) + } + + override fun check(timeout: Duration): KSolverStatus { + yicesTry { ensureAssertionsInitiated() } + return super.check(timeout) + } + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { + yicesTry { ensureAssertionsInitiated() } + return super.checkWithAssumptions(assumptions, timeout) + } + + override fun close() { + super.close() + manager.close(this) + } +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt new file mode 100644 index 000000000..098cc6601 --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesForkingSolverManager.kt @@ -0,0 +1,160 @@ +package io.ksmt.solver.yices + +import com.sri.yices.Yices +import io.ksmt.KAst +import io.ksmt.KContext +import io.ksmt.decl.KDecl +import io.ksmt.expr.KExpr +import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KForkingSolverManager +import io.ksmt.solver.util.KExprIntInternalizerBase +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap +import java.util.IdentityHashMap +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger + +/** + * Responsible for creation and managing of [KYicesForkingSolver]. + * + * It's cheaper to create multiple copies of solvers with [KYicesForkingSolver.fork] + * instead of assertions transferring in [KYicesSolver] instances manually. + * + * All solvers created with one manager (via both [KYicesForkingSolver.fork] and [createForkingSolver]) + * use the same cache. + */ +class KYicesForkingSolverManager( + private val ctx: KContext +) : KForkingSolverManager { + + private val solvers = ConcurrentHashMap.newKeySet() + + private fun ensureSolverRegistered(s: KYicesForkingSolver) = check(s in solvers) { + "Solver is not registered by the manager." + } + + private val expressionsCache = ExpressionsCache().withNotInternalizedDefaultValue() + private val expressionsReversedCache = ExpressionsReversedCache() + private val sortsCache = SortsCache().withNotInternalizedDefaultValue() + private val sortsReversedCache = SortsReversedCache() + private val declsCache = DeclsCache().withNotInternalizedDefaultValue() + private val declsReversedCache = DeclsReversedCache() + private val varsCache = VarsCache().withNotInternalizedDefaultValue() + private val varsReversedCache = VarsReversedCache() + private val typesCache = TypesCache() + private val termsCache = TermsCache() + private val maxUninterpretedSortValueIndex = IdentityHashMap() + + private val scopedExpressions = IdentityHashMap() + private val scopedUninterpretedValues = IdentityHashMap() + private val expressionLevels = IdentityHashMap() + + internal fun getExpressionsCache(s: KYicesForkingSolver): ExpressionsCache = ensureSolverRegistered(s).let { + expressionsCache + } + internal fun getExpressionsReversedCache(s: KYicesForkingSolver) = ensureSolverRegistered(s).let { + expressionsReversedCache + } + internal fun getSortsCache(s: KYicesForkingSolver): SortsCache = ensureSolverRegistered(s).let { sortsCache } + internal fun getSortsReversedCache(s: KYicesForkingSolver): SortsReversedCache = ensureSolverRegistered(s).let { + sortsReversedCache + } + internal fun getDeclsCache(s: KYicesForkingSolver): DeclsCache = ensureSolverRegistered(s).let { declsCache } + internal fun getDeclsReversedCache(s: KYicesForkingSolver): DeclsReversedCache = ensureSolverRegistered(s).let { + declsReversedCache + } + internal fun getVarsCache(s: KYicesForkingSolver): VarsCache = ensureSolverRegistered(s).let { varsCache } + internal fun getVarsReversedCache(s: KYicesForkingSolver): VarsReversedCache = ensureSolverRegistered(s).let { + varsReversedCache + } + internal fun getTypesCache(s: KYicesForkingSolver): TypesCache = ensureSolverRegistered(s).let { typesCache } + internal fun getTermsCache(s: KYicesForkingSolver): TermsCache = ensureSolverRegistered(s).let { termsCache } + internal fun getMaxUninterpretedSortValueIdx(s: KYicesForkingSolver) = ensureSolverRegistered(s).let { + maxUninterpretedSortValueIndex.getValue(s) + } + + override fun createForkingSolver(): KForkingSolver = + KYicesForkingSolver(ctx, this, null).also { + solvers += it + maxUninterpretedSortValueIndex[it] = AtomicInteger(0) + scopedExpressions[it] = ScopedExpressions(::HashSet, ::HashSet) + scopedUninterpretedValues[it] = ScopedUninterpretedSortValues(::HashMap, ::HashMap) + expressionLevels[it] = ExpressionLevels() + } + + internal fun createForkingSolver(parent: KYicesForkingSolver) = KYicesForkingSolver(ctx, this, parent).also { + solvers += it + scopedExpressions[it] = ScopedExpressions(::HashSet, ::HashSet) + .apply { fork(scopedExpressions.getValue(parent)) } + scopedUninterpretedValues[it] = ScopedUninterpretedSortValues(::HashMap, ::HashMap) + .apply { fork(scopedUninterpretedValues.getValue(parent)) } + expressionLevels[it] = ExpressionLevels(expressionLevels.getValue(parent)) + + val parentMaxUninterpretedSortValueIdx = maxUninterpretedSortValueIndex.getValue(parent).get() + maxUninterpretedSortValueIndex[it] = AtomicInteger(parentMaxUninterpretedSortValueIdx) + } + + internal fun createUninterpretedValuesTracker(solver: KYicesForkingSolver) = UninterpretedValuesTracker( + ctx, + scopedExpressions.getValue(solver), + scopedUninterpretedValues.getValue(solver), + expressionLevels.getValue(solver) + ) + + /** + * Unregisters [solver] for this manager + */ + internal fun close(solver: KYicesForkingSolver) { + solvers -= solver + decRef(solver) + } + + override fun close() { + solvers.forEach(KYicesForkingSolver::close) + } + + private fun decRef(solver: KYicesForkingSolver) { + scopedExpressions -= solver + scopedUninterpretedValues -= solver + maxUninterpretedSortValueIndex -= solver + expressionLevels -= solver + + if (solvers.isEmpty()) { + expressionsCache.clear() + expressionsReversedCache.clear() + sortsCache.clear() + sortsReversedCache.clear() + declsCache.clear() + declsReversedCache.clear() + varsCache.clear() + varsReversedCache.clear() + typesCache.forEach(Yices::yicesDecrefType) + termsCache.forEach(Yices::yicesDecrefTerm) + typesCache.clear() + termsCache.clear() + } + } + + private fun Object2IntOpenHashMap.withNotInternalizedDefaultValue() = apply { + defaultReturnValue(KExprIntInternalizerBase.NOT_INTERNALIZED) + } +} + +private typealias ExpressionsCache = Object2IntOpenHashMap> +private typealias ExpressionsReversedCache = Int2ObjectOpenHashMap> +private typealias SortsCache = Object2IntOpenHashMap +private typealias SortsReversedCache = Int2ObjectOpenHashMap +private typealias DeclsCache = Object2IntOpenHashMap> +private typealias DeclsReversedCache = Int2ObjectOpenHashMap> +private typealias VarsCache = Object2IntOpenHashMap> +private typealias VarsReversedCache = Int2ObjectOpenHashMap> +private typealias TypesCache = IntOpenHashSet +private typealias TermsCache = IntOpenHashSet +private typealias ScopedExpressions = ScopedLinkedFrame>> +@Suppress("MaxLineLength") +private typealias ScopedUninterpretedSortValues = ScopedLinkedFrame>> +private typealias ExpressionLevels = Object2IntOpenHashMap> diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesModel.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesModel.kt index b643189f9..8e461e797 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesModel.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesModel.kt @@ -9,10 +9,10 @@ import io.ksmt.decl.KDecl import io.ksmt.decl.KFuncDecl import io.ksmt.expr.KExpr import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KModel import io.ksmt.solver.model.KFuncInterp import io.ksmt.solver.model.KFuncInterpEntryVarsFree import io.ksmt.solver.model.KFuncInterpVarsFree -import io.ksmt.solver.KModel import io.ksmt.solver.model.KModelEvaluator import io.ksmt.solver.model.KModelImpl import io.ksmt.sort.KArray2Sort @@ -62,14 +62,14 @@ class KYicesModel( private val interpretations = hashMapOf, KFuncInterp<*>>() private val funcInterpretationsToDo = arrayListOf>>() - override fun uninterpretedSortUniverse( - sort: KUninterpretedSort - ): Set? = uninterpretedSortUniverse.getOrPut(sort) { - val sortDependencies = uninterpretedSortDependencies[sort] ?: return null + override fun uninterpretedSortUniverse(sort: KUninterpretedSort) = uninterpretedSortUniverse.getOrPut(sort) { + val knownTrackedValues = yicesCtx.uninterpretedSortValues(sort) + + val sortDependencies = uninterpretedSortDependencies[sort] ?: return knownTrackedValues sortDependencies.forEach { interpretation(it) } - knownUninterpretedSortValues[sort]?.values?.toHashSet() ?: hashSetOf() + knownTrackedValues + (knownUninterpretedSortValues[sort]?.values?.toHashSet() ?: hashSetOf()) } private val evaluatorWithModelCompletion by lazy { KModelEvaluator(ctx, this, isComplete = true) } @@ -107,7 +107,7 @@ class KYicesModel( } } - private fun functionInterpretation(yval: YVal, decl: KFuncDecl): KFuncInterp { + private fun functionInterpretation(yval: YVal, decl: KFuncDecl): KFuncInterp { val functionChildren = model.expandFunction(yval) val default = if (yval.tag != YValTag.UNKNOWN) { getValue(functionChildren.value, decl.sort).uncheckedCast<_, KExpr>() @@ -163,7 +163,7 @@ class KYicesModel( declarations.forEach { interpretation(it) } val uninterpretedSortsUniverses = uninterpretedSorts.associateWith { - uninterpretedSortUniverse(it) ?: error("missed sort universe for $it") + uninterpretedSortUniverse(it) } return KModelImpl(ctx, interpretations.toMap(), uninterpretedSortsUniverses) diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolver.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolver.kt index 5d54a73ad..a1220a44b 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolver.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolver.kt @@ -1,234 +1,41 @@ package io.ksmt.solver.yices -import com.sri.yices.Config -import com.sri.yices.Context -import com.sri.yices.Status -import com.sri.yices.YicesException -import it.unimi.dsi.fastutil.ints.IntArrayList -import it.unimi.dsi.fastutil.ints.IntOpenHashSet import io.ksmt.KContext import io.ksmt.expr.KExpr -import io.ksmt.solver.KModel -import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverException -import io.ksmt.solver.KSolverStatus import io.ksmt.sort.KBoolSort -import java.util.Timer -import java.util.TimerTask -import kotlin.time.Duration -class KYicesSolver(private val ctx: KContext) : KSolver { - private val yicesCtx = KYicesContext() +class KYicesSolver(ctx: KContext) : KYicesSolverBase(ctx) { + override val yicesCtx = KYicesContext(ctx) - private val config = Config() - private val nativeContext by lazy { - Context(config).also { - config.close() - } - } - - private val exprInternalizer: KYicesExprInternalizer by lazy { - KYicesExprInternalizer(yicesCtx) - } - private val exprConverter: KYicesExprConverter by lazy { - KYicesExprConverter(ctx, yicesCtx) - } - - private var lastAssumptions: TrackedAssumptions? = null - private var lastCheckStatus = KSolverStatus.UNKNOWN - private var lastReasonOfUnknown: String? = null - - private var currentLevelTrackedAssertions = mutableListOf, YicesTerm>>() - private val trackedAssertions = mutableListOf(currentLevelTrackedAssertions) - - private val timer = Timer() - - override fun configure(configurator: KYicesSolverConfiguration.() -> Unit) { - require(config.isActive) { - "Solver instance has already been created" - } - - KYicesSolverConfigurationImpl(config).configurator() - } - - override fun assert(expr: KExpr) = yicesTry { - ctx.ensureContextMatch(expr) - - val yicesExpr = with(exprInternalizer) { expr.internalize() } - nativeContext.assertFormula(yicesExpr) - } - - override fun assertAndTrack(expr: KExpr) = yicesTry { - ctx.ensureContextMatch(expr) - - val trackVarExpr = ctx.mkFreshConst("track", ctx.boolSort) - val trackedExpr = with(ctx) { !trackVarExpr or expr } + private val trackedAssertions = ScopedArrayFrame, YicesTerm>>>(::ArrayList) + override val currentScope: UInt + get() = trackedAssertions.currentScope - assert(trackedExpr) + override val hasTrackedAssertions: Boolean + get() = trackedAssertions.any { it.isNotEmpty() } - val yicesTrackVar = with(exprInternalizer) { trackVarExpr.internalize() } - currentLevelTrackedAssertions += expr to yicesTrackVar + override fun saveTrackedAssertion(track: YicesTerm, trackedExpr: KExpr) { + trackedAssertions.currentFrame += trackedExpr to track } - override fun push(): Unit = yicesTry { - nativeContext.push() - - currentLevelTrackedAssertions = mutableListOf() - trackedAssertions.add(currentLevelTrackedAssertions) - } - - override fun pop(n: UInt) = yicesTry { - val currentScope = trackedAssertions.lastIndex.toUInt() - require(n <= currentScope) { - "Can not pop $n scope levels because current scope level is $currentScope" - } - - if (n == 0u) return - - repeat(n.toInt()) { - nativeContext.pop() - trackedAssertions.removeLast() - } - currentLevelTrackedAssertions = trackedAssertions.last() - } - - override fun check(timeout: Duration): KSolverStatus = yicesTryCheck { - if (trackedAssertions.any { it.isNotEmpty() }) { - return checkWithAssumptions(emptyList(), timeout) - } - - checkWithTimer(timeout) { - nativeContext.check() - }.processCheckResult() - } - - override fun checkWithAssumptions( - assumptions: List>, - timeout: Duration - ): KSolverStatus = yicesTryCheck { - ctx.ensureContextMatch(assumptions) - - val yicesAssumptions = TrackedAssumptions().also { lastAssumptions = it } - + override fun collectTrackedAssertions(collector: (Pair, YicesTerm>) -> Unit) { trackedAssertions.forEach { frame -> - frame.forEach { assertion -> - yicesAssumptions.assumeTrackedAssertion(assertion) - } - } - - with(exprInternalizer) { - assumptions.forEach { assumedExpr -> - yicesAssumptions.assumeAssumption(assumedExpr, assumedExpr.internalize()) - } - } - - checkWithTimer(timeout) { - nativeContext.checkWithAssumptions(yicesAssumptions.assumedTerms()) - }.processCheckResult() - } - - override fun model(): KModel = yicesTry { - require(lastCheckStatus == KSolverStatus.SAT) { - "Model are only available after SAT checks, current solver status: $lastCheckStatus" - } - val model = nativeContext.model - - return KYicesModel(model, ctx, yicesCtx, exprInternalizer, exprConverter) - } - - override fun unsatCore(): List> = yicesTry { - require(lastCheckStatus == KSolverStatus.UNSAT) { - "Unsat cores are only available after UNSAT checks" + frame.forEach(collector) } - - lastAssumptions?.resolveUnsatCore(nativeContext.unsatCore) ?: emptyList() } - override fun reasonOfUnknown(): String { - require(lastCheckStatus == KSolverStatus.UNKNOWN) { - "Unknown reason is only available after UNKNOWN checks" - } - - // There is no way to retrieve reason of unknown from Yices in general case. - return lastReasonOfUnknown ?: "unknown" - } - - override fun interrupt() = yicesTry { - nativeContext.stopSearch() - } - - private inline fun checkWithTimer(timeout: Duration, body: () -> T): T { - val task = StopSearchTask() - - if (timeout.isFinite()) { - timer.schedule(task, timeout.inWholeMilliseconds) - } - - return try { - body() - } finally { - task.cancel() - } - } - - private inline fun yicesTry(body: () -> T): T = try { - body() - } catch (ex: YicesException) { - throw KSolverException(ex) - } - - private inline fun yicesTryCheck(body: () -> KSolverStatus): KSolverStatus = try { - invalidateSolverState() - body() - } catch (ex: YicesException) { - lastReasonOfUnknown = ex.message - KSolverStatus.UNKNOWN.also { lastCheckStatus = it } - } - - private fun invalidateSolverState() { - lastCheckStatus = KSolverStatus.UNKNOWN - lastReasonOfUnknown = null - lastAssumptions = null - } - - private fun Status.processCheckResult() = when (this) { - Status.SAT -> KSolverStatus.SAT - Status.UNSAT -> KSolverStatus.UNSAT - else -> KSolverStatus.UNKNOWN - }.also { lastCheckStatus = it } - - override fun close() { - nativeContext.close() - yicesCtx.close() - timer.cancel() + override fun configure(configurator: KYicesSolverConfiguration.() -> Unit) { + requireActiveConfig() + KYicesSolverConfigurationImpl(config).configurator() } - private inner class StopSearchTask : TimerTask() { - override fun run() { - nativeContext.stopSearch() - } + override fun push() { + super.push() + trackedAssertions.push() } - private class TrackedAssumptions { - private val assumedExprs = arrayListOf, YicesTerm>>() - private val assumedTerms = IntArrayList() - - fun assumeTrackedAssertion(trackedAssertion: Pair, YicesTerm>) { - assumedExprs.add(trackedAssertion) - assumedTerms.add(trackedAssertion.second) - } - - fun assumeAssumption(expr: KExpr, term: YicesTerm) = - assumeTrackedAssertion(expr to term) - - fun assumedTerms(): YicesTermArray { - assumedTerms.trim() // Elements length now equal to size - return assumedTerms.elements() - } - - fun resolveUnsatCore(yicesUnsatCore: YicesTermArray): List> { - val unsatCoreTerms = IntOpenHashSet(yicesUnsatCore) - return assumedExprs.mapNotNull { (expr, term) -> expr.takeIf { unsatCoreTerms.contains(term) } } - } + override fun pop(n: UInt) { + super.pop(n) + trackedAssertions.pop(n) } } diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverBase.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverBase.kt new file mode 100644 index 000000000..d8ad8eb00 --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverBase.kt @@ -0,0 +1,232 @@ +package io.ksmt.solver.yices + +import com.sri.yices.Config +import com.sri.yices.Context +import com.sri.yices.Status +import com.sri.yices.Yices +import com.sri.yices.YicesException +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverException +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import io.ksmt.utils.NativeLibraryLoader +import it.unimi.dsi.fastutil.ints.IntArrayList +import it.unimi.dsi.fastutil.ints.IntOpenHashSet +import java.util.Timer +import java.util.TimerTask +import kotlin.time.Duration + +abstract class KYicesSolverBase(protected val ctx: KContext) : KSolver { + protected abstract val yicesCtx: KYicesContext + protected val nativeContext by lazy { Context(config).also { config.close() } } + protected val config by lazy { Config() } + + protected val exprInternalizer: KYicesExprInternalizer by lazy { KYicesExprInternalizer(yicesCtx) } + protected val exprConverter: KYicesExprConverter by lazy { KYicesExprConverter(ctx, yicesCtx) } + + private var lastAssumptions: TrackedAssumptions? = null + private var lastCheckStatus = KSolverStatus.UNKNOWN + private var lastReasonOfUnknown: String? = null + + protected abstract val currentScope: UInt + + protected abstract fun saveTrackedAssertion(track: YicesTerm, trackedExpr: KExpr) + protected abstract fun collectTrackedAssertions(collector: (Pair, YicesTerm>) -> Unit) + protected abstract val hasTrackedAssertions: Boolean + + private val timer = Timer() + + protected fun requireActiveConfig() = require(config.isActive) { + "Solver instance has already been created" + } + + override fun assert(expr: KExpr) = yicesTry { + ctx.ensureContextMatch(expr) + + val yicesExpr = with(exprInternalizer) { expr.internalize() } + nativeContext.assertFormula(yicesExpr) + } + + override fun assertAndTrack(expr: KExpr) = yicesTry { + ctx.ensureContextMatch(expr) + + val trackVarExpr = ctx.mkFreshConst("track", ctx.boolSort) + val trackedExpr = with(ctx) { !trackVarExpr or expr } + + assert(trackedExpr) + + val yicesTrackVar = with(exprInternalizer) { trackVarExpr.internalize() } + saveTrackedAssertion(yicesTrackVar, expr) + } + + override fun push(): Unit = yicesTry { + nativeContext.push() + yicesCtx.pushAssertionLevel() + } + + override fun pop(n: UInt) = yicesTry { + require(n <= currentScope) { + "Can not pop $n scope levels because current scope level is $currentScope" + } + + if (n == 0u) return + + repeat(n.toInt()) { + nativeContext.pop() + } + yicesCtx.popAssertionLevel(n) + } + + override fun check(timeout: Duration): KSolverStatus = if (hasTrackedAssertions) { + checkWithAssumptions(emptyList(), timeout) + } else yicesTryCheck { + checkWithTimer(timeout) { + nativeContext.check() + }.processCheckResult() + } + + override fun checkWithAssumptions( + assumptions: List>, + timeout: Duration + ): KSolverStatus = yicesTryCheck { + ctx.ensureContextMatch(assumptions) + + val yicesAssumptions = TrackedAssumptions().also { lastAssumptions = it } + + collectTrackedAssertions(yicesAssumptions::assumeTrackedAssertion) + + with(exprInternalizer) { + assumptions.forEach { assumedExpr -> + yicesAssumptions.assumeAssumption(assumedExpr, assumedExpr.internalize()) + } + } + + checkWithTimer(timeout) { + nativeContext.checkWithAssumptions(yicesAssumptions.assumedTerms()) + }.processCheckResult() + } + + override fun model(): KModel = yicesTry { + require(lastCheckStatus == KSolverStatus.SAT) { + "Model are only available after SAT checks, current solver status: $lastCheckStatus" + } + val model = nativeContext.model + + return KYicesModel(model, ctx, yicesCtx, exprInternalizer, exprConverter) + } + + override fun unsatCore(): List> = yicesTry { + require(lastCheckStatus == KSolverStatus.UNSAT) { + "Unsat cores are only available after UNSAT checks" + } + + lastAssumptions?.resolveUnsatCore(nativeContext.unsatCore) ?: emptyList() + } + + override fun reasonOfUnknown(): String { + require(lastCheckStatus == KSolverStatus.UNKNOWN) { + "Unknown reason is only available after UNKNOWN checks" + } + + // There is no way to retrieve reason of unknown from Yices in general case. + return lastReasonOfUnknown ?: "unknown" + } + + override fun interrupt() = yicesTry { + nativeContext.stopSearch() + } + + private inline fun checkWithTimer(timeout: Duration, body: () -> T): T { + val task = StopSearchTask() + + if (timeout.isFinite()) { + timer.schedule(task, timeout.inWholeMilliseconds) + } + + return try { + body() + } finally { + task.cancel() + } + } + + protected inline fun yicesTry(body: () -> T): T = try { + body() + } catch (ex: YicesException) { + throw KSolverException(ex) + } + + private inline fun yicesTryCheck(body: () -> KSolverStatus): KSolverStatus = try { + invalidateSolverState() + body() + } catch (ex: YicesException) { + lastReasonOfUnknown = ex.message + KSolverStatus.UNKNOWN.also { lastCheckStatus = it } + } + + private fun invalidateSolverState() { + lastCheckStatus = KSolverStatus.UNKNOWN + lastReasonOfUnknown = null + lastAssumptions = null + } + + private fun Status.processCheckResult() = when (this) { + Status.SAT -> KSolverStatus.SAT + Status.UNSAT -> KSolverStatus.UNSAT + else -> KSolverStatus.UNKNOWN + }.also { lastCheckStatus = it } + + override fun close() { + nativeContext.close() + yicesCtx.close() + timer.cancel() + } + + private inner class StopSearchTask : TimerTask() { + override fun run() { + nativeContext.stopSearch() + } + } + + private class TrackedAssumptions { + private val assumedExprs = arrayListOf, YicesTerm>>() + private val assumedTerms = IntArrayList() + + fun assumeTrackedAssertion(trackedAssertion: Pair, YicesTerm>) { + assumedExprs.add(trackedAssertion) + assumedTerms.add(trackedAssertion.second) + } + + fun assumeAssumption(expr: KExpr, term: YicesTerm) = + assumeTrackedAssertion(expr to term) + + fun assumedTerms(): YicesTermArray { + assumedTerms.trim() // Elements length now equal to size + return assumedTerms.elements() + } + + fun resolveUnsatCore(yicesUnsatCore: YicesTermArray): List> { + val unsatCoreTerms = IntOpenHashSet(yicesUnsatCore) + return assumedExprs.mapNotNull { (expr, term) -> expr.takeIf { unsatCoreTerms.contains(term) } } + } + } + + companion object { + init { + if (!Yices.isReady()) { + NativeLibraryLoader.load { os -> + when (os) { + NativeLibraryLoader.OS.LINUX -> listOf("libyices", "libyices2java") + NativeLibraryLoader.OS.WINDOWS -> listOf("libyices", "libyices2java") + NativeLibraryLoader.OS.MACOS -> listOf("libyices", "libyices2java") + } + } + Yices.init() + Yices.setReadyFlag(true) + } + } + } +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverConfiguration.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverConfiguration.kt index f97b72673..0d477b2a9 100644 --- a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverConfiguration.kt +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/KYicesSolverConfiguration.kt @@ -31,6 +31,19 @@ class KYicesSolverConfigurationImpl(private val config: Config) : KYicesSolverCo } } +class KYicesForkingSolverConfigurationImpl(private val config: Config) : KYicesSolverConfiguration { + private val options = hashMapOf() + + override fun setYicesOption(option: String, value: String) { + config.set(option, value) + options[option] = value + } + + fun fork(config: Config) = KYicesForkingSolverConfigurationImpl(config).also { + options.forEach { (option, value) -> it.setYicesOption(option, value) } + } +} + class KYicesSolverUniversalConfiguration( private val builder: KSolverUniversalConfigurationBuilder ) : KYicesSolverConfiguration { diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt new file mode 100644 index 000000000..93f757e6f --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/ScopedFrame.kt @@ -0,0 +1,149 @@ +package io.ksmt.solver.yices + +internal interface ScopedFrame { + val currentScope: UInt + val currentFrame: T + + fun getFrame(level: Int): T + + fun any(predicate: (T) -> Boolean): Boolean + + /** + * find value [V] in frame [T], and return it or null + */ + fun find(predicate: (T) -> V?): V? + fun forEach(action: (T) -> Unit) + + fun push() + fun pop(n: UInt = 1u) +} + +internal class ScopedArrayFrame( + currentFrame: T, + private inline val createNewFrame: () -> T +) : ScopedFrame { + constructor(createNewFrame: () -> T) : this(createNewFrame(), createNewFrame) + + private val frames = arrayListOf(currentFrame) + + override var currentFrame = currentFrame + private set + + override val currentScope: UInt + get() = frames.size.toUInt() + + override fun getFrame(level: Int): T = frames[level] + + override fun any(predicate: (T) -> Boolean): Boolean = frames.any(predicate) + + override fun find(predicate: (T) -> V?): V? { + frames.forEach { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun forEach(action: (T) -> Unit) = frames.forEach(action) + + override fun push() { + currentFrame = createNewFrame() + frames += currentFrame + } + + override fun pop(n: UInt) { + repeat(n.toInt()) { frames.removeLast() } + currentFrame = frames.last() + } +} + +internal class ScopedLinkedFrame private constructor( + private var current: LinkedFrame, + private inline val createNewFrame: () -> T, + private inline val copyFrame: (T) -> T +) : ScopedFrame { + constructor( + currentFrame: T, + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(LinkedFrame(currentFrame), createNewFrame, copyFrame) + + constructor( + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(createNewFrame(), createNewFrame, copyFrame) + + override val currentFrame: T + get() = current.value + + override val currentScope: UInt + get() = current.scope + + override fun getFrame(level: Int): T { + if (level > current.scope.toInt() || level < 0) throw IllegalArgumentException("Level $level is out of scope") + var cur: LinkedFrame? = current + + while (cur != null && level < cur.scope.toInt()) { + cur = cur.previous + } + return cur!!.value + } + + override fun any(predicate: (T) -> Boolean): Boolean { + forEachReversed { frame -> + if (predicate(frame)) return true + } + return false + } + + fun stacked(): ArrayDeque = ArrayDeque().also { stack -> + forEachReversed { frame -> + stack.addLast(frame) + } + } + + override fun find(predicate: (T) -> V?): V? { + forEachReversed { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun forEach(action: (T) -> Unit) = forEachReversed(action) + + override fun push() { + current = LinkedFrame(createNewFrame(), current) + } + + override fun pop(n: UInt) { + repeat(n.toInt()) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + } + recreateTopFrame() + } + + private fun recreateTopFrame() { + val newTopFrame = copyFrame(currentFrame) + current = LinkedFrame(newTopFrame, current.previous) + } + + fun fork(parent: ScopedLinkedFrame) { + current = parent.current + recreateTopFrame() + } + + private inline fun forEachReversed(action: (T) -> Unit) { + var cur: LinkedFrame? = current + while (cur != null) { + action(cur.value) + cur = cur.previous + } + } + + private class LinkedFrame( + val value: E, + val previous: LinkedFrame? = null + ) { + val scope: UInt = previous?.scope?.plus(1u) ?: 0u + } + +} diff --git a/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/UninterpretedValuesTracker.kt b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/UninterpretedValuesTracker.kt new file mode 100644 index 000000000..e8cc27a09 --- /dev/null +++ b/ksmt-yices/src/main/kotlin/io/ksmt/solver/yices/UninterpretedValuesTracker.kt @@ -0,0 +1,90 @@ +package io.ksmt.solver.yices + +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.expr.transformer.KNonRecursiveTransformer +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap + +class UninterpretedValuesTracker internal constructor( + private val ctx: KContext, + private val scopedExpressions: ScopedFrame>>, + private val uninterpretedValues: ScopedFrame>>, + private val expressionLevels: Object2IntOpenHashMap> +) { + private var analyzer: ExprUninterpretedValuesAnalyzer = createNewAnalyzer() + + private fun createNewAnalyzer() = ExprUninterpretedValuesAnalyzer( + ctx, + scopedExpressions, + uninterpretedValues, + expressionLevels + ) + + fun expressionUse(expr: KExpr<*>) { + if (expr in scopedExpressions.currentFrame) return + analyzer.apply(expr) + } + + fun expressionSave(expr: KExpr<*>) { + if (scopedExpressions.currentFrame.add(expr)) { + expressionLevels.put(expr, scopedExpressions.currentScope.toInt()) + } + } + + fun addToCurrentLevel(value: KUninterpretedSortValue) { + analyzer.addToCurrentLevel(value) + } + + fun getUninterpretedSortValues(sort: KUninterpretedSort) = hashSetOf().apply { + uninterpretedValues.forEach { frame -> + frame[sort]?.also { this += it } + } + } + + fun push() { + scopedExpressions.push() + uninterpretedValues.push() + } + + fun pop(n: UInt) { + scopedExpressions.pop(n) + uninterpretedValues.pop(n) + + analyzer = createNewAnalyzer() + } + + private class ExprUninterpretedValuesAnalyzer( + ctx: KContext, + val scopedExpressions: ScopedFrame>>, + val uninterpretedValues: ScopedFrame>>, + val expressionLevels: Object2IntOpenHashMap> + ) : KNonRecursiveTransformer(ctx) { + + fun addToCurrentLevel(value: KUninterpretedSortValue) { + uninterpretedValues.currentFrame.getOrPut(value.sort) { hashSetOf() } += value + } + + override fun transformExpr(expr: KExpr): KExpr { + if (scopedExpressions.currentFrame.add(expr)) + expressionLevels[expr] = scopedExpressions.currentScope.toInt() + return super.transformExpr(expr) + } + + override fun transform(expr: KUninterpretedSortValue): KExpr { + addToCurrentLevel(expr) + return super.transform(expr) + } + + override fun exprTransformationRequired(expr: KExpr): Boolean { + val frameLevel = expressionLevels.getInt(expr) + if (frameLevel < scopedExpressions.currentScope.toInt()) { + // If expr is valid on its level we don't need to move it + return expr !in scopedExpressions.getFrame(frameLevel) + } + return super.exprTransformationRequired(expr) + } + } +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesForkingTracker.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesForkingTracker.kt new file mode 100644 index 000000000..e67fae0b4 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesForkingTracker.kt @@ -0,0 +1,42 @@ +package io.ksmt.solver.z3 + +import io.ksmt.KContext +import io.ksmt.expr.KUninterpretedSortValue + +/** + * Uninterpreted sort values tracker with ability to fork. + * On child tracker creation ([fork]), it will have the same [registeredUninterpretedSortValues] as its parent + * to prevent descriptors loss. [expressionLevels] and [valueTrackerFrames] are copied + * to restore parental state of caching. + * Also, all axioms are asserted lazily on [assertPendingUninterpretedValueConstraints] + */ +class ExpressionUninterpretedValuesForkingTracker : ExpressionUninterpretedValuesTracker { + private constructor( + ctx: KContext, + z3Ctx: KZ3Context, + registeredUninterpretedSortValues: HashMap + ) : super(ctx, z3Ctx, registeredUninterpretedSortValues) + + constructor(ctx: KContext, z3Ctx: KZ3Context) : super(ctx, z3Ctx) + + fun fork(z3Ctx: KZ3Context) = ExpressionUninterpretedValuesForkingTracker( + ctx, z3Ctx, registeredUninterpretedSortValues + ).also { child -> + child.expressionLevels += expressionLevels + + var isFirstFrame = true + valueTrackerFrames.forEach { frame -> + if (!isFirstFrame) { + child.pushAssertionLevel() + } + + if (frame.initialized) { + child.currentFrame.ensureInitialized() + child.currentFrame.currentLevelUninterpretedValues += frame.currentLevelUninterpretedValues + child.currentFrame.currentLevelExpressions += frame.currentLevelExpressions + } + + isFirstFrame = false + } + } +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt index 8c8a74e2f..6915acd4e 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ExpressionUninterpretedValuesTracker.kt @@ -3,13 +3,13 @@ package io.ksmt.solver.z3 import com.microsoft.z3.Native import com.microsoft.z3.Solver import com.microsoft.z3.solverAssert -import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap import io.ksmt.KContext import io.ksmt.expr.KExpr import io.ksmt.expr.KUninterpretedSortValue import io.ksmt.expr.transformer.KNonRecursiveTransformer import io.ksmt.sort.KSort import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap /** * Uninterpreted sort values distinct constraints management. @@ -19,21 +19,25 @@ import io.ksmt.sort.KUninterpretedSort * 2. Assert distinct constraints ([assertPendingUninterpretedValueConstraints]) * that may be introduced during internalization. * */ -class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Context) { - private val expressionLevels = Object2IntOpenHashMap>().apply { +open class ExpressionUninterpretedValuesTracker protected constructor( + val ctx: KContext, + val z3Ctx: KZ3Context, + protected val registeredUninterpretedSortValues: HashMap +) { + constructor(ctx: KContext, z3Ctx: KZ3Context) : this(ctx, z3Ctx, hashMapOf()) + + protected val expressionLevels = Object2IntOpenHashMap>().apply { defaultReturnValue(Int.MAX_VALUE) // Level which is greater than any possible level } - private var currentFrame = ValueTrackerAssertionFrame( - ctx, this, expressionLevels, + protected var currentFrame = ValueTrackerAssertionFrame( + ctx, expressionLevels, level = 0, notAssertedConstraintsFromPreviousLevels = 0 ) + private set - private val valueTrackerFrames = arrayListOf(currentFrame) - - private val registeredUninterpretedSortValues = - hashMapOf() + protected val valueTrackerFrames = arrayListOf(currentFrame) /** * Skip any value tracking related actions until @@ -121,21 +125,22 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont z3Ctx.releaseTemporaryAst(constraintLhs) } - private data class UninterpretedSortValueDescriptor( + protected data class UninterpretedSortValueDescriptor( val value: KUninterpretedSortValue, val nativeUniqueValueDescriptor: Long, val nativeValueExpr: Long ) - private class ValueTrackerAssertionFrame( + protected inner class ValueTrackerAssertionFrame( val ctx: KContext, - val tracker: ExpressionUninterpretedValuesTracker, val expressionLevels: Object2IntOpenHashMap>, val level: Int, val notAssertedConstraintsFromPreviousLevels: Int ) { - private var initialized = false - private var lastAssertedConstraint = 0 + var initialized = false + private set + + var lastAssertedConstraint = 0 lateinit var currentLevelExpressions: MutableSet> lateinit var currentLevelUninterpretedValues: MutableList @@ -146,7 +151,7 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont * since we might not have any uninterpreted values on * a current assertion level. * */ - private fun ensureInitialized() { + fun ensureInitialized() { if (initialized) return currentLevelExpressions = hashSetOf() @@ -163,7 +168,7 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont val notAssertedConstraints = numberOfConstraints - lastAssertedConstraint val nextLevelRemainingConstraints = notAssertedConstraintsFromPreviousLevels + notAssertedConstraints return ValueTrackerAssertionFrame( - ctx, tracker, expressionLevels, + ctx, expressionLevels, level = level + 1, notAssertedConstraintsFromPreviousLevels = nextLevelRemainingConstraints ) @@ -215,7 +220,7 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont } fun addRegisteredValueToCurrentLevel(value: KUninterpretedSortValue) { - val descriptor = tracker.registeredUninterpretedSortValues[value] + val descriptor = registeredUninterpretedSortValues[value] ?: error("Value $value was not registered") addRegisteredValueToCurrentLevel(descriptor) } @@ -226,10 +231,10 @@ class ExpressionUninterpretedValuesTracker(val ctx: KContext, val z3Ctx: KZ3Cont currentLevelUninterpretedValues.add(descriptor) } - fun getFrame(level: Int) = tracker.valueTrackerFrames[level] + fun getFrame(level: Int) = valueTrackerFrames[level] } - private class ExprUninterpretedValuesAnalyzer( + protected class ExprUninterpretedValuesAnalyzer( ctx: KContext, val frame: ValueTrackerAssertionFrame ) : KNonRecursiveTransformer(ctx) { diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt index 834c28b7c..714f2b178 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Context.kt @@ -2,12 +2,14 @@ package io.ksmt.solver.z3 import com.microsoft.z3.Context import com.microsoft.z3.Solver +import com.microsoft.z3.Z3Exception import com.microsoft.z3.decRefUnsafe import com.microsoft.z3.incRefUnsafe import io.ksmt.KContext import io.ksmt.decl.KDecl import io.ksmt.expr.KExpr import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KSolverException import io.ksmt.solver.util.KExprLongInternalizerBase.Companion.NOT_INTERNALIZED import io.ksmt.sort.KSort import io.ksmt.sort.KUninterpretedSort @@ -17,33 +19,30 @@ import it.unimi.dsi.fastutil.longs.LongSet import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap @Suppress("TooManyFunctions") -class KZ3Context( +open class KZ3Context( ksmtCtx: KContext, - private val ctx: Context + private val ctx: Context, ) : AutoCloseable { constructor(ksmtCtx: KContext) : this(ksmtCtx, Context()) - private var isClosed = false + protected var isClosed = false - private val expressions = Object2LongOpenHashMap>().apply { - defaultReturnValue(NOT_INTERNALIZED) - } + protected open val expressions = Object2LongOpenHashMap>().apply { defaultReturnValue(NOT_INTERNALIZED) } + protected open val sorts = Object2LongOpenHashMap().apply { defaultReturnValue(NOT_INTERNALIZED) } + protected open val decls = Object2LongOpenHashMap>().apply { defaultReturnValue(NOT_INTERNALIZED) } - private val sorts = Object2LongOpenHashMap().apply { - defaultReturnValue(NOT_INTERNALIZED) - } + protected open val z3Expressions = Long2ObjectOpenHashMap>() + protected open val z3Sorts = Long2ObjectOpenHashMap() + protected open val z3Decls = Long2ObjectOpenHashMap>() + protected open val tmpNativeObjects = LongOpenHashSet() + protected open val converterNativeObjects = LongOpenHashSet() - private val decls = Object2LongOpenHashMap>().apply { - defaultReturnValue(NOT_INTERNALIZED) - } + protected open val uninterpretedSortValueInterpreter = hashMapOf() + protected open val uninterpretedSortValueDecls = Long2ObjectOpenHashMap() + protected open val uninterpretedSortValueInterpreters = LongOpenHashSet() - private val z3Expressions = Long2ObjectOpenHashMap>() - private val z3Sorts = Long2ObjectOpenHashMap() - private val z3Decls = Long2ObjectOpenHashMap>() - private val tmpNativeObjects = LongOpenHashSet() - private val converterNativeObjects = LongOpenHashSet() + open val uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this) - val uninterpretedValuesTracker = ExpressionUninterpretedValuesTracker(ksmtCtx, this) @JvmField val nCtx: Long = ctx.nCtx() @@ -54,17 +53,17 @@ class KZ3Context( val isActive: Boolean get() = !isClosed + internal fun findInternalizedExprWithoutAnalysis(expr: KExpr<*>): Long { + val result = expressions.getLong(expr) + return if (result == NOT_INTERNALIZED) NOT_INTERNALIZED else result + } + /** * Find internalized expr. * Returns [NOT_INTERNALIZED] if expression was not found. * */ - fun findInternalizedExpr(expr: KExpr<*>): Long { - val result = expressions.getLong(expr) - if (result == NOT_INTERNALIZED) return NOT_INTERNALIZED - - uninterpretedValuesTracker.expressionUse(expr) - - return result + fun findInternalizedExpr(expr: KExpr<*>): Long = findInternalizedExprWithoutAnalysis(expr).also { + if (it != NOT_INTERNALIZED) uninterpretedValuesTracker.expressionUse(expr) } fun saveInternalizedExpr(expr: KExpr<*>, internalized: Long) { @@ -148,11 +147,6 @@ class KZ3Context( return ast } - private val uninterpretedSortValueInterpreter = hashMapOf() - - private val uninterpretedSortValueDecls = Long2ObjectOpenHashMap() - private val uninterpretedSortValueInterpreters = LongOpenHashSet() - fun saveUninterpretedSortValueDecl(decl: Long, value: KUninterpretedSortValue): Long { if (uninterpretedSortValueDecls.putIfAbsent(decl, value) == null) { incRefUnsafe(nCtx, decl) @@ -262,7 +256,6 @@ class KZ3Context( override fun close() { if (isClosed) return - isClosed = true uninterpretedSortValueInterpreter.clear() @@ -290,11 +283,24 @@ class KZ3Context( sorts.clear() z3Sorts.clear() - ctx.close() + z3Try { + isClosed = true + ctx.close() + } } - private fun LongSet.decRefAll() = - longIterator().forEachRemaining { - decRefUnsafe(nCtx, it) - } + private fun LongSet.decRefAll() = longIterator().forEachRemaining { + decRefUnsafe(nCtx, it) + } + + fun ensureActive() { + check(!isClosed) { "The context is already closed." } + } + + inline fun z3Try(body: () -> T): T = try { + ensureActive() + body() + } catch (ex: Z3Exception) { + throw KSolverException(ex) + } } diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingContext.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingContext.kt new file mode 100644 index 000000000..cabd6bac9 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingContext.kt @@ -0,0 +1,49 @@ +package io.ksmt.solver.z3 + +import com.microsoft.z3.Context +import io.ksmt.KContext + +class KZ3ForkingContext private constructor( + ksmtCtx: KContext, + private val ctx: Context, + manager: KZ3ForkingSolverManager, + parent: KZ3ForkingContext?, +) : KZ3Context(ksmtCtx, ctx) { + + constructor( + ksmtCtx: KContext, + ctx: Context, + manager: KZ3ForkingSolverManager + ) : this(ksmtCtx, ctx, manager, null) + + constructor(ksmtCtx: KContext, manager: KZ3ForkingSolverManager) : this(ksmtCtx, Context(), manager) + + // common for parent and child structures + override val expressions = with(manager) { getExpressionsCache() } + override val sorts = with(manager) { getSortsCache() } + override val decls = with(manager) { getDeclsCache() } + + override val z3Expressions = with(manager) { getExpressionsReversedCache() } + override val z3Sorts = with(manager) { getSortsReversedCache() } + override val z3Decls = with(manager) { getDeclsReversedCache() } + override val tmpNativeObjects = with(manager) { getTmpNativeObjectsCache() } + override val converterNativeObjects = with(manager) { getConverterNativeObjectsCache() } + + override val uninterpretedSortValueInterpreter = with(manager) { getUninterpretedSortValueInterpreter() } + override val uninterpretedSortValueDecls = with(manager) { getUninterpretedSortValueDecls() } + override val uninterpretedSortValueInterpreters = with(manager) { getUninterpretedSortValueInterpreters() } + + override val uninterpretedValuesTracker: ExpressionUninterpretedValuesForkingTracker = parent + ?.uninterpretedValuesTracker?.fork(this) + ?: ExpressionUninterpretedValuesForkingTracker(ksmtCtx, this) + + internal fun fork(ksmtCtx: KContext, manager: KZ3ForkingSolverManager): KZ3ForkingContext { + ensureActive() + return KZ3ForkingContext(ksmtCtx, ctx, manager, this) + } + + override fun close() { + if (isClosed) return + isClosed = true + } +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt new file mode 100644 index 000000000..82592bb72 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolver.kt @@ -0,0 +1,133 @@ +package io.ksmt.solver.z3 + +import com.microsoft.z3.solverAssert +import com.microsoft.z3.solverAssertAndTrack +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap +import it.unimi.dsi.fastutil.longs.LongOpenHashSet +import kotlin.time.Duration + +open class KZ3ForkingSolver internal constructor( + ctx: KContext, + private val manager: KZ3ForkingSolverManager, + parent: KZ3ForkingSolver? +) : KZ3SolverBase(ctx), KForkingSolver { + final override val z3Ctx: KZ3ForkingContext = manager.createZ3ForkingContext(parent?.z3Ctx) + + private val trackedAssertions = ScopedLinkedFrame>>( + ::Long2ObjectOpenHashMap, ::Long2ObjectOpenHashMap + ) + private val z3Assertions = ScopedLinkedFrame(::LongOpenHashSet, ::LongOpenHashSet) + + private val isChild = parent != null + private var assertionsInitiated = !isChild + + private val config: KZ3ForkingSolverConfigurationImpl by lazy { + z3Ctx.z3Try { + z3Ctx.nativeContext.mkParams().let { + parent?.config?.fork(it)?.apply { setParameters(solver) } ?: KZ3ForkingSolverConfigurationImpl(it) + } + } + } + + init { + if (parent != null) { + // lightweight copying via copying of the linked list node + trackedAssertions.fork(parent.trackedAssertions) + z3Assertions.fork(parent.z3Assertions) + config // initialize child config + } + } + + override fun configure(configurator: KZ3SolverConfiguration.() -> Unit) { + config.configurator() + config.setParameters(solver) + } + + /** + * Creates lazily initiated forked solver with shared cache, preserving parental assertions and configuration. + */ + override fun fork(): KForkingSolver = manager.createForkingSolver(this) + + override fun saveTrackedAssertion(track: Long, trackedExpr: KExpr) { + trackedAssertions.currentFrame[track] = trackedExpr + } + + override fun findTrackedExprByTrack(track: Long): KExpr? = trackedAssertions.findNonNullValue { + it[track] + } + + /** + * Asserts parental (in case of child) assertions if already not + */ + private fun ensureAssertionsInitiated() { + if (assertionsInitiated) return + + z3Assertions.stacked() + .zip(trackedAssertions.stacked()) + .asReversed() + .forEachIndexed { scope, (z3AssertionFrame, trackedFrame) -> + if (scope > 0) { + solver.push() + currentScope++ + } + + z3AssertionFrame.forEach(solver::solverAssert) + trackedFrame.forEach { (track, expr) -> + /** tracked [expr] was previously internalized by parent */ + solver.solverAssertAndTrack(track, z3Ctx.findInternalizedExprWithoutAnalysis(expr)) + } + } + + assertionsInitiated = true + } + + override fun push() { + z3Ctx.z3Try { ensureAssertionsInitiated() } + super.push() + trackedAssertions.push() + z3Assertions.push() + } + + override fun pop(n: UInt) { + z3Ctx.z3Try { ensureAssertionsInitiated() } + super.pop(n) + trackedAssertions.pop(n) + z3Assertions.pop(n) + } + + override fun assert(expr: KExpr) = z3Ctx.z3Try { + ensureAssertionsInitiated() + ctx.ensureContextMatch(expr) + + val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } + solver.solverAssert(z3Expr) + + z3Ctx.assertPendingAxioms(solver) + z3Assertions.currentFrame += z3Expr + } + + override fun assertAndTrack(expr: KExpr) { + z3Ctx.z3Try { ensureAssertionsInitiated() } + super.assertAndTrack(expr) + } + + override fun check(timeout: Duration): KSolverStatus { + z3Ctx.z3Try { ensureAssertionsInitiated() } + return super.check(timeout) + } + + override fun checkWithAssumptions(assumptions: List>, timeout: Duration): KSolverStatus { + z3Ctx.z3Try { ensureAssertionsInitiated() } + return super.checkWithAssumptions(assumptions, timeout) + } + + override fun close() { + super.close() + manager.close(this) + } +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt new file mode 100644 index 000000000..a3c2a6406 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3ForkingSolverManager.kt @@ -0,0 +1,158 @@ +package io.ksmt.solver.z3 + +import com.microsoft.z3.Context +import com.microsoft.z3.Z3Exception +import com.microsoft.z3.decRefUnsafe +import io.ksmt.KAst +import io.ksmt.KContext +import io.ksmt.decl.KDecl +import io.ksmt.expr.KExpr +import io.ksmt.expr.KUninterpretedSortValue +import io.ksmt.solver.KForkingSolver +import io.ksmt.solver.KForkingSolverManager +import io.ksmt.solver.KSolverException +import io.ksmt.solver.util.KExprLongInternalizerBase +import io.ksmt.sort.KSort +import io.ksmt.sort.KUninterpretedSort +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap +import it.unimi.dsi.fastutil.longs.LongOpenHashSet +import it.unimi.dsi.fastutil.longs.LongSet +import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap +import java.util.concurrent.ConcurrentHashMap + +/** + * Responsible for creation and managing of [KZ3ForkingSolver]. + * + * It's cheaper to create multiple copies of solvers with [KZ3ForkingSolver.fork] + * instead of assertions transferring in [KZ3Solver] instances. + * + * All created solvers with one manager (via both [KZ3ForkingSolver.fork] and [createForkingSolver]) + * use the same [Context], cache, and registered uninterpreted sort values. + */ +class KZ3ForkingSolverManager(private val ctx: KContext) : KForkingSolverManager { + private val z3Context by lazy { Context() } + private val solvers = ConcurrentHashMap.newKeySet() + + // shared cache + private val expressionsCache = ExpressionsCache().withNotInternalizedAsDefaultValue() + private val expressionsReversedCache = ExpressionsReversedCache() + private val sortsCache = SortsCache().withNotInternalizedAsDefaultValue() + private val sortsReversedCache = SortsReversedCache() + private val declsCache = DeclsCache().withNotInternalizedAsDefaultValue() + private val declsReversedCache = DeclsReversedCache() + + private val tmpNativeObjectsCache = TmpNativeObjectsCache() + private val converterNativeObjectsCache = ConverterNativeObjectsCache() + + private val uninterpretedSortValueInterpreter = UninterpretedSortValueInterpreterCache() + private val uninterpretedSortValueDecls = UninterpretedSortValueDecls() + private val uninterpretedSortValueInterpreters = UninterpretedSortValueInterpretersCache() + + internal fun KZ3Context.getExpressionsCache() = ensureContextMatches(nativeContext).let { expressionsCache } + internal fun KZ3Context.getExpressionsReversedCache() = ensureContextMatches(nativeContext) + .let { expressionsReversedCache } + + internal fun KZ3Context.getSortsCache() = ensureContextMatches(nativeContext).let { sortsCache } + internal fun KZ3Context.getSortsReversedCache() = ensureContextMatches(nativeContext).let { sortsReversedCache } + internal fun KZ3Context.getDeclsCache() = ensureContextMatches(nativeContext).let { declsCache } + internal fun KZ3Context.getDeclsReversedCache() = ensureContextMatches(nativeContext).let { declsReversedCache } + internal fun KZ3Context.getTmpNativeObjectsCache() = ensureContextMatches(nativeContext) + .let { tmpNativeObjectsCache } + + internal fun KZ3Context.getConverterNativeObjectsCache() = ensureContextMatches(nativeContext) + .let { converterNativeObjectsCache } + + internal fun KZ3Context.getUninterpretedSortValueInterpreter() = ensureContextMatches(nativeContext) + .let { uninterpretedSortValueInterpreter } + + internal fun KZ3Context.getUninterpretedSortValueDecls() = ensureContextMatches(nativeContext) + .let { uninterpretedSortValueDecls } + + internal fun KZ3Context.getUninterpretedSortValueInterpreters() = ensureContextMatches(nativeContext) + .let { uninterpretedSortValueInterpreters } + + override fun createForkingSolver(): KForkingSolver { + return KZ3ForkingSolver(ctx, this, null).also { solvers += it } + } + + internal fun createForkingSolver(parent: KZ3ForkingSolver): KForkingSolver { + return KZ3ForkingSolver(ctx, this, parent).also { solvers += it } + } + + internal fun createZ3ForkingContext(parentCtx: KZ3ForkingContext? = null) = parentCtx?.fork(ctx, this) + ?: KZ3ForkingContext(ctx, z3Context, this) + + /** + * unregister [solver] for this manager + */ + internal fun close(solver: KZ3ForkingSolver) { + solvers -= solver + closeContextIfStale() + } + + override fun close() { + solvers.forEach(KZ3ForkingSolver::close) + } + + private fun closeContextIfStale() { + if (solvers.isEmpty()) { + val nCtx = z3Context.nCtx() + + expressionsReversedCache.keys.decRefAll(nCtx) + expressionsReversedCache.clear() + expressionsCache.clear() + + sortsReversedCache.keys.decRefAll(nCtx) + sortsReversedCache.clear() + sortsCache.clear() + + declsReversedCache.keys.decRefAll(nCtx) + declsReversedCache.clear() + declsCache.clear() + + uninterpretedSortValueInterpreters.decRefAll(nCtx) + uninterpretedSortValueInterpreters.clear() + uninterpretedSortValueInterpreter.clear() + uninterpretedSortValueDecls.clear() + + converterNativeObjectsCache.decRefAll(nCtx) + converterNativeObjectsCache.clear() + tmpNativeObjectsCache.decRefAll(nCtx) + tmpNativeObjectsCache.clear() + + try { + ctx.close() + } catch (e: Z3Exception) { + throw KSolverException(e) + } + } + } + + private fun Object2LongOpenHashMap.withNotInternalizedAsDefaultValue() = apply { + defaultReturnValue(KExprLongInternalizerBase.NOT_INTERNALIZED) + } + + private fun LongSet.decRefAll(nCtx: Long) = longIterator().forEachRemaining { + decRefUnsafe(nCtx, it) + } + + private fun ensureContextMatches(ctx: Context) { + require(ctx == z3Context) { "Context is not registered by manager." } + } +} + +private typealias ExpressionsCache = Object2LongOpenHashMap> +private typealias ExpressionsReversedCache = Long2ObjectOpenHashMap> + +private typealias SortsCache = Object2LongOpenHashMap +private typealias SortsReversedCache = Long2ObjectOpenHashMap + +private typealias DeclsCache = Object2LongOpenHashMap> +private typealias DeclsReversedCache = Long2ObjectOpenHashMap> + +private typealias TmpNativeObjectsCache = LongOpenHashSet +private typealias ConverterNativeObjectsCache = LongOpenHashSet + +private typealias UninterpretedSortValueInterpreterCache = HashMap +private typealias UninterpretedSortValueDecls = Long2ObjectOpenHashMap +private typealias UninterpretedSortValueInterpretersCache = LongOpenHashSet diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt index f8bdfa43f..f76330cf8 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3Solver.kt @@ -1,252 +1,29 @@ package io.ksmt.solver.z3 -import com.microsoft.z3.Solver -import com.microsoft.z3.Status -import com.microsoft.z3.Z3Exception -import com.microsoft.z3.solverAssert -import com.microsoft.z3.solverAssertAndTrack -import com.microsoft.z3.solverCheckAssumptions -import com.microsoft.z3.solverGetUnsatCore -import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap import io.ksmt.KContext import io.ksmt.expr.KExpr -import io.ksmt.solver.KModel import io.ksmt.solver.KSolver -import io.ksmt.solver.KSolverException -import io.ksmt.solver.KSolverStatus import io.ksmt.sort.KBoolSort -import io.ksmt.utils.NativeLibraryLoader -import java.lang.ref.PhantomReference -import java.lang.ref.ReferenceQueue -import java.util.IdentityHashMap -import kotlin.time.Duration -import kotlin.time.DurationUnit -open class KZ3Solver(private val ctx: KContext) : KSolver { - private val z3Ctx = KZ3Context(ctx) - private val solver = createSolver() +open class KZ3Solver(ctx: KContext) : KZ3SolverBase(ctx), KSolver { + override val z3Ctx: KZ3Context = KZ3Context(ctx) + private val trackedAssertions = ScopedArrayFrameOfLong2ObjectOpenHashMap>() - private var lastCheckStatus = KSolverStatus.UNKNOWN - private var lastReasonOfUnknown: String? = null - private var lastModel: KZ3Model? = null - private var lastUnsatCore: List>? = null - - private var currentScope: UInt = 0u - - @Suppress("LeakingThis") - private val contextCleanupActionHandler = registerContextForCleanup(this, z3Ctx) - - private val exprInternalizer by lazy { - createExprInternalizer(z3Ctx) - } - private val exprConverter by lazy { - createExprConverter(z3Ctx) + override fun saveTrackedAssertion(track: Long, trackedExpr: KExpr) { + trackedAssertions.currentFrame[track] = trackedExpr } - open fun createExprInternalizer(z3Ctx: KZ3Context): KZ3ExprInternalizer = KZ3ExprInternalizer(ctx, z3Ctx) - - open fun createExprConverter(z3Ctx: KZ3Context) = KZ3ExprConverter(ctx, z3Ctx) - - private fun createSolver(): Solver = z3Ctx.nativeContext.mkSolver() - - override fun configure(configurator: KZ3SolverConfiguration.() -> Unit) { - val params = z3Ctx.nativeContext.mkParams() - KZ3SolverConfigurationImpl(params).configurator() - solver.setParameters(params) + override fun findTrackedExprByTrack(track: Long): KExpr? = trackedAssertions.findNonNullValue { + it[track] } override fun push() { - solver.push() - z3Ctx.pushAssertionLevel() - currentScope++ + super.push() + trackedAssertions.push() } override fun pop(n: UInt) { - require(n <= currentScope) { - "Can not pop $n scope levels because current scope level is $currentScope" - } - if (n == 0u) return - - solver.pop(n.toInt()) - repeat(n.toInt()) { z3Ctx.popAssertionLevel() } - - currentScope -= n - } - - override fun assert(expr: KExpr) = z3Try { - ctx.ensureContextMatch(expr) - - val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } - solver.solverAssert(z3Expr) - - z3Ctx.assertPendingAxioms(solver) - } - - private val trackedAssertions = Long2ObjectOpenHashMap>() - - override fun assertAndTrack(expr: KExpr) = z3Try { - ctx.ensureContextMatch(expr) - - val trackExpr = ctx.mkFreshConst("track", ctx.boolSort) - val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } - val z3TrackVar = with(exprInternalizer) { trackExpr.internalizeExpr() } - - trackedAssertions.put(z3TrackVar, expr) - - solver.solverAssertAndTrack(z3Expr, z3TrackVar) - } - - override fun check(timeout: Duration): KSolverStatus = z3TryCheck { - solver.updateTimeout(timeout) - solver.check().processCheckResult() - } - - override fun checkWithAssumptions( - assumptions: List>, - timeout: Duration - ): KSolverStatus = z3TryCheck { - ctx.ensureContextMatch(assumptions) - - val z3Assumptions = with(exprInternalizer) { - LongArray(assumptions.size) { - val assumption = assumptions[it] - - /** - * Assumptions are trivially unsat and no check-sat is required. - * */ - if (assumption == ctx.falseExpr) { - lastUnsatCore = listOf(ctx.falseExpr) - lastCheckStatus = KSolverStatus.UNSAT - return KSolverStatus.UNSAT - } - - assumption.internalizeExpr() - } - } - - solver.updateTimeout(timeout) - - solver.solverCheckAssumptions(z3Assumptions).processCheckResult() - } - - override fun model(): KModel = z3Try { - require(lastCheckStatus == KSolverStatus.SAT) { - "Model are only available after SAT checks, current solver status: $lastCheckStatus" - } - - val model = lastModel ?: KZ3Model( - model = solver.model, - ctx = ctx, - z3Ctx = z3Ctx, - internalizer = exprInternalizer - ) - lastModel = model - - model - } - - override fun unsatCore(): List> = z3Try { - require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } - - val unsatCore = lastUnsatCore ?: with(exprConverter) { - val solverUnsatCore = solver.solverGetUnsatCore() - solverUnsatCore.map { trackedAssertions.get(it) ?: it.convertExpr() } - } - lastUnsatCore = unsatCore - - unsatCore - } - - override fun reasonOfUnknown(): String = z3Try { - require(lastCheckStatus == KSolverStatus.UNKNOWN) { "Unknown reason is only available after UNKNOWN checks" } - lastReasonOfUnknown ?: solver.reasonUnknown - } - - override fun interrupt() = z3Try { - solver.interrupt() - } - - override fun close() { - unregisterContextCleanup(contextCleanupActionHandler) - z3Ctx.close() - } - - private fun Status?.processCheckResult() = when (this) { - Status.SATISFIABLE -> KSolverStatus.SAT - Status.UNSATISFIABLE -> KSolverStatus.UNSAT - Status.UNKNOWN -> KSolverStatus.UNKNOWN - null -> KSolverStatus.UNKNOWN - }.also { lastCheckStatus = it } - - private fun Solver.updateTimeout(timeout: Duration) { - val z3Timeout = if (timeout == Duration.INFINITE) { - UInt.MAX_VALUE.toInt() - } else { - timeout.toInt(DurationUnit.MILLISECONDS) - } - val params = z3Ctx.nativeContext.mkParams().apply { - add("timeout", z3Timeout) - } - setParameters(params) - } - - private inline fun z3Try(body: () -> T): T = try { - body() - } catch (ex: Z3Exception) { - throw KSolverException(ex) - } - - private fun invalidateSolverState() { - lastReasonOfUnknown = null - lastCheckStatus = KSolverStatus.UNKNOWN - lastModel = null - lastUnsatCore = null - } - - private inline fun z3TryCheck(body: () -> KSolverStatus): KSolverStatus = try { - invalidateSolverState() - body() - } catch (ex: Z3Exception) { - lastReasonOfUnknown = ex.message - KSolverStatus.UNKNOWN.also { lastCheckStatus = it } - } - - companion object { - init { - System.setProperty("z3.skipLibraryLoad", "true") - NativeLibraryLoader.load { os -> - when (os) { - NativeLibraryLoader.OS.LINUX -> listOf("libz3", "libz3java") - NativeLibraryLoader.OS.MACOS -> listOf("libz3", "libz3java") - NativeLibraryLoader.OS.WINDOWS -> listOf("vcruntime140", "vcruntime140_1", "libz3", "libz3java") - } - } - } - - private val cleanupHandlers = ReferenceQueue() - private val contextForCleanup = IdentityHashMap, KZ3Context>() - - /** Ensure Z3 native context is closed and all native memory is released. - * */ - private fun registerContextForCleanup(solver: KZ3Solver, context: KZ3Context): PhantomReference { - cleanupStaleContexts() - val cleanupHandler = PhantomReference(solver, cleanupHandlers) - contextForCleanup[cleanupHandler] = context - - return cleanupHandler - } - - private fun unregisterContextCleanup(handler: PhantomReference) { - contextForCleanup.remove(handler) - handler.clear() - cleanupStaleContexts() - } - - private fun cleanupStaleContexts() { - while (true) { - val handler = cleanupHandlers.poll() ?: break - contextForCleanup.remove(handler)?.close() - } - } + super.pop(n) + trackedAssertions.pop(n) } } diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt new file mode 100644 index 000000000..c95d0f636 --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverBase.kt @@ -0,0 +1,249 @@ +package io.ksmt.solver.z3 + +import com.microsoft.z3.Solver +import com.microsoft.z3.Status +import com.microsoft.z3.Z3Exception +import com.microsoft.z3.solverAssert +import com.microsoft.z3.solverAssertAndTrack +import com.microsoft.z3.solverCheckAssumptions +import com.microsoft.z3.solverGetUnsatCore +import io.ksmt.KContext +import io.ksmt.expr.KExpr +import io.ksmt.solver.KModel +import io.ksmt.solver.KSolver +import io.ksmt.solver.KSolverStatus +import io.ksmt.sort.KBoolSort +import io.ksmt.utils.NativeLibraryLoader +import java.lang.ref.PhantomReference +import java.lang.ref.ReferenceQueue +import java.util.IdentityHashMap +import kotlin.time.Duration +import kotlin.time.DurationUnit + +abstract class KZ3SolverBase(protected val ctx: KContext) : KSolver { + protected abstract val z3Ctx: KZ3Context + protected val solver by lazy { createSolver() } + + protected var lastCheckStatus = KSolverStatus.UNKNOWN + protected var lastReasonOfUnknown: String? = null + protected var lastModel: KZ3Model? = null + protected var lastUnsatCore: List>? = null + + protected open var currentScope: UInt = 0u + + @Suppress("LeakingThis") + private val contextCleanupActionHandler = registerContextForCleanup(this, z3Ctx) + + protected val exprInternalizer by lazy { + createExprInternalizer(z3Ctx) + } + protected val exprConverter by lazy { + createExprConverter(z3Ctx) + } + + open fun createExprInternalizer(z3Ctx: KZ3Context): KZ3ExprInternalizer = KZ3ExprInternalizer(ctx, z3Ctx) + + open fun createExprConverter(z3Ctx: KZ3Context) = KZ3ExprConverter(ctx, z3Ctx) + + private fun createSolver(): Solver = z3Ctx.z3Try { z3Ctx.nativeContext.mkSolver() } + + override fun configure(configurator: KZ3SolverConfiguration.() -> Unit) = z3Ctx.z3Try { + val params = z3Ctx.nativeContext.mkParams() + KZ3SolverConfigurationImpl(params).configurator() + solver.setParameters(params) + } + + override fun push(): Unit = z3Ctx.z3Try { + solver.push() + z3Ctx.pushAssertionLevel() + currentScope++ + } + + override fun pop(n: UInt) = z3Ctx.z3Try { + require(n <= currentScope) { + "Can not pop $n scope levels because current scope level is $currentScope" + } + if (n == 0u) return + + solver.pop(n.toInt()) + repeat(n.toInt()) { z3Ctx.popAssertionLevel() } + + currentScope -= n + } + + override fun assert(expr: KExpr) = z3Ctx.z3Try { + ctx.ensureContextMatch(expr) + + val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } + solver.solverAssert(z3Expr) + + z3Ctx.assertPendingAxioms(solver) + } + + protected abstract fun saveTrackedAssertion(track: Long, trackedExpr: KExpr) + protected abstract fun findTrackedExprByTrack(track: Long): KExpr? + + override fun assertAndTrack(expr: KExpr) = z3Ctx.z3Try { + ctx.ensureContextMatch(expr) + + val trackExpr = ctx.mkFreshConst("track", ctx.boolSort) + val z3Expr = with(exprInternalizer) { expr.internalizeExpr() } + val z3TrackVar = with(exprInternalizer) { trackExpr.internalizeExpr() } + + solver.solverAssertAndTrack(z3Expr, z3TrackVar) + saveTrackedAssertion(z3TrackVar, expr) + } + + override fun check(timeout: Duration): KSolverStatus = z3TryCheck { + solver.updateTimeout(timeout) + solver.check().processCheckResult() + } + + override fun checkWithAssumptions( + assumptions: List>, + timeout: Duration + ): KSolverStatus = z3TryCheck { + ctx.ensureContextMatch(assumptions) + + val z3Assumptions = with(exprInternalizer) { + LongArray(assumptions.size) { + val assumption = assumptions[it] + + /** + * Assumptions are trivially unsat and no check-sat is required. + * */ + if (assumption == ctx.falseExpr) { + lastUnsatCore = listOf(ctx.falseExpr) + lastCheckStatus = KSolverStatus.UNSAT + return KSolverStatus.UNSAT + } + + assumption.internalizeExpr() + } + } + + solver.updateTimeout(timeout) + + solver.solverCheckAssumptions(z3Assumptions).processCheckResult() + } + + override fun model(): KModel = z3Ctx.z3Try { + require(lastCheckStatus == KSolverStatus.SAT) { + "Model are only available after SAT checks, current solver status: $lastCheckStatus" + } + + val model = lastModel ?: KZ3Model( + model = solver.model, + ctx = ctx, + z3Ctx = z3Ctx, + internalizer = exprInternalizer + ) + lastModel = model + + model + } + + override fun unsatCore(): List> = z3Ctx.z3Try { + require(lastCheckStatus == KSolverStatus.UNSAT) { "Unsat cores are only available after UNSAT checks" } + + val unsatCore = lastUnsatCore ?: with(exprConverter) { + val solverUnsatCore = solver.solverGetUnsatCore() + solverUnsatCore.map { solverUnsatCoreExpr -> + findTrackedExprByTrack(solverUnsatCoreExpr) ?: solverUnsatCoreExpr.convertExpr() + } + } + lastUnsatCore = unsatCore + + unsatCore + } + + override fun reasonOfUnknown(): String = z3Ctx.z3Try { + require(lastCheckStatus == KSolverStatus.UNKNOWN) { "Unknown reason is only available after UNKNOWN checks" } + lastReasonOfUnknown ?: solver.reasonUnknown + } + + override fun interrupt() = z3Ctx.z3Try { + solver.interrupt() + } + + override fun close() { + unregisterContextCleanup(contextCleanupActionHandler) + z3Ctx.close() + } + + protected fun Status?.processCheckResult() = when (this) { + Status.SATISFIABLE -> KSolverStatus.SAT + Status.UNSATISFIABLE -> KSolverStatus.UNSAT + Status.UNKNOWN -> KSolverStatus.UNKNOWN + null -> KSolverStatus.UNKNOWN + }.also { lastCheckStatus = it } + + protected fun Solver.updateTimeout(timeout: Duration) { + val z3Timeout = if (timeout == Duration.INFINITE) { + UInt.MAX_VALUE.toInt() + } else { + timeout.toInt(DurationUnit.MILLISECONDS) + } + val params = z3Ctx.nativeContext.mkParams().apply { + add("timeout", z3Timeout) + } + setParameters(params) + } + + protected fun invalidateSolverState() { + lastReasonOfUnknown = null + lastCheckStatus = KSolverStatus.UNKNOWN + lastModel = null + lastUnsatCore = null + } + + protected inline fun z3TryCheck(body: () -> KSolverStatus): KSolverStatus = try { + invalidateSolverState() + body() + } catch (ex: Z3Exception) { + lastReasonOfUnknown = ex.message + KSolverStatus.UNKNOWN.also { lastCheckStatus = it } + } + + companion object { + init { + System.setProperty("z3.skipLibraryLoad", "true") + NativeLibraryLoader.load { os -> + when (os) { + NativeLibraryLoader.OS.LINUX -> listOf("libz3", "libz3java") + NativeLibraryLoader.OS.MACOS -> listOf("libz3", "libz3java") + NativeLibraryLoader.OS.WINDOWS -> listOf("vcruntime140", "vcruntime140_1", "libz3", "libz3java") + } + } + } + + private val cleanupHandlers = ReferenceQueue>() + private val contextForCleanup = IdentityHashMap>, KZ3Context>() + + /** Ensure Z3 native context is closed and all native memory is released. + * */ + private fun registerContextForCleanup( + solver: KSolver, + context: KZ3Context + ): PhantomReference> { + cleanupStaleContexts() + val cleanupHandler = PhantomReference(solver, cleanupHandlers) + contextForCleanup[cleanupHandler] = context + + return cleanupHandler + } + + private fun unregisterContextCleanup(handler: PhantomReference>) { + contextForCleanup.remove(handler) + handler.clear() + cleanupStaleContexts() + } + + private fun cleanupStaleContexts() { + while (true) { + val handler = cleanupHandlers.poll() ?: break + contextForCleanup.remove(handler)?.close() + } + } + } +} diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverConfiguration.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverConfiguration.kt index 92c2c559a..a3cc6dcc6 100644 --- a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverConfiguration.kt +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/KZ3SolverConfiguration.kt @@ -1,6 +1,7 @@ package io.ksmt.solver.z3 import com.microsoft.z3.Params +import com.microsoft.z3.Solver import io.ksmt.solver.KSolverConfiguration import io.ksmt.solver.KSolverUniversalConfigurationBuilder @@ -45,6 +46,44 @@ class KZ3SolverConfigurationImpl(private val params: Params) : KZ3SolverConfigur } } +class KZ3ForkingSolverConfigurationImpl(private val params: Params) : KZ3SolverConfiguration { + private val booleanOptions = hashMapOf() + private val intOptions = hashMapOf() + private val doubleOptions = hashMapOf() + private val stringOptions = hashMapOf() + + override fun setZ3Option(option: String, value: Boolean) { + params.add(option, value) + booleanOptions[option] = value + } + + override fun setZ3Option(option: String, value: Int) { + params.add(option, value) + intOptions[option] = value + } + + override fun setZ3Option(option: String, value: Double) { + params.add(option, value) + doubleOptions[option] = value + } + + override fun setZ3Option(option: String, value: String) { + params.add(option, value) + stringOptions[option] = value + } + + fun fork(params: Params): KZ3ForkingSolverConfigurationImpl = KZ3ForkingSolverConfigurationImpl(params).also { + booleanOptions.forEach { (option, value) -> it.setZ3Option(option, value) } + intOptions.forEach { (option, value) -> it.setZ3Option(option, value) } + doubleOptions.forEach { (option, value) -> it.setZ3Option(option, value) } + stringOptions.forEach { (option, value) -> it.setZ3Option(option, value) } + } + + fun setParameters(solver: Solver) { + solver.setParameters(params) + } +} + class KZ3SolverUniversalConfiguration( private val builder: KSolverUniversalConfigurationBuilder ) : KZ3SolverConfiguration { diff --git a/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ScopedFrame.kt b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ScopedFrame.kt new file mode 100644 index 000000000..22056a77a --- /dev/null +++ b/ksmt-z3/src/main/kotlin/io/ksmt/solver/z3/ScopedFrame.kt @@ -0,0 +1,113 @@ +package io.ksmt.solver.z3 + +import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap + +internal interface ScopedFrame { + val currentScope: UInt + val currentFrame: T + + fun push() + fun pop(n: UInt = 1u) +} + +internal class ScopedArrayFrameOfLong2ObjectOpenHashMap( + currentFrame: Long2ObjectOpenHashMap +) : ScopedFrame> { + constructor() : this(Long2ObjectOpenHashMap()) + + private val frames = arrayListOf(currentFrame) + + override var currentFrame = currentFrame + private set + + override val currentScope: UInt + get() = frames.size.toUInt() + + inline fun findNonNullValue(predicate: (Long2ObjectOpenHashMap) -> V?): V? { + frames.forEach { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun push() { + currentFrame = Long2ObjectOpenHashMap() + frames += currentFrame + } + + override fun pop(n: UInt) { + repeat(n.toInt()) { frames.removeLast() } + currentFrame = frames.last() + } +} + +internal class ScopedLinkedFrame private constructor( + private var current: LinkedFrame, + private val createNewFrame: () -> T, + private val copyFrame: (T) -> T +) : ScopedFrame { + constructor( + currentFrame: T, + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(LinkedFrame(currentFrame), createNewFrame, copyFrame) + + constructor( + createNewFrame: () -> T, + copyFrame: (T) -> T + ) : this(createNewFrame(), createNewFrame, copyFrame) + + override val currentFrame: T + get() = current.value + + override val currentScope: UInt + get() = current.scope + + fun stacked(): ArrayDeque = ArrayDeque().also { stack -> + forEachReversed { frame -> + stack.addLast(frame) + } + } + + inline fun findNonNullValue(predicate: (T) -> V?): V? { + forEachReversed { frame -> + predicate(frame)?.let { return it } + } + return null + } + + override fun push() { + current = LinkedFrame(createNewFrame(), current) + } + + override fun pop(n: UInt) { + current = current.previous ?: throw IllegalStateException("Can't pop the bottom scope") + recreateTopFrame() + } + + private fun recreateTopFrame() { + val newTopFrame = copyFrame(currentFrame) + current = LinkedFrame(newTopFrame, current.previous) + } + + fun fork(parent: ScopedLinkedFrame) { + current = parent.current + recreateTopFrame() + } + + private inline fun forEachReversed(action: (T) -> Unit) { + var cur: LinkedFrame? = current + while (cur != null) { + action(cur.value) + cur = cur.previous + } + } + + private class LinkedFrame( + val value: E, + val previous: LinkedFrame? = null + ) { + val scope: UInt = previous?.scope?.plus(1u) ?: 0u + } + +}