Skip to content

Commit

Permalink
refactor(SurrogatesPolyChaos): update it to use SurrogatesBase
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Apr 5, 2024
1 parent 09179ac commit be446a5
Showing 1 changed file with 32 additions and 61 deletions.
93 changes: 32 additions & 61 deletions lib/SurrogatesPolyChaos/src/SurrogatesPolyChaos.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
module SurrogatesPolyChaos

import Surrogates: AbstractSurrogate, add_point!, _check_dimension
export PolynomialChaosSurrogate

using SurrogatesBase
using PolyChaos

mutable struct PolynomialChaosSurrogate{X, Y, L, U, C, O, N} <: AbstractSurrogate
export PolynomialChaosSurrogate, update!

mutable struct PolynomialChaosSurrogate{X, Y, L, U, C, O, N} <: AbstractDeterministicSurrogate
x::X
y::Y
lb::L
Expand All @@ -15,33 +15,41 @@ mutable struct PolynomialChaosSurrogate{X, Y, L, U, C, O, N} <: AbstractSurrogat
num_of_multi_indexes::N
end

function _calculatepce_coeff(x, y, num_of_multi_indexes, op::AbstractCanonicalOrthoPoly)
n = length(x)
A = zeros(eltype(x), n, num_of_multi_indexes)
for i in 1:n
A[i, :] = PolyChaos.evaluate(x[i], op)
end
return (A' * A) \ (A' * y)
end

function PolynomialChaosSurrogate(x, y, lb::Number, ub::Number;
op::AbstractCanonicalOrthoPoly = GaussOrthoPoly(2))
function PolynomialChaosSurrogate(x, y, lb, ub;
op = MultiOrthoPoly([GaussOrthoPoly(2) for j in 1:length(lb)], 2))
n = length(x)
d = length(lb)
poly_degree = op.deg
num_of_multi_indexes = 1 + poly_degree
num_of_multi_indexes = binomial(d + poly_degree, poly_degree)
if n < 2 + 3 * num_of_multi_indexes
throw("To avoid numerical problems, it's strongly suggested to have at least $(2+3*num_of_multi_indexes) samples")
end
coeff = _calculatepce_coeff(x, y, num_of_multi_indexes, op)
return PolynomialChaosSurrogate(x, y, lb, ub, coeff, op, num_of_multi_indexes)
end

function (pcND::PolynomialChaosSurrogate)(val)
sum = zero(eltype(val))
for i in 1:(pcND.num_of_multi_indexes)
sum = sum +
pcND.coeff[i] *
first(PolyChaos.evaluate(pcND.ortopolys.ind[i, :], collect(val),
pcND.ortopolys))
end
return sum
end

function (pc::PolynomialChaosSurrogate)(val::Number)
# Check to make sure dimensions of input matches expected dimension of surrogate
_check_dimension(pc, val)
return pc([val])
end

return sum([pc.coeff[i] * PolyChaos.evaluate(val, pc.ortopolys)[i]
for i in 1:(pc.num_of_multi_indexes)])
function _calculatepce_coeff(x, y, num_of_multi_indexes, op::AbstractCanonicalOrthoPoly)
n = length(x)
A = zeros(eltype(x), n, num_of_multi_indexes)
for i in 1:n
A[i, :] = PolyChaos.evaluate(x[i], op)
end
return (A' * A) \ (A' * y)
end

function _calculatepce_coeff(x, y, num_of_multi_indexes, op::MultiOrthoPoly)
Expand All @@ -58,48 +66,11 @@ function _calculatepce_coeff(x, y, num_of_multi_indexes, op::MultiOrthoPoly)
return (A' * A) \ (A' * y)
end

function PolynomialChaosSurrogate(x, y, lb, ub;
op::MultiOrthoPoly = MultiOrthoPoly([GaussOrthoPoly(2)
for j in 1:length(lb)],
2))
n = length(x)
d = length(lb)
poly_degree = op.deg
num_of_multi_indexes = binomial(d + poly_degree, poly_degree)
if n < 2 + 3 * num_of_multi_indexes
throw("To avoid numerical problems, it's strongly suggested to have at least $(2+3*num_of_multi_indexes) samples")
end
coeff = _calculatepce_coeff(x, y, num_of_multi_indexes, op)
return PolynomialChaosSurrogate(x, y, lb, ub, coeff, op, num_of_multi_indexes)
end

function (pcND::PolynomialChaosSurrogate)(val)
# Check to make sure dimensions of input matches expected dimension of surrogate
_check_dimension(pcND, val)

sum = zero(eltype(val[1]))
for i in 1:(pcND.num_of_multi_indexes)
sum = sum +
pcND.coeff[i] *
first(PolyChaos.evaluate(pcND.ortopolys.ind[i, :], collect(val),
pcND.ortopolys))
end
return sum
end

function add_point!(polych::PolynomialChaosSurrogate, x_new, y_new)
if length(polych.lb) == 1
#1D
polych.x = vcat(polych.x, x_new)
polych.y = vcat(polych.y, y_new)
polych.coeff = _calculatepce_coeff(polych.x, polych.y, polych.num_of_multi_indexes,
polych.ortopolys)
else
polych.x = vcat(polych.x, x_new)
polych.y = vcat(polych.y, y_new)
polych.coeff = _calculatepce_coeff(polych.x, polych.y, polych.num_of_multi_indexes,
function SurrogatesBase.update!(polych::PolynomialChaosSurrogate, x_new, y_new)
polych.x = vcat(polych.x, x_new)
polych.y = vcat(polych.y, y_new)
polych.coeff = _calculatepce_coeff(polych.x, polych.y, polych.num_of_multi_indexes,
polych.ortopolys)
end
nothing
end

Expand Down

0 comments on commit be446a5

Please sign in to comment.