Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/dice-group/PruneCEL into…
Browse files Browse the repository at this point in the history
… develop
  • Loading branch information
Quannz committed Oct 30, 2024
2 parents eee9613 + 4817280 commit 5ec9ead
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 22 deletions.
50 changes: 31 additions & 19 deletions src/main/java/org/dice_research/cel/PruneCEL.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
import org.dice_research.cel.refine.suggest.ExtendedSuggestor;
import org.dice_research.cel.refine.suggest.SelectionScores;
import org.dice_research.cel.refine.suggest.sparql.SparqlBasedSuggestor;
import org.dice_research.cel.score.AccuracyCalculator;
import org.dice_research.cel.score.AvoidingPickySolutionsDecorator;
import org.dice_research.cel.score.F1MeasureCalculator;
import org.dice_research.cel.score.LengthBasedRefinementScorer;
import org.dice_research.cel.score.ScoreCalculator;
import org.dice_research.cel.score.ScoreCalculatorFactory;
Expand Down Expand Up @@ -233,10 +233,10 @@ public void setDebugMode(boolean debugMode) {

public static void main(String[] args) throws Exception {
// XXX Set SPARQL endpoint
// String endpoint = "http://localhost:9080/sparql";
String endpoint = "http://localhost:9080/sparql";
// String endpoint = "http://localhost:3030/exp-bench/sparql";
// String endpoint = "http://localhost:3030/family/sparql";
String endpoint = "http://dice-quan.cs.uni-paderborn.de:9050/sparql";
// String endpoint = "http://dice-quan.cs.uni-paderborn.de:9050/sparql";
// QALD9-plus-wikidata
// String endpoint = "http://dice-quan.cs.uni-paderborn.de:9070/sparql";
// Family
Expand All @@ -246,17 +246,18 @@ public static void main(String[] args) throws Exception {

ScoreCalculatorFactory factory = null;
// XXX Choose either F1 or balanced accuracy
// factory = new F1MeasureCalculator.Factory();
factory = new F1MeasureCalculator.Factory();
// factory = new BalancedAccuracyCalculator.Factory();
factory = new AccuracyCalculator.Factory();
// factory = new AccuracyCalculator.Factory();

// Punish long expressions
factory = new LengthBasedRefinementScorer.Factory(factory);
// XXX (Optional) avoid choosing solutions that work only for a single example
factory = new AvoidingPickySolutionsDecorator.Factory(factory);

boolean useCache = true;
boolean debugMode = true;
boolean debugMode = false;
boolean skipNonImproving = true;

try (SparqlBasedSuggestor suggestor = SparqlBasedSuggestor.create(endpoint, logic, useCache)) {
suggestor.addToClassBlackList(OWL2.NamedIndividual.getURI());
Expand All @@ -273,11 +274,11 @@ public static void main(String[] args) throws Exception {
// XXX Max iterations of the refinement
// cel.setMaxIterations(1000);
// XXX Maximum time (in ms)
cel.setMaxTime(600000);
cel.setMaxTime(60000);
// XXX (Optional) try to avoid refining expressions that have not been created
// in a promising way (i.e., just added a class to an existing expression
// without changing the accuracy of the expression)
cel.setSkipNonImprovingStmts(true);
cel.setSkipNonImprovingStmts(skipNonImproving);
// XXX Keep this commented for now
// cel.activateRecursiveIteration(suggestor, 1.0, 0.5);
cel.setDebugMode(debugMode);
Expand All @@ -291,14 +292,25 @@ public static void main(String[] args) throws Exception {
// Collection<LearningProblem> problems =
// reader.readProblems("/home/micha/Downloads/TandF_deeppavlov_reverse.json");
// Collection<LearningProblem> problems = reader.readProblems("/home/micha/Downloads/CousinTrain_Fold_3.json");
Collection<LearningProblem> problems = reader
.readProblems("/home/micha/Downloads/TandF_ganswer_reverse.json");
// Collection<LearningProblem> problems = reader
// .readProblems("/home/micha/Downloads/TandF_ganswer_reverse.json");
// Collection<LearningProblem> problems =
// reader.readProblems("LPs/QA/TandF_MST5_reverse.json");
Collection<LearningProblem> problems = reader.readProblems("/home/micha/Downloads/AuntTrain_Fold_2.json");

// DEBUG CODE!!!
// ClassExpression ce;
// ce = new Junction(false,
ClassExpression ce;
ce = new Junction(true, new Junction(false,
new SimpleQuantifiedRole(true, "http://www.benchmark.org/family#married", false,
new NamedClass("http://www.benchmark.org/family#Brother")),
new Junction(true,
new SimpleQuantifiedRole(true, "http://www.benchmark.org/family#hasSibling", false,
new Junction(false, new NamedClass("http://www.benchmark.org/family#Mother"),
new NamedClass("http://www.benchmark.org/family#Father"))),
new NamedClass("http://www.benchmark.org/family#Daughter"))),
new NamedClass("http://www.benchmark.org/family#Female"));

// new Junction(false,
// new Junction(true, new NamedClass("http://www.benchmark.org/family#Son", true),
// new Junction(false, new NamedClass("http://www.benchmark.org/family#Grandfather", true),
// new NamedClass("http://www.benchmark.org/family#Father", true)),
Expand Down Expand Up @@ -334,13 +346,13 @@ public static void main(String[] args) throws Exception {
// new NamedClass("http://www.w3.org/2004/02/skos/core#Concept"),
// new NamedClass("http://dbpedia.org/ontology/Agent"))))));

// LearningProblem prob = problems.iterator().next();
// ScoreCalculator scoreCalculator = factory.create(prob.getPositiveExamples().size(),
// prob.getNegativeExamples().size());
// RefinementOperator rho = new SuggestorBasedRefinementOperator(suggestor, logic, scoreCalculator,
// prob.getPositiveExamples(), prob.getNegativeExamples());
// Set<ScoredClassExpression> expressions = rho.refine(ce, System.currentTimeMillis() + 60000);
// System.out.println(expressions.size());
LearningProblem prob = problems.iterator().next();
ScoreCalculator scoreCalculator = factory.create(prob.getPositiveExamples().size(),
prob.getNegativeExamples().size());
RefinementOperator rho = new SuggestorBasedRefinementOperator(suggestor, logic, scoreCalculator,
prob.getPositiveExamples(), prob.getNegativeExamples());
Set<ScoredClassExpression> expressions = rho.refine(ce, System.currentTimeMillis() + 60000);
System.out.println(expressions.size());

// System.out.println(suggestor.suggestClass(prob.getPositiveExamples(), prob.getNegativeExamples(), ce));
// System.out.println(suggestor.scoreExpression(ce, prob.getPositiveExamples(), prob.getNegativeExamples()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ public ClassExpression preprocess(ClassExpression ce) {
// Check them
ClassExpression[] newExpressions = Arrays.stream(subExpressions).filter(checker)
.toArray(ClassExpression[]::new);
if (subExpressions.length != newExpressions.length) {
System.out.println(subExpressions.length + " vs. " + newExpressions.length);
}
if (newExpressions.length == 0) {
return subExpressions[0];
} else if (newExpressions.length == 1) {
Expand Down

0 comments on commit 5ec9ead

Please sign in to comment.