Skip to content

Commit

Permalink
refactor: SurrogatesRandomForest RandomForest code to use XGBoost 2
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Sep 22, 2023
1 parent 400c7f4 commit 50567c2
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mutable struct RandomForestSurrogate{X, Y, B, L, U, N} <: AbstractSurrogate
end

function RandomForestSurrogate(x, y, lb::Number, ub::Number; num_round::Int = 1)
bst = xgboost(reshape(x, length(x), 1), num_round, label = y)
bst = xgboost((reshape(x, length(x), 1), y); num_round)
RandomForestSurrogate(x, y, bst, lb, ub, num_round)
end

Expand All @@ -35,7 +35,7 @@ function RandomForestSurrogate(x, y, lb, ub; num_round::Int = 1)
for j in 1:length(x)
X[j, :] = vec(collect(x[j]))
end
bst = xgboost(X, num_round, label = y)
bst = xgboost((X, y); num_round)
RandomForestSurrogate(x, y, bst, lb, ub, num_round)
end

Expand All @@ -50,8 +50,7 @@ function add_point!(rndfor::RandomForestSurrogate, x_new, y_new)
#1D
rndfor.x = vcat(rndfor.x, x_new)
rndfor.y = vcat(rndfor.y, y_new)
rndfor.bst = xgboost(reshape(rndfor.x, length(rndfor.x), 1), rndfor.num_round,
label = rndfor.y)
rndfor.bst = xgboost((reshape(rndfor.x, length(rndfor.x), 1), rndfor.y); num_round = rndfor.num_round)
else
n_previous = length(rndfor.x)
a = vcat(rndfor.x, x_new)
Expand All @@ -75,7 +74,7 @@ function add_point!(rndfor::RandomForestSurrogate, x_new, y_new)
end
rndfor.x = vcat(rndfor.x, x_new)
rndfor.y = vcat(rndfor.y, y_new)
rndfor.bst = xgboost(X, rndfor.num_round, label = rndfor.y)
rndfor.bst = xgboost((X, rndfor.y); num_round = rndfor.num_round)
end
nothing
end
Expand Down

0 comments on commit 50567c2

Please sign in to comment.