Skip to content

Commit

Permalink
Evolog Modules: Permit 0..n input predicates in module definitions.
Browse files Browse the repository at this point in the history
  • Loading branch information
madmike200590 committed Jul 24, 2024
1 parent 90d3faf commit 761331c
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 30 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
package at.ac.tuwien.kr.alpha.api.programs;

import at.ac.tuwien.kr.alpha.api.programs.modules.Module;
import at.ac.tuwien.kr.alpha.api.programs.rules.NormalRule;

import java.util.List;

/**
* A {@link Program} consisting only of facts and {@link NormalRule}s, i.e. no disjunctive- or choice-rules, and no aggregates in rule bodies.
*
* Copyright (c) 2021, the Alpha Team.
*/
public interface NormalProgram extends Program<NormalRule> {

List<Module> getModules();

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public interface Module {

String getName();

Predicate getInputSpec();
Set<Predicate> getInputSpec();

Set<Predicate> getOutputSpec();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import at.ac.tuwien.kr.alpha.api.programs.InlineDirectives;
import at.ac.tuwien.kr.alpha.api.programs.NormalProgram;
import at.ac.tuwien.kr.alpha.api.programs.atoms.Atom;
import at.ac.tuwien.kr.alpha.api.programs.modules.Module;
import at.ac.tuwien.kr.alpha.api.programs.rules.NormalRule;

/**
Expand All @@ -14,8 +15,16 @@
*/
class NormalProgramImpl extends AbstractProgram<NormalRule> implements NormalProgram {

NormalProgramImpl(List<NormalRule> rules, List<Atom> facts, InlineDirectives inlineDirectives) {
private final List<Module> modules;

NormalProgramImpl(List<NormalRule> rules, List<Atom> facts, InlineDirectives inlineDirectives, List<Module> modules) {
super(rules, facts, inlineDirectives);
this.modules = modules;
}

@Override
public List<Module> getModules() {
return modules;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,19 @@ public static InputProgramBuilder builder(InputProgram program) {
}

public static NormalProgram newNormalProgram(List<NormalRule> rules, List<Atom> facts, InlineDirectives inlineDirectives) {
return new NormalProgramImpl(rules, facts, inlineDirectives);
return new NormalProgramImpl(rules, facts, inlineDirectives, Collections.emptyList());
}

public static NormalProgram newNormalProgram(List<NormalRule> rules, List<Atom> facts, InlineDirectives inlineDirectives, List<Module> modules) {
return new NormalProgramImpl(rules, facts, inlineDirectives, modules);
}

public static NormalProgram toNormalProgram(InputProgram inputProgram) {
List<NormalRule> normalRules = new ArrayList<>();
for (Rule<Head> r : inputProgram.getRules()) {
normalRules.add(Rules.toNormalRule(r));
}
return new NormalProgramImpl(normalRules, inputProgram.getFacts(), inputProgram.getInlineDirectives());
return new NormalProgramImpl(normalRules, inputProgram.getFacts(), inputProgram.getInlineDirectives(), inputProgram.getModules());
}

public static InlineDirectives newInlineDirectives() {
Expand All @@ -81,6 +85,7 @@ public InputProgramBuilder(InputProgram prog) {
this.addFacts(prog.getFacts());
this.addInlineDirectives(prog.getInlineDirectives());
this.addTestCases(prog.getTestCases());
this.addModules(prog.getModules());
}

public InputProgramBuilder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
class ModuleImpl implements Module {

private final String name;
private final Predicate inputSpec;
private final Set<Predicate> inputSpec;
private final Set<Predicate> outputSpec;
private final InputProgram implementation;

ModuleImpl(String name, Predicate inputSpec, Set<Predicate> outputSpec, InputProgram implementation) {
ModuleImpl(String name, Set<Predicate> inputSpec, Set<Predicate> outputSpec, InputProgram implementation) {
this.name = name;
this.inputSpec = inputSpec;
this.outputSpec = outputSpec;
Expand All @@ -26,7 +26,7 @@ public String getName() {
}

@Override
public Predicate getInputSpec() {
public Set<Predicate> getInputSpec() {
return this.inputSpec;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ private Modules() {
throw new AssertionError("Cannot instantiate utility class!");
}

public static Module newModule(final String name, final Predicate inputSpec, final Set<Predicate> outputSpec, final InputProgram implementation) {
public static Module newModule(final String name, final Set<Predicate> inputSpec, final Set<Predicate> outputSpec, final InputProgram implementation) {
return new ModuleImpl(name, inputSpec, outputSpec, implementation);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ public static NormalRule toNormalRule(Rule<Head> rule) {
if (!(rule.getHead() instanceof NormalHead)) {
throw Util.oops("Trying to construct a NormalRule from rule with non-normal head! Head type is: " + rule.getHead().getClass().getSimpleName());
}

}
return newNormalRule(rule.isConstraint() ? null : (NormalHead) rule.getHead(), new LinkedHashSet<>(rule.getBody()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,5 @@ test_assert_all : TEST_ASSERT_ALL CURLY_OPEN statements? CURLY_CLOSE;

test_assert_some : TEST_ASSERT_SOME CURLY_OPEN statements? CURLY_CLOSE;

module_signature : predicate_spec ARROW CURLY_OPEN ('*' | predicate_specs) CURLY_CLOSE;
module_signature : CURLY_OPEN predicate_specs? CURLY_CLOSE ARROW CURLY_OPEN ('*' | predicate_specs) CURLY_CLOSE;

Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,18 @@ public Object visitDirective_module(ASPCore2Parser.Directive_moduleContext ctx)
}
// directive_module: SHARP DIRECTIVE_MODULE id PAREN_OPEN module_signature PAREN_CLOSE CURLY_OPEN statements CURLY_CLOSE;
String name = visitId(ctx.id());
ImmutablePair<Predicate, Set<Predicate>> moduleSignature = visitModule_signature(ctx.module_signature());
ImmutablePair<Set<Predicate>, Set<Predicate>> moduleSignature = visitModule_signature(ctx.module_signature());
startNestedProgram();
visitStatements(ctx.statements());
InputProgram moduleImplementation = endNestedProgram();
currentLevelProgramBuilder.addModule(Modules.newModule(name, moduleSignature.getLeft(), moduleSignature.getRight(), moduleImplementation));
return null;
}

public ImmutablePair<Predicate, Set<Predicate>> visitModule_signature(ASPCore2Parser.Module_signatureContext ctx) {
Predicate inputPredicate = visitPredicate_spec(ctx.predicate_spec());
Set<Predicate> outputPredicates = ctx.predicate_specs() != null ? visitPredicate_specs(ctx.predicate_specs()) : Collections.emptySet();
return ImmutablePair.of(inputPredicate, outputPredicates);
public ImmutablePair<Set<Predicate>, Set<Predicate>> visitModule_signature(ASPCore2Parser.Module_signatureContext ctx) {
Set<Predicate> inputPredicates = ctx.predicate_specs(0) != null ? visitPredicate_specs(ctx.predicate_specs(0)) : Collections.emptySet();
Set<Predicate> outputPredicates = ctx.predicate_specs(1) != null ? visitPredicate_specs(ctx.predicate_specs(1)) : Collections.emptySet();
return ImmutablePair.of(inputPredicates, outputPredicates);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,17 @@ public class ParserTest {
private static final String UNIT_TEST_KEYWORDS_AS_IDS =
"assert(a) :- given(b). # test test(expect: 1) { given { given(b). } assertForAll { :- not assert(a). :- assertForSome(b).}}";

private static final String MODULE_SIMPLE = "#module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";
private static final String MODULE_SIMPLE = "#module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";

private static final String MODULE_OUTPUT_ALL = "#module mod(in/1 => {*}) { a(X). b(X) :- a(X).}";
private static final String MODULE_OUTPUT_ALL = "#module mod({in/1} => {*}) { a(X). b(X) :- a(X).}";

private static final String MODULE_WITH_REGULAR_STMTS = "p(a). p(b). q(X) :- p(X). #module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";
private static final String MODULE_WITH_REGULAR_STMTS = "p(a). p(b). q(X) :- p(X). #module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";

private static final String MODULE_MULTIPLE_DEFINITIONS = "a. b(5). #module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). } q(Y) :- r(S, Y), t(S). #module anotherModule(input/1 => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";
private static final String MODULE_MULTIPLE_DEFINITIONS = "a. b(5). #module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). } q(Y) :- r(S, Y), t(S). #module anotherModule({input/1} => {out1/2, out2/3}) { p(a). p(b). q(X) :- p(X). }";

private static final String MODULE_EMPTY_INPUT_SPEC = "#module someModule({} => {*}) {p(a).}";

private static final String MODULE_MULTIPLE_INPUTS = "#module someModule({input/1, input2/2} => {out1/2}) {p(a).}";

private static final String MODULE_LITERAL = "p(a). q(b). r(X) :- p(X), q(Y), #mod[X, Y](X).";

Expand Down Expand Up @@ -332,9 +336,11 @@ public void simpleModule() {
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("aSimpleModule", module.getName());
Predicate inputSpec = module.getInputSpec();
assertEquals("input", inputSpec.getName());
assertEquals(1, inputSpec.getArity());
Set<Predicate> inputSpec = module.getInputSpec();
assertEquals(1, inputSpec.size());
Predicate inputPredicate = inputSpec.iterator().next();
assertEquals("input", inputPredicate.getName());
assertEquals(1, inputPredicate.getArity());
Set<Predicate> outputSpec = module.getOutputSpec();
assertEquals(2, outputSpec.size());
assertTrue(outputSpec.contains(Predicates.getPredicate("out1", 2)));
Expand All @@ -352,9 +358,11 @@ public void moduleOutputAll() {
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("mod", module.getName());
Predicate inputSpec = module.getInputSpec();
assertEquals("in", inputSpec.getName());
assertEquals(1, inputSpec.getArity());
Set<Predicate> inputSpec = module.getInputSpec();
assertEquals(1, inputSpec.size());
Predicate inputPredicate = inputSpec.iterator().next();
assertEquals("in", inputPredicate.getName());
assertEquals(1, inputPredicate.getArity());
assertTrue(module.getOutputSpec().isEmpty());
InputProgram implementation = module.getImplementation();
assertEquals(1, implementation.getFacts().size());
Expand All @@ -371,9 +379,11 @@ public void moduleAndRegularStmts() {
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("aSimpleModule", module.getName());
Predicate inputSpec = module.getInputSpec();
assertEquals("input", inputSpec.getName());
assertEquals(1, inputSpec.getArity());
Set<Predicate> inputSpec = module.getInputSpec();
assertEquals(1, inputSpec.size());
Predicate inputPredicate = inputSpec.iterator().next();
assertEquals("input", inputPredicate.getName());
assertEquals(1, inputPredicate.getArity());
Set<Predicate> outputSpec = module.getOutputSpec();
assertEquals(2, outputSpec.size());
assertTrue(outputSpec.contains(Predicates.getPredicate("out1", 2)));
Expand All @@ -397,13 +407,42 @@ public void multipleModuleDefinitions() {
@Test
public void invalidNestedModule() {
assertThrows(IllegalStateException.class, () ->
parser.parse("#module aSimpleModule(input/1 => {out1/2, out2/3}) { p(a). p(b). #module anotherModule(input/1 => {out1/2, out2/3}) { p(a). p(b). } }"));
parser.parse("#module aSimpleModule({input/1} => {out1/2, out2/3}) { p(a). p(b). #module anotherModule({input/1} => {out1/2, out2/3}) { p(a). p(b). } }"));
}

@Test
public void invalidNestedTest() {
assertThrows(IllegalStateException.class, () ->
parser.parse("#module mod(foo/1 => {*}) { #test test(expect: 1) { given { b. } assertForAll { :- a. } } }"));
parser.parse("#module mod({foo/1} => {*}) { #test test(expect: 1) { given { b. } assertForAll { :- a. } } }"));
}

@Test
public void emptyInputSpec() {
InputProgram prog = parser.parse(MODULE_EMPTY_INPUT_SPEC);
List<Module> modules = prog.getModules();
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("someModule", module.getName());
Set<Predicate> inputSpec = module.getInputSpec();
assertTrue(inputSpec.isEmpty());
Set<Predicate> outputSpec = module.getOutputSpec();
assertTrue(outputSpec.isEmpty());
}

@Test
public void multipleInputs() {
InputProgram prog = parser.parse(MODULE_MULTIPLE_INPUTS);
List<Module> modules = prog.getModules();
assertEquals(1, modules.size());
Module module = modules.get(0);
assertEquals("someModule", module.getName());
Set<Predicate> inputSpec = module.getInputSpec();
assertEquals(2, inputSpec.size());
assertTrue(inputSpec.contains(Predicates.getPredicate("input", 1)));
assertTrue(inputSpec.contains(Predicates.getPredicate("input2", 2)));
Set<Predicate> outputSpec = module.getOutputSpec();
assertEquals(1, outputSpec.size());
assertTrue(outputSpec.contains(Predicates.getPredicate("out1", 2)));
}

@Test
Expand Down

0 comments on commit 761331c

Please sign in to comment.