Skip to content

Commit

Permalink
Merge pull request #443 from sathvikbhagavan/sb/xgboost_bump
Browse files Browse the repository at this point in the history
Use XGBoost 2
  • Loading branch information
ChrisRackauckas authored Sep 22, 2023
2 parents 67a95be + 50567c2 commit 996f662
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ SurrogatesMOE = "0.1.0"
SurrogatesPolyChaos = "0.1.0"
SurrogatesRandomForest = "0.1.0"
SurrogatesSVM = "0.1.0"
XGBoost = "1.5"
XGBoost = "2"
Zygote = "0.6.49"
2 changes: 1 addition & 1 deletion lib/SurrogatesMOE/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Surrogates = "6.0"
SurrogatesFlux = "0.1.0"
SurrogatesPolyChaos = "0.1.0"
SurrogatesRandomForest = "0.1.0"
XGBoost = "1.5.2"
XGBoost = "2"

[extras]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
2 changes: 1 addition & 1 deletion lib/SurrogatesRandomForest/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ XGBoost = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"

[compat]
Surrogates = "6"
XGBoost = "1.5"
XGBoost = "2"
julia = "1.6"

[extras]
Expand Down
9 changes: 4 additions & 5 deletions lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ mutable struct RandomForestSurrogate{X, Y, B, L, U, N} <: AbstractSurrogate
end

function RandomForestSurrogate(x, y, lb::Number, ub::Number; num_round::Int = 1)
bst = xgboost(reshape(x, length(x), 1), num_round, label = y)
bst = xgboost((reshape(x, length(x), 1), y); num_round)
RandomForestSurrogate(x, y, bst, lb, ub, num_round)
end

Expand All @@ -35,7 +35,7 @@ function RandomForestSurrogate(x, y, lb, ub; num_round::Int = 1)
for j in 1:length(x)
X[j, :] = vec(collect(x[j]))
end
bst = xgboost(X, num_round, label = y)
bst = xgboost((X, y); num_round)
RandomForestSurrogate(x, y, bst, lb, ub, num_round)
end

Expand All @@ -50,8 +50,7 @@ function add_point!(rndfor::RandomForestSurrogate, x_new, y_new)
#1D
rndfor.x = vcat(rndfor.x, x_new)
rndfor.y = vcat(rndfor.y, y_new)
rndfor.bst = xgboost(reshape(rndfor.x, length(rndfor.x), 1), rndfor.num_round,
label = rndfor.y)
rndfor.bst = xgboost((reshape(rndfor.x, length(rndfor.x), 1), rndfor.y); num_round = rndfor.num_round)
else
n_previous = length(rndfor.x)
a = vcat(rndfor.x, x_new)
Expand All @@ -75,7 +74,7 @@ function add_point!(rndfor::RandomForestSurrogate, x_new, y_new)
end
rndfor.x = vcat(rndfor.x, x_new)
rndfor.y = vcat(rndfor.y, y_new)
rndfor.bst = xgboost(X, rndfor.num_round, label = rndfor.y)
rndfor.bst = xgboost((X, rndfor.y); num_round = rndfor.num_round)
end
nothing
end
Expand Down

0 comments on commit 996f662

Please sign in to comment.