Skip to content

Commit

Permalink
Inverse problem solving now works for Lux chains
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Nov 13, 2023
1 parent 4b93047 commit 805380c
Show file tree
Hide file tree
Showing 2 changed files with 380 additions and 54 deletions.
147 changes: 103 additions & 44 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ mutable struct PDELogTargetDensity{
dataset::D
priors::P
allstd::Vector{Vector{Float64}}
autodiff::Bool
names::Tuple
physdt::Float64
extraparams::Int
init_params::I
full_loglikelihood::F
Phi::PH

function PDELogTargetDensity(dim, strategy, dataset,
priors, allstd, autodiff, physdt, extraparams,
priors, allstd, names, physdt, extraparams,
init_params::AbstractVector, full_loglikelihood, Phi)
new{
typeof(strategy),
Expand All @@ -33,15 +33,15 @@ mutable struct PDELogTargetDensity{
dataset,
priors,
allstd,
autodiff,
names,
physdt,
extraparams,
init_params,
full_loglikelihood,
Phi)
end
function PDELogTargetDensity(dim, strategy, dataset,
priors, allstd, autodiff, physdt, extraparams,
priors, allstd, names, physdt, extraparams,
init_params::NamedTuple, full_loglikelihood, Phi)
new{
typeof(strategy),
Expand All @@ -55,7 +55,7 @@ mutable struct PDELogTargetDensity{
dataset,
priors,
allstd,
autodiff,
names,
physdt,
extraparams,
init_params,
Expand Down Expand Up @@ -91,37 +91,83 @@ function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ)
# println("2 : ", length(L2LossData(Tar, θ).partials))
# println("3 : ", length(priorlogpdf(Tar, θ).partials))

return Tar.full_loglikelihood(vcat(vector_to_parameters(θ[1:(end - Tar.extraparams)],
Tar.init_params[1:(end - Tar.extraparams)]), θ[(end - Tar.extraparams + 1):end]),
Tar.allstd) +
L2LossData(Tar, θ) + priorlogpdf(Tar, θ)
# println(length(initial_nnθ))
# println(length(pinnrep.flat_init_params))
# println(initial_nnθ)
# println(pinnrep.flat_init_params)
# println(typeof(θ) <: AbstractVector)
# println(length(θ))
# println(typeof(θ[1:(end - Tar.extraparams)]) <: AbstractVector)
# println(length(θ[1:(end - Tar.extraparams)]))
# println(length(vector_to_parameters(θ[1:(end - Tar.extraparams)],
# Tar.init_params[1:(end - Tar.extraparams)])))

# Tar.full_loglikelihood(vcat(vector_to_parameters(θ[1:(end - Tar.extraparams)],
# Tar.init_params), θ[(end - Tar.extraparams + 1):end]),
# Tar.allstd)

# θ = reduce(vcat, θ)
# yuh = vcat(vector_to_parameters(θ[1:(end - Tar.extraparams)],
# Tar.init_params),
# adapt(typeof(vector_to_parameters(θ[1:(end - Tar.extraparams)],
# Tar.init_params)), θ[(end - Tar.extraparams + 1):end]))

# yuh = ComponentArrays.ComponentArray(;
# # u = vector_to_parameters(θ[1:(end - Tar.extraparams)], Tar.init_params),
# depvar = vector_to_parameters(θ[1:(end - Tar.extraparams)], Tar.init_params),
# p = θ[(end - Tar.extraparams + 1):end])

return Tar.full_loglikelihood(setLuxparameters(Tar, θ),
Tar.allstd) + priorlogpdf(Tar, θ)
# +L2LossData(Tar, θ)
# + L2loss2(Tar, θ)
end

function setLuxparameters(Tar::PDELogTargetDensity, θ)
a = ComponentArrays.ComponentArray(NamedTuple{Tar.names}(i for i in [
vector_to_parameters(θ[1:(end - Tar.extraparams)],
Tar.init_params),
]))

b = θ[(end - Tar.extraparams + 1):end]

ComponentArrays.ComponentArray(;
depvar = a,
p = b)
end
LogDensityProblems.dimension(Tar::PDELogTargetDensity) = Tar.dim

function LogDensityProblems.capabilities(::PDELogTargetDensity)
LogDensityProblems.LogDensityOrder{1}()
end

function L2loss2(Tar::PDELogTargetDensity, θ)
return logpdf(MvNormal(pde(phi, Tar.dataset[end], θ)), zeros(length(pde_eqs)))
end
# L2 losses loglikelihood(needed mainly for ODE parameter estimation)
function L2LossData(Tar::PDELogTargetDensity, θ)
return logpdf(MvNormal(Tar.Phi[1](Tar.dataset[end]',
vector_to_parameters(θ[1:(end - Tar.extraparams)],
Tar.init_params))[1,
:], ones(length(Tar.dataset[end])) .* Tar.allstd[3][1]), zeros(length(Tar.dataset[end])))
# matrix(each row corresponds to vector u's rows)
if Tar.dataset isa Vector{Nothing} || Tar.extraparams == 0
return 0
else
nn = [phi(Tar.dataset[end]', θ[1:(length(θ) - Tar.extraparams)])
for phi in Tar.Phi]

L2logprob = 0
for i in 1:(length(Tar.dataset) - 1)
# for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra
L2logprob += logpdf(MvNormal(nn[i][:],
ones(length(Tar.dataset[end])) .* Tar.allstd[3]),
Tar.dataset[i])
end

return L2logprob
end
# if Tar.dataset isa Vector{Nothing} || Tar.extraparams == 0
# return 0
# else
# nn = [phi(Tar.dataset[end]', θ[1:(length(θ) - Tar.extraparams)])
# for phi in Tar.Phi]

# L2logprob = 0
# for i in 1:(length(Tar.dataset) - 1)
# # for u[i] ith vector must be added to dataset,nn[1,:] is the dx in lotka_volterra
# L2logprob += logpdf(MvNormal(nn[i][:],
# ones(length(Tar.dataset[end])) .* Tar.allstd[3]),
# Tar.dataset[i])
# end

# return L2logprob
# end
return 0
end

# priors for NN parameters + ODE constants
Expand Down Expand Up @@ -188,8 +234,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
init_params = nothing, draw_samples = 1000,
physdt = 1 / 20.0, bcstd = [0.01], l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, autodiff = false,
Kernel = HMC,
param = [], nchains = 1, Kernel = HMC,
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
Expand All @@ -207,35 +252,45 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
# for new L2 loss
# discretization.additional_loss =

# remove inv params
initial_nnθ = pinnrep.flat_init_params[1:(end - length(param))]
# converting vector of parameters to ComponentArray for runtimegenerated functions
names = ntuple(i -> pinnrep.depvars[i], length(discretization.chain))

if nchains > Threads.nthreads()
throw(error("number of chains is greater than available threads"))
elseif nchains < 1
throw(error("number of chains must be greater than 1"))
end

if chain isa Lux.AbstractExplicitLayer
# Lux chain(using component array later as vector_to_parameter need namedtuple,AHMC uses Float64)
initial_θ = collect(Float64, vcat(ComponentArrays.ComponentArray(initial_nnθ)))
# namedtuple form of Lux params required for RuntimeGeneratedFunctions
initial_nnθ, st = Lux.setup(Random.default_rng(), chain)
initial_nnθ = pinnrep.flat_init_params[1:(end - length(param))]
if discretization.multioutput
if chain[1] isa Lux.AbstractExplicitLayer
# Lux chain(using component array later as vector_to_parameter need namedtuple,AHMC uses Float64)
initial_θ = collect(Float64, vcat(ComponentArrays.ComponentArray(initial_nnθ)))
# namedtuple form of Lux params required for RuntimeGeneratedFunctions
initial_nnθ, st = Lux.setup(Random.default_rng(), chain[1])
else
# remove inv params take only NN params
initial_θ = collect(Float64, initial_nnθ)
end
else
# flat_init_params contains also inv params
# initial_θ = collect(Float64, initial_nnθ[1:(length(initial_nnθ) - length(param))])
initial_θ = collect(Float64, initial_nnθ)
if chain isa Lux.AbstractExplicitLayer
# Lux chain(using component array later as vector_to_parameter need namedtuple,AHMC uses Float64)
initial_θ = collect(Float64, vcat(ComponentArrays.ComponentArray(initial_nnθ)))
# namedtuple form of Lux params required for RuntimeGeneratedFunctions
initial_nnθ, st = Lux.setup(Random.default_rng(), chain)
else
# remove inv params take only NN params
initial_nnθ = pinnrep.flat_init_params[1:(end - length(param))]
initial_θ = collect(Float64, initial_nnθ)
end
end

# adding ode parameter estimation
#ode parameter estimation
nparameters = length(initial_θ)

# println(Tar.Phi(initial_θ))

ninv = length(param)
priors = [MvNormal(priorsNNw[1] * ones(nparameters), priorsNNw[2] * ones(nparameters))]

# append Ode params to all paramvector
# append Ode params to all paramvector - initial_θ
if ninv > 0
# shift ode params(initialise ode params by prior means)
# check if means or user speified is better
Expand All @@ -252,14 +307,14 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
dataset,
priors,
[phystd, bcstd, l2std],
autodiff,
names,
physdt,
ninv,
initial_nnθ,
full_weighted_loglikelihood,
Phi)

println(ℓπ.full_loglikelihood(initial_θ, ℓπ.allstd))
println(ℓπ.full_loglikelihood(setLuxparameters(ℓπ, initial_θ), ℓπ.allstd))
println(priorlogpdf(ℓπ, initial_θ))
println(L2LossData(ℓπ, initial_θ))

Expand Down Expand Up @@ -309,6 +364,10 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)

println(ℓπ.full_loglikelihood(setLuxparameters(ℓπ, samples[end]),
ℓπ.allstd))
println(priorlogpdf(ℓπ, samples[end]))
println(L2LossData(ℓπ, samples[end]))
# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')
Expand Down
Loading

0 comments on commit 805380c

Please sign in to comment.