Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BPINN PDE solver #745

Merged
merged 62 commits into from
Jan 7, 2024
Merged

BPINN PDE solver #745

merged 62 commits into from
Jan 7, 2024

Conversation

AstitvaAggarwal
Copy link
Contributor

for #712 and #205

@ChrisRackauckas
Copy link
Member

The docs successfully built on master, so I put this on top of a successful master.

@ChrisRackauckas
Copy link
Member

That looks like a legitimate test failure. You may want to split BPINN PDE to a new test group.

src/discretize.jl Outdated Show resolved Hide resolved
Comment on lines +611 to +615
@info("Sampling Complete.")
@info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end]))
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, samples[end]))
@info("Current MSE against dataset Log-likelihood : ",
L2LossData(ℓπ, samples[end]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intended?

Copy link
Contributor Author

@AstitvaAggarwal AstitvaAggarwal Jan 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, helps in choosing better set of parameters and priors

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that should be returned information rather than just printed

Copy link
Contributor Author

@AstitvaAggarwal AstitvaAggarwal Jan 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i could do that, but the info would be returned only after the sampling. The reason ive chosen to display it is so that at the beginning of the sampling the values of different loglikelihoods are visible and if proportion wrt to each other is off (weak priors will need us to have larger L2 loglikelihoods than residual loglikelihoods) we can simply exit sampling and adjust stds,etc instead of waiting out the whole sampling

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the final values of loglikelihoods shows which objective has been optimized more.

Comment on lines 71 to 73
function vector_to_parameters(ps_new::AbstractVector,
ps::Union{NamedTuple, ComponentArrays.ComponentVector, AbstractVector})
if (ps isa ComponentArrays.ComponentVector) || (ps isa NamedTuple)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use dispatch...

src/PDE_BPINN.jl Outdated
# multioutput case for Lux chains, for each depvar ps would contain Lux ComponentVectors
# which we use for mapping current ahmc sampled vector of parameters onto NNs
i = 0
Luxparams = Vector{ComponentArrays.ComponentVector}()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is non-concrete and will slow things down.

# Tar.allstd)
# end

function setparameters(Tar::PDELogTargetDensity, θ)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't lux have a utility for the transformation of vector into componentvector for this kind of thing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i couldnt find it in the docs(atleaast something specific to this case), so i worked it out

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure an issue get's opened about this after merging.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically you store the axes of the ComponentVector via ax = getaxes(ps_ca) and then construct the CA via ComponentArray(theta, ax)

src/PDE_BPINN.jl Outdated
Comment on lines 199 to 202
if Kernel == HMCDA
δ, λ = MCMCkwargs[:δ], MCMCkwargs[:λ]
Kernel(δ, λ)
elseif Kernel == NUTS
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That won't match if it's a type. Needs to use isa

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kernel here is the constructor name, the constructors are exported via AdvancedHMC.jl

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but why on the datatype? That's going to ruin inference, and doesn't allow kernel-specific parameters

src/PDE_BPINN.jl Outdated
Comment on lines 358 to 359
if nchains > Threads.nthreads()
throw(error("number of chains is greater than available threads"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why wouldn't that be allowed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sampling for each chain is done on a single thread, https://turinglang.org/AdvancedHMC.jl/stable/#Parallel-sampling

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but you could still have main chains than threads. Nothing stops that, it just batches

src/PDE_BPINN.jl Outdated Show resolved Hide resolved
@ChrisRackauckas
Copy link
Member

Generally looking good, though there's some performance issues to clean up and some issues with style not using dispatch when it would clean things up

if bayesian
# required as Physics loss also needed on dataset domain points
pde_loss_functions1, bc_loss_functions1 = if !(dataset_given isa Nothing)
if !(strategy isa GridTraining)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why would only gridtraining be allowed? Wouldn't you just add these points?

Copy link
Contributor Author

@AstitvaAggarwal AstitvaAggarwal Jan 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, the points would get included into the training sets, GridTraining strategy is used for loss on dataset domain points anyways. So i was thinking GridTraining throughout domain is better than, having gridtraining for dataset points and Stochastic/etc on other points

@@ -401,7 +401,7 @@ to the PDE.
For more information, see `discretize` and `PINNRepresentation`.
"""
function SciMLBase.symbolic_discretize(pde_system::PDESystem,
discretization::PhysicsInformedNN)
discretization::PhysicsInformedNN; bayesian::Bool = false,dataset_given=nothing)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is missing from the docstring.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh actually, this is an interface break.

@ChrisRackauckas
Copy link
Member

function SciMLBase.symbolic_discretize(pde_system::PDESystem,
    discretization::PhysicsInformedNN; bayesian::Bool = false,dataset_given=nothing)

that's an interface break, and not one that makes sense. What would bayesian or dataset_given mean for a finite difference method? Those cannot be added to the high level interface as that's not well thought out in a general sense.

It's probably best to make a new discretization type and then have an AbstractPINNDiscretizer that is used in order to make it easy to reuse dispatches. bayesian wouldn't be needed then. I'm not entirely sure about dataset_given though, that should be known by the PDESystem or discretization?

@ChrisRackauckas
Copy link
Member

Almost there.

So I think make the discretizer object so the interface is kept, and then open issues for the following:

  1. BPINNs on PDEs needs a tutorial and other docs
  2. The vector to parameters stuff is a bit slow for type stability issues, but we can merge without handling that since I don't think the functionality should be here anyways, so we should have an issue about reusing some Lux functionality for this.

@@ -401,7 +401,7 @@ to the PDE.
For more information, see `discretize` and `PINNRepresentation`.
"""
function SciMLBase.symbolic_discretize(pde_system::PDESystem,
discretization::PhysicsInformedNN)
discretization)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is type piracy. It needs to specify the top level dispatch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missed out on this, thanks

src/discretize.jl Outdated Show resolved Hide resolved
@ChrisRackauckas ChrisRackauckas merged commit 0687aaf into SciML:master Jan 7, 2024
10 of 12 checks passed
@AstitvaAggarwal
Copy link
Contributor Author

awesome, thanks @ChrisRackauckas @Vaibhavdixit02 and @xtalax.

@AstitvaAggarwal AstitvaAggarwal deleted the Bpinn_pde branch January 7, 2024 19:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants