diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 152ce77e69..41d48793d3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -18,6 +18,7 @@ jobs: matrix: group: - ODEBPINN + - PDEBPINN - NNPDE1 - NNPDE2 - AdaptiveLoss diff --git a/docs/src/tutorials/Lotka_Volterra_BPINNs.md b/docs/src/tutorials/Lotka_Volterra_BPINNs.md index 5937f8d0dc..cbbfe3d4db 100644 --- a/docs/src/tutorials/Lotka_Volterra_BPINNs.md +++ b/docs/src/tutorials/Lotka_Volterra_BPINNs.md @@ -108,7 +108,7 @@ plot!(solution, labels = ["true x" "true y"]) We can see the estimated ODE parameters by - ```@example bpinn -sol_pestim.estimated_ode_params +sol_pestim.estimated_de_params ``` We can see it is close to the true values of the parameters. diff --git a/src/BPINN_ode.jl b/src/BPINN_ode.jl index 5d937e4cf1..7a013371e7 100644 --- a/src/BPINN_ode.jl +++ b/src/BPINN_ode.jl @@ -148,21 +148,24 @@ end BPINN Solution contains the original solution from AdvancedHMC.jl sampling(BPINNstats contains fields related to that) > ensemblesol is the Probabilistic Estimate(MonteCarloMeasurements.jl Particles type) of Ensemble solution from All Neural Network's(made using all sampled parameters) output's. > estimated_nn_params - Probabilistic Estimate of NN params from sampled weights,biases -> estimated_ode_params - Probabilistic Estimate of ODE params from sampled unknown ode paramters +> estimated_de_params - Probabilistic Estimate of DE params from sampled unknown DE paramters """ -struct BPINNsolution{O <: BPINNstats, E, - NP <: Vector{<:MonteCarloMeasurements.Particles{<:Float64}}, - OP <: Union{Vector{Nothing}, - Vector{<:MonteCarloMeasurements.Particles{<:Float64}}}} + +struct BPINNsolution{O <: BPINNstats, E, NP, OP, P} original::O ensemblesol::E estimated_nn_params::NP - estimated_ode_params::OP - - function BPINNsolution(original, ensemblesol, estimated_nn_params, estimated_ode_params) + estimated_de_params::OP + timepoints::P + + function BPINNsolution(original, + ensemblesol, + estimated_nn_params, + estimated_de_params, + timepoints) new{typeof(original), typeof(ensemblesol), typeof(estimated_nn_params), - typeof(estimated_ode_params)}(original, ensemblesol, estimated_nn_params, - estimated_ode_params) + typeof(estimated_de_params), typeof(timepoints)}(original, ensemblesol, estimated_nn_params, + estimated_de_params, timepoints) end end @@ -260,14 +263,14 @@ function DiffEqBase.__solve(prob::DiffEqBase.ODEProblem, end nnparams = length(θinit) - estimnnparams = [Particles(reduce(hcat, samples)[i, :]) for i in 1:nnparams] + estimnnparams = [Particles(reduce(hcat, samples[(end - numensemble):end])[i, :]) for i in 1:nnparams] if ninv == 0 estimated_params = [nothing] else - estimated_params = [Particles(reduce(hcat, samples[(end - ninv + 1):end])[i, :]) + estimated_params = [Particles(reduce(hcat, samples[(end - numensemble):end])[i, :]) for i in (nnparams + 1):(nnparams + ninv)] end - BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params) + BPINNsolution(fullsolution, ensemblecurves, estimnnparams, estimated_params, t) end \ No newline at end of file diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 945093ea04..6f7dbaf839 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -52,19 +52,21 @@ include("discretize.jl") include("neural_adapter.jl") include("advancedHMC_MCMC.jl") include("BPINN_ode.jl") +include("PDE_BPINN.jl") export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE, - KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem, - KolmogorovParamDomain, NNParamKolmogorov, - PhysicsInformedNN, discretize, - GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining, - WeightedIntervalTraining, - build_loss_function, get_loss_function, - generate_training_sets, get_variables, get_argument, get_bounds, - get_phi, get_numeric_derivative, get_numeric_integral, - build_symbolic_equation, build_symbolic_loss_function, symbolic_discretize, - AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss, - MiniMaxAdaptiveLoss, - LogOptions, ahmc_bayesian_pinn_ode, BNNODE + KolmogorovPDEProblem, NNKolmogorov, NNStopping, ParamKolmogorovPDEProblem, + KolmogorovParamDomain, NNParamKolmogorov, + PhysicsInformedNN, discretize, + GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining, + WeightedIntervalTraining, + build_loss_function, get_loss_function, + generate_training_sets, get_variables, get_argument, get_bounds, + get_phi, get_numeric_derivative, get_numeric_integral, + build_symbolic_equation, build_symbolic_loss_function, symbolic_discretize, + AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss, + MiniMaxAdaptiveLoss, LogOptions, + ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters, + BPINNsolution end # module diff --git a/src/PDE_BPINN.jl b/src/PDE_BPINN.jl new file mode 100644 index 0000000000..504e347eeb --- /dev/null +++ b/src/PDE_BPINN.jl @@ -0,0 +1,485 @@ +mutable struct PDELogTargetDensity{ + ST <: AbstractTrainingStrategy, + D <: Union{Nothing, Vector{<:Matrix{<:Real}}}, + P <: Vector{<:Distribution}, + I, + F, + PH, +} + dim::Int64 + strategy::ST + dataset::D + priors::P + allstd::Vector{Vector{Float64}} + names::Tuple + extraparams::Int + init_params::I + full_loglikelihood::F + Φ::PH + + function PDELogTargetDensity(dim, strategy, dataset, + priors, allstd, names, extraparams, + init_params::AbstractVector, full_loglikelihood, Φ) + new{ + typeof(strategy), + typeof(dataset), + typeof(priors), + typeof(init_params), + typeof(full_loglikelihood), + typeof(Φ), + }(dim, + strategy, + dataset, + priors, + allstd, + names, + extraparams, + init_params, + full_loglikelihood, + Φ) + end + function PDELogTargetDensity(dim, strategy, dataset, + priors, allstd, names, extraparams, + init_params::Union{NamedTuple, ComponentArrays.ComponentVector}, + full_loglikelihood, Φ) + new{ + typeof(strategy), + typeof(dataset), + typeof(priors), + typeof(init_params), + typeof(full_loglikelihood), + typeof(Φ), + }(dim, + strategy, + dataset, + priors, + allstd, + names, + extraparams, + init_params, + full_loglikelihood, + Φ) + end +end + +function LogDensityProblems.logdensity(Tar::PDELogTargetDensity, θ) + # for parameter estimation neccesarry to use multioutput case + return Tar.full_loglikelihood(setparameters(Tar, θ), + Tar.allstd) + priorlogpdf(Tar, θ) + L2LossData(Tar, θ) + # + L2loss2(Tar, θ) +end + +# function L2loss2(Tar::PDELogTargetDensity, θ) +# return Tar.full_loglikelihood(setparameters(Tar, θ), +# Tar.allstd) +# end + +function setparameters(Tar::PDELogTargetDensity, θ) + names = Tar.names + ps_new = θ[1:(end - Tar.extraparams)] + ps = Tar.init_params + + if (ps[names[1]] isa ComponentArrays.ComponentVector) + # multioutput case for Lux chains, for each depvar ps would contain Lux ComponentVectors + # which we use for mapping current ahmc sampled vector of parameters onto NNs + + i = 0 + Luxparams = [vector_to_parameters(ps_new[((i += length(ps[x])) - length(ps[x]) + 1):i], + ps[x]) for x in names] + + else + # multioutput Flux + Luxparams = θ + end + + if (Luxparams isa AbstractVector) && (Luxparams[1] isa ComponentArrays.ComponentVector) + # multioutput Lux + a = ComponentArrays.ComponentArray(NamedTuple{Tar.names}(i for i in Luxparams)) + + if Tar.extraparams > 0 + b = θ[(end - Tar.extraparams + 1):end] + + return ComponentArrays.ComponentArray(; + depvar = a, + p = b) + else + return ComponentArrays.ComponentArray(; + depvar = a) + end + else + # multioutput fLux case + return vector_to_parameters(Luxparams, ps) + end +end + +LogDensityProblems.dimension(Tar::PDELogTargetDensity) = Tar.dim + +function LogDensityProblems.capabilities(::PDELogTargetDensity) + LogDensityProblems.LogDensityOrder{1}() +end + +# L2 losses loglikelihood(needed mainly for ODE parameter estimation) +function L2LossData(Tar::PDELogTargetDensity, θ) + Φ = Tar.Φ + init_params = Tar.init_params + dataset = Tar.dataset + sumt = 0 + L2stds = Tar.allstd[3] + # each dep var has a diff dataset depending on its indep var and thier domains + # these datasets are matrices of first col-dep var and remaining cols-all indep var + # Tar.init_params is needed to contruct a vector of parameters into a ComponentVector + + # dataset of form Vector[matrix_x, matrix_y, matrix_z] + # matrix_i is of form [i,indvar1,indvar2,..] (needed in case if heterogenous domains) + + # Phi is the trial solution for each NN in chain array + # Creating logpdf( MvNormal(Phi(t,θ),std), dataset[i] ) + # dataset[i][:, 2:end] -> indepvar cols of a particular depvar's dataset + # dataset[i][:, 1] -> depvar col of depvar's dataset + + if Tar.extraparams > 0 + if Tar.init_params isa ComponentArrays.ComponentVector + for i in eachindex(Φ) + sumt += logpdf(MvNormal(Φ[i](dataset[i][:, 2:end]', + vector_to_parameters(θ[1:(end - Tar.extraparams)], + init_params)[Tar.names[i]])[1, + :], + LinearAlgebra.Diagonal(abs2.(ones(size(dataset[i])[1]) .* + L2stds[i]))), + dataset[i][:, 1]) + end + sumt + else + # Flux case needs subindexing wrt Tar.names indices(hence stored in Tar.names) + for i in eachindex(Φ) + sumt += logpdf(MvNormal(Φ[i](dataset[i][:, 2:end]', + vector_to_parameters(θ[1:(end - Tar.extraparams)], + init_params)[Tar.names[2][i]])[1, + :], + LinearAlgebra.Diagonal(abs2.(ones(size(dataset[i])[1]) .* + L2stds[i]))), + dataset[i][:, 1]) + end + sumt + end + else + return 0 + end +end + +# priors for NN parameters + ODE constants +function priorlogpdf(Tar::PDELogTargetDensity, θ) + allparams = Tar.priors + # Vector of ode parameters priors + invpriors = allparams[2:end] + + # nn weights + nnwparams = allparams[1] + + if Tar.extraparams > 0 + invlogpdf = sum(logpdf(invpriors[length(θ) - i + 1], θ[i]) + for i in (length(θ) - Tar.extraparams + 1):length(θ); init = 0.0) + + return (invlogpdf + + + logpdf(nnwparams, θ[1:(length(θ) - Tar.extraparams)])) + else + return logpdf(nnwparams, θ) + end +end + +function integratorchoice(Integratorkwargs, initial_ϵ) + Integrator = Integratorkwargs[:Integrator] + if Integrator == JitteredLeapfrog + jitter_rate = Integratorkwargs[:jitter_rate] + Integrator(initial_ϵ, jitter_rate) + elseif Integrator == TemperedLeapfrog + tempering_rate = Integratorkwargs[:tempering_rate] + Integrator(initial_ϵ, tempering_rate) + else + Integrator(initial_ϵ) + end +end + +function adaptorchoice(Adaptor, mma, ssa) + if Adaptor != AdvancedHMC.NoAdaptation() + Adaptor(mma, ssa) + else + AdvancedHMC.NoAdaptation() + end +end + +function inference(samples, pinnrep, saveats, numensemble, ℓπ) + domains = pinnrep.domains + phi = pinnrep.phi + dict_depvar_input = pinnrep.dict_depvar_input + depvars = pinnrep.depvars + + names = ℓπ.names + initial_nnθ = ℓπ.init_params + ninv = ℓπ.extraparams + + ranges = Dict([Symbol(d.variables) => infimum(d.domain):dx:supremum(d.domain) + for (d, dx) in zip(domains, saveats)]) + inputs = [dict_depvar_input[i] for i in depvars] + + span = [[ranges[indvar] for indvar in input] for input in inputs] + timepoints = [hcat(vec(map(points -> collect(points), + Iterators.product(span[i]...)))...) + for i in eachindex(phi)] + + # order of range's domains must match chain's inputs and dep_vars + samples = samples[(end - numensemble):end] + nnparams = length(samples[1][1:(end - ninv)]) + # get rows-ith param and col-ith sample value + estimnnparams = [Particles(reduce(hcat, samples)[i, :]) + for i in 1:nnparams] + + # PDE params + if ninv == 0 + estimated_params = [nothing] + else + estimated_params = [Particles(reduce(hcat, samples)[i, :]) + for i in (nnparams + 1):(nnparams + ninv)] + end + + # names is an indicator of type of chain + if names[1] != 1 + # getting parameter ranges in case of Lux chains + Luxparams = [] + i = 0 + for x in names + len = length(initial_nnθ[x]) + push!(Luxparams, (i + 1):(i + len)) + i += len + end + + # convert to format directly usable by lux + estimatedLuxparams = [vector_to_parameters(estimnnparams[Luxparams[i]], + initial_nnθ[names[i]]) for i in eachindex(phi)] + + # infer predictions(preds) each row - NN, each col - ith sample + samplesn = reduce(hcat, samples) + preds = [] + for j in eachindex(phi) + push!(preds, + [phi[j](timepoints[j], + vector_to_parameters(samplesn[:, i][Luxparams[j]], + initial_nnθ[names[j]])) for i in 1:numensemble]) + end + + # note here no of samples referse to numensemble and points is the no of points in each dep_vars discretization + # each phi will give output in single domain of depvar(so we have each row as a vector of vector outputs) + # so we get after reduce a single matrix of n rows(samples), and j cols(points) + ensemblecurves = [Particles(reduce(vcat, preds[i])) for i in eachindex(phi)] + + return ensemblecurves, estimatedLuxparams, estimated_params, timepoints + else + # get intervals for parameters corresponding to flux chains + Fluxparams = names[2] + + # convert to format directly usable by Flux + estimatedFluxparams = [estimnnparams[Fluxparams[i]] for i in eachindex(phi)] + + # infer predictions(preds) each row - NN, each col - ith sample + samplesn = reduce(hcat, samples) + preds = [] + for j in eachindex(phi) + push!(preds, + [phi[j](timepoints[j], samplesn[:, i][Fluxparams[j]]) for i in 1:numensemble]) + end + + ensemblecurves = [Particles(reduce(vcat, preds[i])) for i in eachindex(phi)] + + return ensemblecurves, estimatedFluxparams, estimated_params, timepoints + end +end + +# priors: pdf for W,b + pdf for ODE params +function ahmc_bayesian_pinn_pde(pde_system, discretization; + draw_samples = 1000, + bcstd = [0.01], l2std = [0.05], + phystd = [0.05], priorsNNw = (0.0, 2.0), + param = [], nchains = 1, Kernel = HMC(0.1, 30), + Adaptorkwargs = (Adaptor = StanHMCAdaptor, + Metric = DiagEuclideanMetric, targetacceptancerate = 0.8), + Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0], + numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false) + pinnrep = symbolic_discretize(pde_system, discretization) + dataset_pde, dataset_bc = discretization.dataset + + if ((dataset_bc isa Nothing) && (dataset_pde isa Nothing)) + dataset = nothing + elseif dataset_bc isa Nothing + dataset = dataset_pde + elseif dataset_pde isa Nothing + dataset = dataset_bc + else + dataset = [vcat(dataset_pde[i], dataset_bc[i]) for i in eachindex(dataset_pde)] + end + + if discretization.param_estim && isempty(param) + throw(UndefVarError(:param)) + elseif discretization.param_estim && dataset isa Nothing + throw(UndefVarError(:dataset)) + elseif discretization.param_estim && length(l2std) != length(pinnrep.depvars) + throw(error("L2 stds length must match number of dependant variables")) + end + + # for physics loglikelihood + full_weighted_loglikelihood = pinnrep.loss_functions.full_loss_function + chain = discretization.chain + + if length(pinnrep.domains) != length(saveats) + throw(error("Number of independant variables must match saveat inference discretization steps")) + end + + # NN solutions for loglikelihood which is used for L2lossdata + Φ = pinnrep.phi + + # for new L2 loss + # discretization.additional_loss = + + if nchains < 1 + throw(error("number of chains must be greater than or equal to 1")) + end + + # remove inv params take only NN params, AHMC uses Float64 + initial_nnθ = pinnrep.flat_init_params[1:(end - length(param))] + initial_θ = collect(Float64, initial_nnθ) + + # contains only NN parameters + initial_nnθ = pinnrep.init_params + + if (discretization.multioutput && chain[1] isa Lux.AbstractExplicitLayer) + # converting vector of parameters to ComponentArray for runtimegenerated functions + names = ntuple(i -> pinnrep.depvars[i], length(chain)) + else + # Flux multioutput + i = 0 + temp = [] + for j in eachindex(initial_nnθ) + len = length(initial_nnθ[j]) + push!(temp, (i + 1):(i + len)) + i += len + end + names = tuple(1, temp) + end + + #ode parameter estimation + nparameters = length(initial_θ) + ninv = length(param) + priors = [ + MvNormal(priorsNNw[1] * ones(nparameters), + LinearAlgebra.Diagonal(abs2.(priorsNNw[2] .* ones(nparameters)))), + ] + + # append Ode params to all paramvector - initial_θ + if ninv > 0 + # shift ode params(initialise ode params by prior means) + # check if means or user speified is better + initial_θ = vcat(initial_θ, [Distributions.params(param[i])[1] for i in 1:ninv]) + priors = vcat(priors, param) + nparameters += ninv + end + + # vector in case of N-dimensional domains + strategy = discretization.strategy + + # dimensions would be total no of params,initial_nnθ for Lux namedTuples + ℓπ = PDELogTargetDensity(nparameters, + strategy, + dataset, + priors, + [phystd, bcstd, l2std], + names, + ninv, + initial_nnθ, + full_weighted_loglikelihood, + Φ) + + Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor], + Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate] + + # Define Hamiltonian system (nparameters ~ dimensionality of the sampling space) + metric = Metric(nparameters) + hamiltonian = Hamiltonian(metric, ℓπ, ForwardDiff) + + @info("Current Physics Log-likelihood : ", + ℓπ.full_loglikelihood(setparameters(ℓπ, initial_θ), + ℓπ.allstd)) + @info("Current Prior Log-likelihood : ", priorlogpdf(ℓπ, initial_θ)) + @info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ)) + + # parallel sampling option + if nchains != 1 + + # Cache to store the chains + bpinnsols = Vector{Any}(undef, nchains) + + Threads.@threads for i in 1:nchains + # each chain has different initial NNparameter values(better posterior exploration) + initial_θ = vcat(randn(nparameters - ninv), + initial_θ[(nparameters - ninv + 1):end]) + initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) + integrator = integratorchoice(Integratorkwargs, initial_ϵ) + adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric), + StepSizeAdaptor(targetacceptancerate, integrator)) + Kernel = AdvancedHMC.make_kernel(Kernel, integrator) + samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; + progress = progress, verbose = verbose) + + # return a chain(basic chain),samples and stats + matrix_samples = hcat(samples...) + mcmc_chain = MCMCChains.Chains(matrix_samples') + + fullsolution = BPINNstats(mcmc_chain, samples, stats) + ensemblecurves, estimnnparams, estimated_params, timepoints = inference(samples, + pinnrep, + saveat, + numensemble, + ℓπ) + + bpinnsols[i] = BPINNsolution(fullsolution, + ensemblecurves, + estimnnparams, + estimated_params, + timepoints) + end + return bpinnsols + else + initial_ϵ = find_good_stepsize(hamiltonian, initial_θ) + integrator = integratorchoice(Integratorkwargs, initial_ϵ) + adaptor = adaptorchoice(Adaptor, MassMatrixAdaptor(metric), + StepSizeAdaptor(targetacceptancerate, integrator)) + + Kernel = AdvancedHMC.make_kernel(Kernel, integrator) + samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, + adaptor; progress = progress, verbose = verbose) + + # return a chain(basic chain),samples and stats + matrix_samples = hcat(samples...) + mcmc_chain = MCMCChains.Chains(matrix_samples') + + @info("Sampling Complete.") + @info("Current Physics Log-likelihood : ", + ℓπ.full_loglikelihood(setparameters(ℓπ, samples[end]), + ℓπ.allstd)) + @info("Current Prior Log-likelihood : ", priorlogpdf(ℓπ, samples[end])) + @info("Current MSE against dataset Log-likelihood : ", + L2LossData(ℓπ, samples[end])) + + fullsolution = BPINNstats(mcmc_chain, samples, stats) + ensemblecurves, estimnnparams, estimated_params, timepoints = inference(samples, + pinnrep, + saveats, + numensemble, + ℓπ) + + return BPINNsolution(fullsolution, + ensemblecurves, + estimnnparams, + estimated_params, + timepoints) + end +end \ No newline at end of file diff --git a/src/advancedHMC_MCMC.jl b/src/advancedHMC_MCMC.jl index df174e2538..0a9c569cc4 100644 --- a/src/advancedHMC_MCMC.jl +++ b/src/advancedHMC_MCMC.jl @@ -65,9 +65,11 @@ mutable struct LogTargetDensity{C, S, ST <: AbstractTrainingStrategy, I, end """ -cool function to convert parameter's vector to ComponentArray of parameters (for Lux Chain: vector of samples -> Lux ComponentArrays) +function needed for converting vector of sampled parameters into ComponentVector in case of Lux chain output, derivatives +the sampled parameters are of exotic type `Dual` due to ForwardDiff's autodiff tagging """ -function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) +function vector_to_parameters(ps_new::AbstractVector, + ps::Union{NamedTuple, ComponentArrays.ComponentVector}) @assert length(ps_new) == Lux.parameterlength(ps) i = 1 function get_ps(x) @@ -78,6 +80,8 @@ function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) return Functors.fmap(get_ps, ps) end +vector_to_parameters(ps_new::AbstractVector, ps::AbstractVector) = ps_new + function LogDensityProblems.logdensity(Tar::LogTargetDensity, θ) return physloglikelihood(Tar, θ) + priorweights(Tar, θ) + L2LossData(Tar, θ) end @@ -552,6 +556,10 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; end end + @info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, initial_θ)) + @info("Current Prior Log-likelihood : ", priorweights(ℓπ, initial_θ)) + @info("Current MSE against dataset Log-likelihood : ", L2LossData(ℓπ, initial_θ)) + Adaptor, Metric, targetacceptancerate = Adaptorkwargs[:Adaptor], Adaptorkwargs[:Metric], Adaptorkwargs[:targetacceptancerate] @@ -598,6 +606,12 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain; samples, stats = sample(hamiltonian, Kernel, initial_θ, draw_samples, adaptor; progress = progress, verbose = verbose) + @info("Sampling Complete.") + @info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end])) + @info("Current Prior Log-likelihood : ", priorweights(ℓπ, samples[end])) + @info("Current MSE against dataset Log-likelihood : ", + L2LossData(ℓπ, samples[end])) + # return a chain(basic chain),samples and stats matrix_samples = hcat(samples...) mcmc_chain = MCMCChains.Chains(matrix_samples') diff --git a/src/discretize.jl b/src/discretize.jl index 4308a79b4e..400be5d2c1 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -401,7 +401,7 @@ to the PDE. For more information, see `discretize` and `PINNRepresentation`. """ function SciMLBase.symbolic_discretize(pde_system::PDESystem, - discretization::PhysicsInformedNN) + discretization::AbstractPINN) eqs = pde_system.eqs bcs = pde_system.bcs chain = discretization.chain @@ -567,7 +567,6 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, strategy, datafree_pde_loss_functions, datafree_bc_loss_functions) - # setup for all adaptive losses num_pde_losses = length(pde_loss_functions) num_bc_losses = length(bc_loss_functions) @@ -586,88 +585,188 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, pde_loss_functions, bc_loss_functions) - function full_loss_function(θ, p) + function get_likelihood_estimate_function(discretization::PhysicsInformedNN) + function full_loss_function(θ, p) + # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them + pde_losses = [pde_loss_function(θ) for pde_loss_function in pde_loss_functions] + bc_losses = [bc_loss_function(θ) for bc_loss_function in bc_loss_functions] - # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them - pde_losses = [pde_loss_function(θ) for pde_loss_function in pde_loss_functions] - bc_losses = [bc_loss_function(θ) for bc_loss_function in bc_loss_functions] + # this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized + # that's why we prefer the user to maintain the increment in the outer loop callback during optimization + ChainRulesCore.@ignore_derivatives if self_increment + iteration[1] += 1 + end - # this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized - # that's why we prefer the user to maintain the increment in the outer loop callback during optimization - ChainRulesCore.@ignore_derivatives if self_increment - iteration[1] += 1 - end + ChainRulesCore.@ignore_derivatives begin + reweight_losses_func(θ, pde_losses, + bc_losses) + end + + weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses + weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses + + sum_weighted_pde_losses = sum(weighted_pde_losses) + sum_weighted_bc_losses = sum(weighted_bc_losses) + weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_bc_losses + + full_weighted_loss = if additional_loss isa Nothing + weighted_loss_before_additional + else + function _additional_loss(phi, θ) + (θ_, p_) = if (param_estim == true) + if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || + (!(phi isa Vector) && phi.f isa Optimisers.Restructure) + # Isa Flux Chain + θ[1:(end - length(default_p))], θ[(end - length(default_p) + 1):end] + else + θ.depvar, θ.p + end + else + θ, nothing + end + return additional_loss(phi, θ_, p_) + end + weighted_additional_loss_val = adaloss.additional_loss_weights[1] * + _additional_loss(phi, θ) + weighted_loss_before_additional + weighted_additional_loss_val + end - ChainRulesCore.@ignore_derivatives begin reweight_losses_func(θ, pde_losses, - bc_losses) end + ChainRulesCore.@ignore_derivatives begin + if iteration[1] % log_frequency == 0 + logvector(pinnrep.logger, pde_losses, "unweighted_loss/pde_losses", + iteration[1]) + logvector(pinnrep.logger, + bc_losses, + "unweighted_loss/bc_losses", + iteration[1]) + logvector(pinnrep.logger, weighted_pde_losses, + "weighted_loss/weighted_pde_losses", + iteration[1]) + logvector(pinnrep.logger, weighted_bc_losses, + "weighted_loss/weighted_bc_losses", + iteration[1]) + if !(additional_loss isa Nothing) + logscalar(pinnrep.logger, weighted_additional_loss_val, + "weighted_loss/weighted_additional_loss", iteration[1]) + end + logscalar(pinnrep.logger, sum_weighted_pde_losses, + "weighted_loss/sum_weighted_pde_losses", iteration[1]) + logscalar(pinnrep.logger, sum_weighted_bc_losses, + "weighted_loss/sum_weighted_bc_losses", iteration[1]) + logscalar(pinnrep.logger, full_weighted_loss, + "weighted_loss/full_weighted_loss", + iteration[1]) + logvector(pinnrep.logger, adaloss.pde_loss_weights, + "adaptive_loss/pde_loss_weights", + iteration[1]) + logvector(pinnrep.logger, adaloss.bc_loss_weights, + "adaptive_loss/bc_loss_weights", + iteration[1]) + end + end - weighted_pde_losses = adaloss.pde_loss_weights .* pde_losses - weighted_bc_losses = adaloss.bc_loss_weights .* bc_losses + return full_weighted_loss + end - sum_weighted_pde_losses = sum(weighted_pde_losses) - sum_weighted_bc_losses = sum(weighted_bc_losses) - weighted_loss_before_additional = sum_weighted_pde_losses + sum_weighted_bc_losses + return full_loss_function + end - full_weighted_loss = if additional_loss isa Nothing - weighted_loss_before_additional + function get_likelihood_estimate_function(discretization::BayesianPINN) + dataset_pde, dataset_bc = discretization.dataset + + # required as Physics loss also needed on the discrete dataset domain points + # data points are discrete and so by default GridTraining loss applies + # passing placeholder dx with GridTraining, it uses data points irl + datapde_loss_functions, databc_loss_functions = if (!(dataset_bc isa Nothing)||!(dataset_pde isa Nothing)) + merge_strategy_with_loglikelihood_function(pinnrep, + GridTraining(0.1), + datafree_pde_loss_functions, + datafree_bc_loss_functions, train_sets_pde = dataset_pde, train_sets_bc = dataset_bc) else - function _additional_loss(phi, θ) - (θ_, p_) = if (param_estim == true) - if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || - (!(phi isa Vector) && phi.f isa Optimisers.Restructure) - # Isa Flux Chain - θ[1:(end - length(default_p))], θ[(end - length(default_p) + 1):end] + (nothing, nothing) + end + + function full_loss_function(θ, allstd::Vector{Vector{Float64}}) + stdpdes, stdbcs, stdextra = allstd + # the aggregation happens on cpu even if the losses are gpu, probably fine since it's only a few of them + pde_loglikelihoods = [logpdf(Normal(0, stdpdes[i]), pde_loss_function(θ)) + for (i, pde_loss_function) in enumerate(pde_loss_functions)] + + bc_loglikelihoods = [logpdf(Normal(0, stdbcs[j]), bc_loss_function(θ)) + for (j, bc_loss_function) in enumerate(bc_loss_functions)] + + if !(datapde_loss_functions isa Nothing) + pde_loglikelihoods += [logpdf(Normal(0, stdpdes[j]), pde_loss_function(θ)) + for (j, pde_loss_function) in enumerate(datapde_loss_functions)] + + end + + if !(databc_loss_functions isa Nothing) + bc_loglikelihoods += [logpdf(Normal(0, stdbcs[j]), bc_loss_function(θ)) + for (j, bc_loss_function) in enumerate(databc_loss_functions)] + end + + # this is kind of a hack, and means that whenever the outer function is evaluated the increment goes up, even if it's not being optimized + # that's why we prefer the user to maintain the increment in the outer loop callback during optimization + ChainRulesCore.@ignore_derivatives if self_increment + iteration[1] += 1 + end + + ChainRulesCore.@ignore_derivatives begin + reweight_losses_func(θ, pde_loglikelihoods, + bc_loglikelihoods) + end + + weighted_pde_loglikelihood = adaloss.pde_loss_weights .* pde_loglikelihoods + weighted_bc_loglikelihood = adaloss.bc_loss_weights .* bc_loglikelihoods + + sum_weighted_pde_loglikelihood = sum(weighted_pde_loglikelihood) + sum_weighted_bc_loglikelihood = sum(weighted_bc_loglikelihood) + weighted_loglikelihood_before_additional = sum_weighted_pde_loglikelihood + + sum_weighted_bc_loglikelihood + + full_weighted_loglikelihood = if additional_loss isa Nothing + weighted_loglikelihood_before_additional + else + function _additional_loss(phi, θ) + (θ_, p_) = if (param_estim == true) + if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || + (!(phi isa Vector) && phi.f isa Optimisers.Restructure) + # Isa Flux Chain + θ[1:(end - length(default_p))], + θ[(end - length(default_p) + 1):end] + else + θ.depvar, θ.p + end else - θ.depvar, θ.p + θ, nothing end - else - θ, nothing + return additional_loss(phi, θ_, p_) end - return additional_loss(phi, θ_, p_) + + _additional_loglikelihood = logpdf(Normal(0, stdextra), + _additional_loss(phi, θ)) + + weighted_additional_loglikelihood = adaloss.additional_loss_weights[1] * + _additional_loglikelihood + + weighted_loglikelihood_before_additional + weighted_additional_loglikelihood end - weighted_additional_loss_val = adaloss.additional_loss_weights[1] * - _additional_loss(phi, θ) - weighted_loss_before_additional + weighted_additional_loss_val + + return full_weighted_loglikelihood end - ChainRulesCore.@ignore_derivatives begin if iteration[1] % log_frequency == 0 - logvector(pinnrep.logger, pde_losses, "unweighted_loss/pde_losses", - iteration[1]) - logvector(pinnrep.logger, bc_losses, "unweighted_loss/bc_losses", iteration[1]) - logvector(pinnrep.logger, weighted_pde_losses, - "weighted_loss/weighted_pde_losses", - iteration[1]) - logvector(pinnrep.logger, weighted_bc_losses, - "weighted_loss/weighted_bc_losses", - iteration[1]) - if !(additional_loss isa Nothing) - logscalar(pinnrep.logger, weighted_additional_loss_val, - "weighted_loss/weighted_additional_loss", iteration[1]) - end - logscalar(pinnrep.logger, sum_weighted_pde_losses, - "weighted_loss/sum_weighted_pde_losses", iteration[1]) - logscalar(pinnrep.logger, sum_weighted_bc_losses, - "weighted_loss/sum_weighted_bc_losses", iteration[1]) - logscalar(pinnrep.logger, full_weighted_loss, - "weighted_loss/full_weighted_loss", - iteration[1]) - logvector(pinnrep.logger, adaloss.pde_loss_weights, - "adaptive_loss/pde_loss_weights", - iteration[1]) - logvector(pinnrep.logger, adaloss.bc_loss_weights, - "adaptive_loss/bc_loss_weights", - iteration[1]) - end end - - return full_weighted_loss + return full_loss_function end + full_loss_function = get_likelihood_estimate_function(discretization) pinnrep.loss_functions = PINNLossFunctions(bc_loss_functions, pde_loss_functions, - full_loss_function, additional_loss, - datafree_pde_loss_functions, - datafree_bc_loss_functions) + full_loss_function, additional_loss, + datafree_pde_loss_functions, + datafree_bc_loss_functions) return pinnrep + end """ diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 428f03fc36..ea66f725fd 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -150,6 +150,136 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN end end + +""" +```julia +BayesianPINN(chain, + strategy; + init_params = nothing, + phi = nothing, + param_estim = false, + additional_loss = nothing, + adaptive_loss = nothing, + logger = nothing, + log_options = LogOptions(), + iteration = nothing, + dataset=nothing, + kwargs...) where {iip} +``` + +## Positional Arguments + +* `chain`: a vector of Flux.jl or Lux.jl chains with a d-dimensional input and a + 1-dimensional output corresponding to each of the dependent variables. Note that this + specification respects the order of the dependent variables as specified in the PDESystem. +* `strategy`: determines which training strategy will be used. See the Training Strategy + documentation for more details. + +## Keyword Arguments + +* `init_params`: the initial parameters of the neural networks. This should match the + specification of the chosen `chain` library. For example, if a Flux.chain is used, then + `init_params` should match `Flux.destructure(chain)[1]` in shape. If `init_params` is not + given, then the neural network default parameters are used. Note that for Lux, the default + will convert to Float64. +* `phi`: a trial solution, specified as `phi(x,p)` where `x` is the coordinates vector for + the dependent variable and `p` are the weights of the phi function (generally the weights + of the neural network defining `phi`). By default, this is generated from the `chain`. This + should only be used to more directly impose functional information in the training problem, + for example imposing the boundary condition by the test function formulation. +* `adaptive_loss`: the choice for the adaptive loss function. See the + [adaptive loss page](@ref adaptive_loss) for more details. Defaults to no adaptivity. +* `additional_loss`: a function `additional_loss(phi, θ, p_)` where `phi` are the neural + network trial solutions, `θ` are the weights of the neural network(s), and `p_` are the + hyperparameters . If `param_estim = true`, then `θ` additionally + contains the parameters of the differential equation appended to the end of the vector. +* `param_estim`: whether the parameters of the differential equation should be included in + the values sent to the `additional_loss` function. Defaults to `false`. +* `logger`: ?? needs docs +* `log_options`: ?? why is this separate from the logger? +* `iteration`: used to control the iteration counter??? +* `kwargs`: Extra keyword arguments. +""" +struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN + chain::Any + strategy::T + init_params::P + phi::PH + derivative::DER + param_estim::PE + additional_loss::AL + adaptive_loss::ADA + logger::LOG + log_options::LogOptions + iteration::Vector{Int64} + self_increment::Bool + multioutput::Bool + dataset::D + kwargs::K + + @add_kwonly function BayesianPINN(chain, + strategy; + init_params = nothing, + phi = nothing, + derivative = nothing, + param_estim = false, + additional_loss = nothing, + adaptive_loss = nothing, + logger = nothing, + log_options = LogOptions(), + iteration = nothing, + dataset = nothing, + kwargs...) + multioutput = chain isa AbstractArray + + if phi === nothing + if multioutput + _phi = Phi.(chain) + else + _phi = Phi(chain) + end + else + _phi = phi + end + + if derivative === nothing + _derivative = numeric_derivative + else + _derivative = derivative + end + + if iteration isa Vector{Int64} + self_increment = false + else + iteration = [1] + self_increment = true + end + + if dataset isa Nothing + dataset = (nothing, nothing) + end + + new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), + typeof(param_estim), + typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(dataset), + typeof(kwargs)}(chain, + strategy, + init_params, + _phi, + _derivative, + param_estim, + additional_loss, + adaptive_loss, + logger, + log_options, + iteration, + self_increment, + multioutput, + dataset, + kwargs) + end +end + """ `PINNRepresentation`` diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 8af7358753..ca66f6b203 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -16,6 +16,40 @@ struct GridTraining{T} <: AbstractTrainingStrategy dx::T end +# include dataset points in pde_residual loglikelihood (BayesianPINN) +function merge_strategy_with_loglikelihood_function(pinnrep::PINNRepresentation, + strategy::GridTraining, + datafree_pde_loss_function, + datafree_bc_loss_function; train_sets_pde = nothing,train_sets_bc=nothing) + @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + + eltypeθ = eltype(pinnrep.flat_init_params) + + # is vec as later each _set in pde_train_sets are coloumns as points transformed to vector of points (pde_train_sets must be rowwise) + pde_loss_functions = if !(train_sets_pde isa Nothing) + pde_train_sets = [train_set[:, 2:end] for train_set in train_sets_pde] + pde_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)), + pde_train_sets) + [get_loss_function(_loss, _set, eltypeθ, strategy) + for (_loss, _set) in zip(datafree_pde_loss_function, + pde_train_sets)] + else + nothing + end + + bc_loss_functions = if !(train_sets_bc isa Nothing) + bcs_train_sets = [train_set[:, 2:end] for train_set in train_sets_bc] + bcs_train_sets = adapt.(parameterless_type(ComponentArrays.getdata(flat_init_params)), + bcs_train_sets) + [get_loss_function(_loss, _set, eltypeθ, strategy) + for (_loss, _set) in zip(datafree_bc_loss_function, bcs_train_sets)] + else + nothing + end + + pde_loss_functions, bc_loss_functions +end + function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::GridTraining, datafree_pde_loss_function, diff --git a/test/BPINN_PDE_tests.jl b/test/BPINN_PDE_tests.jl new file mode 100644 index 0000000000..dc742a7af2 --- /dev/null +++ b/test/BPINN_PDE_tests.jl @@ -0,0 +1,201 @@ +using Test, MCMCChains, Lux, ModelingToolkit +import ModelingToolkit: Interval, infimum, supremum +using ForwardDiff, Distributions, OrdinaryDiffEq +using Flux, AdvancedHMC, Statistics, Random, Functors +using NeuralPDE, MonteCarloMeasurements +using ComponentArrays + +Random.seed!(100) + +# Cospit example +@parameters t +@variables u(..) + +Dt = Differential(t) + +eqs = Dt(u(t)) - cos(2 * π * t) ~ 0 +bcs = [u(0) ~ 0.0] +domains = [t ∈ Interval(0.0, 2.0)] + +chainf = Flux.Chain(Flux.Dense(1, 6, tanh), Flux.Dense(6, 1)) |> Flux.f64 +init1, re1 = Flux.destructure(chainf) +chainl = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 1)) +initl, st = Lux.setup(Random.default_rng(), chainl) + +@named pde_system = PDESystem(eqs, bcs, domains, [t], [u(t)]) + +# non adaptive case +discretization = NeuralPDE.BayesianPINN([chainl], GridTraining([0.01])) + +sol1 = ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 1500, + bcstd = [0.02], + phystd = [0.01], + priorsNNw = (0.0, 1.0), + saveats = [1 / 50.0]) + +discretization = NeuralPDE.BayesianPINN([chainf], GridTraining([0.01])) +sol2 = ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 1500, + bcstd = [0.01], + phystd = [0.005], + priorsNNw = (0.0, 1.0), + saveats = [1 / 50.0]) + +analytic_sol_func(u0, t) = u0 + sin(2 * π * t) / (2 * π) +ts = vec(sol1.timepoints[1]) +u_real = [analytic_sol_func(0.0, t) for t in ts] +u_predict = pmean(sol1.ensemblesol[1]) +@test u_predict≈u_real atol=0.5 +@test mean(u_predict .- u_real) < 0.1 + +ts = vec(sol2.timepoints[1]) +u_real = [analytic_sol_func(0.0, t) for t in ts] +u_predict = pmean(sol2.ensemblesol[1]) +@test u_predict≈u_real atol=0.5 +@test mean(u_predict .- u_real) < 0.1 + +## Example 1, 1D ode +@parameters θ +@variables u(..) +Dθ = Differential(θ) + +# 1D ODE +eq = Dθ(u(θ)) ~ θ^3 + 2 * θ + (θ^2) * ((1 + 3 * (θ^2)) / (1 + θ + (θ^3))) - + u(θ) * (θ + ((1 + 3 * (θ^2)) / (1 + θ + θ^3))) + +# Initial and boundary conditions +bcs = [u(0.0) ~ 1.0] + +# Space and time domains +domains = [θ ∈ Interval(0.0, 1.0)] + +# Neural network +chain = Lux.Chain(Lux.Dense(1, 12, Flux.σ), Lux.Dense(12, 1)) + +discretization = NeuralPDE.BayesianPINN([chain], + GridTraining([0.01])) + +@named pde_system = PDESystem(eq, bcs, domains, [θ], [u]) + +sol1 = ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 500, + bcstd = [0.1], + phystd = [0.05], + priorsNNw = (0.0, 10.0), + saveats = [1 / 100.0]) + +analytic_sol_func(t) = exp(-(t^2) / 2) / (1 + t + t^3) + t^2 +ts = sol1.timepoints[1] +u_real = vec([analytic_sol_func(t) for t in ts]) +u_predict = pmean(sol1.ensemblesol[1]) +@test u_predict≈u_real atol=0.8 + +# example 3 (3 degree ODE) +@parameters x +@variables u(..), Dxu(..), Dxxu(..), O1(..), O2(..) +Dxxx = Differential(x)^3 +Dx = Differential(x) + +# ODE +eq = Dx(Dxxu(x)) ~ cos(pi * x) + +# Initial and boundary conditions +ep = (cbrt(eps(eltype(Float64))))^2 / 6 + +bcs = [u(0.0) ~ 0.0, + u(1.0) ~ cos(pi), + Dxu(1.0) ~ 1.0, + Dxu(x) ~ Dx(u(x)) + ep * O1(x), + Dxxu(x) ~ Dx(Dxu(x)) + ep * O2(x)] + +# Space and time domains +domains = [x ∈ Interval(0.0, 1.0)] + +# Neural network +chain = [ + Lux.Chain(Lux.Dense(1, 10, Lux.tanh), Lux.Dense(10, 10, Lux.tanh), + Lux.Dense(10, 1)), Lux.Chain(Lux.Dense(1, 10, Lux.tanh), Lux.Dense(10, 10, Lux.tanh), + Lux.Dense(10, 1)), Lux.Chain(Lux.Dense(1, 10, Lux.tanh), Lux.Dense(10, 10, Lux.tanh), + Lux.Dense(10, 1)), Lux.Chain(Lux.Dense(1, 4, Lux.tanh), Lux.Dense(4, 1)), + Lux.Chain(Lux.Dense(1, 4, Lux.tanh), Lux.Dense(4, 1))] + +discretization = NeuralPDE.BayesianPINN(chain, GridTraining(0.01)) + +@named pde_system = PDESystem(eq, bcs, domains, [x], + [u(x), Dxu(x), Dxxu(x), O1(x), O2(x)]) + +sol1 = ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 200, + bcstd = [0.01, 0.01, 0.01, 0.01, 0.01], + phystd = [0.005], + priorsNNw = (0.0, 10.0), + saveats = [1 / 100.0]) + +analytic_sol_func(x) = (π * x * (-x + (π^2) * (2 * x - 3) + 1) - sin(π * x)) / (π^3) + +u_predict = pmean(sol1.ensemblesol[1]) +xs = vec(sol1.timepoints[1]) +u_real = [analytic_sol_func(x) for x in xs] +@test u_predict≈u_real atol=0.5 + +# diff_u = abs.(u_real .- u_predict) +# plot(xs, u_real) +# plot!(xs, u_predict) +# plot!(xs, diff_u) + +# 2D Poissons equation +@parameters x y +@variables u(..) +Dxx = Differential(x)^2 +Dyy = Differential(y)^2 + +# 2D PDE +eq = Dxx(u(x, y)) + Dyy(u(x, y)) ~ -sin(pi * x) * sin(pi * y) + +# Boundary conditions +bcs = [u(0, y) ~ 0.0, u(1, y) ~ 0.0, + u(x, 0) ~ 0.0, u(x, 1) ~ 0.0] + +# Space and time domains +domains = [x ∈ Interval(0.0, 1.0), + y ∈ Interval(0.0, 1.0)] + +# Neural network +dim = 2 # number of dimensions +chain = Lux.Chain(Lux.Dense(dim, 9, Lux.σ), Lux.Dense(9, 9, Lux.σ), Lux.Dense(9, 1)) + +# Discretization +dx = 0.05 +discretization=NeuralPDE.BayesianPINN([chain], GridTraining(dx)) + +@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) + +sol1 = ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 200, + bcstd = [0.003, 0.003, 0.003, 0.003], + phystd = [0.003], + priorsNNw = (0.0, 10.0), + saveats = [1 / 100.0, 1 / 100.0]) + +xs = sol1.timepoints[1] +analytic_sol_func(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2) + +u_predict = pmean(sol1.ensemblesol[1]) +u_real = [analytic_sol_func(xs[:, i][1], xs[:, i][2]) for i in 1:length(xs[1, :])] +diff_u = abs.(u_predict .- u_real) +@test u_predict≈u_real atol=1.5 + +# using Plots, StatsPlots +# plotly() +# plot(sol1.timepoints[1][1, :], +# sol1.timepoints[1][2, :], +# pmean(sol1.ensemblesol[1]), +# linetype = :contourf) +# plot(sol1.timepoints[1][1, :], sol1.timepoints[1][2, :], u_real, linetype = :contourf) +# plot(sol1.timepoints[1][1, :], sol1.timepoints[1][2, :], diff_u, linetype = :contourf) \ No newline at end of file diff --git a/test/BPINN_PDEinvsol_tests.jl b/test/BPINN_PDEinvsol_tests.jl new file mode 100644 index 0000000000..3521c8c913 --- /dev/null +++ b/test/BPINN_PDEinvsol_tests.jl @@ -0,0 +1,195 @@ +using Test, MCMCChains, Lux, ModelingToolkit +import ModelingToolkit: Interval, infimum, supremum +using ForwardDiff, Distributions, OrdinaryDiffEq +using Flux, AdvancedHMC, Statistics, Random, Functors +using NeuralPDE, MonteCarloMeasurements +using ComponentArrays + +Random.seed!(100) + +# Cos(pit) periodic curve (Parameter Estimation) +println("Example 1, 2d Periodic System") +@parameters t, p +@variables u(..) + +Dt = Differential(t) +eqs = Dt(u(t)) - cos(p * t) ~ 0 +bcs = [u(0) ~ 0.0] +domains = [t ∈ Interval(0.0, 2.0)] + +chainf = Flux.Chain(Flux.Dense(1, 6, tanh), Flux.Dense(6, 1)) |> Flux.f64 +init1, re1 = Flux.destructure(chainf) +chainl = Lux.Chain(Lux.Dense(1, 6, tanh), Lux.Dense(6, 1)) +initl, st = Lux.setup(Random.default_rng(), chainl) + +@named pde_system = PDESystem(eqs, + bcs, + domains, + [t], + [u(t)], + [p], + defaults = Dict([p => 4.0])) + +analytic_sol_func1(u0, t) = u0 + sin(2 * π * t) / (2 * π) +timepoints = collect(0.0:(1 / 100.0):2.0) +u = [analytic_sol_func1(0.0, timepoint) for timepoint in timepoints] +u = u .+ (u .* 0.2) .* randn(size(u)) +dataset = [hcat(u, timepoints)] + +# plot(dataset[1][:, 2], dataset[1][:, 1]) +# plot!(timepoints, u) + +# checking all training strategies +discretization = NeuralPDE.BayesianPINN([chainl], + StochasticTraining(200), + param_estim = true, dataset = [dataset, nothing]) + +ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 1500, + bcstd = [0.05], + phystd = [0.01], l2std = [0.01], + priorsNNw = (0.0, 1.0), + saveats = [1 / 50.0], + param = [LogNormal(6.0, 0.5)]) + +discretization = NeuralPDE.BayesianPINN([chainl], + QuasiRandomTraining(200), + param_estim = true, dataset = [dataset, nothing]) + +ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 1500, + bcstd = [0.05], + phystd = [0.01], l2std = [0.01], + priorsNNw = (0.0, 1.0), + saveats = [1 / 50.0], + param = [LogNormal(6.0, 0.5)]) + +discretization = NeuralPDE.BayesianPINN([chainl], + QuadratureTraining(), param_estim = true, dataset = [dataset, nothing]) + +ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 1500, + bcstd = [0.05], + phystd = [0.01], l2std = [0.01], + priorsNNw = (0.0, 1.0), + saveats = [1 / 50.0], + param = [LogNormal(6.0, 0.5)]) + +discretization = NeuralPDE.BayesianPINN([chainl], + GridTraining([0.02]), + param_estim = true, dataset = [dataset, nothing]) + +sol1 = ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 1500, + bcstd = [0.05], + phystd = [0.01], l2std = [0.01], + priorsNNw = (0.0, 1.0), + saveats = [1 / 50.0], + param = [LogNormal(6.0, 0.5)]) + +discretization = NeuralPDE.BayesianPINN([chainf], + GridTraining([0.02]), param_estim = true, dataset = [dataset, nothing]) + +sol2 = ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 1500, + bcstd = [0.03], + phystd = [0.01], l2std = [0.01], + priorsNNw = (0.0, 1.0), + saveats = [1 / 50.0], + param = [LogNormal(6.0, 0.5)]) + +param = 2 * π +ts = vec(sol1.timepoints[1]) +u_real = [analytic_sol_func1(0.0, t) for t in ts] +u_predict = pmean(sol1.ensemblesol[1]) + +@test u_predict≈u_real atol=1.5 +@test mean(u_predict .- u_real) < 0.1 +@test sol1.estimated_de_params[1]≈param atol=param * 0.3 + +ts = vec(sol2.timepoints[1]) +u_real = [analytic_sol_func1(0.0, t) for t in ts] +u_predict = pmean(sol2.ensemblesol[1]) + +@test u_predict≈u_real atol=0.5 +@test mean(u_predict .- u_real) < 0.1 +@test sol2.estimated_de_params[1]≈param atol=param * 0.3 + +## Example Lorenz System (Parameter Estimation) +println("Example 2, Lorenz System") +@parameters t, σ_ +@variables x(..), y(..), z(..) +Dt = Differential(t) +eqs = [Dt(x(t)) ~ σ_ * (y(t) - x(t)), + Dt(y(t)) ~ x(t) * (28.0 - z(t)) - y(t), + Dt(z(t)) ~ x(t) * y(t) - 8 / 3 * z(t)] + +bcs = [x(0) ~ 1.0, y(0) ~ 0.0, z(0) ~ 0.0] +domains = [t ∈ Interval(0.0, 1.0)] + +input_ = length(domains) +n = 7 +chain = [ + Lux.Chain(Lux.Dense(input_, n, Lux.tanh), Lux.Dense(n, n, Lux.tanh), + Lux.Dense(n, 1)), + Lux.Chain(Lux.Dense(input_, n, Lux.tanh), Lux.Dense(n, n, Lux.tanh), + Lux.Dense(n, 1)), + Lux.Chain(Lux.Dense(input_, n, Lux.tanh), Lux.Dense(n, n, Lux.tanh), + Lux.Dense(n, 1)), +] + +#Generate Data +function lorenz!(du, u, p, t) + du[1] = 10.0 * (u[2] - u[1]) + du[2] = u[1] * (28.0 - u[3]) - u[2] + du[3] = u[1] * u[2] - (8 / 3) * u[3] +end + +u0 = [1.0; 0.0; 0.0] +tspan = (0.0, 1.0) +prob = ODEProblem(lorenz!, u0, tspan) +sol = solve(prob, Tsit5(), dt = 0.01, saveat = 0.05) +ts = sol.t +us = hcat(sol.u...) +us = us .+ ((0.05 .* randn(size(us))) .* us) +ts_ = hcat(sol(ts).t...)[1, :] +dataset = [hcat(us[i, :], ts_) for i in 1:3] + +# using Plots, StatsPlots +# plot(hcat(sol.u...)[1, :], hcat(sol.u...)[2, :], hcat(sol.u...)[3, :]) +# plot!(dataset[1][:, 1], dataset[2][:, 1], dataset[3][:, 1]) +# plot(dataset[1][:, 2:end], dataset[1][:, 1]) +# plot!(dataset[2][:, 2:end], dataset[2][:, 1]) +# plot!(dataset[3][:, 2:end], dataset[3][:, 1]) + +discretization = NeuralPDE.BayesianPINN(chain, NeuralPDE.GridTraining([0.01]); + param_estim = true, dataset = [dataset, nothing]) + +@named pde_system = PDESystem(eqs, bcs, domains, + [t], [x(t), y(t), z(t)], [σ_], defaults = Dict([p => 1.0 for p in [σ_]])) + +sol1 = ahmc_bayesian_pinn_pde(pde_system, + discretization; + draw_samples = 50, + bcstd = [0.3, 0.3, 0.3], + phystd = [0.1, 0.1, 0.1], + l2std = [1, 1, 1], + priorsNNw = (0.0, 1.0), + saveats = [0.01], + param = [Normal(12.0, 2)]) + +idealp = 10.0 +p_ = sol1.estimated_de_params[1] + +# plot(pmean(sol1.ensemblesol[1]), pmean(sol1.ensemblesol[2]), pmean(sol1.ensemblesol[3])) +# plot(sol1.timepoints[1]', pmean(sol1.ensemblesol[1])) +# plot!(sol1.timepoints[2]', pmean(sol1.ensemblesol[2])) +# plot!(sol1.timepoints[3]', pmean(sol1.ensemblesol[3])) + +@test sum(abs, pmean(p_) - 10.00) < 0.3 * idealp[1] +# @test sum(abs, pmean(p_[2]) - (8 / 3)) < 0.3 * idealp[2] \ No newline at end of file diff --git a/test/BPINN_Tests.jl b/test/BPINN_Tests.jl index 4873d5f457..cb0303daf0 100644 --- a/test/BPINN_Tests.jl +++ b/test/BPINN_Tests.jl @@ -9,18 +9,6 @@ using NeuralPDE, MonteCarloMeasurements # on latest Julia version it performs much better for below tests Random.seed!(100) -# for sampled params->lux ComponentArray -function vector_to_parameters(ps_new::AbstractVector, ps::NamedTuple) - @assert length(ps_new) == Lux.parameterlength(ps) - i = 1 - function get_ps(x) - z = reshape(view(ps_new, i:(i + length(x) - 1)), size(x)) - i += length(x) - return z - end - return Functors.fmap(get_ps, ps) -end - ## PROBLEM-1 (WITHOUT PARAMETER ESTIMATION) linear_analytic = (u0, p, t) -> u0 + sin(2 * π * t) / (2 * π) linear = (u, p, t) -> cos(2 * π * t) @@ -187,8 +175,8 @@ meanscurve2 = prob.u0 .+ (t .- prob.tspan[1]) .* luxmean @test mean(abs.(physsol1_1 .- sol2lux.ensemblesol[1])) < 8e-2 # ESTIMATED ODE PARAMETERS (NN1 AND NN2) -@test abs(p - sol2flux.estimated_ode_params[1]) < abs(0.15 * p) -@test abs(p - sol2lux.estimated_ode_params[1]) < abs(0.15 * p) +@test abs(p - sol2flux.estimated_de_params[1]) < abs(0.15 * p) +@test abs(p - sol2lux.estimated_de_params[1]) < abs(0.15 * p) ## PROBLEM-2 linear = (u, p, t) -> u / p + exp(t / p) * cos(t) @@ -338,11 +326,11 @@ param1 = mean(i[62] for i in fhsampleslux22[1000:1500]) # (flux chain) @test mean(abs.(physsol2 .- sol3flux_pestim.ensemblesol[1])) < 0.15 # estimated parameters(flux chain) -param1 = sol3flux_pestim.estimated_ode_params[1] +param1 = sol3flux_pestim.estimated_de_params[1] @test abs(param1 - p) < abs(0.45 * p) # (lux chain) @test mean(abs.(physsol2 .- sol3lux_pestim.ensemblesol[1])) < 0.15 # estimated parameters(lux chain) -param1 = sol3lux_pestim.estimated_ode_params[1] +param1 = sol3lux_pestim.estimated_de_params[1] @test abs(param1 - p) < abs(0.45 * p) \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index afe7186a28..5d6ac6909e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,11 +15,15 @@ function dev_subpkg(subpkg) end @time begin - #fixes 682 if GROUP == "All" || GROUP == "ODEBPINN" @time @safetestset "Bpinn ODE solver" begin include("BPINN_Tests.jl") end end + if GROUP == "All" || GROUP == "PDEBPINN" + @time @safetestset "Bpinn PDE solver" begin include("BPINN_PDE_tests.jl") end + @time @safetestset "Bpinn PDE invaddloss solver" begin include("BPINN_PDEinvsol_tests.jl") end + end + if GROUP == "All" || GROUP == "NNPDE1" @time @safetestset "NNPDE" begin include("NNPDE_tests.jl") end end