Skip to content

Commit

Permalink
Merge pull request #475 from sathvikbhagavan/sb/svm
Browse files Browse the repository at this point in the history
refactor(SurrogatesSVM): use SurrogatesBase and cleanup code
  • Loading branch information
ChrisRackauckas authored Mar 3, 2024
2 parents f760ab5 + 496aab4 commit bc9f174
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 76 deletions.
10 changes: 7 additions & 3 deletions lib/SurrogatesSVM/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@ version = "0.1.0"

[deps]
LIBSVM = "b1bec4e5-fd48-53fe-b0cb-9723c09d164b"
Surrogates = "6fc51010-71bc-11e9-0e15-a3fcc6593c49"
SurrogatesBase = "89f642e6-4179-4274-8202-c11f4bd9a72c"

[compat]
LIBSVM = "0.8"
Surrogates = "6"
SafeTestsets = "0.1"
Surrogates = "6.9"
SurrogatesBase = "1.1"
Test = "1"
julia = "1.10"

[extras]
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Surrogates = "6fc51010-71bc-11e9-0e15-a3fcc6593c49"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["SafeTestsets", "Test"]
test = ["SafeTestsets", "Surrogates", "Test"]
87 changes: 38 additions & 49 deletions lib/SurrogatesSVM/src/SurrogatesSVM.jl
Original file line number Diff line number Diff line change
@@ -1,79 +1,68 @@
module SurrogatesSVM

import Surrogates: AbstractSurrogate, add_point!
export SVMSurrogate

using SurrogatesBase
using LIBSVM

mutable struct SVMSurrogate{X, Y, M, L, U} <: AbstractSurrogate
export SVMSurrogate, update!

mutable struct SVMSurrogate{X, Y, M, L, U} <: AbstractDeterministicSurrogate
x::X
y::Y
model::M
lb::L
ub::U
end

function SVMSurrogate(x, y, lb::Number, ub::Number)
xn = reshape(x, length(x), 1)
model = LIBSVM.fit!(SVC(), xn, y)
SVMSurrogate(xn, y, model, lb, ub)
end
"""
SVMSurrogate(x, y, lb, ub)
function (svmsurr::SVMSurrogate)(val::Number)
return LIBSVM.predict(svmsurr.model, [val])
end
Builds a SVM Surrogate using [LIBSVM](https://github.com/JuliaML/LIBSVM.jl).
"""
SVMSurrogate(x,y,lb,ub)
## Arguments
Builds SVM surrogate.
- `x`: Input data points.
- `y`: Output data points.
- `lb`: Lower bound of input data points.
- `ub`: Upper bound of output data points.
"""
function SVMSurrogate(x, y, lb, ub)
X = Array{Float64, 2}(undef, length(x), length(x[1]))
for j in 1:length(x)
X[j, :] = vec(collect(x[j]))
X = Array{Float64, 2}(undef, length(x), length(first(x)))
if length(lb) == 1
for j in eachindex(x)
X[j, 1] = x[j]
end
else
for j in eachindex(x)
X[j, :] = x[j]
end
end
model = LIBSVM.fit!(SVC(), X, y)
SVMSurrogate(x, y, model, lb, ub)
end

function (svmsurr::SVMSurrogate)(val)
n = length(val)
return LIBSVM.predict(svmsurr.model, reshape(collect(val), 1, n))[1]
return LIBSVM.predict(svmsurr.model, reshape(val, 1, n))[1]
end

function add_point!(svmsurr::SVMSurrogate, x_new, y_new)
"""
update!(svmsurr::SVMSurrogate, x_new, y_new)
## Arguments
- `svmsurr`: Surrogate of type [`SVMSurrogate`](@ref).
- `x_new`: Vector of new data points to be added to the training set of SVMSurrogate.
- `y_new`: Vector of new output points to be added to the training set of SVMSurrogate.
"""
function update!(svmsurr::SVMSurrogate, x_new, y_new)
svmsurr.x = vcat(svmsurr.x, x_new)
svmsurr.y = vcat(svmsurr.y, y_new)
if length(svmsurr.lb) == 1
#1D
svmsurr.x = vcat(svmsurr.x, x_new)
svmsurr.y = vcat(svmsurr.y, y_new)
svmsurr.model = LIBSVM.fit!(SVC(), reshape(svmsurr.x, length(svmsurr.x), 1),
svmsurr.y)
svmsurr.model = LIBSVM.fit!(
SVC(), reshape(svmsurr.x, length(svmsurr.x), 1), svmsurr.y)
else
n_previous = length(svmsurr.x)
a = vcat(svmsurr.x, x_new)
n_after = length(a)
dim_new = n_after - n_previous
n = length(svmsurr.x)
d = length(svmsurr.x[1])
tot_dim = n + dim_new
X = Array{Float64, 2}(undef, tot_dim, d)
for j in 1:n
X[j, :] = vec(collect(svmsurr.x[j]))
end
if dim_new == 1
X[n + 1, :] = vec(collect(x_new))
else
i = 1
for j in (n + 1):tot_dim
X[j, :] = vec(collect(x_new[i]))
i = i + 1
end
end
svmsurr.x = vcat(svmsurr.x, x_new)
svmsurr.y = vcat(svmsurr.y, y_new)
svmsurr.model = LIBSVM.fit!(SVC(), X, svmsurr.y)
svmsurr.model = LIBSVM.fit!(SVC(), transpose(reduce(hcat, svmsurr.x)), svmsurr.y)
end
nothing
end

end # module
61 changes: 37 additions & 24 deletions lib/SurrogatesSVM/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,42 @@
using SafeTestsets

@safetestset "SVMSurrogate" begin
using Surrogates, LIBSVM
using SurrogatesSVM

#1D

obj_1D = x -> 2 * x + 1
a = 0.0
b = 10.0
x = sample(5, a, b, SobolSample())
y = obj_1D.(x)
my_svm_1D = SVMSurrogate(x, y, a, b)
val = my_svm_1D(5.0)
add_point!(my_svm_1D, 3.1, 7.2)
add_point!(my_svm_1D, [3.2, 3.5], [7.4, 8.0])

#ND
obj_N = x -> x[1]^2 * x[2]
lb = [0.0, 0.0]
ub = [10.0, 10.0]
x = sample(100, lb, ub, RandomSample())
y = obj_N.(x)
my_svm_ND = SVMSurrogate(x, y, lb, ub)
val = my_svm_ND((5.0, 1.2))
add_point!(my_svm_ND, (1.0, 1.0), 1.0)
add_point!(my_svm_ND, [(1.2, 1.2), (1.5, 1.5)], [1.728, 3.375])
using Surrogates
using LIBSVM
using Test
@testset "1D" begin
obj_1D = x -> 2 * x + 1
a = 0.0
b = 10.0
x = sample(5, a, b, SobolSample())
y = obj_1D.(x)
svm = LIBSVM.fit!(SVC(), reshape(x, length(x), 1), y)
my_svm_1D = SVMSurrogate(x, y, a, b)
val = my_svm_1D([5.0])
@test LIBSVM.predict(svm, [5.0;;])[1] == val
update!(my_svm_1D, [3.1], [7.2])
update!(my_svm_1D, [3.2, 3.5], [7.4, 8.0])
svm = LIBSVM.fit!(SVC(), reshape(my_svm_1D.x, length(my_svm_1D.x), 1), my_svm_1D.y)
val = my_svm_1D([3.1])
@test LIBSVM.predict(svm, [3.1;;])[1] == val
end
@testset "ND" begin
obj_N = x -> x[1]^2 * x[2]
lb = [0.0, 0.0]
ub = [10.0, 10.0]
x = collect.(sample(100, lb, ub, RandomSample()))
y = obj_N.(x)
svm = LIBSVM.fit!(SVC(), transpose(reduce(hcat, x)), y)
my_svm_ND = SVMSurrogate(x, y, lb, ub)
x_test = [5.0, 1.2]
val = my_svm_ND(x_test)
@test LIBSVM.predict(svm, reshape(x_test, 1, 2))[1] == val
update!(my_svm_ND, [[1.0, 1.0]], [1.0])
update!(my_svm_ND, [[1.2, 1.2], [1.5, 1.5]], [1.728, 3.375])
svm = LIBSVM.fit!(SVC(), transpose(reduce(hcat, my_svm_ND.x)), my_svm_ND.y)
x_test = [1.0, 1.0]
val = my_svm_ND(x_test)
@test LIBSVM.predict(svm, reshape(x_test, 1, 2))[1] == val
end
end

0 comments on commit bc9f174

Please sign in to comment.