-
-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
BPINN PDE solver #745
BPINN PDE solver #745
Conversation
8c6a46f
to
8684c57
Compare
The docs successfully built on master, so I put this on top of a successful master. |
8684c57
to
3e7e784
Compare
That looks like a legitimate test failure. You may want to split BPINN PDE to a new test group. |
@info("Sampling Complete.") | ||
@info("Current Physics Log-likelihood : ", physloglikelihood(ℓπ, samples[end])) | ||
@info("Current Prior Log-likelihood : ", priorweights(ℓπ, samples[end])) | ||
@info("Current MSE against dataset Log-likelihood : ", | ||
L2LossData(ℓπ, samples[end])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this intended?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, helps in choosing better set of parameters and priors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that should be returned information rather than just printed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i could do that, but the info would be returned only after the sampling. The reason ive chosen to display it is so that at the beginning of the sampling the values of different loglikelihoods are visible and if proportion wrt to each other is off (weak priors will need us to have larger L2 loglikelihoods than residual loglikelihoods) we can simply exit sampling and adjust stds,etc instead of waiting out the whole sampling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the final values of loglikelihoods shows which objective has been optimized more.
src/advancedHMC_MCMC.jl
Outdated
function vector_to_parameters(ps_new::AbstractVector, | ||
ps::Union{NamedTuple, ComponentArrays.ComponentVector, AbstractVector}) | ||
if (ps isa ComponentArrays.ComponentVector) || (ps isa NamedTuple) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use dispatch...
src/PDE_BPINN.jl
Outdated
# multioutput case for Lux chains, for each depvar ps would contain Lux ComponentVectors | ||
# which we use for mapping current ahmc sampled vector of parameters onto NNs | ||
i = 0 | ||
Luxparams = Vector{ComponentArrays.ComponentVector}() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is non-concrete and will slow things down.
# Tar.allstd) | ||
# end | ||
|
||
function setparameters(Tar::PDELogTargetDensity, θ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't lux have a utility for the transformation of vector into componentvector for this kind of thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i couldnt find it in the docs(atleaast something specific to this case), so i worked it out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make sure an issue get's opened about this after merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typically you store the axes of the ComponentVector via ax = getaxes(ps_ca)
and then construct the CA via ComponentArray(theta, ax)
src/PDE_BPINN.jl
Outdated
if Kernel == HMCDA | ||
δ, λ = MCMCkwargs[:δ], MCMCkwargs[:λ] | ||
Kernel(δ, λ) | ||
elseif Kernel == NUTS |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That won't match if it's a type. Needs to use isa
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kernel here is the constructor name, the constructors are exported via AdvancedHMC.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes but why on the datatype? That's going to ruin inference, and doesn't allow kernel-specific parameters
src/PDE_BPINN.jl
Outdated
if nchains > Threads.nthreads() | ||
throw(error("number of chains is greater than available threads")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why wouldn't that be allowed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sampling for each chain is done on a single thread, https://turinglang.org/AdvancedHMC.jl/stable/#Parallel-sampling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but you could still have main chains than threads. Nothing stops that, it just batches
Generally looking good, though there's some performance issues to clean up and some issues with style not using dispatch when it would clean things up |
src/discretize.jl
Outdated
if bayesian | ||
# required as Physics loss also needed on dataset domain points | ||
pde_loss_functions1, bc_loss_functions1 = if !(dataset_given isa Nothing) | ||
if !(strategy isa GridTraining) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would only gridtraining be allowed? Wouldn't you just add these points?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, the points would get included into the training sets, GridTraining strategy is used for loss on dataset domain points anyways. So i was thinking GridTraining throughout domain is better than, having gridtraining for dataset points and Stochastic/etc on other points
src/discretize.jl
Outdated
@@ -401,7 +401,7 @@ to the PDE. | |||
For more information, see `discretize` and `PINNRepresentation`. | |||
""" | |||
function SciMLBase.symbolic_discretize(pde_system::PDESystem, | |||
discretization::PhysicsInformedNN) | |||
discretization::PhysicsInformedNN; bayesian::Bool = false,dataset_given=nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is missing from the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh actually, this is an interface break.
that's an interface break, and not one that makes sense. What would It's probably best to make a new discretization type and then have an |
Almost there. So I think make the discretizer object so the interface is kept, and then open issues for the following:
|
src/discretize.jl
Outdated
@@ -401,7 +401,7 @@ to the PDE. | |||
For more information, see `discretize` and `PINNRepresentation`. | |||
""" | |||
function SciMLBase.symbolic_discretize(pde_system::PDESystem, | |||
discretization::PhysicsInformedNN) | |||
discretization) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is type piracy. It needs to specify the top level dispatch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missed out on this, thanks
awesome, thanks @ChrisRackauckas @Vaibhavdixit02 and @xtalax. |
for #712 and #205