Skip to content

Commit

Permalink
Model saving [beta]
Browse files Browse the repository at this point in the history
  • Loading branch information
satiracode committed Aug 4, 2024
1 parent 90b52f0 commit 5d3fa06
Show file tree
Hide file tree
Showing 37 changed files with 1,076 additions and 32 deletions.
3 changes: 2 additions & 1 deletion mlcore/src/main/java/org/owasp/netryx/mlcore/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import org.owasp.netryx.mlcore.metrics.ModelEvaluator;
import org.owasp.netryx.mlcore.params.HyperParameter;
import org.owasp.netryx.mlcore.prediction.Prediction;
import org.owasp.netryx.mlcore.serialize.MLComponent;

import java.util.List;

public interface Model {
public interface Model extends MLComponent {
void fit(DataFrame X, DataFrame y);

List<? extends Prediction> predict(DataFrame x);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package org.owasp.netryx.mlcore.encoder;

import org.owasp.netryx.mlcore.frame.DataFrame;
import org.owasp.netryx.mlcore.serialize.MLComponent;

public interface Encoder {
public interface Encoder extends MLComponent {
void fit(DataFrame df, String columnName);

DataFrame transform(DataFrame df);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import org.owasp.netryx.mlcore.frame.series.AbstractSeries;
import org.owasp.netryx.mlcore.frame.DataFrame;
import org.owasp.netryx.mlcore.frame.series.DoubleSeries;
import org.owasp.netryx.mlcore.serialize.component.StringDoubleMapComponent;
import org.owasp.netryx.mlcore.serialize.flag.MLFlag;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.*;

public class LabelEncoder implements Encoder {
Expand Down Expand Up @@ -48,4 +53,28 @@ public Map<String, Double> getLabelMapping() {
public String getColumnName() {
return columnName;
}
}

@Override
public void save(DataOutputStream out) throws IOException {
out.writeInt(MLFlag.START_ENCODER);

out.writeUTF(columnName);
new StringDoubleMapComponent(labelMapping).save(out);

out.writeInt(MLFlag.END_ENCODER);
}

@Override
public void load(DataInputStream in) throws IOException {
MLFlag.ensureStartEncoder(in.readInt());

this.columnName = in.readUTF();

var component = new StringDoubleMapComponent();
component.load(in);

this.labelMapping = component.getMap();

MLFlag.ensureEndEncoder(in.readInt());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import org.owasp.netryx.mlcore.frame.series.AbstractSeries;
import org.owasp.netryx.mlcore.frame.DataFrame;
import org.owasp.netryx.mlcore.frame.series.Series;
import org.owasp.netryx.mlcore.serialize.flag.MLFlag;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
Expand Down Expand Up @@ -49,4 +53,33 @@ public String getColumnName() {
public Set<String> getUniqueValues() {
return uniqueValues;
}

@Override
public void save(DataOutputStream out) throws IOException {
out.writeInt(MLFlag.START_ENCODER);
out.writeUTF(columnName);

var size = uniqueValues.size();

out.writeInt(size);
for (var value : uniqueValues)
out.writeUTF(value);

out.writeInt(MLFlag.END_ENCODER);
}

@Override
public void load(DataInputStream in) throws IOException {
MLFlag.ensureStartEncoder(in.readInt());

columnName = in.readUTF();

var size = in.readInt();
this.uniqueValues = new HashSet<>(size);

for (var i = 0; i < size; i++)
uniqueValues.add(in.readUTF());

MLFlag.ensureEndEncoder(in.readInt());
}
}
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
package org.owasp.netryx.mlcore.encoder.tfidf;

import org.owasp.netryx.mlcore.serialize.MLComponent;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.regex.Pattern;

public class NGram {
private final int minNgram;
private final int maxNgram;
private final Pattern tokenPattern;
public class NGram implements MLComponent {
private int minNgram;
private int maxNgram;
private Pattern tokenPattern;

public NGram(int minNgram, int maxNgram, String tokenPattern) {
this.minNgram = minNgram;
this.maxNgram = maxNgram;
this.tokenPattern = Pattern.compile(tokenPattern);
}

public NGram() {
this(0, 0, "");
}

public List<String> extractNgrams(String document) {
List<String> ngrams = new ArrayList<>();

Expand All @@ -37,4 +47,29 @@ public List<String> extractNgrams(String document) {
public static NGram create(int minNgram, int maxNgram, String tokenPattern) {
return new NGram(minNgram, maxNgram, tokenPattern);
}

@Override
public void save(DataOutputStream out) throws IOException {
out.writeInt(minNgram);
out.writeInt(maxNgram);

var pattern = tokenPattern.pattern().getBytes(StandardCharsets.UTF_8);
out.writeInt(pattern.length);
out.write(pattern);
}

@Override
public void load(DataInputStream in) throws IOException {
minNgram = in.readInt();
maxNgram = in.readInt();

var size = in.readInt();
var bytes = new byte[size];
var readBytes = in.read(bytes);

if (readBytes != size)
throw new IllegalArgumentException("Not a pattern: " + new String(bytes));

this.tokenPattern = Pattern.compile(new String(bytes));
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package org.owasp.netryx.mlcore.encoder.tfidf;

import org.owasp.netryx.mlcore.encoder.Encoder;
import org.owasp.netryx.mlcore.frame.series.AbstractSeries;
import org.owasp.netryx.mlcore.frame.DataFrame;
import org.owasp.netryx.mlcore.frame.series.AbstractSeries;
import org.owasp.netryx.mlcore.frame.series.DoubleSeries;
import org.owasp.netryx.mlcore.serialize.component.StringDoubleMapComponent;
import org.owasp.netryx.mlcore.serialize.flag.MLFlag;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -15,7 +20,8 @@
public class TfidfEncoder implements Encoder {
private String columnName;
private Map<String, Double> idfValues;
private final NGram nGram;
private NGram nGram;

private final Executor executor;

public TfidfEncoder(NGram nGram, int parallelism) {
Expand Down Expand Up @@ -65,7 +71,6 @@ private void computeIdfValues(Map<String, Integer> documentFrequencies, int numD
documentFrequencies.forEach((term, docFrequency) -> {
var idf = Math.log((double) (numDocuments + 1) / (1 + docFrequency));
idfValues.put(term, idf);
System.out.println("Term: " + term + ", Doc Frequency: " + docFrequency + ", IDF: " + idf);
});
}

Expand Down Expand Up @@ -109,9 +114,38 @@ private DataFrame createTransformedDataFrame(DataFrame df, List<Map<String, Doub
var termTfidfValues = tfidfVectors.stream()
.map(vector -> vector.getOrDefault(term, 0.0))
.collect(Collectors.toList());

newData.put(term, new DoubleSeries(termTfidfValues));
});

return new DataFrame(newData);
}

@Override
public void save(DataOutputStream out) throws IOException {
out.writeInt(MLFlag.START_ENCODER);

out.writeUTF(columnName);
new StringDoubleMapComponent(idfValues).save(out);
nGram.save(out);

out.writeInt(MLFlag.END_ENCODER);
}

@Override
public void load(DataInputStream in) throws IOException {
MLFlag.ensureStartEncoder(in.readInt());

columnName = in.readUTF();

var component = new StringDoubleMapComponent();
component.load(in);

idfValues = component.getMap();

this.nGram = new NGram();
nGram.load(in);

MLFlag.ensureEndEncoder(in.readInt());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class ARFFLoader implements DataFrameLoader {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import org.ejml.simple.SimpleMatrix;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;

public class LogisticLossFunction implements LossFunction {
@Override
public SimpleMatrix predict(SimpleMatrix X, SimpleMatrix coefficients) {
Expand Down Expand Up @@ -32,4 +36,14 @@ private SimpleMatrix sigmoid(SimpleMatrix z) {
}
return result;
}

@Override
public void save(DataOutputStream out) throws IOException {
// nothing to store
}

@Override
public void load(DataInputStream in) throws IOException {
// nothing to store
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package org.owasp.netryx.mlcore.loss;

import org.ejml.simple.SimpleMatrix;
import org.owasp.netryx.mlcore.serialize.MLComponent;

public interface LossFunction {
public interface LossFunction extends MLComponent {
SimpleMatrix predict(SimpleMatrix X, SimpleMatrix coefficients);

SimpleMatrix gradient(SimpleMatrix X, SimpleMatrix y, SimpleMatrix coefficients);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

import org.ejml.simple.SimpleMatrix;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;

public class MeanSquaredError implements LossFunction {
@Override
public SimpleMatrix predict(SimpleMatrix X, SimpleMatrix coefficients) {
Expand All @@ -19,4 +23,14 @@ public double loss(SimpleMatrix y, SimpleMatrix predictions) {
var diff = y.minus(predictions);
return diff.elementMult(diff).elementSum() / y.getNumRows();
}

@Override
public void save(DataOutputStream out) throws IOException {
// nothing to store
}

@Override
public void load(DataInputStream in) throws IOException {
// nothing to store
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
import org.owasp.netryx.mlcore.params.HyperParameter;
import org.owasp.netryx.mlcore.prediction.LabelPrediction;
import org.owasp.netryx.mlcore.regularization.Regularization;
import org.owasp.netryx.mlcore.serialize.component.MatrixComponent;
import org.owasp.netryx.mlcore.serialize.flag.MLFlag;
import org.owasp.netryx.mlcore.util.DataUtil;

import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
Expand Down Expand Up @@ -74,4 +79,26 @@ public static LinearRegression create(Regularization regularizer) {
public static LinearRegression create() {
return create(null);
}

@Override
public void save(DataOutputStream out) throws IOException {
out.writeInt(MLFlag.START_MODEL);

new MatrixComponent(coefficients).save(out);
regularizer.save(out);

out.writeInt(MLFlag.END_MODEL);
}

@Override
public void load(DataInputStream in) throws IOException {
MLFlag.ensureStartModel(in.readInt());

var matrix = new MatrixComponent();
matrix.load(in);
this.coefficients = matrix.getMatrix();

regularizer.load(in);
MLFlag.ensureEndModel(in.readInt());
}
}
Loading

0 comments on commit 5d3fa06

Please sign in to comment.