From 33be85df015f2e25655b2654236c90fce15dc236 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 23 Jul 2024 07:39:56 +0000 Subject: [PATCH] refactor: use SurrogatesBase in SurrogatesAbstractGPs --- .../src/SurrogatesAbstractGPs.jl | 34 ++++++++----------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/lib/SurrogatesAbstractGPs/src/SurrogatesAbstractGPs.jl b/lib/SurrogatesAbstractGPs/src/SurrogatesAbstractGPs.jl index 11eab4432..04ef521aa 100644 --- a/lib/SurrogatesAbstractGPs/src/SurrogatesAbstractGPs.jl +++ b/lib/SurrogatesAbstractGPs/src/SurrogatesAbstractGPs.jl @@ -1,11 +1,10 @@ module SurrogatesAbstractGPs -import Surrogates: add_point!, AbstractSurrogate, std_error_at_point, _check_dimension -export AbstractGPSurrogate, var_at_point, logpdf_surrogate +using SurrogatesBase, AbstractGPs -using AbstractGPs +export AbstractGPSurrogate, logpdf_surrogate, update!, finite_posterior -mutable struct AbstractGPSurrogate{X, Y, GP, GP_P, S} <: AbstractSurrogate +mutable struct AbstractGPSurrogate{X, Y, GP, GP_P, S} <: AbstractStochasticSurrogate x::X y::Y gp::GP @@ -20,29 +19,24 @@ end # predictor function (g::AbstractGPSurrogate)(val) - # Check to make sure dimensions of input matches expected dimension of surrogate - _check_dimension(g, val) - return only(mean(g.gp_posterior([val]))) end -# for add point -# copies of x and y need to be made because we get -#"Error: cannot resize array with shared data " if we push! directly to x and y -function add_point!(g::AbstractGPSurrogate, new_x, new_y) - if new_x in g.x - println("Adding a sample that already exists, cannot build AbstracgGPSurrogate.") - return +function SurrogatesBase.update!(g::AbstractGPSurrogate, new_x, new_y) + for x in new_x + in(x, g.x) && + error("Adding a sample that already exists, cannot update AbstractGPSurrogate!") end - x_copy = copy(g.x) - push!(x_copy, new_x) - y_copy = copy(g.y) - push!(y_copy, new_y) - updated_posterior = posterior(g.gp(x_copy, g.Σy), y_copy) - g.x, g.y, g.gp_posterior = x_copy, y_copy, updated_posterior + g.x = vcat(g.x, new_x) + g.y = vcat(g.y, new_y) + g.gp_posterior = posterior(g.gp(g.x, g.Σy), g.y) nothing end +function SurrogatesBase.finite_posterior(g::AbstractGPSurrogate, xs) + g.gp_posterior(xs) +end + function std_error_at_point(g::AbstractGPSurrogate, val) return sqrt(only(var(g.gp_posterior([val])))) end