-
Notifications
You must be signed in to change notification settings - Fork 0
aurora: Logistic Regression
Clouke edited this page May 4, 2023
·
1 revision
LogisticRegression lr = new LogisticRegressionBuilder()
.inputSize(3) // the input size - must be the same as the input array length
.outputSize(1) // the output size - must be the same as the output array length
.learningRate(0.1) // define the learning rate
.epochs(1_000_000) // define the epochs
.activation(ActivationFunction.SIGMOID) // define the activation function - ActivationFunction.SIGMOID by default
.printing(Bar.CLASSIC) // define the printing - Bar.CLASSIC by default
.build();
Create a Data Set:
double[][] inputs = new double[][] { // 3 input size
{0, 0, 1},
{0, 1, 1},
{1, 0, 1},
{1, 1, 1}
};
double[][] outputs = new double[][] { // 1 output size
{0},
{1},
{1},
{0}
};
Train the model:
lr.train(inputs, outputs);
Printing: Printing comes with attributes, providing information about the training process:
-
Loss
: Represents the decreasing error which means the model is improving -
Stage
: Represents the current stage in the training process, which goes to 100 when it is completed -
Accuracy
: Represents the accuracy score of the model, whereas you may useHyperparameterTuning
for the best score -
Epoch
: Represents the current iteration
[##################=============] Loss: 3.7998585945554775E-14 | Stage: 57 | Accuracy: 0.999999999999962 | Epoch: 570000 (\)
double[] output = lr.predict(new double[] {1, 0, 0}); // output should be close to 0.9 since we trained the output to be 1
Create a Test Set:
Map<double[], double[]> data = new HashMap<>();
data.put(new double[] {0, 0, 1}, new double[] {0});
data.put(new double[] {0, 1, 1}, new double[] {0});
data.put(new double[] {1, 0, 1}, new double[] {1});
data.put(new double[] {1, 1, 1}, new double[] {1});
TestSet testSet = new TestSet(data);
Evaluate:
Evaluation eval = lr.evaluate(testSet);
eval.printSummary();
Prints:
Evaluation Summary of Type: Logistic Regression
Accuracy: 1.0
Precision: 1.0
Recall: 1.0
F1 Score: 1.0
Save your Logistic Regression Model:
Model model = lr.toModel();
model.save("my_directory");
Load from file:
LogisticRegressionModel model = null;
try (ModelLoader loader = new ModelLoader(new File("my_directory"))) {
model = loader.load(LogisticRegressionModel.class);
}
Load from URL:
LogisticRegressionModel model = null;
try (ModelLoader loader = new ModelLoader(new URL("my_model_url"))) {
model = loader.load(LogisticRegressionModel.class);
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}