Skip to content

Commit

Permalink
basic implementation of probability estimates
Browse files Browse the repository at this point in the history
  • Loading branch information
stropitek committed Jul 26, 2017
1 parent cbd9546 commit a152bb3
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ CXX = em++

CFLAGS = -Wall -Wconversion -O3 -fPIC --memory-init-file 0
BUILD_DIR=dist
EXPORTED_FUNCTIONS="['_parse_command_line', '_create_svm_nodes', '_add_instance', '_libsvm_train_problem', '_libsvm_train', '_libsvm_predict_one', '_get_svr_epsilon', '_svm_free_model', '_svm_get_svm_type', '_svm_get_nr_sv', '_svm_get_nr_class', '_svm_get_sv_indices', '_svm_get_labels', '_libsvm_cross_validation', '_free_problem', '_serialize_model', '_deserialize_model']"
EXPORTED_FUNCTIONS="['_parse_command_line', '_create_svm_nodes', '_add_instance', '_libsvm_train_problem', '_libsvm_train', '_libsvm_predict_one', '_libsvm_predict_one_probability', '_get_svr_epsilon', '_svm_free_model', '_svm_get_svm_type', '_svm_get_nr_sv', '_svm_get_nr_class', '_svm_get_sv_indices', '_svm_get_labels', '_libsvm_cross_validation', '_free_problem', '_serialize_model', '_deserialize_model']"

all: wasm asm

Expand Down
41 changes: 41 additions & 0 deletions examples/probabilities.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
const Kernel = new require('ml-kernel');
const SVM = require('../asm');
const range = require('lodash.range');

'use strict';

const gamma = 0.2;
const cost = 1;

function exec(SVM, precomputed) {
const data = require('ml-dataset-iris');
var trainData;

const features = data.getNumbers();
let labels = data.getClasses();
const classes = data.getDistinctClasses();
const c = {};
classes.forEach((v, idx) => c[v] = idx);
labels = labels.map(l => c[l]);


if (precomputed) {
const kernel = new Kernel('gaussian', {sigma: 1 / Math.sqrt(gamma)});
trainData = kernel.compute(features).addColumn(0, range(1, labels.length + 1));
} else {
trainData = features;
}

const svm = new SVM({
quiet: true,
cost: cost,
kernel: precomputed ? SVM.KERNEL_TYPES.PRECOMPUTED : SVM.KERNEL_TYPES.RBF,
gamma,
probabilityEstimates: true
});
svm.train(trainData, labels);
var pred = svm.predictProbability(trainData);
console.log(JSON.stringify(pred, null, 2));
}

exec(SVM, true);
13 changes: 12 additions & 1 deletion js-interfaces.c
Original file line number Diff line number Diff line change
Expand Up @@ -194,18 +194,29 @@ void free_problem(struct svm_problem* prob) {
free(prob);
}

double libsvm_predict_one(struct svm_model* model, double* data, int size) {
struct svm_node* init_node(double* data, int size) {
struct svm_node* node = Malloc(struct svm_node, size + 1);
for(int i=0; i<size; i++) {
node[i].index = i + 1;
node[i].value = data[i];
}
node[size].index = -1;
return node;
}

double libsvm_predict_one(struct svm_model* model, double* data, int size) {
struct svm_node* node = init_node(data, size);
double pred = svm_predict(model, node);
free(node);
return pred;
}

double libsvm_predict_one_probability(struct svm_model* model, double* data, int size, double* prob_estimates) {
struct svm_node* node = init_node(data, size);
double pred = svm_predict_probability(model, node, prob_estimates);
return pred;
}

struct svm_model* libsvm_train(double *data, double *labels, int nb_features, int nb_dimensions, const char* command) {
struct svm_problem* prob = create_svm_nodes(nb_features, nb_dimensions);
for(int i=0; i<nb_features; i++) {
Expand Down
37 changes: 37 additions & 0 deletions src/loadSVM.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module.exports = function (libsvm) {

/* eslint-disable camelcase */
const predict_one = libsvm.cwrap('libsvm_predict_one', 'number', ['number', 'array', 'number']);
const predict_one_probability = libsvm.cwrap('libsvm_predict_one_probability', 'number', ['number', 'array', 'number', 'number']);
const add_instance = libsvm.cwrap('add_instance', null, ['number', 'array', 'number', 'number', 'number']);
const create_svm_nodes = libsvm.cwrap('create_svm_nodes', 'number', ['number', 'number']);
const train_problem = libsvm.cwrap('libsvm_train_problem', 'number', ['number', 'string']);
Expand Down Expand Up @@ -131,6 +132,42 @@ module.exports = function (libsvm) {
return arr;
}

/**
* Predict the label with probability estimate of many samples.
* @param {Array<Array<number>>} samples - The samples to predict.
* @return {Array<object>} - An array of objects containing the prediction label and the probability estimates for each label
*/
predictProbability(samples) {
let arr = [];
for (let i = 0; i < samples.length; i++) {
arr.push(this.predictOneProbability(samples[i]));
}
return arr;
}

/** Predict the label with probability estimate.
* @param {Array<number>} sample
* @return {object} - An object containing the prediction label and the probability estimates for each label
*/

predictOneProbability(sample) {
const labels = this.getLabels();
const nbLabels = labels.length;
const estimates = libsvm._malloc(nbLabels * 8);
const prediction = predict_one_probability(this.model, new Uint8Array(new Float64Array(sample).buffer), sample.length, estimates);
const estimatesArr = Array.from(libsvm.HEAPF64.subarray(estimates / 8, estimates / 8 + nbLabels));
const result = {
prediction,
estimates: labels.map((label, idx) => ({
label,
probability: estimatesArr[idx]
}))
};
libsvm._free(estimates);
return result;
}


/**
* Get the array of labels from the model. Useful when creating an SVM instance with SVM.load
* @return {Array<number>} - The list of labels.
Expand Down

0 comments on commit a152bb3

Please sign in to comment.