Skip to content

Commit

Permalink
refactor: add scalar evaluation for SurrogatesSVM
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Mar 7, 2024
1 parent 6bd95e2 commit a20c8d5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion lib/SurrogatesSVM/src/SurrogatesSVM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ function SVMSurrogate(x, y, lb, ub)
SVMSurrogate(x, y, model, lb, ub)
end

function (svmsurr::SVMSurrogate)(val::Number)
return svmsurr([val])

Check warning on line 44 in lib/SurrogatesSVM/src/SurrogatesSVM.jl

View check run for this annotation

Codecov / codecov/patch

lib/SurrogatesSVM/src/SurrogatesSVM.jl#L43-L44

Added lines #L43 - L44 were not covered by tests
end

function (svmsurr::SVMSurrogate)(val)
n = length(val)
return LIBSVM.predict(svmsurr.model, reshape(val, 1, n))[1]
Expand All @@ -54,7 +58,7 @@ end
- `x_new`: Vector of new data points to be added to the training set of SVMSurrogate.
- `y_new`: Vector of new output points to be added to the training set of SVMSurrogate.
"""
function update!(svmsurr::SVMSurrogate, x_new, y_new)
function SurrogatesBase.update!(svmsurr::SVMSurrogate, x_new, y_new)

Check warning on line 61 in lib/SurrogatesSVM/src/SurrogatesSVM.jl

View check run for this annotation

Codecov / codecov/patch

lib/SurrogatesSVM/src/SurrogatesSVM.jl#L61

Added line #L61 was not covered by tests
svmsurr.x = vcat(svmsurr.x, x_new)
svmsurr.y = vcat(svmsurr.y, y_new)
if length(svmsurr.lb) == 1
Expand Down
2 changes: 1 addition & 1 deletion lib/SurrogatesSVM/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using SafeTestsets
update!(my_svm_1D, [3.1], [7.2])
update!(my_svm_1D, [3.2, 3.5], [7.4, 8.0])
svm = LIBSVM.fit!(SVC(), reshape(my_svm_1D.x, length(my_svm_1D.x), 1), my_svm_1D.y)
val = my_svm_1D([3.1])
val = my_svm_1D(3.1)
@test LIBSVM.predict(svm, [3.1;;])[1] == val
end
@testset "ND" begin
Expand Down

0 comments on commit a20c8d5

Please sign in to comment.