From a20c8d5c14c81c24b87949f912a4af5f462c8afa Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Thu, 7 Mar 2024 09:30:19 +0000 Subject: [PATCH] refactor: add scalar evaluation for SurrogatesSVM --- lib/SurrogatesSVM/src/SurrogatesSVM.jl | 6 +++++- lib/SurrogatesSVM/test/runtests.jl | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/lib/SurrogatesSVM/src/SurrogatesSVM.jl b/lib/SurrogatesSVM/src/SurrogatesSVM.jl index d69b365e..a8320c8a 100644 --- a/lib/SurrogatesSVM/src/SurrogatesSVM.jl +++ b/lib/SurrogatesSVM/src/SurrogatesSVM.jl @@ -40,6 +40,10 @@ function SVMSurrogate(x, y, lb, ub) SVMSurrogate(x, y, model, lb, ub) end +function (svmsurr::SVMSurrogate)(val::Number) + return svmsurr([val]) +end + function (svmsurr::SVMSurrogate)(val) n = length(val) return LIBSVM.predict(svmsurr.model, reshape(val, 1, n))[1] @@ -54,7 +58,7 @@ end - `x_new`: Vector of new data points to be added to the training set of SVMSurrogate. - `y_new`: Vector of new output points to be added to the training set of SVMSurrogate. """ -function update!(svmsurr::SVMSurrogate, x_new, y_new) +function SurrogatesBase.update!(svmsurr::SVMSurrogate, x_new, y_new) svmsurr.x = vcat(svmsurr.x, x_new) svmsurr.y = vcat(svmsurr.y, y_new) if length(svmsurr.lb) == 1 diff --git a/lib/SurrogatesSVM/test/runtests.jl b/lib/SurrogatesSVM/test/runtests.jl index a15f1027..00240227 100644 --- a/lib/SurrogatesSVM/test/runtests.jl +++ b/lib/SurrogatesSVM/test/runtests.jl @@ -18,7 +18,7 @@ using SafeTestsets update!(my_svm_1D, [3.1], [7.2]) update!(my_svm_1D, [3.2, 3.5], [7.4, 8.0]) svm = LIBSVM.fit!(SVC(), reshape(my_svm_1D.x, length(my_svm_1D.x), 1), my_svm_1D.y) - val = my_svm_1D([3.1]) + val = my_svm_1D(3.1) @test LIBSVM.predict(svm, [3.1;;])[1] == val end @testset "ND" begin