Skip to content

Commit

Permalink
refactor: use SurrogatesBase and clean up code
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Mar 1, 2024
1 parent 60c2d6f commit 0c48096
Showing 1 changed file with 38 additions and 49 deletions.
87 changes: 38 additions & 49 deletions lib/SurrogatesSVM/src/SurrogatesSVM.jl
Original file line number Diff line number Diff line change
@@ -1,79 +1,68 @@
module SurrogatesSVM

import Surrogates: AbstractSurrogate, add_point!
export SVMSurrogate

using SurrogatesBase
using LIBSVM

mutable struct SVMSurrogate{X, Y, M, L, U} <: AbstractSurrogate
export SVMSurrogate, update!

mutable struct SVMSurrogate{X, Y, M, L, U} <: AbstractDeterministicSurrogate
x::X
y::Y
model::M
lb::L
ub::U
end

function SVMSurrogate(x, y, lb::Number, ub::Number)
xn = reshape(x, length(x), 1)
model = LIBSVM.fit!(SVC(), xn, y)
SVMSurrogate(xn, y, model, lb, ub)
end
"""
SVMSurrogate(x, y, lb, ub)
function (svmsurr::SVMSurrogate)(val::Number)
return LIBSVM.predict(svmsurr.model, [val])
end
Builds a SVM Surrogate using [LIBSVM](https://github.com/JuliaML/LIBSVM.jl).
"""
SVMSurrogate(x,y,lb,ub)
## Arguments
Builds SVM surrogate.
- `x`: Input data points.
- `y`: Output data points.
- `lb`: Lower bound of input data points.
- `ub`: Upper bound of output data points.
"""
function SVMSurrogate(x, y, lb, ub)
X = Array{Float64, 2}(undef, length(x), length(x[1]))
for j in 1:length(x)
X[j, :] = vec(collect(x[j]))
X = Array{Float64, 2}(undef, length(x), length(first(x)))
if length(lb) == 1
for j in eachindex(x)
X[j, 1] = x[j]
end
else
for j in eachindex(x)
X[j, :] = x[j]
end
end
model = LIBSVM.fit!(SVC(), X, y)
SVMSurrogate(x, y, model, lb, ub)
end

function (svmsurr::SVMSurrogate)(val)
n = length(val)
return LIBSVM.predict(svmsurr.model, reshape(collect(val), 1, n))[1]
return LIBSVM.predict(svmsurr.model, reshape(val, 1, n))[1]
end

function add_point!(svmsurr::SVMSurrogate, x_new, y_new)
"""
update!(svmsurr::SVMSurrogate, x_new, y_new)
## Arguments
- `svmsurr`: Surrogate of type [`SVMSurrogate`](@ref).
- `x_new`: Vector of new data points to be added to the training set of SVMSurrogate.
- `y_new`: Vector of new output points to be added to the training set of SVMSurrogate.
"""
function update!(svmsurr::SVMSurrogate, x_new, y_new)
svmsurr.x = vcat(svmsurr.x, x_new)
svmsurr.y = vcat(svmsurr.y, y_new)
if length(svmsurr.lb) == 1
#1D
svmsurr.x = vcat(svmsurr.x, x_new)
svmsurr.y = vcat(svmsurr.y, y_new)
svmsurr.model = LIBSVM.fit!(SVC(), reshape(svmsurr.x, length(svmsurr.x), 1),
svmsurr.y)
svmsurr.model = LIBSVM.fit!(
SVC(), reshape(svmsurr.x, length(svmsurr.x), 1), svmsurr.y)
else
n_previous = length(svmsurr.x)
a = vcat(svmsurr.x, x_new)
n_after = length(a)
dim_new = n_after - n_previous
n = length(svmsurr.x)
d = length(svmsurr.x[1])
tot_dim = n + dim_new
X = Array{Float64, 2}(undef, tot_dim, d)
for j in 1:n
X[j, :] = vec(collect(svmsurr.x[j]))
end
if dim_new == 1
X[n + 1, :] = vec(collect(x_new))
else
i = 1
for j in (n + 1):tot_dim
X[j, :] = vec(collect(x_new[i]))
i = i + 1
end
end
svmsurr.x = vcat(svmsurr.x, x_new)
svmsurr.y = vcat(svmsurr.y, y_new)
svmsurr.model = LIBSVM.fit!(SVC(), X, svmsurr.y)
svmsurr.model = LIBSVM.fit!(SVC(), transpose(reduce(hcat, svmsurr.x)), svmsurr.y)
end
nothing
end

end # module

0 comments on commit 0c48096

Please sign in to comment.