Skip to content

Commit

Permalink
Merge pull request #836 from sathvikbhagavan/sb/batch
Browse files Browse the repository at this point in the history
refactor: correctly lower quadrature training strategy in NNODE
  • Loading branch information
ChrisRackauckas authored Mar 22, 2024
2 parents 99feba6 + 7e3de98 commit 0cb2e06
Show file tree
Hide file tree
Showing 11 changed files with 82 additions and 87 deletions.
8 changes: 4 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ AdvancedHMC = "0.6.1"
Aqua = "0.8"
ArrayInterface = "7.7"
CUDA = "5.2"
ChainRulesCore = "1.18"
ChainRulesCore = "1.21"
ComponentArrays = "0.15.8"
Cubature = "1.5"
DiffEqBase = "6.144"
Expand All @@ -59,7 +59,7 @@ Integrals = "4"
LineSearches = "7.2"
LinearAlgebra = "1"
LogDensityProblems = "2"
Lux = "0.5.14"
Lux = "0.5.22"
LuxCUDA = "0.3.2"
MCMCChains = "6"
MethodOfLines = "0.10.7"
Expand All @@ -82,7 +82,7 @@ SymbolicUtils = "1.4"
Symbolics = "5.17"
Test = "1"
UnPack = "1"
Zygote = "0.6.68"
Zygote = "0.6.69"
julia = "1.10"

[extras]
Expand All @@ -91,12 +91,12 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"

[targets]
test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"]
8 changes: 4 additions & 4 deletions docs/src/tutorials/neural_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function loss(cord, θ)
ch2 .- phi(cord, res.u)
end
strategy = NeuralPDE.QuadratureTraining()
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6)
prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy)
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000)
Expand Down Expand Up @@ -173,7 +173,7 @@ for i in 1:count_decomp
bcs_ = create_bcs(domains_[1].domain, phi_bound)
@named pde_system_ = PDESystem(eq, bcs_, domains_, [x, y], [u(x, y)])
push!(pde_system_map, pde_system_)
strategy = NeuralPDE.QuadratureTraining()
strategy = NeuralPDE.QuadratureTraining(; reltol = 1e-6)
discretization = NeuralPDE.PhysicsInformedNN(chains[i], strategy;
init_params = init_params[i])
Expand Down Expand Up @@ -243,10 +243,10 @@ callback = function (p, l)
end
prob_ = NeuralPDE.neural_adapter(losses, init_params2, pde_system_map,
NeuralPDE.QuadratureTraining())
NeuralPDE.QuadratureTraining(; reltol = 1e-6))
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
prob_ = NeuralPDE.neural_adapter(losses, res_.u, pde_system_map,
NeuralPDE.QuadratureTraining())
NeuralPDE.QuadratureTraining(; reltol = 1e-6))
res_ = Optimization.solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 5000)
phi_ = PhysicsInformedNN(chain2, strategy; init_params = res_.u).phi
Expand Down
2 changes: 1 addition & 1 deletion src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ function BNNODE(chain, Kernel = HMC; strategy = nothing, draw_samples = 2000,
targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,),
autodiff = false, progress = false, verbose = false)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
BNNODE(chain, Kernel, strategy,
draw_samples, priorsNNw, param, l2std,
phystd, dataset, physdt, MCMCkwargs,
Expand Down
1 change: 1 addition & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte
using SciMLBase: @add_kwonly, parameterless_type
using UnPack: @unpack
import ChainRulesCore, Lux, ComponentArrays
using Lux: FromFluxAdaptor
using ChainRulesCore: @non_differentiable

RuntimeGeneratedFunctions.init(@__MODULE__)
Expand Down
2 changes: 1 addition & 1 deletion src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
MCMCkwargs = (n_leapfrog = 30,),
progress = false, verbose = false)

!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
# NN parameter prior mean and variance(PriorsNN must be a tuple)
if isinplace(prob)
throw(error("The BPINN ODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))
Expand Down
2 changes: 1 addition & 1 deletion src/dae_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end

function NNDAE(chain, opt, init_params = nothing; strategy = nothing, autodiff = false,
kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
NNDAE(chain, opt, init_params, autodiff, strategy, kwargs)
end

Expand Down
43 changes: 14 additions & 29 deletions src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ of the physics-informed neural network which is used as a solver for a standard
## Positional Arguments
* `chain`: A neural network architecture, defined as a `Lux.AbstractExplicitLayer` or `Flux.Chain`.
`Flux.Chain` will be converted to `Lux` using `Lux.transform`.
`Flux.Chain` will be converted to `Lux` using `adapt(FromFluxAdaptor(false, false), chain)`.
* `opt`: The optimizer to train the neural network.
* `init_params`: The initial parameter of the neural network. By default, this is `nothing`
which thus uses the random initialization provided by the neural network library.
Expand All @@ -27,11 +27,10 @@ of the physics-informed neural network which is used as a solver for a standard
the PDE operators. The reverse mode of the loss function is always
automatic differentiation (via Zygote), this is only for the derivative
in the loss function (the derivative with respect to time).
* `batch`: The batch size to use for the internal quadrature. Defaults to `0`, which
means the application of the neural network is done at individual time points one
at a time. `batch>0` means the neural network is applied at a row vector of values
`t` simultaneously, i.e. it's the batch size for the neural network evaluations.
This requires a neural network compatible with batched data.
* `batch`: The batch size for the loss computation. Defaults to `true`, means the neural network is applied at a row vector of values
`t` simultaneously, i.e. it's the batch size for the neural network evaluations. This requires a neural network compatible with batched data.
`false` means which means the application of the neural network is done at individual time points one at a time.
This is not applicable to `QuadratureTraining` where `batch` is passed in the `strategy` which is the number of points it can parallelly compute the integrand.
* `param_estim`: Boolean to indicate whether parameters of the differential equations are learnt along with parameters of the neural network.
* `strategy`: The training strategy used to choose the points for the evaluations.
Default of `nothing` means that `QuadratureTraining` with QuadGK is used if no
Expand Down Expand Up @@ -88,8 +87,8 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function},
end
function NNODE(chain, opt, init_params = nothing;
strategy = nothing,
autodiff = false, batch = nothing, param_estim = false, additional_loss = nothing, kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
NNODE(chain, opt, init_params, autodiff, batch, strategy, param_estim, additional_loss, kwargs)
end

Expand All @@ -111,11 +110,7 @@ end

function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params)
θ, st = Lux.setup(Random.default_rng(), chain)
if init_params === nothing
init_params = ComponentArrays.ComponentArray(θ)
else
init_params = ComponentArrays.ComponentArray(init_params)
end
isnothing(init_params) && (init_params = θ)
ODEPhi(chain, t, u0, st), init_params
end

Expand Down Expand Up @@ -182,7 +177,7 @@ function ode_dfdx(phi::ODEPhi, t::AbstractVector, θ, autodiff::Bool)
end

"""
inner_loss(phi, f, autodiff, t, θ, p)
inner_loss(phi, f, autodiff, t, θ, p, param_estim)
Simple L2 inner loss at a time `t` with parameters `θ` of the neural network.
"""
Expand Down Expand Up @@ -220,7 +215,7 @@ function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector,
end

"""
generate_loss(strategy, phi, f, autodiff, tspan, p, batch)
generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim)
Representation of the loss function, parametric on the training strategy `strategy`.
"""
Expand All @@ -229,14 +224,13 @@ function generate_loss(strategy::QuadratureTraining, phi, f, autodiff::Bool, tsp
integrand(t::Number, θ) = abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim))

integrand(ts, θ) = [abs2(inner_loss(phi, f, autodiff, t, θ, p, param_estim)) for t in ts]
@assert batch == 0 # not implemented

function loss(θ, _)
intprob = IntegralProblem(integrand, (tspan[1], tspan[2]), θ)
sol = solve(intprob, QuadGKJL(); abstol = strategy.abstol, reltol = strategy.reltol)
intf = BatchIntegralFunction(integrand, max_batch = strategy.batch)
intprob = IntegralProblem(intf, (tspan[1], tspan[2]), θ)
sol = solve(intprob, strategy.quadrature_alg; abstol = strategy.abstol, reltol = strategy.reltol, maxiters = strategy.maxiters)
sol.u
end

return loss
end

Expand Down Expand Up @@ -395,16 +389,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
alg.strategy
end

batch = if alg.batch === nothing
if strategy isa QuadratureTraining
strategy.batch
else
true
end
else
alg.batch
end

batch = alg.batch
inner_f = generate_loss(strategy, phi, f, autodiff, tspan, p, batch, param_estim)
additional_loss = alg.additional_loss
(param_estim && isnothing(additional_loss)) && throw(ArgumentError("Please provide `additional_loss` in `NNODE` for parameter estimation (`param_estim` is true)."))
Expand Down
6 changes: 3 additions & 3 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ methodology.
* `chain`: a vector of Lux/Flux chains with a d-dimensional input and a
1-dimensional output corresponding to each of the dependent variables. Note that this
specification respects the order of the dependent variables as specified in the PDESystem.
Flux chains will be converted to Lux internally using `Lux.transform`.
Flux chains will be converted to Lux internally using `adapt(FromFluxAdaptor(false, false), chain)`.
* `strategy`: determines which training strategy will be used. See the Training Strategy
documentation for more details.
Expand Down Expand Up @@ -107,7 +107,7 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN
if multioutput
!all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain))
else
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
end
if phi === nothing
if multioutput
Expand Down Expand Up @@ -243,7 +243,7 @@ struct BayesianPINN{T, P, PH, DER, PE, AL, ADA, LOG, D, K} <: AbstractPINN
if multioutput
!all(i -> i isa Lux.AbstractExplicitLayer, chain) && (chain = Lux.transform.(chain))
else
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
!(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain))
end
if phi === nothing
if multioutput
Expand Down
6 changes: 1 addition & 5 deletions src/training_strategies.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ struct QuadratureTraining{Q <: SciMLBase.AbstractIntegralAlgorithm, T} <:
batch::Int64
end

function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-6, abstol = 1e-3,
function QuadratureTraining(; quadrature_alg = CubatureJLh(), reltol = 1e-3, abstol = 1e-6,
maxiters = 1_000, batch = 100)
QuadratureTraining(quadrature_alg, reltol, abstol, maxiters, batch)
end
Expand Down Expand Up @@ -306,11 +306,7 @@ function get_loss_function(loss_function, lb, ub, eltypeθ, strategy::Quadrature
end
area = eltypeθ(prod(abs.(ub .- lb)))
f_ = (lb, ub, loss_, θ) -> begin
# last_x = 1
function integrand(x, θ)
# last_x = x
# mean(abs2,loss_(x,θ), dims=2)
# size_x = fill(size(x)[2],(1,1))
x = adapt(parameterless_type(ComponentArrays.getdata(θ)), x)
sum(abs2, view(loss_(x, θ), 1, :), dims = 2) #./ size_x
end
Expand Down
Loading

0 comments on commit 0cb2e06

Please sign in to comment.