Skip to content

Commit

Permalink
Robust stop criteria for fitting Platt transform (#277)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiong-zhang authored Jan 31, 2024
1 parent 22f01be commit 93c6c84
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 16 deletions.
35 changes: 31 additions & 4 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,23 +2050,46 @@ def link_calibrator_methods(self):
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f32,
c_uint32,
[c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)],
[
c_uint64,
POINTER(c_float),
POINTER(c_float),
POINTER(c_double),
c_uint64, # max_iter
c_double, # eps
],
)
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f64,
c_uint32,
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
[
c_uint64,
POINTER(c_double),
POINTER(c_double),
POINTER(c_double),
c_uint64, # max_iter
c_double, # eps
],
)

def fit_platt_transform(self, logits, targets, clip_tgt_prob=True):
def fit_platt_transform(
self,
logits,
targets,
max_iter=100,
eps=1e-5,
clip_tgt_prob=True,
):
"""Python to C/C++ interface for platt transfrom fit.
Ref: https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf
Args:
logits (ndarray): 1-d array of logit with length N.
targets (ndarray): 1-d array of target probability scores within [0, 1] with length N.
clip_tgt_prob (bool): whether to clip the target probability to
max_iter (int, optional): max number of iterations to train. Default 100
eps (float, optional): epsilon. Defaults to 1e-5
clip_tgt_prob (bool, optional): whether to clip the target probability to
[1/(prior0 + 2), 1 - 1/(prior1 + 2)]
where prior1 = sum(targets), prior0 = N - prior1
Returns:
Expand Down Expand Up @@ -2097,13 +2120,17 @@ def fit_platt_transform(self, logits, targets, clip_tgt_prob=True):
logits.ctypes.data_as(POINTER(c_float)),
tgt_prob.ctypes.data_as(POINTER(c_float)),
AB.ctypes.data_as(POINTER(c_double)),
max_iter,
eps,
)
elif tgt_prob.dtype == np.float64:
return_code = clib.clib_float32.c_fit_platt_transform_f64(
len(logits),
logits.ctypes.data_as(POINTER(c_double)),
tgt_prob.ctypes.data_as(POINTER(c_double)),
AB.ctypes.data_as(POINTER(c_double)),
max_iter,
eps,
)
else:
raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}")
Expand Down
6 changes: 4 additions & 2 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,11 @@ extern "C" {
size_t num_samples, \
const VAL_TYPE* logits, \
const VAL_TYPE* tgt_probs, \
double* AB \
double* AB, \
size_t max_iter, \
double eps \
) { \
return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1], max_iter, eps); \
}
C_FIT_PLATT_TRANSFORM(_f32, float32_t)
C_FIT_PLATT_TRANSFORM(_f64, float64_t)
Expand Down
22 changes: 12 additions & 10 deletions pecos/core/utils/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ namespace pecos {
// https://github.com/cjlin1/libsvm/blob/master/svm.cpp

template <typename value_type>
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B, size_t max_iter, double eps) {
// define the return code
enum {
SUCCESS=0,
Expand All @@ -288,10 +288,8 @@ namespace pecos {
};

// hyper parameters
int max_iter = 100; // Maximal number of iterations
double min_step = 1e-10; // Minimal step taken in line search
double sigma = 1e-12; // For numerically strict PD of Hessian
double eps = 1e-5;

// calculate prior of B
double prior1 = 0;
Expand All @@ -300,7 +298,6 @@ namespace pecos {
}
double prior0 = double(num_samples) - prior1;


// Initial Point and Initial Fun Value
A = 0.0; B = log((prior0 + 1.0) / (prior1 + 1.0));
double fval = 0.0;
Expand All @@ -313,7 +310,7 @@ namespace pecos {
fval += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB));
}
}
int iter;
size_t iter = 0;
for (iter = 0; iter < max_iter; iter++) {
// Update Gradient and Hessian (use H' = H + sigma I)
double h11 = sigma;
Expand Down Expand Up @@ -342,16 +339,22 @@ namespace pecos {
g2 += d1;
}

// Stopping Criteria
if (fabs(g1) < eps && fabs(g2) < eps)
break;

// Finding Newton direction: -inv(H') * g
double det = h11 * h22 - h21 * h21;
double dA = -(h22 * g1 - h21 * g2) / det;
double dB = -(-h21 * g1 + h11 * g2) / det;
double gd = g1 * dA + g2 * dB;

// Stopping Criteria
if (fabs(g1) < eps && fabs(g2) < eps) {
break;
}
// additional stop criteria to handle the case when det is large
if (fabs(dA) < eps && fabs(dB) < eps) {
break;
}

// Line Search
double stepsize = 1.0;

Expand All @@ -370,8 +373,7 @@ namespace pecos {
}
}
// Check sufficient decrease
if (newf < fval + 0.0001 * stepsize * gd)
{
if (newf < fval + 0.0001 * stepsize * gd) {
A = newA;
B = newB;
fval = newf;
Expand Down

0 comments on commit 93c6c84

Please sign in to comment.