From 0c480965435916ec0c9975ed42271d33fdc5febd Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Fri, 1 Mar 2024 10:31:16 +0000 Subject: [PATCH] refactor: use SurrogatesBase and clean up code --- lib/SurrogatesSVM/src/SurrogatesSVM.jl | 87 +++++++++++--------------- 1 file changed, 38 insertions(+), 49 deletions(-) diff --git a/lib/SurrogatesSVM/src/SurrogatesSVM.jl b/lib/SurrogatesSVM/src/SurrogatesSVM.jl index 6cc23f38..d69b365e 100644 --- a/lib/SurrogatesSVM/src/SurrogatesSVM.jl +++ b/lib/SurrogatesSVM/src/SurrogatesSVM.jl @@ -1,11 +1,11 @@ module SurrogatesSVM -import Surrogates: AbstractSurrogate, add_point! -export SVMSurrogate - +using SurrogatesBase using LIBSVM -mutable struct SVMSurrogate{X, Y, M, L, U} <: AbstractSurrogate +export SVMSurrogate, update! + +mutable struct SVMSurrogate{X, Y, M, L, U} <: AbstractDeterministicSurrogate x::X y::Y model::M @@ -13,25 +13,28 @@ mutable struct SVMSurrogate{X, Y, M, L, U} <: AbstractSurrogate ub::U end -function SVMSurrogate(x, y, lb::Number, ub::Number) - xn = reshape(x, length(x), 1) - model = LIBSVM.fit!(SVC(), xn, y) - SVMSurrogate(xn, y, model, lb, ub) -end +""" + SVMSurrogate(x, y, lb, ub) -function (svmsurr::SVMSurrogate)(val::Number) - return LIBSVM.predict(svmsurr.model, [val]) -end +Builds a SVM Surrogate using [LIBSVM](https://github.com/JuliaML/LIBSVM.jl). -""" -SVMSurrogate(x,y,lb,ub) +## Arguments -Builds SVM surrogate. + - `x`: Input data points. + - `y`: Output data points. + - `lb`: Lower bound of input data points. + - `ub`: Upper bound of output data points. """ function SVMSurrogate(x, y, lb, ub) - X = Array{Float64, 2}(undef, length(x), length(x[1])) - for j in 1:length(x) - X[j, :] = vec(collect(x[j])) + X = Array{Float64, 2}(undef, length(x), length(first(x))) + if length(lb) == 1 + for j in eachindex(x) + X[j, 1] = x[j] + end + else + for j in eachindex(x) + X[j, :] = x[j] + end end model = LIBSVM.fit!(SVC(), X, y) SVMSurrogate(x, y, model, lb, ub) @@ -39,41 +42,27 @@ end function (svmsurr::SVMSurrogate)(val) n = length(val) - return LIBSVM.predict(svmsurr.model, reshape(collect(val), 1, n))[1] + return LIBSVM.predict(svmsurr.model, reshape(val, 1, n))[1] end -function add_point!(svmsurr::SVMSurrogate, x_new, y_new) +""" + update!(svmsurr::SVMSurrogate, x_new, y_new) + +## Arguments + + - `svmsurr`: Surrogate of type [`SVMSurrogate`](@ref). + - `x_new`: Vector of new data points to be added to the training set of SVMSurrogate. + - `y_new`: Vector of new output points to be added to the training set of SVMSurrogate. +""" +function update!(svmsurr::SVMSurrogate, x_new, y_new) + svmsurr.x = vcat(svmsurr.x, x_new) + svmsurr.y = vcat(svmsurr.y, y_new) if length(svmsurr.lb) == 1 - #1D - svmsurr.x = vcat(svmsurr.x, x_new) - svmsurr.y = vcat(svmsurr.y, y_new) - svmsurr.model = LIBSVM.fit!(SVC(), reshape(svmsurr.x, length(svmsurr.x), 1), - svmsurr.y) + svmsurr.model = LIBSVM.fit!( + SVC(), reshape(svmsurr.x, length(svmsurr.x), 1), svmsurr.y) else - n_previous = length(svmsurr.x) - a = vcat(svmsurr.x, x_new) - n_after = length(a) - dim_new = n_after - n_previous - n = length(svmsurr.x) - d = length(svmsurr.x[1]) - tot_dim = n + dim_new - X = Array{Float64, 2}(undef, tot_dim, d) - for j in 1:n - X[j, :] = vec(collect(svmsurr.x[j])) - end - if dim_new == 1 - X[n + 1, :] = vec(collect(x_new)) - else - i = 1 - for j in (n + 1):tot_dim - X[j, :] = vec(collect(x_new[i])) - i = i + 1 - end - end - svmsurr.x = vcat(svmsurr.x, x_new) - svmsurr.y = vcat(svmsurr.y, y_new) - svmsurr.model = LIBSVM.fit!(SVC(), X, svmsurr.y) + svmsurr.model = LIBSVM.fit!(SVC(), transpose(reduce(hcat, svmsurr.x)), svmsurr.y) end - nothing end + end # module