Skip to content

Commit

Permalink
Merge pull request #477 from sathvikbhagavan/sb/rf
Browse files Browse the repository at this point in the history
refactor: SurrogatesRandomForest with SurrogatesBase
  • Loading branch information
ChrisRackauckas authored Mar 12, 2024
2 parents 26b5705 + a20c8d5 commit 43a4b3e
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 79 deletions.
9 changes: 7 additions & 2 deletions lib/SurrogatesRandomForest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@ version = "0.1.1"

[deps]
Surrogates = "6fc51010-71bc-11e9-0e15-a3fcc6593c49"
SurrogatesBase = "89f642e6-4179-4274-8202-c11f4bd9a72c"
XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"

[compat]
Surrogates = "6"
SafeTestsets = "0.1"
Surrogates = "6.9"
SurrogatesBase = "1.1"
Test = "1"
XGBoost = "2"
julia = "1.10"

[extras]
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Surrogates = "6fc51010-71bc-11e9-0e15-a3fcc6593c49"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["SafeTestsets", "Test"]
test = ["SafeTestsets", "Surrogates", "Test"]
85 changes: 37 additions & 48 deletions lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
module SurrogatesRandomForest

import Surrogates: add_point!, AbstractSurrogate, _check_dimension
export RandomForestSurrogate
using SurrogatesBase
using XGBoost: xgboost, predict

using XGBoost
mutable struct RandomForestSurrogate{X, Y, B, L, U, N} <: AbstractSurrogate
export RandomForestSurrogate, update!

mutable struct RandomForestSurrogate{X, Y, B, L, U, N} <:
SurrogatesBase.AbstractDeterministicSurrogate
x::X
y::Y
bst::B
Expand All @@ -13,69 +15,56 @@ mutable struct RandomForestSurrogate{X, Y, B, L, U, N} <: AbstractSurrogate
num_round::N
end

function RandomForestSurrogate(x, y, lb::Number, ub::Number; num_round::Int = 1)
bst = xgboost((reshape(x, length(x), 1), y); num_round)
RandomForestSurrogate(x, y, bst, lb, ub, num_round)
end

function (rndfor::RandomForestSurrogate)(val::Number)
# Check to make sure dimensions of input matches expected dimension of surrogate
_check_dimension(rndfor, val)
return XGBoost.predict(rndfor.bst, reshape([val], 1, 1))[1]
end

"""
RandomForestSurrogate(x,y,lb,ub,num_round)
RandomForestSurrogate(x, y, lb, ub, num_round)
Build Random forest surrogate. num_round is the number of trees.
## Arguments
- `x`: Input data points.
- `y`: Output data points.
- `lb`: Lower bound of input data points.
- `ub`: Upper bound of output data points.
## Keyword Arguments
- `num_round`: number of rounds of training.
"""
function RandomForestSurrogate(x, y, lb, ub; num_round::Int = 1)
X = Array{Float64, 2}(undef, length(x), length(x[1]))
for j in 1:length(x)
X[j, :] = vec(collect(x[j]))
if length(lb) == 1
for j in eachindex(x)
X[j, 1] = x[j]
end
else
for j in eachindex(x)
X[j, :] = x[j]
end
end
bst = xgboost((X, y); num_round)
RandomForestSurrogate(x, y, bst, lb, ub, num_round)
end

function (rndfor::RandomForestSurrogate)(val::Number)
return rndfor([val])
end

function (rndfor::RandomForestSurrogate)(val)
# Check to make sure dimensions of input matches expected dimension of surrogate
_check_dimension(rndfor, val)
return XGBoost.predict(rndfor.bst, reshape(collect(val), 1, length(val)))[1]
return predict(rndfor.bst, reshape(val, length(val), 1))[1]
end

function add_point!(rndfor::RandomForestSurrogate, x_new, y_new)
function SurrogatesBase.update!(rndfor::RandomForestSurrogate, x_new, y_new)
rndfor.x = vcat(rndfor.x, x_new)
rndfor.y = vcat(rndfor.y, y_new)
if length(rndfor.lb) == 1
#1D
rndfor.x = vcat(rndfor.x, x_new)
rndfor.y = vcat(rndfor.y, y_new)
rndfor.bst = xgboost((reshape(rndfor.x, length(rndfor.x), 1), rndfor.y);
num_round = rndfor.num_round)
else
n_previous = length(rndfor.x)
a = vcat(rndfor.x, x_new)
n_after = length(a)
dim_new = n_after - n_previous
n = length(rndfor.x)
d = length(rndfor.x[1])
tot_dim = n + dim_new
X = Array{Float64, 2}(undef, tot_dim, d)
for j in 1:n
X[j, :] = vec(collect(rndfor.x[j]))
end
if dim_new == 1
X[n + 1, :] = vec(collect(x_new))
else
i = 1
for j in (n + 1):tot_dim
X[j, :] = vec(collect(x_new[i]))
i = i + 1
end
end
rndfor.x = vcat(rndfor.x, x_new)
rndfor.y = vcat(rndfor.y, y_new)
rndfor.bst = xgboost((X, rndfor.y); num_round = rndfor.num_round)
rndfor.bst = xgboost(
(transpose(reduce(hcat, rndfor.x)), rndfor.y); num_round = rndfor.num_round)
end
nothing
end

end # module
57 changes: 30 additions & 27 deletions lib/SurrogatesRandomForest/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
using SafeTestsets

@safetestset "RandomForestSurrogate" begin
using Surrogates, XGBoost
using Surrogates: sample, SobolSample
using Surrogates
using SurrogatesRandomForest

#1D
obj_1D = x -> 3 * x + 1
x = [1.0, 2.0, 3.0, 4.0, 5.0]
y = obj_1D.(x)
a = 0.0
b = 10.0
num_round = 2
my_forest_1D = RandomForestSurrogate(x, y, a, b, num_round = 2)
my_forest_kwarg = RandomForestSurrogate(x, y, a, b)
val = my_forest_1D(3.5)
add_point!(my_forest_1D, 6.0, 19.0)
add_point!(my_forest_1D, [7.0, 8.0], obj_1D.([7.0, 8.0]))

#ND
lb = [0.0, 0.0, 0.0]
ub = [10.0, 10.0, 10.0]
x = sample(5, lb, ub, SobolSample())
obj_ND = x -> x[1] * x[2]^2 * x[3]
y = obj_ND.(x)
my_forest_ND = RandomForestSurrogate(x, y, lb, ub, num_round = 2)
my_forest_kwarg = RandomForestSurrogate(x, y, lb, ub)
val = my_forest_ND((1.0, 1.0, 1.0))
add_point!(my_forest_ND, (1.0, 1.0, 1.0), 1.0)
add_point!(my_forest_ND, [(1.2, 1.2, 1.0), (1.5, 1.5, 1.0)], [1.728, 3.375])
using Test
using XGBoost: xgboost, predict
@testset "1D" begin
obj_1D = x -> 3 * x + 1
x = [1.0, 2.0, 3.0, 4.0, 5.0]
y = obj_1D.(x)
a = 0.0
b = 10.0
num_round = 2
my_forest_1D = RandomForestSurrogate(x, y, a, b; num_round = 2)
xgboost1 = xgboost((reshape(x, length(x), 1), y); num_round = 2)
val = my_forest_1D(3.5)
@test predict(xgboost1, [3.5;;])[1] == val
update!(my_forest_1D, [6.0], [19.0])
update!(my_forest_1D, [7.0, 8.0], obj_1D.([7.0, 8.0]))
end
@testset "ND" begin
lb = [0.0, 0.0, 0.0]
ub = [10.0, 10.0, 10.0]
x = collect.(sample(5, lb, ub, SobolSample()))
obj_ND = x -> x[1] * x[2]^2 * x[3]
y = obj_ND.(x)
my_forest_ND = RandomForestSurrogate(x, y, lb, ub; num_round = 2)
xgboostND = xgboost((reduce(hcat, x)', y); num_round = 2)
val = my_forest_ND([1.0, 1.0, 1.0])
@test predict(xgboostND, reshape([1.0, 1.0, 1.0], 3, 1))[1] == val
update!(my_forest_ND, [[1.0, 1.0, 1.0]], [1.0])
update!(my_forest_ND, [[1.2, 1.2, 1.0], [1.5, 1.5, 1.0]], [1.728, 3.375])
end
end
6 changes: 5 additions & 1 deletion lib/SurrogatesSVM/src/SurrogatesSVM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ function SVMSurrogate(x, y, lb, ub)
SVMSurrogate(x, y, model, lb, ub)
end

function (svmsurr::SVMSurrogate)(val::Number)
return svmsurr([val])
end

function (svmsurr::SVMSurrogate)(val)
n = length(val)
return LIBSVM.predict(svmsurr.model, reshape(val, 1, n))[1]
Expand All @@ -54,7 +58,7 @@ end
- `x_new`: Vector of new data points to be added to the training set of SVMSurrogate.
- `y_new`: Vector of new output points to be added to the training set of SVMSurrogate.
"""
function update!(svmsurr::SVMSurrogate, x_new, y_new)
function SurrogatesBase.update!(svmsurr::SVMSurrogate, x_new, y_new)
svmsurr.x = vcat(svmsurr.x, x_new)
svmsurr.y = vcat(svmsurr.y, y_new)
if length(svmsurr.lb) == 1
Expand Down
2 changes: 1 addition & 1 deletion lib/SurrogatesSVM/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using SafeTestsets
update!(my_svm_1D, [3.1], [7.2])
update!(my_svm_1D, [3.2, 3.5], [7.4, 8.0])
svm = LIBSVM.fit!(SVC(), reshape(my_svm_1D.x, length(my_svm_1D.x), 1), my_svm_1D.y)
val = my_svm_1D([3.1])
val = my_svm_1D(3.1)
@test LIBSVM.predict(svm, [3.1;;])[1] == val
end
@testset "ND" begin
Expand Down

0 comments on commit 43a4b3e

Please sign in to comment.