Skip to content

Commit

Permalink
Removed thread pool analyzing different decision classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
agudys committed Feb 13, 2024
1 parent a22ed84 commit 74d669c
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 148 deletions.
2 changes: 1 addition & 1 deletion adaa.analytics.rules/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ codeQuality {
}

sourceCompatibility = 1.8
version = '1.7.3'
version = '1.7.4'


jar {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ public ClassificationSnC(AbstractFinder finder, InductionParameters params) {
public RuleSetBase run(ExampleSet dataset) {
Logger.log("ClassificationSnC.run()\n", Level.FINE);

ClassificationRuleSet finalRuleset = (ClassificationRuleSet) factory.create(dataset);

// use contrast attribute if specified
final Attribute outputAttr = (dataset.getAttributes().getSpecial(ContrastRule.CONTRAST_ATTRIBUTE_ROLE) == null)
? dataset.getAttributes().getLabel()
Expand All @@ -72,168 +70,129 @@ public RuleSetBase run(ExampleSet dataset) {

boolean weighted = (dataset.getAttributes().getWeight() != null);

ExecutorService pool = Executors.newFixedThreadPool(numClassThreads);
Semaphore mutex = new Semaphore(1);
AtomicInteger totalRules = new AtomicInteger(0);

// perform prepreprocessing
finder.preprocess(dataset);
ClassificationRuleSet ruleset = (ClassificationRuleSet) factory.create(dataset);

// add rulesets from all classes
double defaultClassP = 0;

// array of futures, each consisting of ruleset and P value
Map<Integer, Future<Pair<ClassificationRuleSet, Double>>> futures = new HashMap<Integer, Future<Pair<ClassificationRuleSet, Double>>>();

// iterate over all classes
for (int cid = 0; cid < mapping.size(); ++cid) {
final int classId = cid;
Logger.log("Class " + classId + " started\n" , Level.FINE);

Future<Pair<ClassificationRuleSet, Double>> future = pool.submit( () -> {
Logger.log("Class " + classId + " started\n" , Level.FINE);

preprocessClass(dataset, classId);

ClassificationRuleSet ruleset = (ClassificationRuleSet) factory.create(dataset);

IntegerBitSet positives = new IntegerBitSet(dataset.size());
IntegerBitSet negatives = new IntegerBitSet(dataset.size());
IntegerBitSet uncoveredPositives = new IntegerBitSet(dataset.size());
Set<Integer> uncovered = new HashSet<Integer>();

double weighted_P = 0;
double weighted_N = 0;

// at the beginning rule set does not cover any examples
for (int id = 0; id < dataset.size(); ++id) {
Example e = dataset.getExample(id);
double w = dataset.getAttributes().getWeight() == null ? 1.0 : e.getWeight();

if ((double)e.getValue(outputAttr) == classId) {
weighted_P += w;
positives.add(id);
} else {
weighted_N += w;
negatives.add(id);
}
preprocessClass(dataset, classId);

IntegerBitSet positives = new IntegerBitSet(dataset.size());
IntegerBitSet negatives = new IntegerBitSet(dataset.size());
IntegerBitSet uncoveredPositives = new IntegerBitSet(dataset.size());
Set<Integer> uncovered = new HashSet<Integer>();

double weighted_P = 0;
double weighted_N = 0;

// at the beginning rule set does not cover any examples
for (int id = 0; id < dataset.size(); ++id) {
Example e = dataset.getExample(id);
double w = !weighted ? 1.0 : e.getWeight();

if ((double)e.getValue(outputAttr) == classId) {
weighted_P += w;
positives.add(id);
} else {
weighted_N += w;
negatives.add(id);
}
uncoveredPositives.addAll(positives);
uncovered.addAll(positives);
uncovered.addAll(negatives);

boolean carryOn = uncoveredPositives.size() > 0;
double uncovered_p = weighted_P;

while (carryOn) {

Logger.log("Class " + classId + " uncovered positive weight:" +
uncovered_p + "/" + weighted_P + "\n", Level.FINE);
Rule rule = factory.create(
new CompoundCondition(),
new ElementaryCondition(outputAttr.getName(), new SingletonSet((double)classId, mapping.getValues())));

// rule covers everything at the beginning
rule.setWeighted_P(weighted_P);
rule.setWeighted_N(weighted_N);
rule.setWeighted_p(weighted_P);
rule.setWeighted_n(weighted_N);

rule.setCoveredPositives(new IntegerBitSet(dataset.size()));
rule.setCoveredNegatives(new IntegerBitSet(dataset.size()));
rule.getCoveredPositives().addAll(positives);
rule.getCoveredNegatives().addAll(negatives);
rule.setRuleOrderNum(ruleset.getRules().size());

rule.getConsequence().setCovering(positives);

double t = System.nanoTime();
carryOn = (finder.grow(rule, dataset, uncoveredPositives) > 0);
ruleset.setGrowingTime( ruleset.getGrowingTime() + (System.nanoTime() - t) / 1e9);

if (carryOn) {
if (params.isPruningEnabled()) {
Logger.log("Before prunning:" + rule.toString() + "\n" , Level.FINE);
t = System.nanoTime();
finder.prune(rule, dataset, uncoveredPositives);
ruleset.setPruningTime( ruleset.getPruningTime() + (System.nanoTime() - t) / 1e9);
}


Logger.log("Class " + classId + ", candidate rule " + ruleset.getRules().size() + ":" + rule.toString() + "\n", Level.FINE);

// remove covered examples
int previouslyUncovered = uncoveredPositives.size();
uncoveredPositives.removeAll(rule.getCoveredPositives());
uncovered.removeAll(rule.getCoveredPositives());
uncovered.removeAll(rule.getCoveredNegatives());

uncovered_p = 0;
for (int id : uncoveredPositives) {
Example e = dataset.getExample(id);
uncovered_p += dataset.getAttributes().getWeight() == null ? 1.0 : e.getWeight();
}

Logger.log("Uncovered positives" + uncovered_p + "\n", Level.FINER);


// stop if number of positive examples remaining is less than threshold
if (uncovered_p <= params.getMaximumUncoveredFraction() * weighted_P) {
carryOn = false;
}

// stop and ignore last rule if no new positive examples covered
if (uncoveredPositives.size() == previouslyUncovered) {
carryOn = false;
} else {
finder.postprocess(rule, dataset);
ruleset.addRule(rule);
mutex.acquire(1);
Logger.log( "\r" + StringUtils.repeat("\t", 10) + "\r", Level.INFO);
Logger.log("\t" + totalRules.incrementAndGet() + " rules" , Level.INFO);
mutex.release(1);
}
//report to operator command proxy
this.operatorCommandProxy.onNewRule(rule);
this.operatorCommandProxy.onProgressChange(dataset.size(), uncovered.size());
}
uncoveredPositives.addAll(positives);
uncovered.addAll(positives);
uncovered.addAll(negatives);

if (weighted_P > defaultClassP) {
defaultClassP = weighted_P;
ruleset.setDefaultClass(classId);
}

boolean carryOn = uncoveredPositives.size() > 0;
double uncovered_p = weighted_P;

while (carryOn) {

Logger.log("Class " + classId + " uncovered positive weight:" +
uncovered_p + "/" + weighted_P + "\n", Level.FINE);
Rule rule = factory.create(
new CompoundCondition(),
new ElementaryCondition(outputAttr.getName(), new SingletonSet((double) classId, mapping.getValues())));

// rule covers everything at the beginning
rule.setWeighted_P(weighted_P);
rule.setWeighted_N(weighted_N);
rule.setWeighted_p(weighted_P);
rule.setWeighted_n(weighted_N);

rule.setCoveredPositives(new IntegerBitSet(dataset.size()));
rule.setCoveredNegatives(new IntegerBitSet(dataset.size()));
rule.getCoveredPositives().addAll(positives);
rule.getCoveredNegatives().addAll(negatives);
rule.setRuleOrderNum(ruleset.getRules().size());

rule.getConsequence().setCovering(positives);

double t = System.nanoTime();
carryOn = (finder.grow(rule, dataset, uncoveredPositives) > 0);
ruleset.setGrowingTime(ruleset.getGrowingTime() + (System.nanoTime() - t) / 1e9);

if (carryOn) {
if (params.isPruningEnabled()) {
Logger.log("Before prunning:" + rule.toString() + "\n", Level.FINE);
t = System.nanoTime();
finder.prune(rule, dataset, uncoveredPositives);
ruleset.setPruningTime(ruleset.getPruningTime() + (System.nanoTime() - t) / 1e9);
}


Logger.log("Class " + classId + ", candidate rule " + ruleset.getRules().size() + ":" + rule.toString() + "\n", Level.FINE);

// remove covered examples
int previouslyUncovered = uncoveredPositives.size();
uncoveredPositives.removeAll(rule.getCoveredPositives());
uncovered.removeAll(rule.getCoveredPositives());
uncovered.removeAll(rule.getCoveredNegatives());

uncovered_p = 0;
for (int id : uncoveredPositives) {
Example e = dataset.getExample(id);
uncovered_p += dataset.getAttributes().getWeight() == null ? 1.0 : e.getWeight();
}
if (this.operatorCommandProxy.isRequestStop()) {

Logger.log("Uncovered positives" + uncovered_p + "\n", Level.FINER);

// stop if number of positive examples remaining is less than threshold
if (uncovered_p <= params.getMaximumUncoveredFraction() * weighted_P) {
carryOn = false;
}
}

return new Pair<ClassificationRuleSet, Double>(ruleset, weighted_P);
});

futures.put(classId, future);
}

// add rulesets from all classes
double defaultClassP = 0;

for (int classId = 0; classId < mapping.size(); ++classId) {
Pair<ClassificationRuleSet, Double> result;

try {
result = futures.get(classId).get();
ClassificationRuleSet partialSet = result.getFirst();
finalRuleset.getRules().addAll(partialSet.getRules());
finalRuleset.setGrowingTime( finalRuleset.getGrowingTime() + partialSet.getGrowingTime());
finalRuleset.setPruningTime( finalRuleset.getPruningTime() + partialSet.getPruningTime());

// set default class
if (result.getSecond() > defaultClassP) {
defaultClassP = result.getSecond();
finalRuleset.setDefaultClass(classId);
// stop and ignore last rule if no new positive examples covered
if (uncoveredPositives.size() == previouslyUncovered) {
carryOn = false;
} else {
finder.postprocess(rule, dataset);
ruleset.addRule(rule);
Logger.log("\r" + StringUtils.repeat("\t", 10) + "\r", Level.INFO);
Logger.log("\t" + ruleset.getRules().size() + " rules", Level.INFO);
}
//report to operator command proxy
this.operatorCommandProxy.onNewRule(rule);
this.operatorCommandProxy.onProgressChange(dataset.size(), uncovered.size());
}
if (this.operatorCommandProxy.isRequestStop()) {
carryOn = false;
}

} catch (InterruptedException | ExecutionException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}

}

pool.shutdown();

return finalRuleset;
return ruleset;
}

public void preprocessClass(ExampleSet dataset, int classId) {
Expand Down

0 comments on commit 74d669c

Please sign in to comment.