-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
96dcdbd
commit 951e5d8
Showing
3 changed files
with
99 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
/** | ||
* This offers loss functions for a model. Only one such loss function exists | ||
* right now, but we may need more in the future. | ||
* | ||
* Each loss function needs an actual calculation (put in a jml::Vector, get out | ||
* a double), and a derivative with respect to a particular entry. That is, if | ||
* we change the i-th entry of the input vector, how much the loss will change | ||
* by. This will be very important later for gradient descent. | ||
* | ||
* Author: Adam Hutchings | ||
* Date: 10-17-23 | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <functional> | ||
|
||
#include <jml/jmldefs.h> | ||
#include <jml/math/vector.hpp> | ||
|
||
namespace jml { | ||
|
||
typedef std::function<double(const jml::Vector&, const jml::Vector&)> LF; | ||
typedef std::function<double(const jml::Vector&, const jml::Vector&, int i)> DL; | ||
|
||
class JML_API LossFunction { | ||
|
||
private: | ||
LF loss; | ||
DL dl; | ||
|
||
public: | ||
// Create a loss function with a given loss and derivative. | ||
LossFunction(const LF& lf, const DL& dl); | ||
double get_loss(const jml::Vector& actual, const jml::Vector& expected); | ||
double get_loss_derivative( | ||
const jml::Vector& actual, const jml::Vector& expected, int index | ||
); | ||
|
||
}; | ||
|
||
// Also, we provide an example set of loss functions (L^2 norm) | ||
extern LF l2lf; | ||
extern DL l2dl; | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#include <jml/math/loss_functions.hpp> | ||
|
||
#include <cmath> | ||
|
||
#include <jml/internal/logger.hpp> | ||
|
||
namespace jml { | ||
|
||
LossFunction::LossFunction(const LF& lf, const DL& dl) { | ||
this->loss = lf; | ||
this->dl = dl; | ||
} | ||
|
||
double LossFunction::get_loss( | ||
const Vector& actual, const Vector& expected | ||
) { | ||
return this->loss(actual, expected); | ||
} | ||
|
||
double LossFunction::get_loss_derivative( | ||
const Vector& actual, const Vector& expected, int index | ||
) { | ||
return this->dl(actual, expected, index); | ||
} | ||
|
||
LF l2lf = [](const Vector& actual, const Vector& expected) { | ||
int a = actual.get_size(), e = expected.get_size(); | ||
if (a != e) { | ||
LOGGER->log(Log(WARN) | ||
<< "Tried to compare a vector of length " << a | ||
<< "to a vector of length " << e << ".\n"); | ||
} | ||
double total = 0; | ||
double diff; | ||
for (int i = 0; i < a; ++i) { | ||
diff = actual.get_entry(i) - expected.get_entry(i); | ||
diff *= diff; | ||
total += diff; | ||
} | ||
return sqrt(total); | ||
}; | ||
|
||
DL l2dl = [](const Vector& actual, const Vector& expected, int i) { | ||
double l = l2lf(actual, expected); | ||
double ret = 1.0 / (2 * l); | ||
ret *= 2 * (actual.get_entry(i) - expected.get_entry(i)); | ||
return ret; | ||
}; | ||
|
||
} |