diff --git a/lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl b/lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl index fc77f0dda..69532f999 100644 --- a/lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl +++ b/lib/SurrogatesRandomForest/src/SurrogatesRandomForest.jl @@ -14,7 +14,7 @@ mutable struct RandomForestSurrogate{X, Y, B, L, U, N} <: AbstractSurrogate end function RandomForestSurrogate(x, y, lb::Number, ub::Number; num_round::Int = 1) - bst = xgboost(reshape(x, length(x), 1), num_round, label = y) + bst = xgboost((reshape(x, length(x), 1), y); num_round) RandomForestSurrogate(x, y, bst, lb, ub, num_round) end @@ -35,7 +35,7 @@ function RandomForestSurrogate(x, y, lb, ub; num_round::Int = 1) for j in 1:length(x) X[j, :] = vec(collect(x[j])) end - bst = xgboost(X, num_round, label = y) + bst = xgboost((X, y); num_round) RandomForestSurrogate(x, y, bst, lb, ub, num_round) end @@ -50,8 +50,7 @@ function add_point!(rndfor::RandomForestSurrogate, x_new, y_new) #1D rndfor.x = vcat(rndfor.x, x_new) rndfor.y = vcat(rndfor.y, y_new) - rndfor.bst = xgboost(reshape(rndfor.x, length(rndfor.x), 1), rndfor.num_round, - label = rndfor.y) + rndfor.bst = xgboost((reshape(rndfor.x, length(rndfor.x), 1), rndfor.y); num_round = rndfor.num_round) else n_previous = length(rndfor.x) a = vcat(rndfor.x, x_new) @@ -75,7 +74,7 @@ function add_point!(rndfor::RandomForestSurrogate, x_new, y_new) end rndfor.x = vcat(rndfor.x, x_new) rndfor.y = vcat(rndfor.y, y_new) - rndfor.bst = xgboost(X, rndfor.num_round, label = rndfor.y) + rndfor.bst = xgboost((X, rndfor.y); num_round = rndfor.num_round) end nothing end