diff --git a/README.md b/README.md index 4a59848..7f4d331 100644 --- a/README.md +++ b/README.md @@ -46,9 +46,21 @@ Tseitin's transformation: ``` # SMT solver -An SMT solver is implemented for linear real arithmetic (QF_LRA), linear integer arithmetic (QF_LIA), equality logic (QF_EQ), equality logic with uninterpreted functions (QF_EQUF) and bit vector arithmetic (QF_BV). +An SMT solver is implemented for non-linear real arithmetic (QF_NRA), linear real arithmetic (QF_LRA), linear integer arithmetic (QF_LIA), equality logic (QF_EQ), equality logic with uninterpreted functions (QF_EQUF) and bit vector arithmetic (QF_BV). ## Examples +### QF_NRA (⚠️ highly experimental ⚠️) +```c++ +> smt QF_NRA (x^2 + y^2 = 1) & (x^2 + y^3 = 1/2) +SAT: +x=(x^6-2x^4+2x^2-3/4, 3/4, 15/16) ≈ 0.8249554777467228; y=(x^9+1/2x^6+3/4x^3+1/8, -7/4, 7/4) ≈ -0.5651977173836394; +Time: 456ms +``` +```c++ +> smt QF_NRA (x*y > 0) & (y*z > 0) & (x*z > 0) & (x + y + z = 0) +UNSAT +Time: 7ms +``` ### QF_LRA ```c++ > smt QF_LRA (x<=-3 | x>=3) & (y=5) & (x+y>=12) diff --git a/src/main/java/me/paultristanwagner/satchecking/command/impl/SMTCommand.java b/src/main/java/me/paultristanwagner/satchecking/command/impl/SMTCommand.java index 971b31a..4314216 100644 --- a/src/main/java/me/paultristanwagner/satchecking/command/impl/SMTCommand.java +++ b/src/main/java/me/paultristanwagner/satchecking/command/impl/SMTCommand.java @@ -25,6 +25,7 @@ public SMTCommand() { "smt ", """ Available theories: + QF_NRA (Non-linear real arithmetic) (! highly experimental !), QF_LRA (Linear real arithmetic), QF_LIA (Linear integer arithmetic), QF_EQ (Equality logic), @@ -32,6 +33,8 @@ public SMTCommand() { QF_BV (Bitvector arithmetic) Examples: + smt QF_NRA (x^2 + y^2 = 1) & (x^2 + y^3 = 1/2) + smt QF_LRA (x <= 5) & (max(x)) smt QF_LIA (x <= 3/2) & (max(x)) diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/ComparisonLexer.java b/src/main/java/me/paultristanwagner/satchecking/parse/ComparisonLexer.java new file mode 100644 index 0000000..d273788 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/parse/ComparisonLexer.java @@ -0,0 +1,21 @@ +package me.paultristanwagner.satchecking.parse; + +import static me.paultristanwagner.satchecking.parse.TokenType.*; + +public class ComparisonLexer extends Lexer { + + public ComparisonLexer(String input) { + super(input); + + registerTokenTypes( + EQUALS, + NOT_EQUALS, + GREATER_EQUALS, + LOWER_EQUALS, + LESS_THAN, + GREATER_THAN + ); + + initialize(input); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintParser.java b/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintParser.java index 424feb0..2a27fc3 100644 --- a/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintParser.java +++ b/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintParser.java @@ -156,6 +156,6 @@ private static Number FRACTION(Lexer lexer) { long numerator = Long.parseLong(parts[0]); long denominator = Long.parseLong(parts[1]); - return Number.of(numerator, denominator); + return Number.number(numerator, denominator); } } diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/MultivariatePolynomialConstraintParser.java b/src/main/java/me/paultristanwagner/satchecking/parse/MultivariatePolynomialConstraintParser.java new file mode 100644 index 0000000..8024350 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/parse/MultivariatePolynomialConstraintParser.java @@ -0,0 +1,70 @@ +package me.paultristanwagner.satchecking.parse; + +import me.paultristanwagner.satchecking.sat.Result; +import me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomial; +import me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint; + +import static me.paultristanwagner.satchecking.parse.TokenType.*; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint.Comparison.GREATER_THAN; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint.Comparison.GREATER_THAN_OR_EQUALS; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint.Comparison.LESS_THAN; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint.Comparison.LESS_THAN_OR_EQUALS; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint.Comparison.NOT_EQUALS; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint.multivariatePolynomialConstraint; + +public class MultivariatePolynomialConstraintParser implements Parser { + + @Override + public MultivariatePolynomialConstraint parse(String string) { + ParseResult result = parseWithRemaining(string); + if(!result.complete()) { + throw new SyntaxError("Expected end of input", string, result.charsRead()); + } + + return result.result(); + } + + @Override + public ParseResult parseWithRemaining(String string) { + Parser parser = new PolynomialParser(); + + ParseResult pResult = parser.parseWithRemaining(string); + MultivariatePolynomial p = pResult.result(); + + Lexer lexer = new ComparisonLexer(string.substring(pResult.charsRead())); + MultivariatePolynomialConstraint.Comparison comparison = parseComparison(lexer); + + ParseResult qResult = parser.parseWithRemaining(lexer.getRemaining()); + MultivariatePolynomial q = qResult.result(); + + MultivariatePolynomial d = p.subtract(q); + MultivariatePolynomialConstraint constraint = multivariatePolynomialConstraint(d, comparison); + + int charsRead = pResult.charsRead() + lexer.getCursor() + qResult.charsRead(); + return new ParseResult<>(constraint, charsRead, qResult.complete()); + } + + private MultivariatePolynomialConstraint.Comparison parseComparison(Lexer lexer) { + if(lexer.canConsume(TokenType.EQUALS)) { + lexer.consume(TokenType.EQUALS); + return MultivariatePolynomialConstraint.Comparison.EQUALS; + } else if(lexer.canConsume(TokenType.NOT_EQUALS)) { + lexer.consume(TokenType.NOT_EQUALS); + return NOT_EQUALS; + } else if(lexer.canConsume(TokenType.LESS_THAN)) { + lexer.consume(TokenType.LESS_THAN); + return LESS_THAN; + } else if(lexer.canConsume(TokenType.GREATER_THAN)) { + lexer.consume(TokenType.GREATER_THAN); + return GREATER_THAN; + } else if(lexer.canConsume(LOWER_EQUALS)) { + lexer.consume(LOWER_EQUALS); + return LESS_THAN_OR_EQUALS; + } else if(lexer.canConsume(GREATER_EQUALS)) { + lexer.consume(GREATER_EQUALS); + return GREATER_THAN_OR_EQUALS; + } else { + throw new SyntaxError("Expected comparison operator", lexer.getInput(), lexer.getCursor()); + } + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/PolynomialLexer.java b/src/main/java/me/paultristanwagner/satchecking/parse/PolynomialLexer.java new file mode 100644 index 0000000..cdc93c5 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/parse/PolynomialLexer.java @@ -0,0 +1,21 @@ +package me.paultristanwagner.satchecking.parse; + +import static me.paultristanwagner.satchecking.parse.TokenType.*; + +public class PolynomialLexer extends Lexer { + + public PolynomialLexer(String input) { + super(input); + + registerTokenTypes( + PLUS, + MINUS, + FRACTION, + DECIMAL, + IDENTIFIER, + TIMES, + POWER); + + initialize(input); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/PolynomialParser.java b/src/main/java/me/paultristanwagner/satchecking/parse/PolynomialParser.java new file mode 100644 index 0000000..b23c289 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/parse/PolynomialParser.java @@ -0,0 +1,145 @@ +package me.paultristanwagner.satchecking.parse; + +import me.paultristanwagner.satchecking.theory.arithmetic.Number; +import me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomial; + +import java.util.Scanner; + +import static me.paultristanwagner.satchecking.parse.TokenType.*; +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.number; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomial.constant; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomial.variable; + +public class PolynomialParser implements Parser { + + public static void main(String[] args){ + Scanner scanner = new Scanner(System.in); + PolynomialParser parser = new PolynomialParser(); + + String line; + while ((line = scanner.nextLine()) != null) { + MultivariatePolynomial polynomial = parser.parse(line); + System.out.println(polynomial); + } + } + + /* + * Grammar for polynomial constraints: + * ::= EQUALS 0 + * + * ::= PLUS + * | MINUS + * | + * + * ::= TIMES + * | + * + * ::= FRACTION + * | DECIMAL + * | IDENTIFIER + * | IDENTIFIER POWER INTEGER + */ + + public MultivariatePolynomial parse(String string) { + ParseResult result = parseWithRemaining(string); + + if (!result.complete()) { + throw new SyntaxError("Expected end of input", string, result.charsRead()); + } + + return result.result(); + } + + @Override + public ParseResult parseWithRemaining(String string) { + Lexer lexer = new PolynomialLexer(string); + + lexer.requireNextToken(); + + MultivariatePolynomial polynomial = parseTerm(lexer); + + return new ParseResult<>(polynomial, lexer.getCursor(), lexer.getCursor() == string.length()); + } + + private MultivariatePolynomial parseTerm(Lexer lexer) { + MultivariatePolynomial term1 = parseMonomial(lexer); + + while(lexer.canConsume(PLUS) || lexer.canConsume(MINUS)){ + if (lexer.canConsume(PLUS)) { + lexer.consume(PLUS); + MultivariatePolynomial term2 = parseMonomial(lexer); + term1 = term1.add(term2); + } else if (lexer.canConsume(MINUS)) { + lexer.consume(MINUS); + MultivariatePolynomial term2 = parseMonomial(lexer); + term1 = term1.subtract(term2); + } + } + + return term1; + } + + private MultivariatePolynomial parseMonomial(Lexer lexer) { + MultivariatePolynomial monomial1 = parseFactor(lexer); + + while (lexer.canConsumeEither(TIMES, IDENTIFIER)) { + if(lexer.canConsume(TIMES)) { + lexer.consume(TIMES); + } + MultivariatePolynomial monomial2 = parseFactor(lexer); + monomial1 = monomial1.multiply(monomial2); + } + + return monomial1; + } + + private int parseSign(Lexer lexer) { + int sign = 1; + while(lexer.canConsumeEither(PLUS, MINUS)){ + if (lexer.canConsume(PLUS)) { + lexer.consume(PLUS); + } else if (lexer.canConsume(MINUS)) { + lexer.consume(MINUS); + sign *= -1; + } + } + + return sign; + } + + private MultivariatePolynomial parseFactor(Lexer lexer) { + int sign = parseSign(lexer); + + if (lexer.canConsumeEither(DECIMAL, FRACTION)) { + String value = lexer.getLookahead().getValue(); + lexer.consumeEither(DECIMAL, FRACTION); + Number number = Number.parse(value); + + if(sign == -1) { + number = number.negate(); + } + + return constant(number); + } + + if (lexer.canConsume(IDENTIFIER)) { + String variable = lexer.getLookahead().getValue(); + lexer.consume(IDENTIFIER); + + if (lexer.canConsume(POWER)) { + lexer.consume(POWER); + lexer.require(DECIMAL); + String exponent = lexer.getLookahead().getValue(); + lexer.consume(DECIMAL); + + MultivariatePolynomial monomial = variable(variable).pow(Integer.parseInt(exponent)); + + return monomial.multiply(constant(number(sign))); + } + + return variable(variable).multiply(constant(number(sign))); + } + + throw new SyntaxError("Expected either a decimal or an identifier", lexer.getInput(), lexer.getCursor()); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/TheoryCNFParser.java b/src/main/java/me/paultristanwagner/satchecking/parse/TheoryCNFParser.java index 4517ec8..90a5c17 100644 --- a/src/main/java/me/paultristanwagner/satchecking/parse/TheoryCNFParser.java +++ b/src/main/java/me/paultristanwagner/satchecking/parse/TheoryCNFParser.java @@ -7,6 +7,7 @@ import me.paultristanwagner.satchecking.theory.EqualityFunctionConstraint; import me.paultristanwagner.satchecking.theory.LinearConstraint; import me.paultristanwagner.satchecking.theory.bitvector.constraint.BitVectorConstraint; +import me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint; import java.util.ArrayList; import java.util.List; @@ -78,7 +79,10 @@ private T LITERAL(Lexer lexer) { ParseResult parseResult; try { - if (constraintClass == LinearConstraint.class) { + if(constraintClass == MultivariatePolynomialConstraint.class) { + MultivariatePolynomialConstraintParser multivariatePolynomialConstraintParser = new MultivariatePolynomialConstraintParser(); + parseResult = (ParseResult) multivariatePolynomialConstraintParser.parseWithRemaining(remaining); + } else if (constraintClass == LinearConstraint.class) { LinearConstraintParser linearConstraintParser = new LinearConstraintParser(); parseResult = (ParseResult) linearConstraintParser.parseWithRemaining(remaining); } else if (constraintClass == EqualityConstraint.class) { diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/TokenType.java b/src/main/java/me/paultristanwagner/satchecking/parse/TokenType.java index b6799dc..66d8278 100644 --- a/src/main/java/me/paultristanwagner/satchecking/parse/TokenType.java +++ b/src/main/java/me/paultristanwagner/satchecking/parse/TokenType.java @@ -23,6 +23,7 @@ public class TokenType { static final TokenType MINUS = TokenType.of("-", "^-"); static final TokenType TIMES = TokenType.of("*", "^\\*"); static final TokenType DIVIDE = TokenType.of("/", "^\\/"); + static final TokenType POWER = TokenType.of("^", "^\\^"); static final TokenType REMAINDER = TokenType.of("*", "^\\%"); static final TokenType AND = TokenType.of("and", "^(&|&&|and|AND|∧)"); static final TokenType OR = TokenType.of("or", "^(\\|\\||\\||or|OR|∨)"); diff --git a/src/main/java/me/paultristanwagner/satchecking/smt/VariableAssignment.java b/src/main/java/me/paultristanwagner/satchecking/smt/VariableAssignment.java index 8211a34..ce18f8b 100644 --- a/src/main/java/me/paultristanwagner/satchecking/smt/VariableAssignment.java +++ b/src/main/java/me/paultristanwagner/satchecking/smt/VariableAssignment.java @@ -2,30 +2,35 @@ import java.util.*; -public class VariableAssignment { +public class VariableAssignment extends HashMap { - private final Map assignments = new HashMap<>(); + public VariableAssignment() { + } + + public VariableAssignment(Map assignments) { + this.putAll(assignments); + } public void assign(String variable, O value) { - assignments.put(variable, value); + this.put(variable, value); } public O getAssignment(String variable) { - return assignments.get(variable); + return this.get(variable); } public Set getVariables() { - return assignments.keySet(); + return this.keySet(); } @Override public String toString() { StringBuilder builder = new StringBuilder(); - List variables = new ArrayList<>(assignments.keySet()); + List variables = new ArrayList<>(this.keySet()); variables.sort(String::compareTo); for (String variable : variables) { - O value = assignments.get(variable); + O value = this.get(variable); builder.append(variable).append("=").append(value).append("; "); } diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/Theory.java b/src/main/java/me/paultristanwagner/satchecking/theory/Theory.java index a6df974..1151c57 100644 --- a/src/main/java/me/paultristanwagner/satchecking/theory/Theory.java +++ b/src/main/java/me/paultristanwagner/satchecking/theory/Theory.java @@ -5,12 +5,16 @@ import me.paultristanwagner.satchecking.smt.solver.LessLazySMTSolver; import me.paultristanwagner.satchecking.smt.solver.SMTSolver; import me.paultristanwagner.satchecking.theory.bitvector.constraint.BitVectorConstraint; +import me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint; import me.paultristanwagner.satchecking.theory.solver.*; import java.util.List; public class Theory { + private static final String QF_NRA_NAME = "QF_NRA"; + public static final Theory QF_NRA = new Theory(QF_NRA_NAME, true); + private static final String QF_LRA_NAME = "QF_LRA"; public static final Theory QF_LRA = new Theory(QF_LRA_NAME, true); @@ -26,7 +30,7 @@ public class Theory { private static final String QF_BV_NAME = "QF_BV"; public static final Theory QF_BV = new Theory(QF_BV_NAME, true); - private final static List theories = List.of(QF_LRA, QF_LIA, QF_EQ, QF_EQUF, QF_BV); + private final static List theories = List.of(QF_NRA, QF_LRA, QF_LIA, QF_EQ, QF_EQUF, QF_BV); private final String name; private final boolean complete; @@ -49,6 +53,7 @@ public static Theory get(String name) { @SuppressWarnings("rawtypes") public TheoryCNFParser getCNFParser() { return switch (name) { + case QF_NRA_NAME -> new TheoryCNFParser<>(MultivariatePolynomialConstraint.class); case QF_LRA_NAME, QF_LIA_NAME -> new TheoryCNFParser<>(LinearConstraint.class); case QF_EQ_NAME -> new TheoryCNFParser<>(EqualityConstraint.class); case QF_EQUF_NAME -> new TheoryCNFParser<>(EqualityFunctionConstraint.class); @@ -60,6 +65,7 @@ public TheoryCNFParser getCNFParser() { @SuppressWarnings("rawtypes") public TheorySolver getTheorySolver() { return switch (name) { + case QF_NRA_NAME -> new NonLinearRealArithmeticSolver(); case QF_LRA_NAME -> new SimplexOptimizationSolver(); case QF_LIA_NAME -> new LinearIntegerSolver(); case QF_EQ_NAME -> new EqualityLogicSolver(); @@ -72,7 +78,7 @@ public TheorySolver getTheorySolver() { @SuppressWarnings("rawtypes") public SMTSolver getSMTSolver() { return switch (name) { - case QF_LRA_NAME, QF_LIA_NAME, QF_EQUF_NAME, QF_BV_NAME -> new FullLazySMTSolver<>(); + case QF_NRA_NAME, QF_LRA_NAME, QF_LIA_NAME, QF_EQUF_NAME, QF_BV_NAME -> new FullLazySMTSolver<>(); case QF_EQ_NAME -> new LessLazySMTSolver<>(); default -> throw new IllegalArgumentException("Unknown theory: " + name); }; diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Float.java b/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Float.java index 23d1b02..0f67c4c 100644 --- a/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Float.java +++ b/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Float.java @@ -1,5 +1,6 @@ package me.paultristanwagner.satchecking.theory.arithmetic; +import java.math.BigInteger; import java.util.Objects; public class Float implements Number { @@ -41,6 +42,11 @@ public Float multiply(Number other) { return new Float(value * otherFloat.value); } + @Override + public Float pow(int exponent) { + return new Float(Math.pow(value, exponent)); + } + @Override public Float divide(Number other) { if (!(other instanceof Float otherFloat)) { @@ -50,6 +56,20 @@ public Float divide(Number other) { return new Float(value / otherFloat.value); } + @Override + public Float midpoint(Number other) { + if (!(other instanceof Float otherFloat)) { + throw new IllegalArgumentException("Cannot add " + other + " to " + this); + } + + return new Float((value + otherFloat.value) / 2); + } + + @Override + public Number mediant(Number other) { + throw new UnsupportedOperationException(); + } + @Override public Float negate() { return new Float(-value); @@ -85,6 +105,26 @@ public Number floor() { return new Float(Math.floor(value)); } + @Override + public Number gcd(Number other) { + throw new UnsupportedOperationException(); + } + + @Override + public Number lcm(Number other) { + throw new UnsupportedOperationException(); + } + + @Override + public BigInteger getNumerator() { + throw new UnsupportedOperationException(); + } + + @Override + public BigInteger getDenominator() { + throw new UnsupportedOperationException(); + } + @Override public boolean lessThan(Number other) { if (!(other instanceof Float otherFloat)) { @@ -123,6 +163,29 @@ public int hashCode() { } public static Float parse(String string) { + if(string.contains("/")) { + return parseFraction(string); + } + return new Float(Double.parseDouble(string)); } + + private static Float parseFraction(String string) { + String[] parts = string.split("/"); + + String numerator = parts[0]; + String denominator = parts.length > 1 ? parts[1] : "1"; + + return new Float(Double.parseDouble(numerator) / Double.parseDouble(denominator)); + } + + @Override + public float approximateAsFloat() { + return (float) value; + } + + @Override + public double approximateAsDouble() { + return value; + } } diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Number.java b/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Number.java index 10b9009..2eef7b0 100644 --- a/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Number.java +++ b/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Number.java @@ -2,6 +2,8 @@ import me.paultristanwagner.satchecking.Config; +import java.math.BigInteger; + public interface Number { static Number ZERO() { @@ -28,7 +30,7 @@ static Number parse(String string) { } } - static Number of(long numerator, long denominator) { + static Number number(long numerator, long denominator) { if (Config.get().useFloats()) { return new Float((double) numerator / denominator); } else { @@ -36,14 +38,24 @@ static Number of(long numerator, long denominator) { } } + static Number number(long numerator) { + return number(numerator, 1); + } + Number add(Number other); Number subtract(Number other); Number multiply(Number other); + Number pow(int exponent); + Number divide(Number other); + Number midpoint(Number other); + + Number mediant(Number other); + Number negate(); Number abs(); @@ -58,6 +70,14 @@ static Number of(long numerator, long denominator) { Number floor(); + Number gcd(Number other); + + Number lcm(Number other); + + BigInteger getNumerator(); + + BigInteger getDenominator(); + boolean lessThan(Number other); boolean lessThanOrEqual(Number other); @@ -83,4 +103,18 @@ default boolean isNegative() { default boolean isNonNegative() { return isZero() || isPositive(); } + + default int sign() { + if (isZero()) { + return 0; + } else if (isPositive()) { + return 1; + } else { + return -1; + } + } + + float approximateAsFloat(); + + double approximateAsDouble(); } diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Rational.java b/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Rational.java index 094c60a..c971763 100644 --- a/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Rational.java +++ b/src/main/java/me/paultristanwagner/satchecking/theory/arithmetic/Rational.java @@ -1,6 +1,8 @@ package me.paultristanwagner.satchecking.theory.arithmetic; +import java.math.BigDecimal; import java.math.BigInteger; +import java.math.MathContext; import java.util.Objects; import static java.math.BigInteger.TEN; @@ -76,6 +78,15 @@ public Rational multiply(Number other) { return new Rational(num, den); } + @Override + public Rational pow(int exponent) { + if (exponent < 0) { + return new Rational(denominator.pow(-exponent), numerator.pow(-exponent)); + } + + return new Rational(numerator.pow(exponent), denominator.pow(exponent)); + } + @Override public Rational divide(Number other) { if (!(other instanceof Rational otherExact)) { @@ -92,6 +103,24 @@ public Rational divide(Number other) { return new Rational(num, den); } + public Rational midpoint(Number other) { + if (!(other instanceof Rational otherExact)) { + throw new IllegalArgumentException("Cannot add " + other + " to " + this); + } + + return this.add(other).divide(new Rational(2)); + } + + public Rational mediant(Number other) { + if (!(other instanceof Rational otherExact)) { + throw new IllegalArgumentException("Cannot add " + other + " to " + this); + } + + BigInteger num = numerator.add(otherExact.numerator); + BigInteger den = denominator.add(otherExact.denominator); + return new Rational(num, den); + } + @Override public Rational negate() { return new Rational(numerator.negate(), denominator); @@ -146,13 +175,41 @@ public Number floor() { return new Rational(div); } + public Number gcd(Number other) { + if (!(other instanceof Rational otherExact)) { + throw new IllegalArgumentException("Cannot add " + other + " to " + this); + } + + if(!isInteger() || !otherExact.isInteger()) { + throw new IllegalArgumentException("Cannot calculate gcd of non-integer numbers"); + } + + return new Rational(numerator.gcd(otherExact.numerator)); + } + + @Override + public Number lcm(Number other) { + if (!(other instanceof Rational otherExact)) { + throw new IllegalArgumentException("Cannot add " + other + " to " + this); + } + + if(!isInteger() || !otherExact.isInteger()) { + throw new IllegalArgumentException("Cannot calculate lcm of non-integer numbers"); + } + + return new Rational(numerator.multiply(otherExact.numerator).divide(numerator.gcd(otherExact.numerator))); + } + @Override public boolean lessThan(Number other) { if (!(other instanceof Rational otherExact)) { throw new IllegalArgumentException("Cannot compare " + other + " to " + this); } - return numerator.multiply(otherExact.denominator).compareTo(otherExact.numerator.multiply(denominator)) < 0; + return numerator + .multiply(otherExact.denominator) + .compareTo(otherExact.numerator.multiply(denominator)) + < 0; } @Override @@ -226,6 +283,16 @@ private static Rational parseRational(String value) { return new Rational(new BigInteger(numerator), new BigInteger(denominator)); } + @Override + public BigInteger getNumerator() { + return numerator; + } + + @Override + public BigInteger getDenominator() { + return denominator; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -239,4 +306,18 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(numerator, denominator); } + + @Override + public float approximateAsFloat() { + BigDecimal numeratorDecimal = new BigDecimal(numerator); + BigDecimal denominatorDecimal = new BigDecimal(denominator); + return numeratorDecimal.divide(denominatorDecimal, MathContext.DECIMAL32).floatValue(); + } + + @Override + public double approximateAsDouble() { + BigDecimal numeratorDecimal = new BigDecimal(numerator); + BigDecimal denominatorDecimal = new BigDecimal(denominator); + return numeratorDecimal.divide(denominatorDecimal, MathContext.DECIMAL64).doubleValue(); + } } diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/CAD.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/CAD.java new file mode 100644 index 0000000..bb6c0b9 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/CAD.java @@ -0,0 +1,212 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +import me.paultristanwagner.satchecking.parse.Parser; +import me.paultristanwagner.satchecking.parse.PolynomialParser; +import me.paultristanwagner.satchecking.smt.VariableAssignment; +import me.paultristanwagner.satchecking.theory.arithmetic.Rational; + +import java.util.*; + +import static me.paultristanwagner.satchecking.theory.nonlinear.Cell.emptyCell; +import static me.paultristanwagner.satchecking.theory.nonlinear.Interval.IntervalBoundType.OPEN; +import static me.paultristanwagner.satchecking.theory.nonlinear.Interval.*; +import static me.paultristanwagner.satchecking.theory.nonlinear.RealAlgebraicNumber.realAlgebraicNumber; + +public class CAD { + + public static void main(String[] args) { + Parser parser = new PolynomialParser(); + MultivariatePolynomial p = parser.parse("x^2 + y^2 - 1"); + MultivariatePolynomial q = parser.parse("x^2 + y^3 - 1/2"); + + RealAlgebraicNumber x = realAlgebraicNumber(parser.parse("x^6-2x^4+2x^2-3/4").toUnivariatePolynomial(), Rational.parse("105/128"), Rational.parse("27/32")); + RealAlgebraicNumber y = realAlgebraicNumber(parser.parse("x^6-1x^4+x^2-1/4").toUnivariatePolynomial(), Rational.parse("1/2"), Rational.parse("3/4")); + + System.out.println(p); + System.out.println(q); + System.out.println(x); + System.out.println(y); + + + + } + + private List variables; + private Set polynomials; + + public CAD(Set polynomials) { + this.polynomials = polynomials; + + Set variablesSet = new HashSet<>(); + for (MultivariatePolynomial polynomial : polynomials) { + variablesSet.addAll(polynomial.variables); + } + this.variables = new ArrayList<>(variablesSet); + + + } + + public Set> compute( + Set constraints + ) { + return compute(constraints, false); + } + + public Set> compute( + Set constraints, + boolean onlyEqualities + ) { + this.polynomials = new HashSet<>(); + for (MultivariatePolynomialConstraint constraint : constraints) { + this.polynomials.add(constraint.getPolynomial()); + } + + Set variablesSet = new HashSet<>(); + for (MultivariatePolynomial polynomial : polynomials) { + variablesSet.addAll(polynomial.variables); + } + this.variables = new ArrayList<>(variablesSet); + + // phase 1: projection + Map> p = new HashMap<>(); + p.put(variables.size(), polynomials); + + for (int r = variables.size() - 1; r >= 1; r--) { + String variable = variables.get(r); + Set proj = mcCallumProjection(p.get(r + 1), variable); + p.put(r, proj); + + String previousVariable = variables.get(r); + p.get(r + 1).stream().filter(poly -> !poly.highestVariable().equals(previousVariable)); + } + + // phase 2: lifting + List> D = new ArrayList<>(); + D.add(List.of(emptyCell())); + + for (int i = 1; i <= variables.size(); i++) { + List D_i = new ArrayList<>(); + String variable = variables.get(i - 1); + + for (Cell R : D.get(i - 1)) { + Map s = R.chooseSamplePoint(); + + Set roots = new HashSet<>(); + for (MultivariatePolynomial polynomial : p.get(i)) { + MultivariatePolynomial substituted = polynomial.substitute(s); + Polynomial univariate = substituted.toUnivariatePolynomial(); + roots.addAll(univariate.isolateRoots()); + } + + // sort roots + List sortedRoots = new ArrayList<>(roots); + sortedRoots.sort((a, b) -> a.equals(b) ? 0 : a.lessThan(b) ? -1 : 1); // todo: make comparable + + // remove duplicates + Iterator iterator = sortedRoots.iterator(); + RealAlgebraicNumber previous = null; + while (iterator.hasNext()) { + RealAlgebraicNumber current = iterator.next(); + if (previous != null && previous.equals(current)) { + iterator.remove(); + } + previous = current; + } + + if (sortedRoots.isEmpty()) { + if (!onlyEqualities) { + D_i.add(R.extend(variable, unboundedInterval())); + } + } else { + if (!onlyEqualities) { + D_i.add(R.extend(variable, intervalLowerUnbounded(sortedRoots.get(0), OPEN))); + D_i.add(R.extend(variable, intervalUpperUnbounded(sortedRoots.get(sortedRoots.size() - 1), OPEN))); + + } + D_i.add(R.extend(variable, pointInterval(sortedRoots.get(sortedRoots.size() - 1)))); + } + + for (int j = 0; j < sortedRoots.size() - 1; j++) { + RealAlgebraicNumber a = sortedRoots.get(j); + RealAlgebraicNumber b = sortedRoots.get(j + 1); + + D_i.add(R.extend(variable, pointInterval(a))); + if (!onlyEqualities) { + D_i.add(R.extend(variable, interval(a, b, OPEN, OPEN))); + } + } + } + + D.add(D_i); + } + + List result = D.get(variables.size()); + Set> assignments = new HashSet<>(); + for (Cell cell : result) { + VariableAssignment assignment = new VariableAssignment<>(cell.chooseSamplePoint()); + assignments.add(assignment); + } + + return assignments; + } + + public Set mcCallumProjection( + Set polynomials, String variable) { + List result = new ArrayList<>(); + + for (MultivariatePolynomial p : polynomials) { + for (MultivariatePolynomial q : polynomials) { + if (p.equals(q)) { + continue; + } + + MultivariatePolynomial resultant = p.resultant(q, variable); + + if (resultant.isConstant()) { + continue; + } + + result.add(resultant); + } + } + + for (MultivariatePolynomial polynomial : polynomials) { + MultivariatePolynomial disc = polynomial.discriminant(variable); + if (disc.isConstant()) { + continue; + } + + result.add(disc); + } + + for (MultivariatePolynomial polynomial : polynomials) { + List coefficients = polynomial.getCoefficients(variable); + + for (MultivariatePolynomial coefficient : coefficients) { + if (coefficient.isConstant()) { + continue; + } + + result.add(coefficient); + } + } + + Set unique = new HashSet<>(); + + for (MultivariatePolynomial multivariatePolynomial : result) { + boolean contains = false; + for (MultivariatePolynomial added : unique) { + if (multivariatePolynomial.equals(added)) { + contains = true; + break; + } + } + + if (!contains) { + unique.add(multivariatePolynomial); + } + } + + return unique; + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Cell.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Cell.java new file mode 100644 index 0000000..0a6da5c --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Cell.java @@ -0,0 +1,54 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +import java.util.*; + +public class Cell { + + private final List variables; + private final List intervals; + + private Cell(List variables, List intervals) { + this.variables = variables; + this.intervals = intervals; + } + + public static Cell cell(List variables, List intervals) { + return new Cell(variables, intervals); + } + + public static Cell cell(List variables, Interval... intervalArray) { + return new Cell(variables, Arrays.asList(intervalArray)); + } + + public static Cell emptyCell() { + return new Cell(new ArrayList<>(), new ArrayList<>()); + } + + public Cell extend(String variable, Interval interval) { + List newVariables = new ArrayList<>(this.variables); + newVariables.add(variable); + + List newIntervals = new ArrayList<>(this.intervals); + newIntervals.add(interval); + + return cell(newVariables, newIntervals); + } + + public Map chooseSamplePoint() { + Map samplePoint = new HashMap<>(); + for (int i = 0; i < intervals.size(); i++) { + samplePoint.put(variables.get(i), intervals.get(i).chooseSample()); + } + + return samplePoint; + } + + public List getIntervals() { + return intervals; + } + + @Override + public String toString() { + return "Cell{" + "intervals=" + intervals + '}'; + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Exponent.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Exponent.java new file mode 100644 index 0000000..0770cce --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Exponent.java @@ -0,0 +1,162 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +public class Exponent implements Comparable { + + private List values; + + private Exponent(List values) { + this.values = values; + } + + public static Exponent exponent(List values) { + return new Exponent(values); + } + + public static Exponent exponent(Integer... values) { + return new Exponent(Arrays.asList(values)); + } + + public static Exponent constantExponent(int length) { + Integer[] zeros = new Integer[length]; + Arrays.fill(zeros, 0); + return exponent(zeros); + } + + @Override + public int compareTo(Exponent o) { + if (this.values.size() != o.values.size()) { + throw new IllegalStateException("Cannot compare exponents"); + } + + for (int r = values.size() - 1; r >= 0; r--) { + int comparison = Integer.compare(this.values.get(r), o.values.get(r)); + if (comparison != 0) { + return comparison; + } + } + + return 0; + } + + public Exponent add(Exponent other) { + if (this.values.size() != other.values.size()) { + throw new IllegalArgumentException("Exponent size does not match"); + } + + List values = new ArrayList<>(); + for (int i = 0; i < this.values.size(); i++) { + values.add(this.get(i) + other.get(i)); + } + + return exponent(values); + } + + public Exponent subtract(Exponent other) { + if (this.values.size() != other.values.size()) { + throw new IllegalArgumentException("Exponent size does not match"); + } + + List values = new ArrayList<>(); + for (int i = 0; i < this.values.size(); i++) { + int result = this.get(i) - other.get(i); + if (result < 0) { + throw new IllegalArgumentException("Cannot subtract exponent if result would be negative"); + } + + values.add(result); + } + + return exponent(values); + } + + public boolean divides(Exponent other) { + if (this.values.size() != other.values.size()) { + throw new IllegalArgumentException("Exponent size does not match"); + } + + for (int i = 0; i < this.values.size(); i++) { + if (this.get(i) > other.get(i)) { + return false; + } + } + + return true; + } + + public int get(int index) { + if (index < 0) { + throw new IllegalArgumentException("Cannot get negative exponent index"); + } else if (index >= values.size()) { + throw new IllegalArgumentException("Invalid exponent index"); + } + + return values.get(index); + } + + public int highestNonZeroIndex() { + for (int r = values.size() - 1; r >= 0; r--) { + if (values.get(r) > 0) { + return r; + } + } + + return -1; + } + + public boolean isConstantExponent() { + for (Integer exponent : values) { + if (exponent != 0) { + return false; + } + } + + return true; + } + + public static Exponent project( + Exponent from, List originVariables, List targetVariables) { + Integer[] newExponentsArray = new Integer[targetVariables.size()]; + Arrays.fill(newExponentsArray, 0); + for (int i = 0; i < from.values.size(); i++) { + String variable = originVariables.get(i); + int variableIndex = targetVariables.indexOf(variable); // todo: inefficient + + if (variableIndex == -1) { + continue; + } + + int exponent = from.get(i); + + newExponentsArray[variableIndex] = exponent; + } + + return exponent(newExponentsArray); + } + + public List getValues() { + return values; + } + + @Override + public String toString() { + return values.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Exponent exponent = (Exponent) o; + return Objects.equals(values, exponent.values); + } + + @Override + public int hashCode() { + return Objects.hash(values); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Interval.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Interval.java new file mode 100644 index 0000000..d309fd4 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Interval.java @@ -0,0 +1,336 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +import me.paultristanwagner.satchecking.parse.PolynomialParser; +import me.paultristanwagner.satchecking.theory.arithmetic.Number; + +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ZERO; +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.number; +import static me.paultristanwagner.satchecking.theory.nonlinear.Interval.IntervalBoundType.*; +import static me.paultristanwagner.satchecking.theory.nonlinear.RealAlgebraicNumber.realAlgebraicNumber; + +public class Interval { + + public static void main(String[] args) { + PolynomialParser parser = new PolynomialParser(); + Polynomial p = parser.parse("x^2").toUnivariatePolynomial(); + Interval interval = interval(number(-1), number(5), CLOSED, CLOSED); + System.out.println(p.evaluate(interval)); + } + + public enum IntervalBoundType { + UNBOUNDED, + OPEN, + CLOSED + } + + private final RealAlgebraicNumber lowerBound; + private final RealAlgebraicNumber upperBound; + private final IntervalBoundType lowerBoundType; + private final IntervalBoundType upperBoundType; + + private Interval( + RealAlgebraicNumber lowerBound, + RealAlgebraicNumber upperBound, + IntervalBoundType lowerBoundType, + IntervalBoundType upperBoundType) { + this.lowerBound = lowerBound; + this.upperBound = upperBound; + this.lowerBoundType = lowerBoundType; + this.upperBoundType = upperBoundType; + } + + public static Interval unboundedInterval() { + return new Interval(null, null, UNBOUNDED, UNBOUNDED); + } + + public static Interval intervalLowerUnbounded( + RealAlgebraicNumber upperBound, IntervalBoundType upperBoundType) { + return new Interval(null, upperBound, UNBOUNDED, upperBoundType); + } + + public static Interval intervalUpperUnbounded( + RealAlgebraicNumber lowerBound, IntervalBoundType lowerBoundType) { + return new Interval(lowerBound, null, lowerBoundType, UNBOUNDED); + } + + public static Interval intervalLowerUnbounded( + Number upperBound, IntervalBoundType upperBoundType) { + return intervalLowerUnbounded(realAlgebraicNumber(upperBound), upperBoundType); + } + + public static Interval intervalUpperUnbounded( + Number lowerBound, IntervalBoundType lowerBoundType) { + return intervalUpperUnbounded(realAlgebraicNumber(lowerBound), lowerBoundType); + } + + public static Interval interval( + RealAlgebraicNumber lowerBound, + RealAlgebraicNumber upperBound, + IntervalBoundType lowerBoundType, + IntervalBoundType upperBoundType) { + return new Interval(lowerBound, upperBound, lowerBoundType, upperBoundType); + } + + public static Interval interval( + Number lowerBound, + Number upperBound, + IntervalBoundType lowerBoundType, + IntervalBoundType upperBoundType) { + return interval( + realAlgebraicNumber(lowerBound), + realAlgebraicNumber(upperBound), + lowerBoundType, + upperBoundType); + } + + public static Interval pointInterval(RealAlgebraicNumber point) { + return new Interval(point, point, CLOSED, CLOSED); + } + + public static Interval pointInterval(Number number) { + return pointInterval(realAlgebraicNumber(number)); + } + + // todo: improve return values when rational numbers are possible + public RealAlgebraicNumber chooseSample() { + if (lowerBoundType == UNBOUNDED && upperBoundType == UNBOUNDED) { + return realAlgebraicNumber(ZERO()); + } else if (lowerBoundType == CLOSED) { + return lowerBound; // todo: here we might be able to return a rational number + } else if (upperBoundType == CLOSED) { + return upperBound; // todo: here we might be able to return a rational number + } + + if (lowerBoundType == UNBOUNDED && upperBoundType == OPEN) { + if (upperBound.isNumeric()) { + return realAlgebraicNumber(upperBound.numericValue().subtract(number(1))); + } else { + return realAlgebraicNumber(upperBound.getLowerBound()); + } + } + + if (lowerBoundType == OPEN && upperBoundType == UNBOUNDED) { + if (lowerBound.isNumeric()) { + return realAlgebraicNumber(lowerBound.numericValue().add(number(1))); + } else { + return realAlgebraicNumber(lowerBound.getUpperBound()); + } + } + + if (lowerBound.isNumeric() && upperBound.isNumeric()) { + Number rationalMidpoint = + lowerBound.numericValue().add(upperBound.numericValue()).divide(number(2)); + return realAlgebraicNumber(rationalMidpoint); + } + + if (!upperBound.isNumeric() && lowerBound.isNumeric()) { + if (upperBound.getLowerBound().equals(lowerBound.numericValue())) { + upperBound.refine(); + } + + return realAlgebraicNumber(upperBound.getLowerBound()); + } else if (!lowerBound.isNumeric() && upperBound.isNumeric()) { + if (lowerBound.getUpperBound().equals(upperBound.numericValue())) { + lowerBound.refine(); + } + + return realAlgebraicNumber(lowerBound.getUpperBound()); + } + + return realAlgebraicNumber(lowerBound.getUpperBound()); + } + + public Interval add(Interval other) { + if (lowerBoundType != CLOSED || other.lowerBoundType != CLOSED || upperBoundType != CLOSED || other.upperBoundType != CLOSED) { + throw new IllegalArgumentException("Can only add closed intervals"); + } + + if (!lowerBound.isNumeric() || !upperBound.isNumeric() || !other.lowerBound.isNumeric() || !other.upperBound.isNumeric()) { + throw new IllegalArgumentException("Can only add numeric intervals"); + } + + return interval( + lowerBound.numericValue().add(other.lowerBound.numericValue()), + upperBound.numericValue().add(other.upperBound.numericValue()), + CLOSED, + CLOSED); + } + + public Interval subtract(Interval other) { + if (lowerBoundType != CLOSED || other.lowerBoundType != CLOSED || upperBoundType != CLOSED || other.upperBoundType != CLOSED) { + throw new IllegalArgumentException("Can only subtract closed intervals"); + } + + if (!lowerBound.isNumeric() || !upperBound.isNumeric() || !other.lowerBound.isNumeric() || !other.upperBound.isNumeric()) { + throw new IllegalArgumentException("Can only subtract numeric intervals"); + } + + return interval( + lowerBound.numericValue().subtract(other.upperBound.numericValue()), + upperBound.numericValue().subtract(other.lowerBound.numericValue()), + CLOSED, + CLOSED); + } + + public Interval multiply(Number number) { + if (lowerBoundType != CLOSED || upperBoundType != CLOSED) { + throw new IllegalArgumentException("Can only multiply closed intervals"); + } + + if (!lowerBound.isNumeric() || !upperBound.isNumeric()) { + throw new IllegalArgumentException("Can only multiply numeric intervals"); + } + + Number lower = lowerBound.numericValue().multiply(number); + Number upper = upperBound.numericValue().multiply(number); + if (lower.greaterThan(upper)) { + Number temp = lower; + lower = upper; + upper = temp; + } + + return interval( + lower, + upper, + CLOSED, + CLOSED); + } + + public Interval multiply(Interval other) { + if (lowerBoundType != CLOSED || other.lowerBoundType != CLOSED || upperBoundType != CLOSED || other.upperBoundType != CLOSED) { + throw new IllegalArgumentException("Can only multiply closed intervals"); + } + + if (!lowerBound.isNumeric() || !upperBound.isNumeric() || !other.lowerBound.isNumeric() || !other.upperBound.isNumeric()) { + throw new IllegalArgumentException("Can only multiply numeric intervals"); + } + + Number[] products = { + lowerBound.numericValue().multiply(other.lowerBound.numericValue()), + lowerBound.numericValue().multiply(other.upperBound.numericValue()), + upperBound.numericValue().multiply(other.lowerBound.numericValue()), + upperBound.numericValue().multiply(other.upperBound.numericValue()) + }; + + Number lower = products[0]; + Number upper = products[0]; + + for (Number product : products) { + if (product.lessThan(lower)) { + lower = product; + } + + if (product.greaterThan(upper)) { + upper = product; + } + } + + return interval(lower, upper, CLOSED, CLOSED); + } + + public Interval pow(int exponent) { + if (lowerBoundType != CLOSED || upperBoundType != CLOSED) { + throw new IllegalArgumentException("Can only raise closed intervals to a power"); + } + + if (!lowerBound.isNumeric() || !upperBound.isNumeric()) { + throw new IllegalArgumentException("Can only raise numeric intervals to a power"); + } + + if (exponent < 0) { + throw new IllegalArgumentException("Cannot raise an interval to a negative power"); + } + + if (exponent == 0) { + return pointInterval(number(1)); + } + + // todo: use binary exponentiation + Interval result = this; + for (int i = 1; i < exponent; i++) { + result = result.multiply(this); + } + + if (exponent % 2 == 0 && lowerBound.isNegative()) { + result = interval(number(0), result.getUpperBound().numericValue(), CLOSED, CLOSED); + } + + return result; + } + + public boolean contains(RealAlgebraicNumber number) { + if (lowerBoundType == UNBOUNDED && upperBoundType == UNBOUNDED) { + return true; + } else if(lowerBoundType == OPEN && number.lessThanOrEqual(lowerBound)) { + return false; + } else if(upperBoundType == OPEN && number.greaterThanOrEqual(upperBound)) { + return false; + } else if(lowerBoundType == CLOSED && number.lessThan(lowerBound)) { + return false; + } else if(upperBoundType == CLOSED && number.greaterThan(upperBound)) { + return false; + } + + return true; + } + + public boolean contains(Number number) { + return contains(realAlgebraicNumber(number)); + } + + public boolean containsZero() { + return contains(ZERO()); + } + + public int sign() { + if (lowerBound.isZero() && upperBound.isZero()) { + return 0; + } + + if (containsZero()) { + throw new IllegalArgumentException("The interval has no unique sign"); + } + + if (lowerBound.isZero()) { + return upperBound.sign(); + } + + return lowerBound.sign(); + } + + public IntervalBoundType getLowerBoundType() { + return lowerBoundType; + } + + public IntervalBoundType getUpperBoundType() { + return upperBoundType; + } + + public RealAlgebraicNumber getLowerBound() { + return lowerBound; + } + + public RealAlgebraicNumber getUpperBound() { + return upperBound; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + switch (lowerBoundType) { + case UNBOUNDED -> sb.append("(-oo"); + case OPEN -> sb.append("(").append(lowerBound); + case CLOSED -> sb.append("[").append(lowerBound); + } + + sb.append(", "); + + switch (upperBoundType) { + case UNBOUNDED -> sb.append("oo)"); + case OPEN -> sb.append(upperBound).append(")"); + case CLOSED -> sb.append(upperBound).append("]"); + } + + return sb.toString(); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Matrix.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Matrix.java new file mode 100644 index 0000000..4d4dcea --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Matrix.java @@ -0,0 +1,103 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomial.ZERO; + +public class Matrix { + + private final List variables; + private final int m; + private final int n; + private final Map, MultivariatePolynomial> entries; + + private Matrix(List variables, int m, int n, Map, MultivariatePolynomial> entries) { + this.variables = variables; + this.m = m; + this.n = n; + this.entries = entries; + } + + public static Matrix matrix(List variables, int m, int n, Map, MultivariatePolynomial> entries) { + return new Matrix(variables, m, n, entries); + } + + public boolean isSquare() { + return m == n; + } + + public MultivariatePolynomial minor(int i, int j) { + if (!isSquare()) { + throw new IllegalStateException("Cannot compute minor of non-square matrix"); + } + + List variables = new ArrayList<>(this.variables); + int n = this.n - 1; + Map, MultivariatePolynomial> entries = new HashMap<>(); + + for (List index : this.entries.keySet()) { + if(index.get(0) == i || index.get(1) == j) { + continue; + } + + List newIndex = new ArrayList<>(List.of( + index.get(0) < i ? index.get(0) : index.get(0) - 1, + index.get(1) < j ? index.get(1) : index.get(1) - 1 + )); + + entries.put(newIndex, this.entries.get(index)); + } + + Matrix subMatrix = matrix(variables, n, n, entries); + return subMatrix.determinant(); + } + + public MultivariatePolynomial determinant() { + if (!isSquare()) { + throw new IllegalStateException("Cannot compute determinant of non-square matrix"); + } + + if (n == 1) { + return entries.values().stream().findAny().orElseGet(MultivariatePolynomial::ZERO); + } + + if (n == 2) { + MultivariatePolynomial a = entries.getOrDefault(List.of(0, 0), ZERO()); + MultivariatePolynomial b = entries.getOrDefault(List.of(0, 1), ZERO()); + MultivariatePolynomial c = entries.getOrDefault(List.of(1, 0), ZERO()); + MultivariatePolynomial d = entries.getOrDefault(List.of(1, 1), ZERO()); + + MultivariatePolynomial ad = a.multiply(d); + MultivariatePolynomial bc = b.multiply(c); + return ad.subtract(bc); + } + + // Laplace's expansion along the 0-th column + int j = 0; + MultivariatePolynomial result = ZERO(); + for (int i = 0; i < n; i++) { + MultivariatePolynomial b = entries.getOrDefault(List.of(i, j), ZERO()); + if (b.isZero()) { + continue; + } + + MultivariatePolynomial minor = minor(i, j); + boolean positive = (i + j) % 2 == 0; + + MultivariatePolynomial bm = b.multiply(minor); + + if (positive) { + result = result.add(bm); + } else { + result = result.subtract(bm); + } + } + + return result; + } + + // todo: add toString method +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/MultivariatePolynomial.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/MultivariatePolynomial.java new file mode 100644 index 0000000..3f6240e --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/MultivariatePolynomial.java @@ -0,0 +1,811 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +import me.paultristanwagner.satchecking.parse.Parser; +import me.paultristanwagner.satchecking.parse.PolynomialParser; +import me.paultristanwagner.satchecking.theory.arithmetic.Number; + +import java.util.*; +import java.util.stream.Collectors; + +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ONE; +import static me.paultristanwagner.satchecking.theory.nonlinear.Exponent.constantExponent; +import static me.paultristanwagner.satchecking.theory.nonlinear.Exponent.exponent; +import static me.paultristanwagner.satchecking.theory.nonlinear.Interval.IntervalBoundType.CLOSED; +import static me.paultristanwagner.satchecking.theory.nonlinear.Interval.interval; +import static me.paultristanwagner.satchecking.theory.nonlinear.Interval.pointInterval; +import static me.paultristanwagner.satchecking.theory.nonlinear.Matrix.matrix; +import static me.paultristanwagner.satchecking.theory.nonlinear.Polynomial.polynomial; + +public class MultivariatePolynomial { + + public static void main(String[] args) { + Parser parser = new PolynomialParser(); + Scanner scanner = new Scanner(System.in); + System.out.print("p: "); + MultivariatePolynomial p = parser.parse("t - y^5 - x^5 + x"); // parser.parse(scanner.nextLine()); + System.out.print("q: "); + MultivariatePolynomial q = parser.parse("x^6-1/2*x^4+2*x^2-1/2"); // parser.parse(scanner.nextLine()); + System.out.print("var: "); + String var = scanner.nextLine(); + + System.out.println("p = " + p); + System.out.println("q = " + q); + + System.out.println("Resultant[p, q, " + var + "] = " + p.resultant(q, var)); + } + + public Map coefficients; + public List variables; + + private MultivariatePolynomial(Map coefficients, List variables) { + this.coefficients = coefficients; + this.variables = variables; + } + + public static MultivariatePolynomial multivariatePolynomial( + Map coefficients, List variables) { + return new MultivariatePolynomial(coefficients, variables); + } + + public static MultivariatePolynomial constant(Number number) { + return multivariatePolynomial(Map.of(exponent(), number), new ArrayList<>()); + } + + public static MultivariatePolynomial variable(String variable) { + return multivariatePolynomial(Map.of(exponent(1), ONE()), List.of(variable)); + } + + public static MultivariatePolynomial monomial(Exponent exponent, List variables) { + return multivariatePolynomial(Map.of(exponent, ONE()), variables); + } + + public static MultivariatePolynomial ZERO() { + return multivariatePolynomial(new HashMap<>(), new ArrayList<>()); + } + + public String highestVariable() { + if (this.variables.isEmpty()) { + throw new IllegalStateException("There are no variables"); + } + + int highestVariableIndex = 0; + for (Exponent exponent : coefficients.keySet()) { + int variableIndex = exponent.highestNonZeroIndex(); + highestVariableIndex = Math.max(highestVariableIndex, variableIndex); + } + + return variables.get(highestVariableIndex); + } + + public int degree(String variable) { + int variableIndex = this.variables.indexOf(variable); + int highestExponent = 0; + + if (variableIndex == -1) { + return 0; + } + + for (Exponent exponent : this.coefficients.keySet()) { + if (this.coefficients.get(exponent).isZero()) { + continue; + } + + int value = exponent.get(variableIndex); + highestExponent = Math.max(highestExponent, value); + } + + return highestExponent; + } + + public Number leadingCoefficientNumber(String variable) { + Exponent leadMonomial = getLeadMonomial(variable); + return coefficients.getOrDefault(leadMonomial, Number.ZERO()); + } + + public MultivariatePolynomial leadingCoefficientInVariable(String variable) { + List coefficients = getCoefficients(variable); + return coefficients.get(coefficients.size() - 1); + } + + public boolean isZero() { + for (Exponent exponent : coefficients.keySet()) { + if (!coefficients.get(exponent).isZero()) { + return false; + } + } + + return true; + } + + public boolean isConstant() { + for (Exponent exponent : coefficients.keySet()) { + if (!coefficients.get(exponent).isZero() && !exponent.isConstantExponent()) { + return false; + } + } + + return true; + } + + public MultivariatePolynomial add(MultivariatePolynomial other) { + Set variablesSet = new HashSet<>(this.variables); + variablesSet.addAll(other.variables); + List newVariables = new ArrayList<>(variablesSet); + Map newCoefficients = new HashMap<>(); + + for (Exponent exponent : this.coefficients.keySet()) { + Exponent thisNewExponent = Exponent.project(exponent, this.variables, newVariables); + + Number c = newCoefficients.getOrDefault(thisNewExponent, Number.ZERO()); + c = c.add(this.coefficients.getOrDefault(exponent, Number.ZERO())); + newCoefficients.put(thisNewExponent, c); + } + + for (Exponent exponent : other.coefficients.keySet()) { + Exponent otherNewExponent = Exponent.project(exponent, other.variables, newVariables); + Number c = newCoefficients.getOrDefault(otherNewExponent, Number.ZERO()); + c = c.add(other.coefficients.getOrDefault(exponent, Number.ZERO())); + newCoefficients.put(otherNewExponent, c); + } + + return multivariatePolynomial(newCoefficients, newVariables); + } + + public MultivariatePolynomial negate() { + List variables = new ArrayList<>(this.variables); + Map coefficients = new HashMap<>(); + for (Exponent exponent : this.coefficients.keySet()) { + coefficients.put(exponent, this.coefficients.get(exponent).negate()); + } + + return multivariatePolynomial(coefficients, variables); + } + + public MultivariatePolynomial subtract(MultivariatePolynomial other) { + return this.add(other.negate()); + } + + public MultivariatePolynomial multiply(MultivariatePolynomial other) { + Set variablesSet = new HashSet<>(this.variables); + variablesSet.addAll(other.variables); + List newVariables = new ArrayList<>(variablesSet); + Map newCoefficients = new HashMap<>(); + + for (Exponent exponent : this.coefficients.keySet()) { + Exponent projectedExponent = Exponent.project(exponent, this.variables, newVariables); + for (Exponent otherExponent : other.coefficients.keySet()) { + Exponent projectedOtherExponent = + Exponent.project(otherExponent, other.variables, newVariables); + Exponent newExponent = projectedExponent.add(projectedOtherExponent); + + Number c = newCoefficients.getOrDefault(newExponent, Number.ZERO()); + c = c.add(this.coefficients.get(exponent).multiply(other.coefficients.get(otherExponent))); + + newCoefficients.put(newExponent, c); + } + } + + return multivariatePolynomial(newCoefficients, newVariables); + } + + // todo: use fast exponentiation + public MultivariatePolynomial pow(int exponent) { + if (exponent < 0) { + if (this.isConstant()) { + Number c = this.coefficients.getOrDefault(getLeadMonomial(), Number.ZERO()); + return constant(c.pow(-exponent)); + } + + throw new IllegalArgumentException("Exponent must be non-negative"); + } + + if (exponent == 0) { + return constant(ONE()); + } + + MultivariatePolynomial result = this; + for (int i = 1; i < exponent; i++) { + result = result.multiply(this); + } + + return result; + } + + public Exponent getLeadMonomial() { + if (this.variables.isEmpty()) { + return constantExponent(0); + } + + return getLeadMonomial(highestVariable()); + } + + public Exponent getLeadMonomial(String variable) { + Exponent leadMonomial = null; + + int variableIndex = variables.indexOf(variable); + + for (Exponent exponent : coefficients.keySet()) { + if (coefficients.get(exponent).isZero()) { + continue; + } + + if (variableIndex == -1) { + if (leadMonomial == null || exponent.compareTo(leadMonomial) > 0) { + leadMonomial = exponent; + } + continue; + } + + if (leadMonomial == null + || leadMonomial.get(variableIndex) < exponent.get(variableIndex) + || leadMonomial.get(variableIndex) == exponent.get(variableIndex) + && leadMonomial.compareTo(exponent) < 0) { + leadMonomial = exponent; + } + } + + if (leadMonomial == null) { + return constantExponent(variables.size()); + } + + return leadMonomial; + } + + public List pseudoDivision( + MultivariatePolynomial divisor, String variable) { + if (divisor.isZero()) { + throw new IllegalArgumentException("Divisor must not be zero"); + } + + if (this.isZero()) { + return List.of(ZERO(), ZERO()); + } + + int a = this.degree(variable); + int b = divisor.degree(variable); + + MultivariatePolynomial bLeadCoefficientInVariable = divisor.leadingCoefficientInVariable(variable); + + MultivariatePolynomial dividend = bLeadCoefficientInVariable.pow(a - b + 1).multiply(this); + + return dividend.divide(divisor, variable); + } + + public List divide(MultivariatePolynomial divisor, String variable) { + if (divisor.isZero()) { + throw new IllegalArgumentException("Divisor must not be zero"); + } + + if (this.isZero()) { + return List.of(ZERO(), ZERO()); + } + + if (this.degree(variable) < divisor.degree(variable)) { + return List.of(ZERO(), this); + } + + Exponent divisorLM = divisor.getLeadMonomial(variable); + Number divisorLC = divisor.leadingCoefficientNumber(variable); + + Set variablesSet = new HashSet<>(this.variables); + variablesSet.addAll(divisor.variables); + List newVariables = new ArrayList<>(variablesSet); + + MultivariatePolynomial quotient = ZERO(); + MultivariatePolynomial remainder = this; + while (!remainder.isZero() && remainder.degree(variable) >= divisor.degree(variable)) { + Exponent remainderLM = remainder.getLeadMonomial(variable); + + Exponent projectedDivisorLM = Exponent.project(divisorLM, divisor.variables, newVariables); + Exponent projectedRemainderLM = + Exponent.project(remainderLM, remainder.variables, newVariables); + + if (!projectedDivisorLM.divides(projectedRemainderLM)) { + break; + } + + Number remainderLC = remainder.leadingCoefficientNumber(variable); + + MultivariatePolynomial monomial = + monomial( + projectedRemainderLM.subtract(projectedDivisorLM), new ArrayList<>(this.variables)); + + Number highestCoefficientDivided = remainderLC.divide(divisorLC); + + MultivariatePolynomial factor = monomial.multiply(constant(highestCoefficientDivided)); + + quotient = quotient.add(factor); + remainder = remainder.subtract(divisor.multiply(factor)); + } + + return List.of(quotient, remainder); + } + + // todo: for testing + public boolean actuallyContainsVariable(String variable) { + if (!variables.contains(variable)) { + return false; + } + + for (Exponent exponent : coefficients.keySet()) { + if (coefficients.get(exponent).isZero()) { + continue; + } + + if (exponent.get(variables.indexOf(variable)) != 0) { + return true; + } + } + + return false; + } + + // todo: clean this up + // using pseudo-remainder sequence à la https://github.com/sympy/sympy + public MultivariatePolynomial resultant(MultivariatePolynomial other, String variable) { + int n = this.degree(variable); + int m = other.degree(variable); + + if (n < m) { + return other.resultant(this, variable); + } + + if (this.isZero() || other.isZero()) { + return ZERO(); + } + + int d = n - m; + MultivariatePolynomial b = (d + 1) % 2 == 0 ? constant(ONE()) : constant(ONE().negate()); + + List R = new ArrayList<>(); + R.add(this); + R.add(other); + + MultivariatePolynomial h = this.pseudoDivision(other, variable).get(1); + h = h.multiply(b); + + MultivariatePolynomial leadCoefficient = other.leadingCoefficientInVariable(variable); + + MultivariatePolynomial c = leadCoefficient.pow(d); + + c = c.negate(); + + MultivariatePolynomial f = this; + MultivariatePolynomial g = other; + + while (!h.isZero()) { + int k = h.degree(variable); + + R.add(h); + + f = g; + g = h; + int temp = m; + m = k; + d = temp - k; + + b = leadCoefficient.negate().multiply(c.pow(d)); + + h = f.pseudoDivision(g, variable).get(1); + h = h.divide(b, variable).get(0); + + leadCoefficient = g.leadingCoefficientInVariable(variable); + + if (d > 1) { + MultivariatePolynomial p = leadCoefficient.negate().pow(d); + MultivariatePolynomial q = c.pow(d - 1); + c = p.divide(q, variable).get(0); + } else { + c = leadCoefficient.negate(); + } + } + + MultivariatePolynomial resultant = R.get(R.size() - 1); + if (resultant.actuallyContainsVariable(variable)) { + return ZERO(); + } + + return R.get(R.size() - 1); + } + + public List getCoefficients(String variable) { + List newVariables = new ArrayList<>(variables); + newVariables.remove(variable); + + int highestExponent = degree(variable); + + MultivariatePolynomial[] coefficientsArray = new MultivariatePolynomial[highestExponent + 1]; + Arrays.fill(coefficientsArray, ZERO()); + + int variableIndex = variables.indexOf(variable); + for (Exponent exponent : this.coefficients.keySet()) { + if (this.coefficients.get(exponent).isZero()) { + continue; + } + + int variableExponent = variableIndex != -1 ? exponent.get(variableIndex) : 0; + + Number c = this.coefficients.get(exponent); + Exponent newExponent = Exponent.project(exponent, this.variables, newVariables); + + Map monomialCoefficients = new HashMap<>(Map.of(newExponent, c)); + MultivariatePolynomial monomial = multivariatePolynomial(monomialCoefficients, newVariables); + + coefficientsArray[variableExponent] = coefficientsArray[variableExponent].add(monomial); + } + + return Arrays.stream(coefficientsArray).collect(Collectors.toList()); + } + + public MultivariatePolynomial derivative(String variable) { + int variableIndex = this.variables.indexOf(variable); + + if (variableIndex == -1) { + return ZERO(); + } + + List newVariables = new ArrayList<>(this.variables); + Map coefficients = new HashMap<>(); + for (Exponent exponent : this.coefficients.keySet()) { + if (exponent.get(variableIndex) == 0) { + continue; + } + + List exponentValues = new ArrayList<>(exponent.getValues()); + exponentValues.set(variableIndex, exponent.get(variableIndex) - 1); + Exponent newExponent = exponent(exponentValues); + + coefficients.put(newExponent, this.coefficients.get(exponent)); + } + + return multivariatePolynomial(coefficients, newVariables); + } + + public MultivariatePolynomial resultantOld(MultivariatePolynomial other, String variable) { + List newVariables = new ArrayList<>(variables); + newVariables.remove(variable); + + List thisCoefficients = this.getCoefficients(variable); + List otherCoefficients = other.getCoefficients(variable); + + Map, MultivariatePolynomial> entries = new HashMap<>(); + + for (int i = 0; i < otherCoefficients.size() - 1; i++) { + for (int j = 0; j < thisCoefficients.size(); j++) { + if (thisCoefficients.get(j).isZero()) { + continue; + } + + entries.put(List.of(i, i + j), thisCoefficients.get(j)); + } + } + + for (int i = 0; i < thisCoefficients.size() - 1; i++) { + for (int j = 0; j < otherCoefficients.size(); j++) { + if (otherCoefficients.get(j).isZero()) { + continue; + } + + entries.put(List.of(otherCoefficients.size() + i - 1, i + j), otherCoefficients.get(j)); + } + } + + int n = thisCoefficients.size() + otherCoefficients.size() - 2; + Matrix sylvesterMatrix = matrix(newVariables, n, n, entries); + + return sylvesterMatrix.determinant(); + } + + public MultivariatePolynomial discriminant(String variable) { + MultivariatePolynomial derivative = this.derivative(variable); + + return this.resultant(derivative, variable); + } + + public MultivariatePolynomial substitute(Map substitution) { + MultivariatePolynomial current = this; + for (String variable : substitution.keySet()) { + RealAlgebraicNumber value = substitution.get(variable); + int variableIndex = current.variables.indexOf(variable); + + if (variableIndex == -1) { + continue; + } + + if (!value.isNumeric()) { + Polynomial ranPolynomial = value.getPolynomial(); + current = + current.resultant( + ranPolynomial.toMultivariatePolynomial(variable), + variable); // todo: this introduces incorrect roots + continue; + } + + Number rationalValue = value.numericValue(); + + List newVariables = new ArrayList<>(current.variables); + newVariables.remove(variable); + + Map newCoefficients = new HashMap<>(); + + for (Exponent exponent : current.coefficients.keySet()) { + int power = exponent.get(variableIndex); + + Number c = current.coefficients.get(exponent); + Exponent newExponent = Exponent.project(exponent, current.variables, newVariables); + + if (power != 0) { + c = c.multiply(rationalValue.pow(power)); + } + + Number prev = newCoefficients.getOrDefault(newExponent, Number.ZERO()); + newCoefficients.put(newExponent, prev.add(c)); + } + + current = multivariatePolynomial(newCoefficients, newVariables); + } + + return current; + } + + public Interval evaluate(Map substitution) { + Interval interval = null; + for (Exponent exponent : coefficients.keySet()) { + if (coefficients.get(exponent).isZero()) { + continue; + } + + Number coefficient = coefficients.get(exponent); + Interval monomialInterval = null; + for (int i = 0; i < exponent.getValues().size(); i++) { + String variable = variables.get(i); + int power = exponent.get(i); + Interval variableInterval = substitution.get(variable); + + if (variableInterval == null) { + throw new IllegalArgumentException("No interval for variable " + variable); + } + + if (monomialInterval == null) { + monomialInterval = variableInterval.pow(power); + } else { + monomialInterval = monomialInterval.multiply(variableInterval.pow(power)); + } + } + + if (interval == null) { + interval = monomialInterval.multiply(coefficient); + } else { + interval = interval.add(monomialInterval.multiply(coefficient)); + } + } + + return interval; + } + + public int evaluateSign(Map substitution) { + Map numericSubstitutions = new HashMap<>(); + substitution.forEach( + (variable, ran) -> { + if (ran.isNumeric()) { + numericSubstitutions.put(variable, ran); + } + }); + + MultivariatePolynomial substituted = this.substitute(numericSubstitutions); + + if (substituted.isZero()) { + return 0; + } else if (substituted.isConstant()) { + Exponent constantExponent = + substituted.coefficients.keySet().stream() + .filter(exponent -> !substituted.coefficients.get(exponent).isZero()) + .findAny() + .orElseThrow(); + return substituted.coefficients.get(constantExponent).sign(); + } + + Map algebraicSubstitutions = new HashMap<>(); + substitution.forEach( + (variable, ran) -> { + if (!ran.isNumeric()) { + algebraicSubstitutions.put(variable, ran); + } + }); + + if (algebraicSubstitutions.isEmpty()) { + throw new IllegalArgumentException("No algebraic substitutions"); + } + + MultivariatePolynomial current = substituted; + Iterator iterator = algebraicSubstitutions.keySet().iterator(); + String firstVariable = iterator.next(); + String freshVariable = firstVariable + "'"; // todo: make sure this is fresh + + Map intervalSubstitution = constructIntervalSubstitution(substitution); + Interval res_I = this.evaluate(intervalSubstitution); + if(!res_I.containsZero()) { + return res_I.sign(); + } + + RealAlgebraicNumber firstSubstitution = algebraicSubstitutions.get(firstVariable); + current = + variable(freshVariable) + .subtract(current) + .resultant( + firstSubstitution.getPolynomial().toMultivariatePolynomial(firstVariable), + firstVariable); + + while (iterator.hasNext()) { + String variable = iterator.next(); + RealAlgebraicNumber ran = algebraicSubstitutions.get(variable); + current = current.resultant(ran.getPolynomial().toMultivariatePolynomial(variable), variable); + } + + Polynomial univariate = current.toUnivariatePolynomial(); + + while (true) { + Set testRoots = + univariate.isolateRoots( + res_I.getLowerBound().numericValue(), res_I.getUpperBound().numericValue()); + + if (testRoots.isEmpty()) { + throw new IllegalStateException("There must be a root in the interval"); + } + + if (testRoots.size() == 1) { + RealAlgebraicNumber root = testRoots.iterator().next(); + return root.sign(); + } + + algebraicSubstitutions.values().forEach(RealAlgebraicNumber::refine); + + intervalSubstitution = constructIntervalSubstitution(substitution); + res_I = this.evaluate(intervalSubstitution); + if(!res_I.containsZero()) { + return res_I.sign(); + } + } + } + + private Map constructIntervalSubstitution(Map substitution) { + Map intervalSubstitution = new HashMap<>(); + + for (String variable : substitution.keySet()) { + RealAlgebraicNumber ran = substitution.get(variable); + + if (ran.isNumeric()) { + intervalSubstitution.put(variable, pointInterval(ran.numericValue())); + } else { + intervalSubstitution.put( + variable, interval(ran.getLowerBound(), ran.getUpperBound(), CLOSED, CLOSED)); + } + } + + return intervalSubstitution; + } + + public void prune() { + List unusedVariables = new ArrayList<>(); + for (String variable : variables) { + int degree = degree(variable); + if (degree == 0) { + unusedVariables.add(variable); + } + } + + List newVariables = new ArrayList<>(variables); + newVariables.removeAll(unusedVariables); + + Map newCoefficients = new HashMap<>(); + for (Exponent exponent : coefficients.keySet()) { + if (coefficients.get(exponent).isZero()) { + continue; + } + + Exponent newExponent = Exponent.project(exponent, variables, newVariables); + newCoefficients.put(newExponent, coefficients.get(exponent)); + } + + this.variables = newVariables; + this.coefficients = newCoefficients; + } + + public Polynomial toUnivariatePolynomial() { + // todo: this is a hack + prune(); + + if (this.variables.size() > 1) { + throw new IllegalArgumentException("Not a univariate polynomial"); + } + + if (this.variables.isEmpty()) { + return polynomial(coefficients.getOrDefault(constantExponent(0), Number.ZERO())); + } + + int degree = 0; + for (Exponent exponent : coefficients.keySet()) { + degree = Math.max(degree, exponent.get(0)); + } + + Number[] coefficientsArray = new Number[degree + 1]; + for (int i = 0; i <= degree; i++) { + Exponent exponent = exponent(i); + coefficientsArray[i] = coefficients.getOrDefault(exponent, Number.ZERO()); + } + + return polynomial(coefficientsArray); + } + + public Number getCoefficient(Exponent exponent) { + return coefficients.getOrDefault(exponent, Number.ZERO()); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + + List monomials = new ArrayList<>(coefficients.keySet()); + monomials.removeIf(monomial -> coefficients.get(monomial).isZero()); + monomials.sort(Comparator.reverseOrder()); + + for (int j = 0; j < monomials.size(); j++) { + Exponent exponent = monomials.get(j); + Number coefficient = coefficients.get(exponent); + + if (j != 0 && coefficient.isNonNegative()) { + sb.append(" + "); + } else if (coefficient.isNegative()) { + sb.append(" - "); + } + + if (exponent.isConstantExponent() || !coefficient.abs().isOne()) { + sb.append(coefficient.abs()); + + if (!exponent.isConstantExponent()) { + sb.append("*"); + } + } + + for (int i = 0; i < variables.size(); i++) { + if (exponent.get(i) == 0) { + continue; + } + + sb.append(variables.get(i)); + if (exponent.get(i) > 1) { + sb.append("^").append(exponent.get(i)); + } + + if (i < variables.size() - 1 && exponent.get(i + 1) != 0) { + sb.append("*"); + } + } + } + + if (sb.isEmpty()) { + return "0"; + } + + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MultivariatePolynomial that = (MultivariatePolynomial) o; + + // todo: this is a hack + this.prune(); + that.prune(); + + return Objects.equals(coefficients, that.coefficients) + && Objects.equals(variables, that.variables); + } + + @Override + public int hashCode() { + return Objects.hash(coefficients, variables); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/MultivariatePolynomialConstraint.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/MultivariatePolynomialConstraint.java new file mode 100644 index 0000000..9820c2c --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/MultivariatePolynomialConstraint.java @@ -0,0 +1,105 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +import me.paultristanwagner.satchecking.theory.Constraint; + +public class MultivariatePolynomialConstraint implements Constraint { + + public enum Comparison { + EQUALS, + NOT_EQUALS, + LESS_THAN, + GREATER_THAN, + LESS_THAN_OR_EQUALS, + GREATER_THAN_OR_EQUALS; + + public boolean evaluateSign(int sign) { + return switch (this) { + case EQUALS -> sign == 0; + case NOT_EQUALS -> sign != 0; + case LESS_THAN -> sign < 0; + case GREATER_THAN -> sign > 0; + case LESS_THAN_OR_EQUALS -> sign <= 0; + case GREATER_THAN_OR_EQUALS -> sign >= 0; + }; + } + + public String toString() { + return switch (this) { + case EQUALS -> "="; + case NOT_EQUALS -> "!="; + case LESS_THAN -> "<"; + case GREATER_THAN -> ">"; + case LESS_THAN_OR_EQUALS -> "<="; + case GREATER_THAN_OR_EQUALS -> ">="; + }; + } + } + + private final MultivariatePolynomial polynomial; + private final Comparison comparison; + + private MultivariatePolynomialConstraint(MultivariatePolynomial polynomial, Comparison comparison) { + this.polynomial = polynomial; + this.comparison = comparison; + } + + public static MultivariatePolynomialConstraint multivariatePolynomialConstraint(MultivariatePolynomial polynomial, Comparison comparison) { + return new MultivariatePolynomialConstraint(polynomial, comparison); + } + + public static MultivariatePolynomialConstraint lessThanZero(MultivariatePolynomial polynomial) { + return new MultivariatePolynomialConstraint(polynomial, Comparison.LESS_THAN); + } + + public static MultivariatePolynomialConstraint greaterThanZero(MultivariatePolynomial polynomial) { + return new MultivariatePolynomialConstraint(polynomial, Comparison.GREATER_THAN); + } + + public static MultivariatePolynomialConstraint lessThanOrEqualsZero(MultivariatePolynomial polynomial) { + return new MultivariatePolynomialConstraint(polynomial, Comparison.LESS_THAN_OR_EQUALS); + } + + public static MultivariatePolynomialConstraint greaterThanOrEqualsZero(MultivariatePolynomial polynomial) { + return new MultivariatePolynomialConstraint(polynomial, Comparison.GREATER_THAN_OR_EQUALS); + } + + public static MultivariatePolynomialConstraint equalsZero(MultivariatePolynomial polynomial) { + return new MultivariatePolynomialConstraint(polynomial, Comparison.EQUALS); + } + + public static MultivariatePolynomialConstraint notEqualsZero(MultivariatePolynomial polynomial) { + return new MultivariatePolynomialConstraint(polynomial, Comparison.NOT_EQUALS); + } + + public static MultivariatePolynomialConstraint equals(MultivariatePolynomial polynomial, MultivariatePolynomial other) { + return equalsZero(polynomial.subtract(other)); + } + + public static MultivariatePolynomialConstraint notEquals(MultivariatePolynomial polynomial, MultivariatePolynomial other) { + return notEqualsZero(polynomial.subtract(other)); + } + + public static MultivariatePolynomialConstraint lessThan(MultivariatePolynomial polynomial, MultivariatePolynomial other) { + return lessThanZero(polynomial.subtract(other)); + } + + public static MultivariatePolynomialConstraint greaterThan(MultivariatePolynomial polynomial, MultivariatePolynomial other) { + return greaterThanZero(polynomial.subtract(other)); + } + + public static MultivariatePolynomialConstraint lessThanOrEquals(MultivariatePolynomial polynomial, MultivariatePolynomial other) { + return lessThanOrEqualsZero(polynomial.subtract(other)); + } + + public static MultivariatePolynomialConstraint greaterThanOrEquals(MultivariatePolynomial polynomial, MultivariatePolynomial other) { + return greaterThanOrEqualsZero(polynomial.subtract(other)); + } + + public MultivariatePolynomial getPolynomial() { + return polynomial; + } + + public Comparison getComparison() { + return comparison; + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Polynomial.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Polynomial.java new file mode 100644 index 0000000..e467760 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/Polynomial.java @@ -0,0 +1,567 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +import me.paultristanwagner.satchecking.parse.Parser; +import me.paultristanwagner.satchecking.parse.PolynomialParser; +import me.paultristanwagner.satchecking.theory.arithmetic.Number; +import me.paultristanwagner.satchecking.theory.arithmetic.Rational; + +import java.math.BigInteger; +import java.util.*; + +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.*; +import static me.paultristanwagner.satchecking.theory.nonlinear.Exponent.exponent; +import static me.paultristanwagner.satchecking.theory.nonlinear.Interval.pointInterval; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomial.multivariatePolynomial; +import static me.paultristanwagner.satchecking.theory.nonlinear.RealAlgebraicNumber.realAlgebraicNumber; + +public class Polynomial { + + public static void main(String[] args) { + Parser parser = new PolynomialParser(); + Polynomial p = parser.parse("x^5+x^4+x^2+x+2").toUnivariatePolynomial(); + Polynomial q = parser.parse("3x^5-7x^3+3x^2").toUnivariatePolynomial(); + + System.out.println(p); + System.out.println(q); + System.out.println(p.pow(100).squareFreeFactorization()); + } + + private final int degree; + private final Number[] coefficients; // smallest degree first + + private Polynomial(Number[] coefficients) { + int degree = 0; + for (int i = 0; i < coefficients.length; i++) { + if (!coefficients[i].isZero()) { + degree = i; + } + } + + this.degree = degree; + this.coefficients = new Number[degree + 1]; + System.arraycopy(coefficients, 0, this.coefficients, 0, degree + 1); + } + + public static Polynomial polynomial(Number... coefficients) { + return new Polynomial(coefficients); + } + + public static Polynomial constant(Number constant) { + return polynomial(constant); + } + + public static Polynomial xToThePowerOf(int exponent) { + Number[] coefficients = new Number[exponent + 1]; + coefficients[exponent] = ONE(); + Arrays.fill(coefficients, 0, exponent, ZERO()); + return new Polynomial(coefficients); + } + + public int getDegree() { + return degree; + } + + public Number[] getCoefficients() { + return coefficients; + } + + public Number evaluate(Number x) { + Number result = ZERO(); + for (int i = 0; i < coefficients.length; i++) { + result = result.add(coefficients[i].multiply(x.pow(i))); + } + return result; + } + + public Number evaluate(RealAlgebraicNumber realAlgebraicNumber) { + if (realAlgebraicNumber.isNumeric()) { + return evaluate(realAlgebraicNumber.numericValue()); + } + + MultivariatePolynomial multivariatePolynomial = toMultivariatePolynomial("x"); + MultivariatePolynomial substituted = multivariatePolynomial.substitute(Map.of("x", realAlgebraicNumber)); + Polynomial univariate = substituted.toUnivariatePolynomial(); + + if (univariate.getDegree() != 0) { + return ZERO(); + } + + return univariate.getCoefficients()[0]; + } + + public Polynomial add(Polynomial other) { + int newDegree = Math.max(degree, other.degree); + Number[] newCoefficients = new Number[newDegree + 1]; + for (int i = 0; i <= newDegree; i++) { + Number coefficient = ZERO(); + if (i <= degree) { + coefficient = coefficient.add(coefficients[i]); + } + if (i <= other.degree) { + coefficient = coefficient.add(other.coefficients[i]); + } + newCoefficients[i] = coefficient; + } + return new Polynomial(newCoefficients); + } + + public Polynomial negate() { + Number[] newCoefficients = new Number[coefficients.length]; + for (int i = 0; i < coefficients.length; i++) { + newCoefficients[i] = coefficients[i].negate(); + } + return new Polynomial(newCoefficients); + } + + public Polynomial subtract(Polynomial other) { + return add(other.negate()); + } + + public Polynomial multiply(Polynomial other) { + int newDegree = degree + other.degree; + Number[] newCoefficients = new Number[newDegree + 1]; + for (int i = 0; i <= newDegree; i++) { + Number coefficient = ZERO(); + for (int j = 0; j <= i; j++) { + if (j <= degree && i - j <= other.degree) { + coefficient = coefficient.add(coefficients[j].multiply(other.coefficients[i - j])); + } + } + newCoefficients[i] = coefficient; + } + return new Polynomial(newCoefficients); + } + + public Polynomial pow(int exponent) { + if (exponent < 0) { + throw new IllegalArgumentException("Exponent must be non-negative"); + } + + if (exponent == 0) { + return polynomial(ONE()); + } + + // TODO: use binary exponentiation + Polynomial result = this; + for (int i = 1; i < exponent; i++) { + result = result.multiply(this); + } + return result; + } + + public Number getLeadingCoefficient() { + return coefficients[degree]; + } + + public boolean isZero() { + for (Number coefficient : coefficients) { + if (!coefficient.isZero()) { + return false; + } + } + return true; + } + + public boolean isOne() { + return degree == 0 && coefficients[0].isOne(); + } + + public boolean isConstant() { + return degree == 0; + } + + public Polynomial toIntegerPolynomial() { + BigInteger lcm = coefficients[0].getDenominator(); + for (int i = 1; i < coefficients.length; i++) { + BigInteger gcd = lcm.gcd(coefficients[i].getDenominator()); + lcm = lcm.multiply(coefficients[i].getDenominator()).divide(gcd); + } + + Number[] newCoefficients = new Number[coefficients.length]; + for (int i = 0; i < coefficients.length; i++) { + newCoefficients[i] = coefficients[i].multiply(new Rational(lcm)); + } + return new Polynomial(newCoefficients); + } + + public Number content() { + Number content = null; + + for (Number coefficient : coefficients) { + if (coefficient.isZero()) { + continue; + } + + if (!coefficient.isInteger()) { + throw new IllegalArgumentException("Cannot calculate content of non-integer coefficients"); + } + + if (content == null) { + content = coefficient; + } else { + content = content.gcd(coefficient); + } + } + + return content; + } + + public List pseudoDivision(Polynomial other) { + int a = degree; + int b = other.degree; + + Number bLC = other.getLeadingCoefficient(); + Number pow = bLC.pow(a - b + 1); + + return polynomial(pow).multiply(this).divide(other); + } + + public List divide(Polynomial other) { + if (other.isZero()) { + throw new ArithmeticException("Cannot divide by zero"); + } + + Polynomial q = polynomial(ZERO()); + Polynomial r = this; + int d = other.degree; + + while (!r.isZero() && r.degree >= d) { + Number c = r.getLeadingCoefficient().divide(other.getLeadingCoefficient()); + Polynomial s = xToThePowerOf(r.degree - d).multiply(polynomial(c)); + q = q.add(s); + Polynomial v = s.multiply(other); + r = r.subtract(v); + } + + return List.of(q, r); + } + + public Polynomial mod(Polynomial other) { + return divide(other).get(1); + } + + // todo: this can be done more efficiently + public Polynomial gcd(Polynomial other) { + if(this.isZero()) { + return other; + } else if(other.isZero()) { + return this; + } + + Polynomial thisInteger = toIntegerPolynomial(); + Polynomial otherInteger = other.toIntegerPolynomial(); + Polynomial nonNormalizedGcd = thisInteger.nonNormalizedGcd(otherInteger); + + return nonNormalizedGcd.divide(constant(nonNormalizedGcd.getLeadingCoefficient())).get(0); + } + + public Polynomial nonNormalizedGcdOld(Polynomial other) { + Polynomial a = this; + Polynomial b = other; + while (!b.isZero()) { + Polynomial r = a.mod(b); + a = b; + b = r; + } + + return a; + } + + public Polynomial nonNormalizedGcd(Polynomial other) { + Polynomial a = this.toIntegerPolynomial(); + Polynomial b = other.toIntegerPolynomial(); + while (!b.isZero()) { + Polynomial r = a.pseudoDivision(b).get(1).toIntegerPolynomial(); + if (r.isZero()) { + return b; + } + + Number content = r.content(); + r = r.divide(constant(content)).get(0); + a = b; + b = r; + } + + return a; + } + + public Polynomial getDerivative() { + if (degree == 0) { + return polynomial(ZERO()); + } + + Number[] newCoefficients = new Number[coefficients.length - 1]; + for (int i = 0; i < newCoefficients.length; i++) { + newCoefficients[i] = coefficients[i + 1].multiply(number(i + 1)); + } + return new Polynomial(newCoefficients); + } + + public List squareFreeFactorization() { + List result = new ArrayList<>(); + + Polynomial P = this; + Polynomial G = P.gcd(P.getDerivative()); + Polynomial C = P.divide(G).get(0); + Polynomial D = P.getDerivative().divide(G).get(0).subtract(C.getDerivative()); + + for (int i = 1; !C.isOne(); i++) { + P = C.gcd(D); + + C = C.divide(P).get(0); + D = D.divide(P).get(0).subtract(C.getDerivative()); + + result.add(P); + } + + return result; + } + + public boolean isSquareFree() { + List squareFreeFactors = squareFreeFactorization(); + for (int i = 0; i < squareFreeFactors.size(); i++) { + if ((i + 1) % 2 == 0 && !squareFreeFactors.get(i).isConstant()) { + return false; + } + } + + return true; + } + + public List sturmSequence() { + if (isConstant()) { + return List.of(this); + } + + List sturmSequence = new ArrayList<>(); + sturmSequence.add(this); + sturmSequence.add(this.getDerivative()); + while (true) { + Polynomial p = sturmSequence.get(sturmSequence.size() - 2); + Polynomial q = sturmSequence.get(sturmSequence.size() - 1); + Polynomial negRem = p.divide(q).get(1).negate(); + + if (negRem.isZero()) { + break; + } + + sturmSequence.add(negRem); + } + + return sturmSequence; + } + + private int sturmSequenceEvaluation(Number xi, List sturmSequence) { + int signChanges = 0; + int sign = 0; + + for (Polynomial p : sturmSequence) { + Number eval = p.evaluate(xi); + if (eval.isZero()) { + continue; + } + + if (eval.isPositive() && sign == -1 || eval.isNegative() && sign == 1) { + signChanges++; + } + + if (eval.isPositive()) { + sign = 1; + } else if (eval.isNegative()) { + sign = -1; + } + } + + return signChanges; + } + + public int numberOfRealRoots() { + Number cauchyBound = cauchyBound(); + + return numberOfRealRoots(cauchyBound.negate(), cauchyBound); + } + + public boolean hasRealRootAt(Number x) { + return evaluate(x).isZero(); + } + + public int numberOfRealRoots(Number a, Number b) { + return numberOfRealRoots(a, b, sturmSequence()); + } + + public int numberOfRealRoots(Number a, Number b, List sturmSequence) { + int roots = + sturmSequenceEvaluation(a, sturmSequence) - sturmSequenceEvaluation(b, sturmSequence); + if (hasRealRootAt(b)) { + roots--; + } + + return roots; + } + + public Number cauchyBound() { + Number highestRatio = Number.ZERO(); + + Number lcoeff = this.getLeadingCoefficient(); + for (int i = 0; i < degree; i++) { + Number c = coefficients[i]; + Number absRatio = c.divide(lcoeff).abs(); + + if (absRatio.greaterThan(highestRatio)) { + highestRatio = absRatio; + } + } + + return ONE().add(highestRatio); + } + + public Set isolateRoots() { + Number cauchyBound = cauchyBound(); + return isolateRoots(cauchyBound.negate(), cauchyBound); + } + + public Set isolateRoots(Number lowerBound, Number upperBound) { + if (isConstant()) { + return Set.of(); + } + + List squareFreeFactors = squareFreeFactorization(); + + if (squareFreeFactors.size() > 1) { + Set roots = new HashSet<>(); + for (Polynomial squareFreeFactor : squareFreeFactors) { + if (squareFreeFactor.isConstant()) { + continue; + } + + Set factorRoots = squareFreeFactor.isolateRoots(lowerBound, upperBound); + roots.addAll(factorRoots); + } + + return roots; + } + + int numberOfRealRoots = numberOfRealRoots(lowerBound, upperBound); + + if (numberOfRealRoots == 0) { + return Set.of(); + } + + if (numberOfRealRoots == 1) { + return Set.of(realAlgebraicNumber(this, lowerBound, upperBound)); + } + + Number split; + if (lowerBound.isNegative() && upperBound.isPositive() || lowerBound.isPositive() && upperBound.isNegative()) { + split = ZERO(); + } else { + split = lowerBound.midpoint(upperBound); + } + + Set leftRoots = new HashSet<>(isolateRoots(lowerBound, split)); + Set rightRoots = isolateRoots(split, upperBound); + leftRoots.addAll(rightRoots); + + if (hasRealRootAt(split)) { + leftRoots.add(realAlgebraicNumber(split)); + } + + return leftRoots; + } + + public MultivariatePolynomial toMultivariatePolynomial(String variable) { + List variables = List.of(variable); + Map coefficients = new HashMap<>(); + for (int i = 0; i <= degree; i++) { + coefficients.put(exponent(i), this.coefficients[i]); + } + + return multivariatePolynomial(coefficients, variables); + } + + public Set isolateRootsAsDoubles() { + Number epsilon = number(2).pow(-54); + + Set roots = isolateRoots(); + Set result = new HashSet<>(); + for (RealAlgebraicNumber root : roots) { + result.add(root.approximate(epsilon).approximateAsDouble()); + } + return result; + } + + @Override + public String toString() { + if (isZero()) { + return "0"; + } + + StringBuilder builder = new StringBuilder(); + for (int i = degree; i >= 0; i--) { + Number coefficient = coefficients[i]; + if (coefficient.isZero()) { + continue; + } + + if (i != degree && coefficient.isNonNegative()) { + builder.append("+"); + } + + if (coefficient.isOne()) { + if (i == 0) { + builder.append("1"); + } else if (i == 1) { + builder.append("x"); + } else { + builder.append("x^").append(i); + } + } else { + if (i == 0) { + builder.append(coefficient); + } else if (i == 1) { + builder.append(coefficient).append("x"); + } else { + builder.append(coefficient).append("x^").append(i); + } + } + } + return builder.toString(); + } + + public Interval evaluate(Interval interval) { + Interval current = null; + for (int i = 0; i < coefficients.length; i++) { + Number coefficient = coefficients[i]; + + Interval term; + if (i == 0) { + term = pointInterval(coefficient); + } else { + term = interval.pow(i).multiply(coefficient); + } + + if (current == null) { + current = term; + } else { + current = current.add(term); + } + } + + return current; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Polynomial that = (Polynomial) o; + return degree == that.degree && Arrays.equals(coefficients, that.coefficients); + } + + @Override + public int hashCode() { + int result = Objects.hash(degree); + result = 31 * result + Arrays.hashCode(coefficients); + return result; + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/RealAlgebraicNumber.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/RealAlgebraicNumber.java new file mode 100644 index 0000000..552805f --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/RealAlgebraicNumber.java @@ -0,0 +1,320 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +import me.paultristanwagner.satchecking.theory.arithmetic.Number; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ZERO; +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.number; + +public class RealAlgebraicNumber { + + private Number value; + private Polynomial polynomial; + private Number lowerBound; + private Number upperBound; + + private RealAlgebraicNumber( + Number value, Polynomial polynomial, Number lowerBound, Number upperBound) { + this.value = value; + this.lowerBound = lowerBound; + this.upperBound = upperBound; + + if (polynomial == null) { + return; + } + + // finding minimal polynomial + // todo: improve efficiency + List squareFreeFactors = polynomial.squareFreeFactorization(); + for (Polynomial squareFreeFactor : squareFreeFactors) { + if (squareFreeFactor.isConstant()) { + continue; + } + + int numberRoots = squareFreeFactor.numberOfRealRoots(lowerBound, upperBound); + if (numberRoots > 0) { + this.polynomial = squareFreeFactor; + break; + } + } + + if (this.polynomial == null) { + throw new IllegalArgumentException("No real roots in interval"); + } + + if (this.polynomial.getDegree() == 1) { + this.value = + this.polynomial + .getCoefficients()[0] + .negate() + .divide(this.polynomial.getCoefficients()[1]); + this.polynomial = null; + this.lowerBound = null; + this.upperBound = null; + } + } + + public static RealAlgebraicNumber realAlgebraicNumber(Number value) { + return new RealAlgebraicNumber(value, null, value, value); + } + + public static RealAlgebraicNumber realAlgebraicNumber( + Polynomial polynomial, Number lowerBound, Number upperBound) { + return new RealAlgebraicNumber(null, polynomial, lowerBound, upperBound); + } + + public boolean isNumeric() { + return value != null || polynomial.getDegree() == 1; + } + + public Number numericValue() { + if (!isNumeric()) { + throw new IllegalStateException("Not a numeric value"); + } + + if (value != null) { + return value; + } + + return polynomial.getCoefficients()[0].negate().divide(polynomial.getCoefficients()[1]); + } + + public Number getLength() { + if (value != null) { + return ZERO(); + } + + return upperBound.subtract(lowerBound); + } + + public void refine() { + if (isNumeric()) { + return; + } + + Number mid = lowerBound.midpoint(upperBound); + + if (this.polynomial.hasRealRootAt(mid)) { + Number quarter = lowerBound.midpoint(mid); + Number threeQuarters = mid.midpoint(upperBound); + this.lowerBound = quarter; + this.upperBound = threeQuarters; + return; + } + + int numberRootsLeft = polynomial.numberOfRealRoots(lowerBound, mid); // todo: check right bound + + if (numberRootsLeft == 1) { + upperBound = mid; + } else { + lowerBound = mid; + } + } + + public void refine(Number epsilon) { + if (isNumeric()) { + return; + } + + while (getLength().greaterThanOrEqual(epsilon)) { + refine(); + } + } + + public Number approximate(Number epsilon) { + if (isNumeric()) { + return numericValue(); + } + + refine(epsilon); + return lowerBound.add(upperBound).divide(number(2)); + } + + public boolean isZero() { + if (isNumeric()) { + return numericValue().isZero(); + } + + return ZERO().greaterThan(lowerBound) && ZERO().lessThan(upperBound) && this.polynomial.hasRealRootAt(ZERO()); + } + + public boolean isPositive() { + if(isZero()) { + return false; + } else if (isNumeric()) { + return numericValue().isPositive(); + } + + while(true) { + if (this.lowerBound.isPositive()) { + return true; + } else if (this.upperBound.isNegative()) { + return false; + } + + this.refine(); + } + } + + public boolean isNegative() { + return !isZero() && !isPositive(); + } + + public int sign() { + if (isZero()) { + return 0; + } else if (isPositive()) { + return 1; + } else { + return -1; + } + } + + public double approximateAsDouble() { + return approximate(number(2).pow(-54)).approximateAsDouble(); + } + + public float approximateAsFloat() { + return approximate(number(2).pow(-24)).approximateAsFloat(); + } + + public boolean lessThan(RealAlgebraicNumber other) { + if (this.equals(other)) { + return false; + } + + if (this.isNumeric() && other.isNumeric()) { + return this.numericValue().lessThan(other.numericValue()); + } + + while (true) { + if (this.isNumeric() && !other.isNumeric()) { + if (this.numericValue().lessThanOrEqual(other.lowerBound)) { + return true; + } else if (this.numericValue().greaterThanOrEqual(other.upperBound)) { + return false; + } + + other.refine(); + } else if (!this.isNumeric() && other.isNumeric()) { + if (this.upperBound.lessThanOrEqual(other.numericValue())) { + return true; + } else if (this.lowerBound.greaterThanOrEqual(other.numericValue())) { + return false; + } + + this.refine(); + } else { + if (this.upperBound.lessThanOrEqual(other.lowerBound)) { + return true; + } else if (this.lowerBound.greaterThanOrEqual(other.upperBound)) { + return false; + } + + this.refine(); + other.refine(); + } + } + } + + public boolean greaterThan(RealAlgebraicNumber other) { + return other.lessThan(this); + } + + public boolean lessThanOrEqual(RealAlgebraicNumber other) { + return this.equals(other) || this.lessThan(other); + } + + public boolean greaterThanOrEqual(RealAlgebraicNumber other) { + return this.equals(other) || this.greaterThan(other); + } + + public Number getLowerBound() { + if (lowerBound == null) { + throw new IllegalStateException("No lower bound"); + } + + return lowerBound; + } + + public Number getUpperBound() { + if (upperBound == null) { + throw new IllegalStateException("No upper bound"); + } + + return upperBound; + } + + public Polynomial getPolynomial() { + if (polynomial == null) { + throw new IllegalStateException("No polynomial"); + } + + return polynomial; + } + + @Override + public String toString() { + if (value != null) { + return value.toString(); + } else { + return "(" + polynomial.toString() + ", " + lowerBound + ", " + upperBound + ") ≈ " + approximateAsDouble(); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RealAlgebraicNumber other = (RealAlgebraicNumber) o; + + if (Objects.equals(value, other.value) && Objects.equals(polynomial, other.polynomial) && Objects.equals(lowerBound, other.lowerBound) && Objects.equals(upperBound, other.upperBound)) { + return true; + } + + if (this.isNumeric() && other.isNumeric()) { + return this.numericValue().equals(other.numericValue()); + } else if (this.isNumeric() && !other.isNumeric()) { + Number numericValue = this.numericValue(); + return numericValue.greaterThan(other.lowerBound) + && numericValue.lessThan(other.upperBound) + && other.polynomial.hasRealRootAt(numericValue); + } else if (!this.isNumeric() && other.isNumeric()) { + Number numericValue = other.numericValue(); + return numericValue.greaterThan(this.lowerBound) + && numericValue.lessThan(this.upperBound) + && this.polynomial.hasRealRootAt(numericValue); + } + + if (this.lowerBound.greaterThanOrEqual(other.upperBound) || this.upperBound.lessThanOrEqual(other.lowerBound)) { + return false; + } + + Number innerLowerBound = this.lowerBound.greaterThan(other.lowerBound) ? this.lowerBound : other.lowerBound; + Number innerUpperBound = this.upperBound.lessThan(other.upperBound) ? this.upperBound : other.upperBound; + + int thisInnerRoots = this.polynomial.numberOfRealRoots(innerLowerBound, innerUpperBound); + int otherInnerRoots = other.polynomial.numberOfRealRoots(innerLowerBound, innerUpperBound); + + if (thisInnerRoots != otherInnerRoots) { + return false; + } + + if (this.polynomial.equals(other.polynomial)) { + return true; + } + + Polynomial gcd = this.polynomial.gcd(other.polynomial); + // todo: we can potentially also evaluate the gcd at the interval bounds + return gcd.numberOfRealRoots(innerLowerBound, innerUpperBound) > 0; + } + + @Override + public int hashCode() { + return Objects.hash(value, polynomial, lowerBound, upperBound); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/RealRootIsolator.java b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/RealRootIsolator.java new file mode 100644 index 0000000..18af358 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/nonlinear/RealRootIsolator.java @@ -0,0 +1,5 @@ +package me.paultristanwagner.satchecking.theory.nonlinear; + +public class RealRootIsolator { + +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/solver/NonLinearRealArithmeticSolver.java b/src/main/java/me/paultristanwagner/satchecking/theory/solver/NonLinearRealArithmeticSolver.java new file mode 100644 index 0000000..6e469c5 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/solver/NonLinearRealArithmeticSolver.java @@ -0,0 +1,68 @@ +package me.paultristanwagner.satchecking.theory.solver; + +import me.paultristanwagner.satchecking.smt.VariableAssignment; +import me.paultristanwagner.satchecking.theory.TheoryResult; +import me.paultristanwagner.satchecking.theory.nonlinear.CAD; +import me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomial; +import me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint; +import me.paultristanwagner.satchecking.theory.nonlinear.RealAlgebraicNumber; + +import java.util.HashSet; +import java.util.Set; + +import static me.paultristanwagner.satchecking.theory.TheoryResult.satisfiable; +import static me.paultristanwagner.satchecking.theory.TheoryResult.unsatisfiable; +import static me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomialConstraint.Comparison.EQUALS; + +public class NonLinearRealArithmeticSolver implements TheorySolver { + + private final Set constraints = new HashSet<>(); + + @Override + public void clear() { + constraints.clear(); + } + + @Override + public void load(Set constraints) { + clear(); + this.constraints.addAll(constraints); + } + + @Override + public void addConstraint(MultivariatePolynomialConstraint constraint) { + constraints.add(constraint); + } + + @Override + public TheoryResult solve() { + boolean onlyEqualities = true; + Set polynomials = new HashSet<>(); + for (MultivariatePolynomialConstraint constraint : constraints) { + polynomials.add(constraint.getPolynomial()); + if (!constraint.getComparison().equals(EQUALS)) { + onlyEqualities = false; + } + } + + CAD cad = new CAD(polynomials); + Set> result = cad.compute(constraints, onlyEqualities); + + for (VariableAssignment realAlgebraicNumberVariableAssignment : result) { + boolean satisfied = true; + for (MultivariatePolynomialConstraint constraint : constraints) { + int sign = constraint.getPolynomial().evaluateSign(realAlgebraicNumberVariableAssignment); + if(!constraint.getComparison().evaluateSign(sign)) { + satisfied = false; + break; + } + } + + if(satisfied) { + return satisfiable(realAlgebraicNumberVariableAssignment); + } + } + + return unsatisfiable(constraints); + } +} diff --git a/src/test/java/MultivariatePolynomialTest.java b/src/test/java/MultivariatePolynomialTest.java new file mode 100644 index 0000000..2b895df --- /dev/null +++ b/src/test/java/MultivariatePolynomialTest.java @@ -0,0 +1,167 @@ +import me.paultristanwagner.satchecking.parse.Parser; +import me.paultristanwagner.satchecking.parse.PolynomialParser; +import me.paultristanwagner.satchecking.theory.arithmetic.Rational; +import me.paultristanwagner.satchecking.theory.nonlinear.Interval; +import me.paultristanwagner.satchecking.theory.nonlinear.MultivariatePolynomial; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.number; +import static me.paultristanwagner.satchecking.theory.arithmetic.Rational.parse; +import static me.paultristanwagner.satchecking.theory.nonlinear.Exponent.exponent; +import static me.paultristanwagner.satchecking.theory.nonlinear.Interval.IntervalBoundType.CLOSED; +import static me.paultristanwagner.satchecking.theory.nonlinear.Interval.interval; +import static org.junit.jupiter.api.Assertions.*; + +public class MultivariatePolynomialTest { + + private Parser parser = new PolynomialParser(); + + @Test + public void testSimple() { + MultivariatePolynomial p = parser.parse("x^2 + 2*x + 1"); + + assertEquals(number(1), p.getCoefficient(exponent(2))); + assertEquals(number(2), p.getCoefficient(exponent(1))); + assertEquals(number(1), p.getCoefficient(exponent(0))); + + assertEquals(2, p.degree("x")); + } + + @Test + public void testAdd() { + MultivariatePolynomial p = parser.parse("x^2*y - 2*x*y^2*z + 1"); + MultivariatePolynomial q = parser.parse("x^2 - 10*y^3*z + 17"); + + MultivariatePolynomial result = p.add(q); + + assertEquals(number(18), result.getCoefficient(exponent(0, 0, 0))); + assertEquals(number(1), result.getCoefficient(exponent(2, 0, 0))); + assertEquals(number(1), result.getCoefficient(exponent(2, 1, 0))); + assertEquals(number(-2), result.getCoefficient(exponent(1, 2, 1))); + assertEquals(number(-10), result.getCoefficient(exponent(0, 3, 1))); + + assertEquals(2, result.degree("x")); + assertEquals(3, result.degree("y")); + assertEquals(1, result.degree("z")); + } + + @Test + public void testMultiply() { + MultivariatePolynomial p = parser.parse("x^2*y - 2*x*y^2*z + 1"); + MultivariatePolynomial q = parser.parse("x^2 - 10*y^3*z + 17"); + + MultivariatePolynomial result = p.multiply(q); + + assertEquals(number(17), result.getCoefficient(exponent(0, 0, 0))); + assertEquals(number(-10), result.getCoefficient(exponent(0, 3, 1))); + assertEquals(number(-2), result.getCoefficient(exponent(3, 2, 1))); + + assertEquals(4, result.degree("x")); + assertEquals(5, result.degree("y")); + assertEquals(2, result.degree("z")); + } + + @Test + public void testPseudoDivision() { + MultivariatePolynomial p = parser.parse("x^2 + y^2"); + MultivariatePolynomial q = parser.parse("x + y"); + + List result = p.pseudoDivision(q, "x"); + + MultivariatePolynomial quotient = result.get(0); + MultivariatePolynomial remainder = result.get(1); + + assertEquals(number(1), quotient.getCoefficient(exponent(1, 0))); + assertEquals(number(-1), quotient.getCoefficient(exponent(0, 1))); + + assertEquals(number(2), remainder.getCoefficient(exponent(0, 2))); + + q = parser.parse("2*x - 2*y"); + result = p.pseudoDivision(q, "x"); + + quotient = result.get(0); + remainder = result.get(1); + + assertEquals(number(0), quotient.getCoefficient(exponent(0, 0))); + assertEquals(number(2), quotient.getCoefficient(exponent(1, 0))); + assertEquals(number(2), quotient.getCoefficient(exponent(0, 1))); + assertEquals(number(0), quotient.getCoefficient(exponent(1, 1))); + + assertEquals(number(0), remainder.getCoefficient(exponent(0, 0))); + assertEquals(number(0), remainder.getCoefficient(exponent(0, 1))); + assertEquals(number(0), remainder.getCoefficient(exponent(2, 0))); + assertEquals(number(8), remainder.getCoefficient(exponent(0, 2))); + } + + @Test + public void testDivision() { + MultivariatePolynomial p = parser.parse("x^2 + y^2"); + MultivariatePolynomial q = parser.parse("2*x - 2*y"); + + List result = p.divide(q, "x"); + + MultivariatePolynomial quotient = result.get(0); + MultivariatePolynomial remainder = result.get(1); + + assertEquals(number(0), quotient.getCoefficient(exponent(0, 0))); + assertEquals(number(1, 2), quotient.getCoefficient(exponent(1, 0))); + assertEquals(number(1, 2), quotient.getCoefficient(exponent(0, 1))); + assertEquals(number(0), quotient.getCoefficient(exponent(1, 1))); + assertEquals(number(0), quotient.getCoefficient(exponent(0, 2))); + assertEquals(number(0), quotient.getCoefficient(exponent(2, 0))); + assertEquals(number(0), quotient.getCoefficient(exponent(2, 2))); + + assertEquals(number(2), remainder.getCoefficient(exponent(0, 2))); + assertEquals(number(0), remainder.getCoefficient(exponent(2, 0))); + } + + @Test + public void testResultant() { + MultivariatePolynomial p = parser.parse("x^2 + x + 1"); + MultivariatePolynomial q = parser.parse("z - x^7 + 1"); + + MultivariatePolynomial result = p.resultant(q, "x"); + + assertEquals(number(3), result.getCoefficient(exponent(0, 0))); + assertEquals(number(3), result.getCoefficient(exponent(0, 1))); + assertEquals(number(1), result.getCoefficient(exponent(0, 2))); + assertEquals(number(0), result.getCoefficient(exponent(0, 3))); + assertEquals(number(0), result.getCoefficient(exponent(0, 4))); + + p = parser.parse("x*z^2-y^3"); + q = parser.parse("-x*y^2 + x*z^2 - y^3"); + + result = p.resultant(q, "z"); + + assertEquals(number(0), result.getCoefficient(exponent(0, 0, 0))); + assertEquals(number(0), result.getCoefficient(exponent(0, 1, 0))); + assertEquals(number(0), result.getCoefficient(exponent(0, 2, 0))); + assertEquals(number(0), result.getCoefficient(exponent(0, 3, 0))); + assertEquals(number(0), result.getCoefficient(exponent(0, 4, 0))); + assertEquals(number(0), result.getCoefficient(exponent(1, 0, 0))); + assertEquals(number(0), result.getCoefficient(exponent(1, 1, 0))); + assertEquals(number(0), result.getCoefficient(exponent(1, 2, 0))); + assertEquals(number(-1), result.getCoefficient(exponent(2, 2, 0))); // todo: investigate why the sign is wrong + + } + + @Test + public void testIntervalEvaluation() { + Interval xInterval = interval(parse("118888613829676491/144115188075855872"), parse("59444306914838247/72057594037927936"), CLOSED, CLOSED); + Interval yInterval = interval(parse("10181696917598453/18014398509481984"), parse("10181696917598453/18014398509481984"), CLOSED, CLOSED); + + MultivariatePolynomial p = parser.parse("x^2 + y^3 - 1/2" ); + + Map intervalMap = Map.of("x", xInterval, "y", yInterval); + Interval result = p.evaluate(intervalMap); + + System.out.println(result); + + assertFalse(result.containsZero()); + assertEquals(1, result.sign()); + } +}