From 881ba9015a56b5dd6d277e7268f539556d667118 Mon Sep 17 00:00:00 2001 From: Paul Tristan Wagner Date: Mon, 4 Mar 2024 00:29:07 +0100 Subject: [PATCH] Generify constraints in linear real arithmetic (#7) --- .../parse/LinearConstraintLexer.java | 10 +- .../parse/LinearConstraintParser.java | 174 ++++++---------- .../satchecking/parse/LinearTermLexer.java | 20 ++ .../satchecking/parse/LinearTermParser.java | 126 ++++++++++++ .../satchecking/theory/LinearConstraint.java | 187 +++++++++--------- .../satchecking/theory/LinearTerm.java | 176 +++++++++++++++++ .../theory/solver/LinearIntegerSolver.java | 5 +- .../solver/SimplexFeasibilitySolver.java | 17 +- .../solver/SimplexOptimizationSolver.java | 127 ++++++------ src/test/java/LinearConstraintParserTest.java | 21 +- src/test/java/MultivariatePolynomialTest.java | 4 +- 11 files changed, 575 insertions(+), 292 deletions(-) create mode 100644 src/main/java/me/paultristanwagner/satchecking/parse/LinearTermLexer.java create mode 100644 src/main/java/me/paultristanwagner/satchecking/parse/LinearTermParser.java create mode 100644 src/main/java/me/paultristanwagner/satchecking/theory/LinearTerm.java diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintLexer.java b/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintLexer.java index 703b8f6..acf221e 100644 --- a/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintLexer.java +++ b/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintLexer.java @@ -7,7 +7,15 @@ public class LinearConstraintLexer extends Lexer { public LinearConstraintLexer(String input) { super(input); - registerTokenTypes(MIN, MAX, FRACTION, DECIMAL, IDENTIFIER, EQUALS, LOWER_EQUALS, GREATER_EQUALS, PLUS, MINUS, LPAREN, RPAREN); + registerTokenTypes( + MIN, + MAX, + EQUALS, + LOWER_EQUALS, + GREATER_EQUALS, + LPAREN, + RPAREN + ); initialize(input); } diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintParser.java b/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintParser.java index 2a27fc3..a19c68e 100644 --- a/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintParser.java +++ b/src/main/java/me/paultristanwagner/satchecking/parse/LinearConstraintParser.java @@ -1,161 +1,111 @@ package me.paultristanwagner.satchecking.parse; import me.paultristanwagner.satchecking.theory.LinearConstraint; +import me.paultristanwagner.satchecking.theory.LinearConstraint.Bound; import me.paultristanwagner.satchecking.theory.LinearConstraint.MaximizingConstraint; import me.paultristanwagner.satchecking.theory.LinearConstraint.MinimizingConstraint; -import me.paultristanwagner.satchecking.theory.arithmetic.Number; +import me.paultristanwagner.satchecking.theory.LinearTerm; + +import java.util.Scanner; import static me.paultristanwagner.satchecking.parse.TokenType.*; +import static me.paultristanwagner.satchecking.parse.TokenType.GREATER_EQUALS; import static me.paultristanwagner.satchecking.theory.LinearConstraint.Bound.*; -import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ONE; public class LinearConstraintParser implements Parser { + public static void main(String[] args) { + LinearConstraintParser parser = new LinearConstraintParser(); + + Scanner scanner = new Scanner(System.in); + String line; + while ((line = scanner.nextLine()) != null) { + try { + LinearConstraint constraint = parser.parse(line); + System.out.println(constraint); + System.out.println("lhs = " + constraint.getLeftHandSide()); + System.out.println("rhs = " + constraint.getRightHandSide()); + System.out.println("bound = " + constraint.getBound()); + System.out.println("lhs - rhs = " + constraint.getDifference()); + } catch (SyntaxError e) { + e.printWithContext(); + e.printStackTrace(); + } + } + } + /* * Grammar for Linear constraints: - * ::= '=' - * | '<=' - * | '>=' + * ::= '=' + * | '<=' '>=' * | MIN '(' ')' * | MAX '(' ')' * - * ::= [ ] [ ] IDENTIFIER - * | [ ] [ ] IDENTIFIER [ ] - * - * ::= '+' [ ] - * | '-' [ ] - * - * ::= FRACTION | DECIMAL - * */ @Override public ParseResult parseWithRemaining(String string) { Lexer lexer = new LinearConstraintLexer(string); - lexer.requireNextToken(); - - LinearConstraint lc = TERM(lexer); - - return new ParseResult<>(lc, lexer.getCursor(), lexer.getCursor() == string.length()); - } - - private static LinearConstraint TERM(Lexer lexer) { - LinearConstraint lc; boolean optimization = false; + boolean minimization = false; if (lexer.canConsume(MIN)) { - optimization = true; lexer.consume(MIN); - lexer.consume(LPAREN); - lc = new MinimizingConstraint(); - } else if (lexer.canConsume(MAX)) { + optimization = true; + minimization = true; + } else if (lexer.canConsume(MAX)) { lexer.consume(MAX); - lexer.consume(LPAREN); - lc = new MaximizingConstraint(); - } else { - lc = new LinearConstraint(); - } - Number coefficient = OPTIONAL_SIGNS(lexer).multiply(OPTIONAL_RATIONAL(lexer)); - Token variableToken = lexer.getLookahead(); - lexer.consume(IDENTIFIER); - String variable = variableToken.getValue(); - lc.setCoefficient(variable, coefficient); - - while (lexer.canConsumeEither(PLUS, MINUS, FRACTION, DECIMAL)) { - coefficient = OPTIONAL_SIGNS(lexer).multiply(OPTIONAL_RATIONAL(lexer)); - variableToken = lexer.getLookahead(); + optimization = true; + } - lexer.consume(IDENTIFIER); + LinearTerm lhs = TERM(lexer); - variable = variableToken.getValue(); - lc.setCoefficient(variable, coefficient); - } + LinearConstraint lc; + if(optimization) { + if(minimization) { + lc = new MinimizingConstraint(lhs); + } else { + lc = new MaximizingConstraint(lhs); + } - if (optimization) { lexer.consume(RPAREN); - return lc; + return new ParseResult<>(lc, lexer.getCursor(), lexer.getCursor() == string.length()); } - lexer.requireEither(EQUALS, LOWER_EQUALS, GREATER_EQUALS); - if (lexer.canConsume(EQUALS)) { - lexer.consume(EQUALS); - lc.setBound(EQUAL); - } else if (lexer.canConsume(LOWER_EQUALS)) { - lexer.consume(LOWER_EQUALS); - lc.setBound(UPPER); - } else { - lexer.consume(GREATER_EQUALS); - lc.setBound(LOWER); - } + Bound bound = BOUND(lexer); - Number value = OPTIONAL_SIGNS(lexer).multiply(RATIONAL(lexer)); - lc.setValue(value); + LinearTerm rhs = TERM(lexer); - return lc; - } + lc = new LinearConstraint(lhs, rhs, bound); - private static Number OPTIONAL_SIGNS(Lexer lexer) { - if (lexer.canConsumeEither(PLUS, MINUS)) { - return SIGNS(lexer); - } else { - return ONE(); - } + return new ParseResult<>(lc, lexer.getCursor(), lexer.getCursor() == string.length()); } - private static Number SIGNS(Lexer lexer) { - lexer.requireEither(PLUS, MINUS); - - Number sign = ONE(); - do { - if (lexer.canConsume(PLUS)) { - lexer.consume(PLUS); - } else { - lexer.consume(MINUS); - sign = sign.negate(); - } - } while (lexer.canConsumeEither(PLUS, MINUS)); - - return sign; - } + private static LinearTerm TERM(Lexer lexer) { + LinearTermParser parser = new LinearTermParser(); + ParseResult result = parser.parseWithRemaining(lexer.getRemaining()); - private static Number OPTIONAL_RATIONAL(Lexer lexer) { - if (lexer.canConsumeEither(FRACTION, DECIMAL)) { - return RATIONAL(lexer); - } + lexer.skip(result.charsRead()); - return ONE(); + return result.result(); } - private static Number RATIONAL(Lexer lexer) { - lexer.requireEither(FRACTION, DECIMAL); - - if (lexer.canConsume(FRACTION)) { - return FRACTION(lexer); + private static Bound BOUND(Lexer lexer) { + lexer.requireEither(EQUALS, LOWER_EQUALS, GREATER_EQUALS); + if (lexer.canConsume(EQUALS)) { + lexer.consume(EQUALS); + return EQUAL; + } else if (lexer.canConsume(LOWER_EQUALS)) { + lexer.consume(LOWER_EQUALS); + return LESS_EQUALS; } else { - return DECIMAL(lexer); + lexer.consume(GREATER_EQUALS); + return Bound.GREATER_EQUALS; } } - - private static Number DECIMAL(Lexer lexer) { - Token token = lexer.getLookahead(); - lexer.consume(DECIMAL); - - return Number.parse(token.getValue()); - } - - private static Number FRACTION(Lexer lexer) { - Token token = lexer.getLookahead(); - lexer.consume(FRACTION); - - String[] parts = token.getValue().split("/"); - - long numerator = Long.parseLong(parts[0]); - long denominator = Long.parseLong(parts[1]); - - return Number.number(numerator, denominator); - } } diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/LinearTermLexer.java b/src/main/java/me/paultristanwagner/satchecking/parse/LinearTermLexer.java new file mode 100644 index 0000000..bebb757 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/parse/LinearTermLexer.java @@ -0,0 +1,20 @@ +package me.paultristanwagner.satchecking.parse; + +import static me.paultristanwagner.satchecking.parse.TokenType.*; + +public class LinearTermLexer extends Lexer { + + public LinearTermLexer(String input) { + super(input); + + registerTokenTypes( + PLUS, + MINUS, + FRACTION, + DECIMAL, + IDENTIFIER + ); + + initialize(input); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/parse/LinearTermParser.java b/src/main/java/me/paultristanwagner/satchecking/parse/LinearTermParser.java new file mode 100644 index 0000000..11ed630 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/parse/LinearTermParser.java @@ -0,0 +1,126 @@ +package me.paultristanwagner.satchecking.parse; + +import me.paultristanwagner.satchecking.theory.LinearTerm; +import me.paultristanwagner.satchecking.theory.arithmetic.Number; + +import static me.paultristanwagner.satchecking.parse.TokenType.*; +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ONE; + +public class LinearTermParser implements Parser { + + /* + * Grammar for linear terms: + * ::= [ ] [ ] IDENTIFIER [ ] + * | [ ] [ ] + * + */ + + @Override + public ParseResult parseWithRemaining(String string) { + LinearTermLexer lexer = new LinearTermLexer(string); + + lexer.requireNextToken(); + + LinearTerm term = new LinearTerm(); + + Number value = OPTIONAL_SIGNS(lexer); + boolean explicitValue = false; + + if(lexer.canConsumeEither(DECIMAL, FRACTION)) { + explicitValue = true; + value = value.multiply(OPTIONAL_RATIONAL(lexer)); + } + + if(!explicitValue || lexer.canConsume(IDENTIFIER)) { + lexer.require(IDENTIFIER); + String identifier = lexer.getLookahead().getValue(); + lexer.consume(IDENTIFIER); + term.addCoefficient(identifier, value); + } else { + term.addConstant(value); + } + + while(lexer.canConsumeEither(PLUS, MINUS)) { + Number sign = SIGNS(lexer); + value = OPTIONAL_SIGNS(lexer); + explicitValue = false; + + if(lexer.canConsumeEither(DECIMAL, FRACTION)) { + explicitValue = true; + value = value.multiply(OPTIONAL_RATIONAL(lexer)); + } + + if(!explicitValue || lexer.canConsume(IDENTIFIER)) { + lexer.require(IDENTIFIER); + String identifier = lexer.getLookahead().getValue(); + lexer.consume(IDENTIFIER); + term.addCoefficient(identifier, sign.multiply(value)); + } else { + term.addConstant(sign.multiply(value)); + } + } + + return new ParseResult<>(term, lexer.getCursor(), lexer.getRemaining().isEmpty()); + } + + private static Number OPTIONAL_SIGNS(Lexer lexer) { + if (lexer.canConsumeEither(PLUS, MINUS)) { + return SIGNS(lexer); + } else { + return ONE(); + } + } + + private static Number SIGNS(Lexer lexer) { + lexer.requireEither(PLUS, MINUS); + + Number sign = ONE(); + do { + if (lexer.canConsume(PLUS)) { + lexer.consume(PLUS); + } else { + lexer.consume(MINUS); + sign = sign.negate(); + } + } while (lexer.canConsumeEither(PLUS, MINUS)); + + return sign; + } + + private static Number OPTIONAL_RATIONAL(Lexer lexer) { + if (lexer.canConsumeEither(FRACTION, DECIMAL)) { + return RATIONAL(lexer); + } + + return ONE(); + } + + private static Number RATIONAL(Lexer lexer) { + lexer.requireEither(FRACTION, DECIMAL); + + if (lexer.canConsume(FRACTION)) { + return FRACTION(lexer); + } else { + return DECIMAL(lexer); + } + } + + private static Number DECIMAL(Lexer lexer) { + Token token = lexer.getLookahead(); + lexer.consume(DECIMAL); + + return Number.parse(token.getValue()); + } + + private static Number FRACTION(Lexer lexer) { + Token token = lexer.getLookahead(); + lexer.consume(FRACTION); + + String[] parts = token.getValue().split("/"); + + long numerator = Long.parseLong(parts[0]); + long denominator = Long.parseLong(parts[1]); + + return Number.number(numerator, denominator); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/LinearConstraint.java b/src/main/java/me/paultristanwagner/satchecking/theory/LinearConstraint.java index d725a4b..1f8a64b 100644 --- a/src/main/java/me/paultristanwagner/satchecking/theory/LinearConstraint.java +++ b/src/main/java/me/paultristanwagner/satchecking/theory/LinearConstraint.java @@ -3,81 +3,87 @@ import me.paultristanwagner.satchecking.smt.VariableAssignment; import me.paultristanwagner.satchecking.theory.arithmetic.Number; -import java.util.HashMap; import java.util.HashSet; -import java.util.Map; import java.util.Set; +import static me.paultristanwagner.satchecking.theory.LinearConstraint.Bound.*; import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ZERO; public class LinearConstraint implements Constraint { - protected final Set variables; - protected final Map coefficients; - private Bound bound; - private Number value; + protected final LinearTerm lhs; + protected final LinearTerm rhs; + protected final LinearTerm difference; + protected final Bound bound; private LinearConstraint derivedFrom; public LinearConstraint() { - this.variables = new HashSet<>(); - this.coefficients = new HashMap<>(); - this.value = ZERO(); + this.lhs = new LinearTerm(); + this.rhs = new LinearTerm(); + this.difference = new LinearTerm(); + this.bound = EQUAL; } public LinearConstraint(LinearConstraint constraint) { - this.variables = new HashSet<>(constraint.variables); - this.coefficients = new HashMap<>(constraint.coefficients); + this.lhs = new LinearTerm(constraint.lhs); + this.rhs = new LinearTerm(constraint.rhs); + this.difference = new LinearTerm(constraint.difference); this.bound = constraint.bound; - this.value = constraint.value; this.derivedFrom = constraint; } - public void setCoefficient(String variable, Number coefficient) { - variables.add(variable); - coefficients.put(variable, coefficient); + public LinearConstraint(LinearTerm lhs, LinearTerm rhs, Bound bound) { + this.lhs = lhs; + this.rhs = rhs; + this.difference = lhs.subtract(rhs); + this.bound = bound; } - public Set getVariables() { - return variables; + public static LinearConstraint equal(LinearTerm lhs, LinearTerm rhs) { + return new LinearConstraint(lhs, rhs, EQUAL); } - public Map getCoefficients() { - return coefficients; + public static LinearConstraint lessThanOrEqual(LinearTerm lhs, LinearTerm rhs) { + return new LinearConstraint(lhs, rhs, GREATER_EQUALS); } - public Bound getBound() { - return bound; + public static LinearConstraint greaterThanOrEqual(LinearTerm lhs, LinearTerm rhs) { + return new LinearConstraint(lhs, rhs, LESS_EQUALS); } - public Number getValue() { - return value; - } - - public void setBound(Bound bound) { - this.bound = bound; + public Set getVariables() { + Set variables = new HashSet<>(lhs.getVariables()); + variables.addAll(rhs.getVariables()); + return variables; } - public void setValue(Number value) { - this.value = value; + public Bound getBound() { + return bound; } public void setDerivedFrom(LinearConstraint derivedFrom) { this.derivedFrom = derivedFrom; } + public boolean constrainsVariable(String variable) { + return difference.getCoefficients().getOrDefault(variable, ZERO()).isNonZero(); + } + public Number getBoundOn(String variable) { + Set variables = getVariables(); + if (variables.size() != 1) { throw new IllegalStateException("Constraint does not have exactly one variable"); } - if (!variables.contains(variable)) { + if (!getVariables().contains(variable)) { throw new IllegalArgumentException("Variable is not in constraint"); } - Number coefficient = coefficients.get(variable); + Number coefficient = difference.coefficients.get(variable); - return value.divide(coefficient); + return difference.getConstant().negate().divide(coefficient); } public LinearConstraint getRoot() { @@ -88,115 +94,100 @@ public LinearConstraint getRoot() { } public LinearConstraint offset(String variable, String substitute, Number offset) { - LinearConstraint constraint = new LinearConstraint(this); - if (!coefficients.containsKey(variable)) { - return this; - } - - Number coeff = coefficients.get(variable); - constraint.variables.remove(variable); - constraint.coefficients.remove(variable); - constraint.setCoefficient(substitute, coeff); - - constraint.value = value.subtract( - coeff.multiply(offset) - ); + LinearTerm lhs = this.lhs.offset(variable, substitute, offset); + LinearTerm rhs = this.rhs.offset(variable, substitute, offset); - return constraint; + return new LinearConstraint(lhs, rhs, this.bound); } public LinearConstraint positiveNegativeSubstitute( String variable, String positive, String negative) { - Number coeff = coefficients.get(variable); - LinearConstraint constraint = new LinearConstraint(this); - constraint.variables.remove(variable); - constraint.coefficients.remove(variable); - constraint.setCoefficient(positive, coeff); - constraint.setCoefficient(negative, coeff.negate()); + LinearTerm lhs = this.lhs.positiveNegativeSubstitute(variable, positive, negative); + LinearTerm rhs = this.rhs.positiveNegativeSubstitute(variable, positive, negative); + + return new LinearConstraint(lhs, rhs, this.bound); + } + + public LinearTerm getLeftHandSide() { + return lhs; + } + + public LinearTerm getRightHandSide() { + return rhs; + } - return constraint; + public LinearTerm getDifference() { + return difference; } public enum Bound { - LOWER, - UPPER, + GREATER_EQUALS, + LESS_EQUALS, EQUAL } @Override public String toString() { StringBuilder sb = new StringBuilder(); - sb.append(serializeTerm(coefficients)); - if (bound == Bound.EQUAL) { + sb.append(lhs); + + if (bound == EQUAL) { sb.append("="); - } else if (bound == Bound.LOWER) { + } else if (bound == GREATER_EQUALS) { sb.append(">="); } else { sb.append("<="); } - sb.append(value); + + sb.append(rhs); return sb.toString(); } public static class MaximizingConstraint extends LinearConstraint { - @Override - public String toString() { - return "max(" + serializeTerm(coefficients) + ")"; + public MaximizingConstraint(LinearTerm term) { + super(term, new LinearTerm(), EQUAL); } - } - public static class MinimizingConstraint extends LinearConstraint { + public LinearTerm getTerm() { + return lhs; + } @Override public String toString() { - return "min(" + serializeTerm(coefficients) + ")"; + return "max(" + lhs + ")"; } } - private static String serializeTerm(Map coefficients) { - StringBuilder sb = new StringBuilder(); - coefficients.forEach( - (variable, coefficient) -> { - if (coefficient.isNonNegative()) { - if (!sb.isEmpty()) { - sb.append("+"); - } - } else { - sb.append("-"); - } - - Number absolute = coefficient.abs(); - if (!absolute.isOne()) { - sb.append(absolute); - } - - sb.append(variable); - }); + public static class MinimizingConstraint extends LinearConstraint { - return sb.toString(); - } + public MinimizingConstraint(LinearTerm term) { + super(term, new LinearTerm(), EQUAL); + } - public Number evaluateTerm(VariableAssignment assignment) { - Number result = ZERO(); - for (String variable : variables) { - Number summand = coefficients.get(variable).multiply(assignment.getAssignment(variable)); - result = result.add(summand); + public LinearTerm getTerm() { + return lhs; + } + + @Override + public String toString() { + return "min(" + lhs + ")"; } - return result; } public boolean evaluate(VariableAssignment assignment) { - Number result = evaluateTerm(assignment); - if (bound == Bound.EQUAL) { - return result.equals(value); - } else if (bound == Bound.LOWER) { - return result.greaterThanOrEqual(value); + Number lhsValue = lhs.evaluate(assignment); + Number rhsValue = rhs.evaluate(assignment); + + if (bound == EQUAL) { + return lhsValue.equals(rhsValue); + } else if (bound == GREATER_EQUALS) { + return lhsValue.greaterThanOrEqual(rhsValue); } else { - return result.lessThanOrEqual(value); + return lhsValue.lessThanOrEqual(rhsValue); } } } diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/LinearTerm.java b/src/main/java/me/paultristanwagner/satchecking/theory/LinearTerm.java new file mode 100644 index 0000000..e5ad968 --- /dev/null +++ b/src/main/java/me/paultristanwagner/satchecking/theory/LinearTerm.java @@ -0,0 +1,176 @@ +package me.paultristanwagner.satchecking.theory; + +import me.paultristanwagner.satchecking.smt.VariableAssignment; +import me.paultristanwagner.satchecking.theory.arithmetic.Number; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ZERO; + +public class LinearTerm { + + protected final Set variables; + protected final Map coefficients; + private Number constant; + + public LinearTerm() { + this.variables = new HashSet<>(); + this.coefficients = new HashMap<>(); + this.constant = ZERO(); + } + + public LinearTerm(LinearTerm term) { + this.variables = new HashSet<>(term.variables); + this.coefficients = new HashMap<>(term.coefficients); + this.constant = term.constant; + } + + public void setCoefficient(String variable, Number coefficient) { + variables.add(variable); + coefficients.put(variable, coefficient); + + if(coefficients.get(variable).equals(ZERO())) { + coefficients.remove(variable); + variables.remove(variable); + } + } + + public void addCoefficient(String variable, Number coefficient) { + if (coefficients.containsKey(variable)) { + coefficients.put(variable, coefficients.get(variable).add(coefficient)); + } else { + setCoefficient(variable, coefficient); + } + + if(coefficients.get(variable).equals(ZERO())) { + coefficients.remove(variable); + variables.remove(variable); + } + } + + public Set getVariables() { + return variables; + } + + public Map getCoefficients() { + return coefficients; + } + + public Number getConstant() { + return constant; + } + + public void setConstant(Number constant) { + this.constant = constant; + } + + public void addConstant(Number value) { + this.constant = this.constant.add(value); + } + + public LinearTerm add(LinearTerm term) { + LinearTerm result = new LinearTerm(this); + term.coefficients.forEach(result::addCoefficient); + result.addConstant(term.constant); + return result; + } + + public LinearTerm negate() { + LinearTerm result = new LinearTerm(); + this.coefficients.forEach((variable, coefficient) -> result.addCoefficient(variable, coefficient.negate())); + result.addConstant(this.constant.negate()); + return result; + } + + public LinearTerm subtract(LinearTerm term) { + return this.add(term.negate()); + } + + public LinearTerm offset(String variable, String substitute, Number offset) { + LinearTerm term = new LinearTerm(this); + if (!coefficients.containsKey(variable)) { + return this; + } + + Number coeff = coefficients.get(variable); + term.variables.remove(variable); + term.coefficients.remove(variable); + term.setCoefficient(substitute, coeff); + + term.constant = constant.add( + coeff.multiply(offset) + ); + + return term; + } + + public LinearTerm positiveNegativeSubstitute( + String variable, String positive, String negative) { + if(!coefficients.containsKey(variable) || coefficients.get(variable).equals(ZERO())) { + return this; + } + + Number coeff = coefficients.get(variable); + + LinearTerm term = new LinearTerm(this); + term.variables.remove(variable); + term.coefficients.remove(variable); + term.setCoefficient(positive, coeff); + term.setCoefficient(negative, coeff.negate()); + + return term; + } + + public Number evaluate(VariableAssignment assignment) { + Number result = ZERO(); + for (String variable : variables) { + Number summand = coefficients.get(variable).multiply(assignment.getAssignment(variable)); + result = result.add(summand); + } + result = result.add(constant); + return result; + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + coefficients.forEach( + (variable, coefficient) -> { + if (coefficient.isNonNegative()) { + if (!sb.isEmpty()) { + sb.append("+"); + } + } else { + sb.append("-"); + } + + Number absolute = coefficient.abs(); + if (!absolute.isOne()) { + sb.append(absolute); + } + + sb.append(variable); + }); + + if (!constant.equals(ZERO())) { + if (constant.isNonNegative()) { + if (!sb.isEmpty()) { + sb.append("+"); + } + } else { + sb.append("-"); + } + + sb.append(constant.abs()); + } + + if(sb.isEmpty()) { + sb.append("0"); + } + + return sb.toString(); + } +} diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/solver/LinearIntegerSolver.java b/src/main/java/me/paultristanwagner/satchecking/theory/solver/LinearIntegerSolver.java index 047be33..6d25104 100644 --- a/src/main/java/me/paultristanwagner/satchecking/theory/solver/LinearIntegerSolver.java +++ b/src/main/java/me/paultristanwagner/satchecking/theory/solver/LinearIntegerSolver.java @@ -38,7 +38,7 @@ public void addConstraint(LinearConstraint constraint) { @Override public TheoryResult solve() { - if (depth > MAXIMUM_BRANCH_DEPTH) { + /* if (depth > MAXIMUM_BRANCH_DEPTH) { return TheoryResult.unknown(); } @@ -142,6 +142,7 @@ public TheoryResult solve() { return TheoryResult.unknown(); } - return TheoryResult.unsatisfiable(constraints); + return TheoryResult.unsatisfiable(constraints); */ + return null; } } diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/solver/SimplexFeasibilitySolver.java b/src/main/java/me/paultristanwagner/satchecking/theory/solver/SimplexFeasibilitySolver.java index 0cea557..81aef0d 100644 --- a/src/main/java/me/paultristanwagner/satchecking/theory/solver/SimplexFeasibilitySolver.java +++ b/src/main/java/me/paultristanwagner/satchecking/theory/solver/SimplexFeasibilitySolver.java @@ -7,6 +7,7 @@ import java.util.*; +import static me.paultristanwagner.satchecking.theory.LinearConstraint.Bound.*; import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ONE; import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ZERO; @@ -70,16 +71,16 @@ public SimplexResult solve() { for (int j = 0; j < variableSet.size(); j++) { String variable = variables.get(j); - tableau[i][j] = constraint.getCoefficients().getOrDefault(variable, ZERO()); + tableau[i][j] = constraint.getDifference().getCoefficients().getOrDefault(variable, ZERO()); } - if (constraint.getBound() == LinearConstraint.Bound.EQUAL) { - lowerBounds.put(slackName, constraint.getValue()); - upperBounds.put(slackName, constraint.getValue()); - } else if (constraint.getBound() == LinearConstraint.Bound.UPPER) { - upperBounds.put(slackName, constraint.getValue()); - } else if (constraint.getBound() == LinearConstraint.Bound.LOWER) { - lowerBounds.put(slackName, constraint.getValue()); + if (constraint.getBound() == EQUAL) { + lowerBounds.put(slackName, constraint.getDifference().getConstant().negate()); + upperBounds.put(slackName, constraint.getDifference().getConstant().negate()); + } else if (constraint.getBound() == LESS_EQUALS) { + upperBounds.put(slackName, constraint.getDifference().getConstant().negate()); + } else if (constraint.getBound() == GREATER_EQUALS) { + lowerBounds.put(slackName, constraint.getDifference().getConstant().negate()); } } diff --git a/src/main/java/me/paultristanwagner/satchecking/theory/solver/SimplexOptimizationSolver.java b/src/main/java/me/paultristanwagner/satchecking/theory/solver/SimplexOptimizationSolver.java index 66cca42..453281f 100644 --- a/src/main/java/me/paultristanwagner/satchecking/theory/solver/SimplexOptimizationSolver.java +++ b/src/main/java/me/paultristanwagner/satchecking/theory/solver/SimplexOptimizationSolver.java @@ -1,21 +1,25 @@ package me.paultristanwagner.satchecking.theory.solver; +import static me.paultristanwagner.satchecking.theory.LinearConstraint.Bound.EQUAL; +import static me.paultristanwagner.satchecking.theory.LinearConstraint.Bound.LESS_EQUALS; +import static me.paultristanwagner.satchecking.theory.LinearConstraint.greaterThanOrEqual; +import static me.paultristanwagner.satchecking.theory.LinearConstraint.lessThanOrEqual; +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ONE; +import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ZERO; + import com.google.common.collect.BiMap; import com.google.common.collect.HashBiMap; +import java.util.*; +import java.util.Map.Entry; import me.paultristanwagner.satchecking.smt.VariableAssignment; import me.paultristanwagner.satchecking.theory.LinearConstraint; import me.paultristanwagner.satchecking.theory.LinearConstraint.MaximizingConstraint; import me.paultristanwagner.satchecking.theory.LinearConstraint.MinimizingConstraint; +import me.paultristanwagner.satchecking.theory.LinearTerm; import me.paultristanwagner.satchecking.theory.SimplexResult; import me.paultristanwagner.satchecking.theory.arithmetic.Number; import org.apache.commons.lang3.tuple.Pair; -import java.util.*; -import java.util.Map.Entry; - -import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ONE; -import static me.paultristanwagner.satchecking.theory.arithmetic.Number.ZERO; - public class SimplexOptimizationSolver implements TheorySolver { private final List allVariables; @@ -34,8 +38,12 @@ public class SimplexOptimizationSolver implements TheorySolver private int columns; private Number[][] tableau; - private final List originalConstraints; + private List originalConstraints; private List constraints; + + private List differences; + private Map origin; + private LinearConstraint originalObjective; private LinearConstraint objective; @@ -50,6 +58,8 @@ public SimplexOptimizationSolver() { this.originalConstraints = new ArrayList<>(); this.constraints = new ArrayList<>(); + this.differences = new ArrayList<>(); + this.origin = new HashMap<>(); this.substitutions = HashBiMap.create(); this.offsets = new HashMap<>(); @@ -58,7 +68,7 @@ public SimplexOptimizationSolver() { } public void maximize(LinearConstraint f) { - if (!(f instanceof MaximizingConstraint)) { + if (!(f instanceof MaximizingConstraint maximizingConstraint)) { throw new IllegalArgumentException("The objective function must be a maximizing constraint."); } @@ -66,12 +76,12 @@ public void maximize(LinearConstraint f) { throw new IllegalStateException("The objective function has already been set."); } - this.originalObjective = f; - this.objective = f; + this.originalObjective = maximizingConstraint; + this.objective = originalObjective; } public void minimize(LinearConstraint f) { - if (!(f instanceof MinimizingConstraint)) { + if (!(f instanceof MinimizingConstraint minimizingConstraint)) { throw new IllegalArgumentException("The objective function must be a minimizing constraint."); } @@ -79,12 +89,8 @@ public void minimize(LinearConstraint f) { throw new IllegalStateException("The objective function has already been set."); } - this.originalObjective = f; - this.objective = new LinearConstraint(f); - - for (String variable : objective.getVariables()) { - objective.setCoefficient(variable, objective.getCoefficients().get(variable).negate()); - } + this.originalObjective = minimizingConstraint; + this.objective = new MaximizingConstraint(minimizingConstraint.getTerm().negate()); } @Override @@ -99,6 +105,8 @@ public void clear() { this.originalConstraints.clear(); this.constraints.clear(); + this.origin.clear(); + this.differences.clear(); this.substitutions.clear(); this.offsets.clear(); @@ -124,19 +132,18 @@ public SimplexResult solve() { allVariables.addAll(tempSet); // Infer bounds - Pair, Map> inferedBounds = - inferBounds(); + Pair, Map> inferredBounds = inferBounds(); - SimplexResult result = checkBoundsConsistency(inferedBounds); + SimplexResult result = checkBoundsConsistency(inferredBounds); if (result != null && !result.isFeasible()) { return result; } List withoutLowerBounds = new ArrayList<>(allVariables); - withoutLowerBounds.removeAll(inferedBounds.getLeft().keySet()); + withoutLowerBounds.removeAll(inferredBounds.getLeft().keySet()); // Transform constraints, where a single variable has a bound other than zero - transformOffsetVariables(inferedBounds); + transformOffsetVariables(inferredBounds); // Replace unbounded variables replaceUnboundedVariables(withoutLowerBounds); @@ -290,28 +297,36 @@ private Pair, Map> infer Iterator iterator = constraints.iterator(); while (iterator.hasNext()) { LinearConstraint constraint = iterator.next(); - if (constraint.getCoefficients().size() != 1) { + // only evaluate bounds for constraints over a single variable + if (constraint.getVariables().size() != 1) { iterator.remove(); keptConstraints.add(constraint); continue; } + // get the variable that is constrained String constraintVariable = constraint.getVariables().iterator().next(); if (!variable.equals(constraintVariable)) continue; + // the bound on the Number bound = constraint.getBoundOn(constraintVariable); - - if (constraint.getBound() - != LinearConstraint.Bound.UPPER - == constraint.getCoefficients().get(constraintVariable).greaterThan(ZERO())) { - if (!lowerBounds.containsKey(variable) - || lowerBounds.get(variable).getBoundOn(variable).lessThan(ZERO())) { - lowerBounds.put(variable, constraint); + boolean isUpperBound = + constraint.getBound() + == LESS_EQUALS + == constraint + .getDifference() + .getCoefficients() + .get(constraintVariable) + .greaterThanOrEqual(ZERO()); + + // tightening bounds + if(isUpperBound) { + if(!upperBounds.containsKey(constraintVariable) || bound.lessThan(upperBounds.get(constraintVariable).getBoundOn(constraintVariable))) { + upperBounds.put(constraintVariable, constraint); } } else { - if (!upperBounds.containsKey(variable) - || upperBounds.get(variable).getBoundOn(variable).greaterThan(bound)) { - upperBounds.put(variable, constraint); + if(!lowerBounds.containsKey(constraintVariable) || bound.greaterThan(lowerBounds.get(constraintVariable).getBoundOn(constraintVariable))) { + lowerBounds.put(constraintVariable, constraint); } } } @@ -374,7 +389,7 @@ private void transformOffsetVariables( for (int i = 0; i < constraints.size(); i++) { LinearConstraint linearConstraint = constraints.get(i); - if (linearConstraint.getCoefficients().containsKey(variable)) { + if (linearConstraint.constrainsVariable(variable)) { LinearConstraint offsetConstraint = linearConstraint.offset(variable, substitute, bound); constraints.set(i, offsetConstraint); } @@ -402,7 +417,7 @@ private void replaceUnboundedVariables(List withoutLowerBounds) { for (int i = 0; i < constraints.size(); i++) { LinearConstraint linearConstraint = constraints.get(i); - if (linearConstraint.getCoefficients().containsKey(unboundedVariable)) { + if (linearConstraint.constrainsVariable(unboundedVariable)) { LinearConstraint positiveNegative = linearConstraint.positiveNegativeSubstitute(unboundedVariable, positive, negative); constraints.set(i, positiveNegative); @@ -410,7 +425,7 @@ private void replaceUnboundedVariables(List withoutLowerBounds) { } if (objective != null) { - if (objective.getCoefficients().containsKey(unboundedVariable)) { + if(objective.constrainsVariable(unboundedVariable)) { objective = objective.positiveNegativeSubstitute(unboundedVariable, positive, negative); } } @@ -434,25 +449,26 @@ private void createTableau() { if (objective != null) { for (int i = 0; i < allVariables.size(); i++) { String variable = allVariables.get(i); - tableau[0][i] = objective.getCoefficients().getOrDefault(variable, ZERO()).negate(); + // todo: abstract Minimizing and Maximizing Constraint into OptimizingConstraint + tableau[0][i] = objective.getLeftHandSide().getCoefficients().getOrDefault(variable, ZERO()).negate(); } } for (int i = 0; i < constraints.size(); i++) { LinearConstraint constraint = constraints.get(i); - if (constraint.getBound() == LinearConstraint.Bound.LOWER) { + if (constraint.getBound() == LinearConstraint.Bound.GREATER_EQUALS) { for (int j = 0; j < allVariables.size(); j++) { String variable = allVariables.get(j); - tableau[i + 1][j] = constraint.getCoefficients().getOrDefault(variable, ZERO()).negate(); + tableau[i + 1][j] = constraint.getDifference().getCoefficients().getOrDefault(variable, ZERO()).negate(); } - tableau[i + 1][allVariables.size()] = constraint.getValue().negate(); + tableau[i + 1][allVariables.size()] = constraint.getDifference().getConstant(); } else { for (int j = 0; j < allVariables.size(); j++) { String variable = allVariables.get(j); - tableau[i + 1][j] = constraint.getCoefficients().getOrDefault(variable, ZERO()); + tableau[i + 1][j] = constraint.getDifference().getCoefficients().getOrDefault(variable, ZERO()); } - tableau[i + 1][allVariables.size()] = constraint.getValue(); + tableau[i + 1][allVariables.size()] = constraint.getDifference().getConstant().negate(); } tableau[i + 1][nonBasicVariables.size() + i] = ONE(); } @@ -498,7 +514,8 @@ private VariableAssignment calculateSolution() { private Number calculateObjectiveValue() { Number objectiveValue = ZERO(); - for (Entry pair : originalObjective.getCoefficients().entrySet()) { + // todo: also access the optimizing constraint here properly + for (Entry pair : originalObjective.getLeftHandSide().getCoefficients().entrySet()) { Number value = getValue(pair.getKey()); objectiveValue = objectiveValue.add(pair.getValue().multiply(value)); } @@ -524,10 +541,10 @@ private Set calculateExplanation(int pivotRow) { } else { String actual = substitutions.inverse().getOrDefault(variable, variable); for (LinearConstraint originalConstraint : originalConstraints) { - if (originalConstraint.getBound() == LinearConstraint.Bound.UPPER) continue; + if (originalConstraint.getBound() == LESS_EQUALS) continue; // if (originalConstraint.getCoefficients().size() != 1) continue; - String onlyVariable = originalConstraint.getCoefficients().keySet().iterator().next(); + String onlyVariable = originalConstraint.getDifference().getCoefficients().keySet().iterator().next(); if (onlyVariable.equals(actual)) { explanation.add(originalConstraint.getRoot()); } @@ -575,25 +592,13 @@ public void addConstraint(LinearConstraint constraint) { return; } - if (constraint.getBound() == LinearConstraint.Bound.EQUAL) { - LinearConstraint first = new LinearConstraint(); - LinearConstraint second = new LinearConstraint(); + if (constraint.getBound() == EQUAL) { + LinearConstraint first = lessThanOrEqual(constraint.getLeftHandSide(), constraint.getRightHandSide()); + LinearConstraint second = greaterThanOrEqual(constraint.getLeftHandSide(), constraint.getRightHandSide()); + first.setDerivedFrom(constraint); second.setDerivedFrom(constraint); - first.setBound(LinearConstraint.Bound.UPPER); - second.setBound(LinearConstraint.Bound.LOWER); - - constraint - .getCoefficients() - .forEach( - (variable, coefficient) -> { - first.setCoefficient(variable, coefficient); - second.setCoefficient(variable, coefficient); - }); - first.setValue(constraint.getValue()); - second.setValue(constraint.getValue()); - constraints.add(first); constraints.add(second); } else { diff --git a/src/test/java/LinearConstraintParserTest.java b/src/test/java/LinearConstraintParserTest.java index 2e9f84e..9947c41 100644 --- a/src/test/java/LinearConstraintParserTest.java +++ b/src/test/java/LinearConstraintParserTest.java @@ -3,6 +3,7 @@ import me.paultristanwagner.satchecking.theory.arithmetic.Number; import org.junit.jupiter.api.Test; +import static me.paultristanwagner.satchecking.theory.LinearConstraint.Bound.*; import static org.junit.jupiter.api.Assertions.assertEquals; public class LinearConstraintParserTest { @@ -12,14 +13,20 @@ public class LinearConstraintParserTest { @Test public void testConstraints() { LinearConstraint lc0 = parser.parse("-31.17x+-+-101.0y=-+-27.156"); - assertEquals(Number.parse("-31.17"), lc0.getCoefficients().get("x")); - assertEquals(Number.parse("101"), lc0.getCoefficients().get("y")); - assertEquals(Number.parse("27.156"), lc0.getValue()); - assertEquals(LinearConstraint.Bound.EQUAL, lc0.getBound()); + assertEquals(Number.parse("-31.17"), lc0.getDifference().getCoefficients().get("x")); + assertEquals(Number.parse("101"), lc0.getDifference().getCoefficients().get("y")); + assertEquals(Number.parse("-27.156"), lc0.getDifference().getConstant()); + assertEquals(EQUAL, lc0.getBound()); LinearConstraint lc1 = parser.parse("a-b>=-1"); - assertEquals(Number.parse("1"), lc1.getCoefficients().get("a")); - assertEquals(Number.parse("-1"), lc1.getCoefficients().get("b")); - assertEquals(LinearConstraint.Bound.LOWER, lc1.getBound()); + assertEquals(Number.parse("1"), lc1.getDifference().getCoefficients().get("a")); + assertEquals(Number.parse("-1"), lc1.getDifference().getCoefficients().get("b")); + assertEquals(GREATER_EQUALS, lc1.getBound()); + + LinearConstraint lc2 = parser.parse("3x-2<=-2y+1"); + assertEquals(Number.parse("3"), lc2.getDifference().getCoefficients().get("x")); + assertEquals(Number.parse("2"), lc2.getDifference().getCoefficients().get("y")); + assertEquals(Number.parse("-3"), lc2.getDifference().getConstant()); + assertEquals(LESS_EQUALS, lc2.getBound()); } } diff --git a/src/test/java/MultivariatePolynomialTest.java b/src/test/java/MultivariatePolynomialTest.java index 2b895df..3dbbbf9 100644 --- a/src/test/java/MultivariatePolynomialTest.java +++ b/src/test/java/MultivariatePolynomialTest.java @@ -145,7 +145,7 @@ public void testResultant() { assertEquals(number(0), result.getCoefficient(exponent(1, 0, 0))); assertEquals(number(0), result.getCoefficient(exponent(1, 1, 0))); assertEquals(number(0), result.getCoefficient(exponent(1, 2, 0))); - assertEquals(number(-1), result.getCoefficient(exponent(2, 2, 0))); // todo: investigate why the sign is wrong + assertEquals(number(-1), result.getCoefficient(exponent(2, 2, 0))); } @@ -159,8 +159,6 @@ public void testIntervalEvaluation() { Map intervalMap = Map.of("x", xInterval, "y", yInterval); Interval result = p.evaluate(intervalMap); - System.out.println(result); - assertFalse(result.containsZero()); assertEquals(1, result.sign()); }