From 496aab416ba48f4dc1179e133bf620671e26718c Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Fri, 1 Mar 2024 10:31:46 +0000 Subject: [PATCH] test: refactor and add tests --- lib/SurrogatesSVM/test/runtests.jl | 61 ++++++++++++++++++------------ 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/lib/SurrogatesSVM/test/runtests.jl b/lib/SurrogatesSVM/test/runtests.jl index 5919c8f2..a15f1027 100644 --- a/lib/SurrogatesSVM/test/runtests.jl +++ b/lib/SurrogatesSVM/test/runtests.jl @@ -1,29 +1,42 @@ using SafeTestsets @safetestset "SVMSurrogate" begin - using Surrogates, LIBSVM using SurrogatesSVM - - #1D - - obj_1D = x -> 2 * x + 1 - a = 0.0 - b = 10.0 - x = sample(5, a, b, SobolSample()) - y = obj_1D.(x) - my_svm_1D = SVMSurrogate(x, y, a, b) - val = my_svm_1D(5.0) - add_point!(my_svm_1D, 3.1, 7.2) - add_point!(my_svm_1D, [3.2, 3.5], [7.4, 8.0]) - - #ND - obj_N = x -> x[1]^2 * x[2] - lb = [0.0, 0.0] - ub = [10.0, 10.0] - x = sample(100, lb, ub, RandomSample()) - y = obj_N.(x) - my_svm_ND = SVMSurrogate(x, y, lb, ub) - val = my_svm_ND((5.0, 1.2)) - add_point!(my_svm_ND, (1.0, 1.0), 1.0) - add_point!(my_svm_ND, [(1.2, 1.2), (1.5, 1.5)], [1.728, 3.375]) + using Surrogates + using LIBSVM + using Test + @testset "1D" begin + obj_1D = x -> 2 * x + 1 + a = 0.0 + b = 10.0 + x = sample(5, a, b, SobolSample()) + y = obj_1D.(x) + svm = LIBSVM.fit!(SVC(), reshape(x, length(x), 1), y) + my_svm_1D = SVMSurrogate(x, y, a, b) + val = my_svm_1D([5.0]) + @test LIBSVM.predict(svm, [5.0;;])[1] == val + update!(my_svm_1D, [3.1], [7.2]) + update!(my_svm_1D, [3.2, 3.5], [7.4, 8.0]) + svm = LIBSVM.fit!(SVC(), reshape(my_svm_1D.x, length(my_svm_1D.x), 1), my_svm_1D.y) + val = my_svm_1D([3.1]) + @test LIBSVM.predict(svm, [3.1;;])[1] == val + end + @testset "ND" begin + obj_N = x -> x[1]^2 * x[2] + lb = [0.0, 0.0] + ub = [10.0, 10.0] + x = collect.(sample(100, lb, ub, RandomSample())) + y = obj_N.(x) + svm = LIBSVM.fit!(SVC(), transpose(reduce(hcat, x)), y) + my_svm_ND = SVMSurrogate(x, y, lb, ub) + x_test = [5.0, 1.2] + val = my_svm_ND(x_test) + @test LIBSVM.predict(svm, reshape(x_test, 1, 2))[1] == val + update!(my_svm_ND, [[1.0, 1.0]], [1.0]) + update!(my_svm_ND, [[1.2, 1.2], [1.5, 1.5]], [1.728, 3.375]) + svm = LIBSVM.fit!(SVC(), transpose(reduce(hcat, my_svm_ND.x)), my_svm_ND.y) + x_test = [1.0, 1.0] + val = my_svm_ND(x_test) + @test LIBSVM.predict(svm, reshape(x_test, 1, 2))[1] == val + end end