From 50567c2291d00120cd0bc8a06e9357be0564972a Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Fri, 22 Sep 2023 04:05:35 +0000 Subject: [PATCH] refactor: `SurrogatesRandomForest` RandomForest code to use `XGBoost` 2 --- lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl b/lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl index fc77f0dda..69532f999 100644 --- a/lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl +++ b/lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl @@ -14,7 +14,7 @@ mutable struct RandomForestSurrogate{X, Y, B, L, U, N} <: AbstractSurrogate end function RandomForestSurrogate(x, y, lb::Number, ub::Number; num_round::Int = 1) - bst = xgboost(reshape(x, length(x), 1), num_round, label = y) + bst = xgboost((reshape(x, length(x), 1), y); num_round) RandomForestSurrogate(x, y, bst, lb, ub, num_round) end @@ -35,7 +35,7 @@ function RandomForestSurrogate(x, y, lb, ub; num_round::Int = 1) for j in 1:length(x) X[j, :] = vec(collect(x[j])) end - bst = xgboost(X, num_round, label = y) + bst = xgboost((X, y); num_round) RandomForestSurrogate(x, y, bst, lb, ub, num_round) end @@ -50,8 +50,7 @@ function add_point!(rndfor::RandomForestSurrogate, x_new, y_new) #1D rndfor.x = vcat(rndfor.x, x_new) rndfor.y = vcat(rndfor.y, y_new) - rndfor.bst = xgboost(reshape(rndfor.x, length(rndfor.x), 1), rndfor.num_round, - label = rndfor.y) + rndfor.bst = xgboost((reshape(rndfor.x, length(rndfor.x), 1), rndfor.y); num_round = rndfor.num_round) else n_previous = length(rndfor.x) a = vcat(rndfor.x, x_new) @@ -75,7 +74,7 @@ function add_point!(rndfor::RandomForestSurrogate, x_new, y_new) end rndfor.x = vcat(rndfor.x, x_new) rndfor.y = vcat(rndfor.y, y_new) - rndfor.bst = xgboost(X, rndfor.num_round, label = rndfor.y) + rndfor.bst = xgboost((X, rndfor.y); num_round = rndfor.num_round) end nothing end