-
Notifications
You must be signed in to change notification settings - Fork 0
aurora: LVQ Neural Network
Clouke edited this page May 3, 2023
·
1 revision
LVQNeuralNetwork lvq = new LVQNeuralNetworkBuilder()
.learningRate(0.1) // define the learning rate
.learningRateStep(0.99) // define the learning rate step
.decayRate(0.0001) // define the decay rate
.decayStep(0.99) // define the decay step
.epochs(1_000_000) // define the epochs
.inputSize(2) // define the input size
.outputSize(2) // define the output size
.printing(Bar.CLASSIC) // define the printing - Bar.CLASSIC by default
.build();
Create a Data Set:
double[][] inputs = new double[][] {
{0, 0},
{0, 1},
{1, 0},
{1, 1}
};
int[] outputs = new int[] {
0,
1,
1,
0
};
Train the network:
lvq.train(inputs, outputs);
Printing: Printing comes with attributes, providing information about the training process:
-
Loss
: Represents the decreasing error which means the model is improving -
Stage
: Represents the current stage in the training process, which goes to 100 when it is completed -
Accuracy
: Represents the accuracy score of the model, whereas you may useHyperparameterTuning
for the best score -
Epoch
: Represents the current iteration
[##############################=] Loss: 0.4737521042166306 | Stage: 99 | Accuracy: 0.75 | Epoch: 999999 (—)
double[] input = new double[]{0, 1};
int out = lvq.classify(input);
Create a Test Set:
Map<double[], double[]> data = new HashMap<>();
data.put(new double[]{0, 0}, new double[]{0, 0});
data.put(new double[]{0, 1}, new double[]{0, 1});
data.put(new double[]{1, 0}, new double[]{1, 0});
data.put(new double[]{1, 1}, new double[]{1, 1});
TestSet testSet = new TestSet(data);
Evaluate:
Evaluation eval = lvq.evaluate(testSet);
eval.printSummary();
Prints:
Evaluation Summary of Type: Learning Vector Quantization
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1 Score: 1.0
Save your LVQ Neural Network Model:
Model model = lvq.toModel();
model.save("my_directory");
Load from file:
LVQNeuralNetworkModel model = null;
try (ModelLoader loader = new ModelLoader(new File("my_directory"))) {
model = loader.load(LVQNeuralNetworkModel.class);
}
Load from URL:
LVQNeuralNetworkModel model = null;
try (ModelLoader loader = new ModelLoader(new URL("my_model_url"))) {
model = loader.load(LVQNeuralNetworkModel.class);
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}