diff --git a/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/atoms/ModuleAtom.java b/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/atoms/ModuleAtom.java index 17cabfc4a..42fedae92 100644 --- a/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/atoms/ModuleAtom.java +++ b/alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/atoms/ModuleAtom.java @@ -25,13 +25,34 @@ public interface ModuleAtom extends Atom { @Override ModuleAtom substitute(Substitution substitution); + @Override + ModuleAtom withTerms(List terms); + interface ModuleInstantiationMode { Optional requestedAnswerSets(); ModuleInstantiationMode ALL = Optional::empty; static ModuleInstantiationMode forNumAnswerSets(int answerSets) { - return () -> Optional.of(answerSets); + return new ModuleInstantiationMode() { + @Override + public Optional requestedAnswerSets() { + return Optional.of(answerSets); + } + + @Override + public int hashCode() { + return answerSets; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof ModuleInstantiationMode)) { + return false; + } + return ((ModuleInstantiationMode) obj).requestedAnswerSets().equals(this.requestedAnswerSets()); + } + }; } } diff --git a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/ModuleAtomImpl.java b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/ModuleAtomImpl.java index a971214c4..72178d72b 100644 --- a/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/ModuleAtomImpl.java +++ b/alpha-commons/src/main/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/ModuleAtomImpl.java @@ -2,7 +2,6 @@ import at.ac.tuwien.kr.alpha.api.grounder.Substitution; import at.ac.tuwien.kr.alpha.api.programs.Predicate; -import at.ac.tuwien.kr.alpha.api.programs.atoms.Atom; import at.ac.tuwien.kr.alpha.api.programs.atoms.ModuleAtom; import at.ac.tuwien.kr.alpha.api.programs.literals.ModuleLiteral; import at.ac.tuwien.kr.alpha.api.programs.terms.Term; @@ -50,7 +49,7 @@ public ModuleInstantiationMode getInstantiationMode() { } @Override - public Atom withTerms(List terms) { + public ModuleAtom withTerms(List terms) { if (terms.size() != this.input.size() + this.output.size()) { throw new IllegalArgumentException( "Cannot apply term list " + terms + " to module atom " + this + ", terms has invalid size!"); diff --git a/alpha-commons/src/test/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/ModuleAtomImplTest.java b/alpha-commons/src/test/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/ModuleAtomImplTest.java new file mode 100644 index 000000000..efa2d1353 --- /dev/null +++ b/alpha-commons/src/test/java/at/ac/tuwien/kr/alpha/commons/programs/atoms/ModuleAtomImplTest.java @@ -0,0 +1,74 @@ +package at.ac.tuwien.kr.alpha.commons.programs.atoms; + +import java.util.List; + +import at.ac.tuwien.kr.alpha.api.programs.atoms.ModuleAtom; +import at.ac.tuwien.kr.alpha.commons.programs.terms.Terms; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +public class ModuleAtomImplTest { + + @Test + public void withTerms() { + ModuleAtom moduleAtom = new ModuleAtomImpl("someModule", + List.of(Terms.newVariable("X"), Terms.newVariable("Y")), + List.of(Terms.newVariable("Z")), ModuleAtom.ModuleInstantiationMode.ALL); + ModuleAtom newModuleAtom = moduleAtom.withTerms(List.of(Terms.newConstant(1), Terms.newConstant(2), Terms.newConstant(3))); + // Check correct construction of original atom (also check that withTerms didn't modify the original atom) + assertEquals(moduleAtom.getInput().size(), 2); + assertEquals(moduleAtom.getOutput().size(), 1); + assertEquals(moduleAtom.getModuleName(), "someModule"); + assertEquals(moduleAtom.getInstantiationMode(), ModuleAtom.ModuleInstantiationMode.ALL); + // Check terms of new atom + assertEquals(newModuleAtom.getInput().size(), 2); + assertEquals(newModuleAtom.getOutput().size(), 1); + assertEquals(newModuleAtom.getModuleName(), "someModule"); + assertEquals(newModuleAtom.getInstantiationMode(), ModuleAtom.ModuleInstantiationMode.ALL); + assertEquals(List.of(Terms.newConstant(1), Terms.newConstant(2)), newModuleAtom.getInput()); + assertEquals(List.of(Terms.newConstant(3)), newModuleAtom.getOutput()); + } + + @Test + public void withTermsNewTermsTooLong() { + ModuleAtom moduleAtom = new ModuleAtomImpl("someModule", + List.of(Terms.newVariable("X"), Terms.newVariable("Y")), + List.of(Terms.newVariable("Z")), ModuleAtom.ModuleInstantiationMode.ALL); + assertThrows(IllegalArgumentException.class, + () -> moduleAtom.withTerms( + List.of(Terms.newConstant(1), Terms.newConstant(2), + Terms.newConstant(3), Terms.newConstant(4)))); + } + + @Test + public void withTermsNewTermsTooShort() { + ModuleAtom moduleAtom = new ModuleAtomImpl("someModule", + List.of(Terms.newVariable("X"), Terms.newVariable("Y")), + List.of(Terms.newVariable("Z")), ModuleAtom.ModuleInstantiationMode.ALL); + assertThrows(IllegalArgumentException.class, + () -> moduleAtom.withTerms(List.of(Terms.newConstant(1)))); + } + + @Test + public void moduleAtomsEqual() { + ModuleAtom m1 = new ModuleAtomImpl("someModule", + List.of(Terms.newVariable("X"), Terms.newVariable("Y")), + List.of(Terms.newVariable("Z")), ModuleAtom.ModuleInstantiationMode.ALL); + ModuleAtom m2 = new ModuleAtomImpl("someModule", + List.of(Terms.newVariable("X"), Terms.newVariable("Y")), + List.of(Terms.newVariable("Z")), ModuleAtom.ModuleInstantiationMode.ALL); + assertEquals(m1, m2); + assertEquals(m1.hashCode(), m2.hashCode()); + + ModuleAtom m3 = new ModuleAtomImpl("someModule", + List.of(Terms.newVariable("X"), Terms.newVariable("Y")), + List.of(Terms.newVariable("Z")), ModuleAtom.ModuleInstantiationMode.forNumAnswerSets(3)); + ModuleAtom m4 = new ModuleAtomImpl("someModule", + List.of(Terms.newVariable("X"), Terms.newVariable("Y")), + List.of(Terms.newVariable("Z")), ModuleAtom.ModuleInstantiationMode.forNumAnswerSets(3)); + assertNotEquals(m1, m3); + assertEquals(m3, m4); + } + +}