Skip to content

Commit

Permalink
Add basic loss function
Browse files Browse the repository at this point in the history
  • Loading branch information
adamhutchings committed Oct 17, 2023
1 parent 96dcdbd commit 951e5d8
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 1 deletion.
46 changes: 46 additions & 0 deletions core/include/jml/math/loss_functions.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/**
* This offers loss functions for a model. Only one such loss function exists
* right now, but we may need more in the future.
*
* Each loss function needs an actual calculation (put in a jml::Vector, get out
* a double), and a derivative with respect to a particular entry. That is, if
* we change the i-th entry of the input vector, how much the loss will change
* by. This will be very important later for gradient descent.
*
* Author: Adam Hutchings
* Date: 10-17-23
*/

#pragma once

#include <functional>

#include <jml/jmldefs.h>
#include <jml/math/vector.hpp>

namespace jml {

typedef std::function<double(const jml::Vector&, const jml::Vector&)> LF;
typedef std::function<double(const jml::Vector&, const jml::Vector&, int i)> DL;

class JML_API LossFunction {

private:
LF loss;
DL dl;

public:
// Create a loss function with a given loss and derivative.
LossFunction(const LF& lf, const DL& dl);
double get_loss(const jml::Vector& actual, const jml::Vector& expected);
double get_loss_derivative(
const jml::Vector& actual, const jml::Vector& expected, int index
);

};

// Also, we provide an example set of loss functions (L^2 norm)
extern LF l2lf;
extern DL l2dl;

}
4 changes: 3 additions & 1 deletion core/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ endif

# TODO: pkg-config stuff

core_sources = ['src/math/matrix.cpp', 'src/math/vector.cpp', 'src/math/activation_functions.cpp', 'src/model.cpp', 'src/logger.cpp', 'capi/cjml.cpp']
core_sources = ['src/math/matrix.cpp', 'src/math/vector.cpp',
'src/math/activation_functions.cpp', 'src/model.cpp', 'src/logger.cpp',
'capi/cjml.cpp', 'src/math/loss_functions.cpp']
jmlcore = library('jmlcore',
core_sources,
include_directories: core_inc,
Expand Down
50 changes: 50 additions & 0 deletions core/src/math/loss_functions.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <jml/math/loss_functions.hpp>

#include <cmath>

#include <jml/internal/logger.hpp>

namespace jml {

LossFunction::LossFunction(const LF& lf, const DL& dl) {
this->loss = lf;
this->dl = dl;
}

double LossFunction::get_loss(
const Vector& actual, const Vector& expected
) {
return this->loss(actual, expected);
}

double LossFunction::get_loss_derivative(
const Vector& actual, const Vector& expected, int index
) {
return this->dl(actual, expected, index);
}

LF l2lf = [](const Vector& actual, const Vector& expected) {
int a = actual.get_size(), e = expected.get_size();
if (a != e) {
LOGGER->log(Log(WARN)
<< "Tried to compare a vector of length " << a
<< "to a vector of length " << e << ".\n");
}
double total = 0;
double diff;
for (int i = 0; i < a; ++i) {
diff = actual.get_entry(i) - expected.get_entry(i);
diff *= diff;
total += diff;
}
return sqrt(total);
};

DL l2dl = [](const Vector& actual, const Vector& expected, int i) {
double l = l2lf(actual, expected);
double ret = 1.0 / (2 * l);
ret *= 2 * (actual.get_entry(i) - expected.get_entry(i));
return ret;
};

}

0 comments on commit 951e5d8

Please sign in to comment.