Skip to content
This repository has been archived by the owner on Mar 19, 2024. It is now read-only.

Commit

Permalink
precision/recall metrics
Browse files Browse the repository at this point in the history
Summary:
This commit adds precision/recall curve to the metrics api.
The `Meter` object is now exposed in python.

The precision/recall curve helps to decide the best threshold.
It can be retrieved from the model object as follows:

```python
ft = fasttext.load_model(model_file)
meter = ft.get_meter(test_file)

label = "__label__bakery"
y_scores, y_true = meter.score_vs_true(label)
precision, recall = meter.precision_recall_curve(label)
```

Reviewed By: EdouardGrave

Differential Revision: D19218524

fbshipit-source-id: 41a7c8e1aa991d076df04c5e497688daf0de4673
  • Loading branch information
Celebio authored and facebook-github-bot committed Apr 27, 2020
1 parent 6d7c77c commit 2cc7f54
Show file tree
Hide file tree
Showing 15 changed files with 421 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ meter.o: src/meter.cc src/meter.h
fasttext.o: src/fasttext.cc src/*.h
$(CXX) $(CXXFLAGS) -c src/fasttext.cc

fasttext: $(OBJS) src/fasttext.cc
fasttext: $(OBJS) src/fasttext.cc src/main.cc
$(CXX) $(CXXFLAGS) $(OBJS) src/main.cc -o fasttext

clean:
Expand Down
18 changes: 18 additions & 0 deletions docs/autotune.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,21 @@ This is equivalent to manually optimize the f1-score we get when we test with `m
Sometimes, you may be interested in predicting more than one label. For example, if you were optimizing the hyperparameters manually to get the best score to predict two labels, you would test with `model.test("cooking.valid", k=2)`. You can also tell autotune to optimize the parameters by testing two labels with the `autotunePredictions` argument.
<!--END_DOCUSAURUS_CODE_TABS-->

You can also force autotune to optimize for the best precision for a given recall, or the best recall for a given precision, for all labels, or for a specific label:

For example, in order to get the best precision at recall = `30%`:
```sh
>> ./fasttext supervised [...] -autotune-metric precisionAtRecall:30
```
And to get the best precision at recall = `30%` for the label `__label__baking`:
```sh
>> ./fasttext supervised [...] -autotune-metric precisionAtRecall:30:__label__baking
```

Similarly, you can use `recallAtPrecision`:
```sh
>> ./fasttext supervised [...] -autotune-metric recallAtPrecision:30
>> ./fasttext supervised [...] -autotune-metric recallAtPrecision:30:__label__baking
```


67 changes: 67 additions & 0 deletions python/fasttext_module/fasttext/FastText.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,66 @@
BOW = "<"
EOW = ">"

displayed_errors = {}


def eprint(*args, **kwargs):
print(*args, file=sys.stderr, **kwargs)


class _Meter(object):
def __init__(self, fasttext_model, meter):
self.f = fasttext_model
self.m = meter

def score_vs_true(self, label):
"""Return scores and the gold of each sample for a specific label"""
label_id = self.f.get_label_id(label)
pair_list = self.m.scoreVsTrue(label_id)

if pair_list:
y_scores, y_true = zip(*pair_list)
else:
y_scores, y_true = ([], ())

return np.array(y_scores, copy=False), np.array(y_true, copy=False)

def precision_recall_curve(self, label=None):
"""Return precision/recall curve"""
if label:
label_id = self.f.get_label_id(label)
pair_list = self.m.precisionRecallCurveLabel(label_id)
else:
pair_list = self.m.precisionRecallCurve()

if pair_list:
precision, recall = zip(*pair_list)
else:
precision, recall = ([], ())

return np.array(precision, copy=False), np.array(recall, copy=False)

def precision_at_recall(self, recall, label=None):
"""Return precision for a given recall"""
if label:
label_id = self.f.get_label_id(label)
precision = self.m.precisionAtRecallLabel(label_id, recall)
else:
precision = self.m.precisionAtRecall(recall)

return precision

def recall_at_precision(self, precision, label=None):
"""Return recall for a given precision"""
if label:
label_id = self.f.get_label_id(label)
recall = self.m.recallAtPrecisionLabel(label_id, precision)
else:
recall = self.m.recallAtPrecision(precision)

return recall


class _FastText(object):
"""
This class defines the API to inspect models and should not be used to
Expand Down Expand Up @@ -100,6 +155,13 @@ def get_word_id(self, word):
"""
return self.f.getWordId(word)

def get_label_id(self, label):
"""
Given a label, get the label id within the dictionary.
Returns -1 if label is not in the dictionary.
"""
return self.f.getLabelId(label)

def get_subword_id(self, subword):
"""
Given a subword, return the index (within input matrix) it hashes to.
Expand Down Expand Up @@ -258,6 +320,11 @@ def test_label(self, path, k=1, threshold=0.0):
"""
return self.f.testLabel(path, k, threshold)

def get_meter(self, path, k=-1):
meter = _Meter(self, self.f.getMeter(path, k))

return meter

def quantize(
self,
input=None,
Expand Down
70 changes: 62 additions & 8 deletions python/fasttext_module/fasttext/pybind/fasttext_pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,15 @@ PYBIND11_MODULE(fasttext_pybind, m) {

py::enum_<fasttext::metric_name>(m, "metric_name")
.value("f1score", fasttext::metric_name::f1score)
.value("labelf1score", fasttext::metric_name::labelf1score)
.value("f1scoreLabel", fasttext::metric_name::f1scoreLabel)
.value("precisionAtRecall", fasttext::metric_name::precisionAtRecall)
.value(
"precisionAtRecallLabel",
fasttext::metric_name::precisionAtRecallLabel)
.value("recallAtPrecision", fasttext::metric_name::recallAtPrecision)
.value(
"recallAtPrecisionLabel",
fasttext::metric_name::recallAtPrecisionLabel)
.export_values();

m.def(
Expand Down Expand Up @@ -186,6 +194,34 @@ PYBIND11_MODULE(fasttext_pybind, m) {
sizeof(fasttext::real) * (int64_t)1});
});

py::class_<fasttext::Meter>(m, "Meter")
.def(py::init<bool>())
.def("scoreVsTrue", &fasttext::Meter::scoreVsTrue)
.def(
"precisionRecallCurveLabel",
py::overload_cast<int32_t>(
&fasttext::Meter::precisionRecallCurve, py::const_))
.def(
"precisionRecallCurve",
py::overload_cast<>(
&fasttext::Meter::precisionRecallCurve, py::const_))
.def(
"precisionAtRecallLabel",
py::overload_cast<int32_t, double>(
&fasttext::Meter::precisionAtRecall, py::const_))
.def(
"precisionAtRecall",
py::overload_cast<double>(
&fasttext::Meter::precisionAtRecall, py::const_))
.def(
"recallAtPrecisionLabel",
py::overload_cast<int32_t, double>(
&fasttext::Meter::recallAtPrecision, py::const_))
.def(
"recallAtPrecision",
py::overload_cast<double>(
&fasttext::Meter::recallAtPrecision, py::const_));

py::class_<fasttext::FastText>(m, "fasttext")
.def(py::init<>())
.def("getArgs", &fasttext::FastText::getArgs)
Expand Down Expand Up @@ -231,20 +267,33 @@ PYBIND11_MODULE(fasttext_pybind, m) {
[](fasttext::FastText& m, std::string s) { m.saveModel(s); })
.def(
"test",
[](fasttext::FastText& m,
const std::string filename,
int32_t k,
fasttext::real threshold) {
[](fasttext::FastText& m,
const std::string& filename,
int32_t k,
fasttext::real threshold) {
std::ifstream ifs(filename);
if (!ifs.is_open()) {
throw std::invalid_argument("Test file cannot be opened!");
}
fasttext::Meter meter;
fasttext::Meter meter(false);
m.test(ifs, k, threshold, meter);
ifs.close();
return std::tuple<int64_t, double, double>(
meter.nexamples(), meter.precision(), meter.recall());
})
.def(
"getMeter",
[](fasttext::FastText& m, const std::string& filename, int32_t k) {
std::ifstream ifs(filename);
if (!ifs.is_open()) {
throw std::invalid_argument("Test file cannot be opened!");
}
fasttext::Meter meter(true);
m.test(ifs, k, 0.0, meter);
ifs.close();

return meter;
})
.def(
"getSentenceVector",
[](fasttext::FastText& m,
Expand Down Expand Up @@ -397,7 +446,7 @@ PYBIND11_MODULE(fasttext_pybind, m) {
if (!ifs.is_open()) {
throw std::invalid_argument("Test file cannot be opened!");
}
fasttext::Meter meter;
fasttext::Meter meter(false);
m.test(ifs, k, threshold, meter);
std::shared_ptr<const fasttext::Dictionary> d = m.getDictionary();
std::unordered_map<std::string, py::dict> returnedValue;
Expand All @@ -412,14 +461,19 @@ PYBIND11_MODULE(fasttext_pybind, m) {
})
.def(
"getWordId",
[](fasttext::FastText& m, const std::string word) {
[](fasttext::FastText& m, const std::string& word) {
return m.getWordId(word);
})
.def(
"getSubwordId",
[](fasttext::FastText& m, const std::string word) {
return m.getSubwordId(word);
})
.def(
"getLabelId",
[](fasttext::FastText& m, const std::string& label) {
return m.getLabelId(label);
})
.def(
"getInputVector",
[](fasttext::FastText& m, fasttext::Vector& vec, int32_t ind) {
Expand Down
64 changes: 55 additions & 9 deletions src/args.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <iostream>
#include <stdexcept>
#include <string>
#include <unordered_map>

namespace fasttext {
Expand Down Expand Up @@ -90,8 +91,16 @@ std::string Args::metricToString(metric_name mn) const {
switch (mn) {
case metric_name::f1score:
return "f1score";
case metric_name::labelf1score:
return "labelf1score";
case metric_name::f1scoreLabel:
return "f1scoreLabel";
case metric_name::precisionAtRecall:
return "precisionAtRecall";
case metric_name::precisionAtRecallLabel:
return "precisionAtRecallLabel";
case metric_name::recallAtPrecision:
return "recallAtPrecision";
case metric_name::recallAtPrecisionLabel:
return "recallAtPrecisionLabel";
}
return "Unknown metric name!"; // should never happen
}
Expand Down Expand Up @@ -388,22 +397,59 @@ void Args::setManual(const std::string& argName) {

metric_name Args::getAutotuneMetric() const {
if (autotuneMetric.substr(0, 3) == "f1:") {
return metric_name::labelf1score;
return metric_name::f1scoreLabel;
} else if (autotuneMetric == "f1") {
return metric_name::f1score;
} else if (autotuneMetric.substr(0, 18) == "precisionAtRecall:") {
size_t semicolon = autotuneMetric.find(":", 18);
if (semicolon != std::string::npos) {
return metric_name::precisionAtRecallLabel;
}
return metric_name::precisionAtRecall;
} else if (autotuneMetric.substr(0, 18) == "recallAtPrecision:") {
size_t semicolon = autotuneMetric.find(":", 18);
if (semicolon != std::string::npos) {
return metric_name::recallAtPrecisionLabel;
}
return metric_name::recallAtPrecision;
}
throw std::runtime_error("Unknown metric : " + autotuneMetric);
}

std::string Args::getAutotuneMetricLabel() const {
if (getAutotuneMetric() == metric_name::labelf1score) {
std::string label = autotuneMetric.substr(3);
if (label.empty()) {
throw std::runtime_error("Empty metric label : " + autotuneMetric);
}
metric_name metric = getAutotuneMetric();
std::string label;
if (metric == metric_name::f1scoreLabel) {
label = autotuneMetric.substr(3);
} else if (
metric == metric_name::precisionAtRecallLabel ||
metric == metric_name::recallAtPrecisionLabel) {
size_t semicolon = autotuneMetric.find(":", 18);
label = autotuneMetric.substr(semicolon + 1);
} else {
return label;
}
return std::string();

if (label.empty()) {
throw std::runtime_error("Empty metric label : " + autotuneMetric);
}
return label;
}

double Args::getAutotuneMetricValue() const {
metric_name metric = getAutotuneMetric();
double value = 0.0;
if (metric == metric_name::precisionAtRecallLabel ||
metric == metric_name::precisionAtRecall ||
metric == metric_name::recallAtPrecisionLabel ||
metric == metric_name::recallAtPrecision) {
size_t firstSemicolon = 18; // semicolon position in "precisionAtRecall:"
size_t secondSemicolon = autotuneMetric.find(":", firstSemicolon);
const std::string valueStr =
autotuneMetric.substr(firstSemicolon, secondSemicolon - firstSemicolon);
value = std::stof(valueStr) / 100.0;
}
return value;
}

int64_t Args::getAutotuneModelSize() const {
Expand Down
10 changes: 9 additions & 1 deletion src/args.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ namespace fasttext {

enum class model_name : int { cbow = 1, sg, sup };
enum class loss_name : int { hs = 1, ns, softmax, ova };
enum class metric_name : int { f1score = 1, labelf1score };
enum class metric_name : int {
f1score = 1,
f1scoreLabel,
precisionAtRecall,
precisionAtRecallLabel,
recallAtPrecision,
recallAtPrecisionLabel
};

class Args {
protected:
Expand Down Expand Up @@ -81,6 +88,7 @@ class Args {
std::string lossToString(loss_name) const;
metric_name getAutotuneMetric() const;
std::string getAutotuneMetricLabel() const;
double getAutotuneMetricValue() const;
int64_t getAutotuneModelSize() const;

static constexpr double kUnlimitedModelSize = -1.0;
Expand Down
Loading

0 comments on commit 2cc7f54

Please sign in to comment.