Skip to content

Commit

Permalink
refactor: cleanup NNODE
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 14, 2024
1 parent b2d9ca6 commit 21c5215
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 312 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <[email protected]>"]
version = "5.16.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down Expand Up @@ -31,17 +32,18 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1.9.0"
Adapt = "4"
AdvancedHMC = "0.6.1"
Aqua = "0.8"
Expand Down Expand Up @@ -78,6 +80,7 @@ OrdinaryDiffEq = "6.87"
Pkg = "1.10"
QuasiMonteCarlo = "0.3.2"
Random = "1"
RecursiveArrayTools = "3.27.0"
Reexport = "1.2"
RuntimeGeneratedFunctions = "0.5.12"
SafeTestsets = "0.1"
Expand All @@ -86,7 +89,6 @@ Statistics = "1.10"
SymbolicUtils = "3.7.2"
Symbolics = "6.14"
Test = "1.10"
UnPack = "1"
WeightInitializers = "1.0.3"
Zygote = "0.6.71"
julia = "1.10"
Expand Down
6 changes: 1 addition & 5 deletions src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,7 @@ function SciMLBase.__solve(prob::SciMLBase.ODEProblem,
saveat = 1 / 50.0,
maxiters = nothing,
numensemble = floor(Int, alg.draw_samples / 3))
@unpack chain, l2std, phystd, param, priorsNNw, Kernel, strategy,
draw_samples, dataset, init_params,
nchains, physdt, Adaptorkwargs, Integratorkwargs,
MCMCkwargs, numensemble, estim_collocate, autodiff, progress,
verbose = alg
(; chain, l2std, phystd, param, priorsNNw, Kernel, strategy, draw_samples, dataset, init_params, nchains, physdt, Adaptorkwargs, Integratorkwargs, MCMCkwargs, numensemble, estim_collocate, autodiff, progress, verbose) = alg

# ahmc_bayesian_pinn_ode needs param=[] for easier vcat operation for full vector of parameters
param = param === nothing ? [] : param
Expand Down
3 changes: 2 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ import DomainSets
using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, rightendpoint,
ProductDomain
using SciMLBase: @add_kwonly, parameterless_type
using UnPack: @unpack

using ADTypes: AutoForwardDiff, AutoZygote
using ChainRulesCore: ChainRulesCore, @non_differentiable, @ignore_derivatives
using ComponentArrays: ComponentArrays, ComponentArray, getdata, getaxes
using ConcreteStructs: @concrete
Expand All @@ -37,6 +37,7 @@ using Lux: Lux, Chain, Dense, SkipConnection, StatefulLuxLayer
using Lux: FromFluxAdaptor, recursive_eltype
using LuxCore: AbstractLuxLayer, AbstractLuxWrapperLayer, AbstractLuxContainerLayer
using Optimisers: Optimisers, Adam
using RecursiveArrayTools: DiffEqArray
using QuasiMonteCarlo: QuasiMonteCarlo, LatinHypercubeSample
using WeightInitializers: glorot_uniform, zeros32

Expand Down
10 changes: 3 additions & 7 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,7 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs;
dict_transformation_vars = nothing,
transformation_vars = nothing,
integrating_depvars = pinnrep.depvars)
@unpack indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input,
phi, derivative, integral,
multioutput, init_params, strategy, eq_params,
param_estim, default_p = pinnrep
(; indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input, phi, derivative, integral, multioutput, init_params, strategy, eq_params, param_estim, default_p) = pinnrep

eltypeθ = eltype(pinnrep.flat_init_params)

Expand Down Expand Up @@ -150,7 +147,7 @@ Returns the body of loss function, which is the executable Julia function, for t
equation or boundary condition.
"""
function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars)
@unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep
(; eq_params, param_estim, default_p, phi, derivative, integral) = pinnrep

bc_indvars = bc_indvars === nothing ? pinnrep.indvars : bc_indvars

Expand Down Expand Up @@ -312,8 +309,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str
end

function get_numeric_integral(pinnrep::PINNRepresentation)
@unpack strategy, indvars, depvars, multioutput, derivative,
depvars, indvars, dict_indvars, dict_depvars = pinnrep
(; strategy, indvars, depvars, multioutput, derivative, depvars, indvars, dict_indvars, dict_depvars) = pinnrep

integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, indvars = indvars, depvars = depvars, dict_indvars = dict_indvars, dict_depvars = dict_depvars) -> begin
function integration_(cord, lb, ub, θ)
Expand Down
Loading

0 comments on commit 21c5215

Please sign in to comment.