Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forking solver implementation #123

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
package io.ksmt.solver.bitwuzla

import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap
import io.ksmt.KContext
import io.ksmt.decl.KDecl
import io.ksmt.decl.KFuncDecl
@@ -15,13 +12,10 @@ import io.ksmt.expr.KExistentialQuantifier
import io.ksmt.expr.KExpr
import io.ksmt.expr.KFunctionApp
import io.ksmt.expr.KFunctionAsArray
import io.ksmt.expr.KUninterpretedSortValue
import io.ksmt.expr.KUniversalQuantifier
import io.ksmt.expr.transformer.KNonRecursiveTransformer
import io.ksmt.solver.KSolverException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.Native
import io.ksmt.solver.util.KExprLongInternalizerBase.Companion.NOT_INTERNALIZED
import io.ksmt.sort.KArray2Sort
import io.ksmt.sort.KArray3Sort
@@ -37,6 +31,13 @@ import io.ksmt.sort.KRealSort
import io.ksmt.sort.KSort
import io.ksmt.sort.KSortVisitor
import io.ksmt.sort.KUninterpretedSort
import it.unimi.dsi.fastutil.longs.Long2ObjectOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2IntOpenHashMap
import it.unimi.dsi.fastutil.objects.Object2LongOpenHashMap
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.Native

open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
private var isClosed = false
@@ -433,6 +434,11 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
return super.transform(expr)
}

override fun transform(expr: KUninterpretedSortValue): KExpr<KUninterpretedSort> {
registerDeclIfNotIgnored(expr.decl)
return super.transform(expr)
}

private val quantifiedVarsScopeOwner = arrayListOf<KExpr<*>>()
private val quantifiedVarsScope = arrayListOf<Set<KDecl<*>>?>()

@@ -474,7 +480,7 @@ open class KBitwuzlaContext(val ctx: KContext) : AutoCloseable {
override fun transform(expr: KExistentialQuantifier): KExpr<KBoolSort> =
expr.transformQuantifier(expr.bounds, expr.body)

override fun transform(expr: KUniversalQuantifier): KExpr<KBoolSort> =
override fun transform(expr: KUniversalQuantifier): KExpr<KBoolSort> =
expr.transformQuantifier(expr.bounds, expr.body)
}

Original file line number Diff line number Diff line change
@@ -151,17 +151,11 @@ import io.ksmt.expr.KUnaryMinusArithExpr
import io.ksmt.expr.KUninterpretedSortValue
import io.ksmt.expr.KUniversalQuantifier
import io.ksmt.expr.KXorExpr
import io.ksmt.expr.rewrite.simplify.rewriteBvAddNoUnderflowExpr
import io.ksmt.expr.rewrite.simplify.rewriteBvMulNoUnderflowExpr
import io.ksmt.expr.rewrite.simplify.rewriteBvNegNoOverflowExpr
import io.ksmt.expr.rewrite.simplify.rewriteBvSubNoUnderflowExpr
import io.ksmt.solver.KSolverUnsupportedFeatureException
import org.ksmt.solver.bitwuzla.bindings.Bitwuzla
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaKind
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaRoundingMode
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.Native
import io.ksmt.solver.bitwuzla.KBitwuzlaExprInternalizer.BvOverflowCheckMode.OVERFLOW
import io.ksmt.solver.bitwuzla.KBitwuzlaExprInternalizer.BvOverflowCheckMode.UNDERFLOW
import io.ksmt.solver.util.KExprLongInternalizerBase
import io.ksmt.sort.KArithSort
import io.ksmt.sort.KArray2Sort
@@ -186,7 +180,13 @@ import io.ksmt.sort.KRealSort
import io.ksmt.sort.KSort
import io.ksmt.sort.KSortVisitor
import io.ksmt.sort.KUninterpretedSort
import org.ksmt.solver.bitwuzla.bindings.Bitwuzla
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaKind
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaRoundingMode
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaSort
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTermArray
import org.ksmt.solver.bitwuzla.bindings.Native
import java.math.BigInteger

@Suppress("LargeClass")
@@ -726,7 +726,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
override fun <T : KBvSort> transform(expr: KBvAddNoOverflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
if (isSigned) {
mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW)
mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW)
} else {
val overflowCheck = Native.bitwuzlaMkTerm2(
bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_UADD_OVERFLOW, a0, a1
@@ -738,20 +738,20 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL

override fun <T : KBvSort> transform(expr: KBvAddNoUnderflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW)
mkBvAddSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW)
}
}

override fun <T : KBvSort> transform(expr: KBvSubNoOverflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW)
mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW)
}
}

override fun <T : KBvSort> transform(expr: KBvSubNoUnderflowExpr<T>) = with(expr) {
if (isSigned) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW)
mkBvSubSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW)
}
} else {
transform {
@@ -776,7 +776,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
override fun <T : KBvSort> transform(expr: KBvMulNoOverflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
if (isSigned) {
mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.OVERFLOW)
mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, OVERFLOW)
} else {
val overflowCheck = Native.bitwuzlaMkTerm2(
bitwuzla, BitwuzlaKind.BITWUZLA_KIND_BV_UMUL_OVERFLOW, a0, a1
@@ -788,7 +788,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL

override fun <T : KBvSort> transform(expr: KBvMulNoUnderflowExpr<T>) = with(expr) {
transform(arg0, arg1) { a0: BitwuzlaTerm, a1: BitwuzlaTerm ->
mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, BvOverflowCheckMode.UNDERFLOW)
mkBvMulSignedNoOverflowTerm(arg0.sort.sizeBits.toInt(), a0, a1, UNDERFLOW)
}
}

@@ -813,7 +813,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
a1,
BitwuzlaKind.BITWUZLA_KIND_BV_SADD_OVERFLOW
) { a0Sign, a1Sign ->
if (mode == BvOverflowCheckMode.OVERFLOW) {
if (mode == OVERFLOW) {
// Both positive
mkAndTerm(longArrayOf(mkNotTerm(a0Sign), mkNotTerm(a1Sign)))
} else {
@@ -833,7 +833,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
a1,
BitwuzlaKind.BITWUZLA_KIND_BV_SSUB_OVERFLOW
) { a0Sign, a1Sign ->
if (mode == BvOverflowCheckMode.OVERFLOW) {
if (mode == OVERFLOW) {
// Positive sub negative
mkAndTerm(longArrayOf(mkNotTerm(a0Sign), a1Sign))
} else {
@@ -853,7 +853,7 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
a1,
BitwuzlaKind.BITWUZLA_KIND_BV_SMUL_OVERFLOW
) { a0Sign, a1Sign ->
if (mode == BvOverflowCheckMode.OVERFLOW) {
if (mode == OVERFLOW) {
// Overflow is possible when sign bits are equal
mkEqTerm(bitwuzlaCtx.ctx.boolSort, a0Sign, a1Sign)
} else {
@@ -1401,6 +1401,8 @@ open class KBitwuzlaExprInternalizer(val bitwuzlaCtx: KBitwuzlaContext) : KExprL
}

override fun transform(expr: KUninterpretedSortValue): KExpr<KUninterpretedSort> = expr.transform {
// register it for uninterpreted sort universe
bitwuzlaCtx.registerDeclaration(expr.decl)
Native.bitwuzlaMkBvValueUint32(
bitwuzla,
expr.sort.internalizeSort(),
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
package io.ksmt.solver.bitwuzla

import io.ksmt.KContext
import io.ksmt.expr.KExpr
import io.ksmt.solver.KForkingSolver
import io.ksmt.solver.KSolverStatus
import io.ksmt.sort.KBoolSort
import kotlin.time.Duration

class KBitwuzlaForkingSolver(
private val ctx: KContext,
private val manager: KBitwuzlaForkingSolverManager,
parent: KBitwuzlaForkingSolver?
) : KBitwuzlaSolverBase(ctx),
KForkingSolver<KBitwuzlaSolverConfiguration> {

private val assertions = ScopedLinkedFrame<MutableList<KExpr<KBoolSort>>>(::ArrayList, ::ArrayList)
private val trackToExprFrames =
ScopedLinkedFrame<MutableList<Pair<KExpr<KBoolSort>, KExpr<KBoolSort>>>>(::ArrayList, ::ArrayList)

private val config: KBitwuzlaForkingSolverConfigurationImpl

init {
if (parent != null) {
config = parent.config.fork(bitwuzlaCtx.bitwuzla)
assertions.fork(parent.assertions)
trackToExprFrames.fork(parent.trackToExprFrames)
} else {
config = KBitwuzlaForkingSolverConfigurationImpl(bitwuzlaCtx.bitwuzla)
}
}

override fun configure(configurator: KBitwuzlaSolverConfiguration.() -> Unit) {
config.configurator()
}

override fun fork(): KForkingSolver<KBitwuzlaSolverConfiguration> = manager.mkForkingSolver(this)

private var assertionsInitiated = parent == null

private fun ensureAssertionsInitiated() {
if (assertionsInitiated) return

assertions.stacked().zip(trackToExprFrames.stacked())
.asReversed()
.forEachIndexed { scope, (assertionsFrame, trackedExprsFrame) ->
if (scope > 0) super.push()

assertionsFrame.forEach { assertion ->
internalizeAndAssertWithAxioms(assertion)
}

trackedExprsFrame.forEach { (track, trackedExpr) ->
super.registerTrackForExpr(trackedExpr, track)
}
}
assertionsInitiated = true
}

override fun assert(expr: KExpr<KBoolSort>) = bitwuzlaCtx.bitwuzlaTry {
ctx.ensureContextMatch(expr)
ensureAssertionsInitiated()

internalizeAndAssertWithAxioms(expr)
assertions.currentFrame += expr
}

override fun assertAndTrack(expr: KExpr<KBoolSort>) {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
super.assertAndTrack(expr)
}

override fun registerTrackForExpr(expr: KExpr<KBoolSort>, track: KExpr<KBoolSort>) {
super.registerTrackForExpr(expr, track)
trackToExprFrames.currentFrame += track to expr
}

override fun push() {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
super.push()
assertions.push()
trackToExprFrames.push()
}

override fun pop(n: UInt) {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
super.pop(n)
assertions.pop(n)
trackToExprFrames.pop(n)
}

override fun check(timeout: Duration): KSolverStatus {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
return super.check(timeout)
}

override fun checkWithAssumptions(assumptions: List<KExpr<KBoolSort>>, timeout: Duration): KSolverStatus {
bitwuzlaCtx.bitwuzlaTry { ensureAssertionsInitiated() }
return super.checkWithAssumptions(assumptions, timeout)
}

override fun close() {
super.close()
manager.close(this)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package io.ksmt.solver.bitwuzla

import io.ksmt.KContext
import io.ksmt.solver.KForkingSolver
import io.ksmt.solver.KForkingSolverManager
import java.util.concurrent.ConcurrentHashMap

class KBitwuzlaForkingSolverManager(private val ctx: KContext) : KForkingSolverManager<KBitwuzlaSolverConfiguration> {
private val solvers = ConcurrentHashMap.newKeySet<KBitwuzlaForkingSolver>()

override fun mkForkingSolver(): KForkingSolver<KBitwuzlaSolverConfiguration> {
return KBitwuzlaForkingSolver(ctx, this, null).also {
solvers += it
}
}

internal fun mkForkingSolver(parent: KBitwuzlaForkingSolver) = KBitwuzlaForkingSolver(ctx, this, parent).also {
solvers += it
}

internal fun close(solver: KBitwuzlaForkingSolver) {
solvers -= solver
}

override fun close() {
solvers.forEach(KBitwuzlaForkingSolver::close)
}
}
Original file line number Diff line number Diff line change
@@ -2,18 +2,15 @@ package io.ksmt.solver.bitwuzla

import io.ksmt.KContext
import io.ksmt.decl.KDecl
import io.ksmt.decl.KUninterpretedSortValueDecl
import io.ksmt.expr.KExpr
import io.ksmt.expr.KUninterpretedSortValue
import io.ksmt.solver.KModel
import io.ksmt.solver.KSolverUnsupportedFeatureException
import io.ksmt.solver.model.KFuncInterp
import io.ksmt.solver.model.KFuncInterpEntryVarsFree
import io.ksmt.solver.model.KFuncInterpEntryVarsFreeOneAry
import io.ksmt.solver.model.KFuncInterpVarsFree
import io.ksmt.solver.KModel
import io.ksmt.solver.KSolverUnsupportedFeatureException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.FunValue
import org.ksmt.solver.bitwuzla.bindings.Native
import io.ksmt.solver.model.KFuncInterpWithVars
import io.ksmt.solver.model.KModelEvaluator
import io.ksmt.solver.model.KModelImpl
@@ -23,6 +20,10 @@ import io.ksmt.sort.KSort
import io.ksmt.sort.KUninterpretedSort
import io.ksmt.utils.mkFreshConstDecl
import io.ksmt.utils.uncheckedCast
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaNativeException
import org.ksmt.solver.bitwuzla.bindings.BitwuzlaTerm
import org.ksmt.solver.bitwuzla.bindings.FunValue
import org.ksmt.solver.bitwuzla.bindings.Native

open class KBitwuzlaModel(
private val ctx: KContext,
@@ -77,7 +78,12 @@ open class KBitwuzlaModel(
* to ensure that [uninterpretedSortValueContext] contains
* all possible values for the given sort.
* */
sortDependency.forEach { interpretation(it) }
sortDependency.forEach {
if (it is KUninterpretedSortValueDecl) {
val value = ctx.mkUninterpretedSortValue(it.sort, it.valueIdx)
uninterpretedSortValueContext.registerValue(value)
} else interpretation(it)
}

uninterpretedSortValueContext.currentSortUniverse(sort)
}
Loading