Skip to content

Commit

Permalink
testing partially complete, minor changes overall
Browse files Browse the repository at this point in the history
  • Loading branch information
AstitvaAggarwal committed Nov 18, 2023
1 parent a2a2292 commit cf2faa4
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 446 deletions.
24 changes: 20 additions & 4 deletions src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ)
# + L2loss2(Tar, θ)
end

function L2loss2(Tar::PDELogTargetDensity, θ)
return Tar.full_loglikelihood(setparameters(Tar, θ),
Tar.allstd)
end
# function L2loss2(Tar::PDELogTargetDensity, θ)
# return Tar.full_loglikelihood(setparameters(Tar, θ),
# Tar.allstd)
# end

function setparameters(Tar::PDELogTargetDensity, θ)
names = Tar.names
ps_new = θ[1:(end - Tar.extraparams)]
Expand Down Expand Up @@ -291,6 +292,12 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
metric = Metric(nparameters)
hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff)

println("Current Physics Log-likelihood : ",
ℓπ.full_loglikelihood(setparameters(ℓπ, initial_θ),
ℓπ.allstd))
println("Current Prior Log-likelihood : ", priorlogpdf(ℓπ, initial_θ))
println("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ))

# parallel sampling option
if nchains != 1
# Cache to store the chains
Expand Down Expand Up @@ -333,6 +340,15 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
mcmc_chain = MCMCChains.Chains(matrix_samples')

println("Sampling Complete.")
println("Current Physics Log-likelihood : ",
ℓπ.full_loglikelihood(setparameters(ℓπ, samples[end]),
ℓπ.allstd))
println("Current Prior Log-likelihood : ", priorlogpdf(ℓπ, samples[end]))
println("Current MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))

return mcmc_chain, samples, stats
end
end
7 changes: 3 additions & 4 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ Cool function needed for converting vector of sampled parameters into ComponentV
the sampled parameters are of exotic type `Dual` due to ForwardDiff's autodiff tagging
"""
function vector_to_parameters(ps_new::AbstractVector,
ps::Union{ComponentArrays.ComponentVector, AbstractVector})
if ps isa ComponentArrays.ComponentVector
ps::Union{NamedTuple, ComponentArrays.ComponentVector, AbstractVector})
if (ps isa ComponentArrays.ComponentVector) || (ps isa NamedTuple)
@assert length(ps_new) == Lux.parameterlength(ps)
i = 1
function get_ps(x)
Expand Down Expand Up @@ -563,7 +563,6 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
println("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, initial_θ))
println("Current Prior Log-likelihood : ", priorweights(ℓπ, initial_θ))
println("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ))
println("Current custom loss Log-likelihood : ", L2loss2(ℓπ, initial_θ))

Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor],
Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate]
Expand Down Expand Up @@ -611,11 +610,11 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples,
adaptor; progress = progress, verbose = verbose)

println("Sampling Complete.")
println("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end]))
println("Current Prior Log-likelihood : ", priorweights(ℓπ, samples[end]))
println("Current MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))
println("Current custom loss Log-likelihood : ", L2loss2(ℓπ, samples[end]))

# return a chain(basic chain),samples and stats
matrix_samples = hcat(samples...)
Expand Down
Loading

0 comments on commit cf2faa4

Please sign in to comment.