Skip to content

Commit

Permalink
refactor: use SurrogatesBase in SurrogatesAbstractGPs
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jul 23, 2024
1 parent 554e9ba commit 33be85d
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions lib/SurrogatesAbstractGPs/src/SurrogatesAbstractGPs.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
module SurrogatesAbstractGPs

import Surrogates: add_point!, AbstractSurrogate, std_error_at_point, _check_dimension
export AbstractGPSurrogate, var_at_point, logpdf_surrogate
using SurrogatesBase, AbstractGPs

using AbstractGPs
export AbstractGPSurrogate, logpdf_surrogate, update!, finite_posterior

mutable struct AbstractGPSurrogate{X, Y, GP, GP_P, S} <: AbstractSurrogate
mutable struct AbstractGPSurrogate{X, Y, GP, GP_P, S} <: AbstractStochasticSurrogate
x::X
y::Y
gp::GP
Expand All @@ -20,29 +19,24 @@ end

# predictor
function (g::AbstractGPSurrogate)(val)
# Check to make sure dimensions of input matches expected dimension of surrogate
_check_dimension(g, val)

return only(mean(g.gp_posterior([val])))
end

# for add point
# copies of x and y need to be made because we get
#"Error: cannot resize array with shared data " if we push! directly to x and y
function add_point!(g::AbstractGPSurrogate, new_x, new_y)
if new_x in g.x
println("Adding a sample that already exists, cannot build AbstracgGPSurrogate.")
return
function SurrogatesBase.update!(g::AbstractGPSurrogate, new_x, new_y)
for x in new_x
in(x, g.x) &&
error("Adding a sample that already exists, cannot update AbstractGPSurrogate!")
end
x_copy = copy(g.x)
push!(x_copy, new_x)
y_copy = copy(g.y)
push!(y_copy, new_y)
updated_posterior = posterior(g.gp(x_copy, g.Σy), y_copy)
g.x, g.y, g.gp_posterior = x_copy, y_copy, updated_posterior
g.x = vcat(g.x, new_x)
g.y = vcat(g.y, new_y)
g.gp_posterior = posterior(g.gp(g.x, g.Σy), g.y)
nothing
end

function SurrogatesBase.finite_posterior(g::AbstractGPSurrogate, xs)
g.gp_posterior(xs)
end

function std_error_at_point(g::AbstractGPSurrogate, val)
return sqrt(only(var(g.gp_posterior([val]))))
end
Expand Down

0 comments on commit 33be85d

Please sign in to comment.