Skip to content

Commit

Permalink
Merge pull request #802 from ayushinav/master
Browse files Browse the repository at this point in the history
add Deep Galerkin method
  • Loading branch information
ChrisRackauckas authored Mar 4, 2024
2 parents 389086f + c380453 commit a20efb6
Show file tree
Hide file tree
Showing 8 changed files with 460 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ jobs:
- AdaptiveLoss
- Logging
- Forward
- NeuralAdapter
- DGM
version:
- "1"
steps:
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Symbolics = "5"
Test = "1"
UnPack = "1"
Zygote = "0.6"
MethodOfLines = "0.10.7"
julia = "1.6"

[extras]
Expand All @@ -95,6 +96,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"

[targets]
test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux"]
test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"]
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ pages = ["index.md",
"Bayesian PINNs for Coupled ODEs" => "tutorials/Lotka_Volterra_BPINNs.md",
"PINNs DAEs" => "tutorials/dae.md",
"Parameter Estimation with PINNs for ODEs" => "tutorials/ode_parameter_estimation.md",
"Deep Galerkin Method" => "tutorials/dgm.md"
#"examples/nnrode_example.md", # currently incorrect
],
"PDE PINN Tutorials" => Any["Introduction to NeuralPDE for PDEs" => "tutorials/pdesystem.md",
Expand Down
112 changes: 112 additions & 0 deletions docs/src/tutorials/dgm.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
## Solving PDEs using Deep Galerkin Method

### Overview

Deep Galerkin Method is a meshless deep learning algorithm to solve high dimensional PDEs. The algorithm does so by approximating the solution of a PDE with a neural network. The loss function of the network is defined in the similar spirit as PINNs, composed of PDE loss and boundary condition loss.

In the following example, we demonstrate computing the loss function using Quasi-Random Sampling, a sampling technique that uses quasi-Monte Carlo sampling to generate low discrepancy random sequences in high dimensional spaces.

### Algorithm
The authors of DGM suggest a network composed of LSTM-type layers that works well for most of the parabolic and quasi-parabolic PDEs.

```math
\begin{align*}
S^1 &= \sigma_1(W^1 \vec{x} + b^1); \\
Z^l &= \sigma_1(U^{z,l} \vec{x} + W^{z,l} S^l + b^{z,l}); \quad l = 1, \ldots, L; \\
G^l &= \sigma_1(U^{g,l} \vec{x} + W^{g,l} S_l + b^{g,l}); \quad l = 1, \ldots, L; \\
R^l &= \sigma_1(U^{r,l} \vec{x} + W^{r,l} S^l + b^{r,l}); \quad l = 1, \ldots, L; \\
H^l &= \sigma_2(U^{h,l} \vec{x} + W^{h,l}(S^l \cdot R^l) + b^{h,l}); \quad l = 1, \ldots, L; \\
S^{l+1} &= (1 - G^l) \cdot H^l + Z^l \cdot S^{l}; \quad l = 1, \ldots, L; \\
f(t, x; \theta) &= \sigma_{out}(W S^{L+1} + b).
\end{align*}
```

where $\vec{x}$ is the concatenated vector of $(t, x)$ and $L$ is the number of LSTM type layers in the network.

### Example

Let's try to solve the following Burger's equation using Deep Galerkin Method for $\alpha = 0.05$ and compare our solution with the finite difference method:

$$
\partial_t u(t, x) + u(t, x) \partial_x u(t, x) - \alpha \partial_{xx} u(t, x) = 0
$$

defined over
$$ t \in [0, 1], x \in [-1, 1] $$

with boundary conditions
```math
\begin{align*}
u(t, x) & = - sin(πx), \\
u(t, -1) & = 0, \\
u(t, 1) & = 0
\end{align*}
```

### Copy- Pasteable code
```julia
using NeuralPDE
using ModelingToolkit, Optimization, OptimizationOptimisers
import Lux: tanh, identity
using Distributions
import ModelingToolkit: Interval, infimum, supremum
using MethodOfLines, OrdinaryDiffEq

@parameters x t
@variables u(..)

Dt= Differential(t)
Dx= Differential(x)
Dxx= Dx^2
α = 0.05;
# Burger's equation
eq= Dt(u(t,x)) + u(t,x) * Dx(u(t,x)) - α * Dxx(u(t,x)) ~ 0

# boundary conditions
bcs= [
u(0.0, x) ~ - sin*x),
u(t, -1.0) ~ 0.0,
u(t, 1.0) ~ 0.0
]

domains = [t Interval(0.0, 1.0), x Interval(-1.0, 1.0)]

# MethodOfLines, for FD solution
dx= 0.01
order = 2
discretization = MOLFiniteDifference([x => dx], t, saveat = 0.01)
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)])
prob = discretize(pde_system, discretization)
sol= solve(prob, Tsit5())
ts = sol[t]
xs = sol[x]

u_MOL = sol[u(t,x)]

# NeuralPDE, using Deep Galerkin Method
strategy = QuasiRandomTraining(4_000, minibatch= 500);
discretization= DeepGalerkin(2, 1, 50, 5, tanh, tanh, identity, strategy);
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]);
prob = discretize(pde_system, discretization);
global iter = 0;
callback = function (p, l)
global iter += 1;
if iter%20 == 0
println("$iter => $l")
end
return false
end

res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 300);
phi = discretization.phi;

u_predict= [first(phi([t, x], res.minimizer)) for t in ts, x in xs]

diff_u = abs.(u_predict .- u_MOL);

using Plots
p1 = plot(tgrid, xgrid, u_MOL', linetype = :contourf, title = "FD");
p2 = plot(tgrid, xgrid, u_predict', linetype = :contourf, title = "predict");
p3 = plot(tgrid, xgrid, diff_u', linetype = :contourf, title = "error");
plot(p1, p2, p3)
```
4 changes: 3 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ include("neural_adapter.jl")
include("advancedHMC_MCMC.jl")
include("BPINN_ode.jl")
include("PDE_BPINN.jl")
include("dgm.jl")

export NNODE, NNDAE,
PhysicsInformedNN, discretize,
Expand All @@ -62,6 +63,7 @@ export NNODE, NNDAE,
AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss,
MiniMaxAdaptiveLoss, LogOptions,
ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters,
BPINNsolution, BayesianPINN
BPINNsolution, BayesianPINN,
DeepGalerkin

end # module
174 changes: 174 additions & 0 deletions src/dgm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
struct dgm_lstm_layer{F1, F2} <:Lux.AbstractExplicitLayer
activation1::Function
activation2::Function
in_dims::Int
out_dims::Int
init_weight::F1
init_bias::F2
end

function dgm_lstm_layer(in_dims::Int, out_dims::Int, activation1, activation2;
init_weight = Lux.glorot_uniform, init_bias = Lux.zeros32)
return dgm_lstm_layer{typeof(init_weight), typeof(init_bias)}(activation1, activation2, in_dims, out_dims, init_weight, init_bias);
end

import Lux:initialparameters, initialstates, parameterlength, statelength

function Lux.initialparameters(rng::AbstractRNG, l::dgm_lstm_layer)
return (
Uz = l.init_weight(rng, l.out_dims, l.in_dims),
Ug = l.init_weight(rng, l.out_dims, l.in_dims),
Ur = l.init_weight(rng, l.out_dims, l.in_dims),
Uh = l.init_weight(rng, l.out_dims, l.in_dims),
Wz = l.init_weight(rng, l.out_dims, l.out_dims),
Wg = l.init_weight(rng, l.out_dims, l.out_dims),
Wr = l.init_weight(rng, l.out_dims, l.out_dims),
Wh = l.init_weight(rng, l.out_dims, l.out_dims),
bz = l.init_bias(rng, l.out_dims) ,
bg = l.init_bias(rng, l.out_dims) ,
br = l.init_bias(rng, l.out_dims) ,
bh = l.init_bias(rng, l.out_dims)
)
end

Lux.initialstates(::AbstractRNG, ::dgm_lstm_layer) = NamedTuple()
Lux.parameterlength(l::dgm_lstm_layer) = 4* (l.out_dims * l.in_dims + l.out_dims * l.out_dims + l.out_dims)
Lux.statelength(l::dgm_lstm_layer) = 0

function (layer::dgm_lstm_layer)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T
@unpack Uz, Ug, Ur, Uh, Wz, Wg, Wr, Wh, bz, bg, br, bh = ps
Z = layer.activation1.(Uz*x+ Wz*S .+ bz);
G = layer.activation1.(Ug*x+ Wg*S .+ bg);
R = layer.activation1.(Ur*x+ Wr*S .+ br);
H = layer.activation2.(Uh*x+ Wh*(S.*R) .+ bh);
S_new = (1. .- G) .* H .+ Z .* S;
return S_new, st;
end

struct dgm_lstm_block{L <:NamedTuple} <: Lux.AbstractExplicitContainerLayer{(:layers,)}
layers::L
end

function dgm_lstm_block(l...)
names = ntuple(i-> Symbol("dgm_lstm_$i"), length(l));
layers = NamedTuple{names}(l);
return dgm_lstm_block(layers);
end

dgm_lstm_block(xs::AbstractVector) = dgm_lstm_block(xs...)

@generated function apply_dgm_lstm_block(layers::NamedTuple{fields}, S::AbstractVecOrMat, x::AbstractVecOrMat, ps, st::NamedTuple) where fields
N = length(fields);
S_symbols = vcat([:S], [gensym() for _ in 1:N])
x_symbol = :x;
st_symbols = [gensym() for _ in 1:N]
calls = [:(($(S_symbols[i + 1]), $(st_symbols[i])) = layers.$(fields[i])(
$(S_symbols[i]), $(x_symbol), ps.$(fields[i]), st.$(fields[i]))) for i in 1:N]
push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),)))))
push!(calls, :(return $(S_symbols[N + 1]), st))
return Expr(:block, calls...)
end

function (L::dgm_lstm_block)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T
return apply_dgm_lstm_block(L.layers, S, x, ps, st)
end

struct dgm{S, L, E} <: Lux.AbstractExplicitContainerLayer{(:d_start, :lstm, :d_end)}
d_start::S
lstm:: L
d_end:: E
end

function (l::dgm)(x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T

S, st_start = l.d_start(x, ps.d_start, st.d_start);
S, st_lstm = l.lstm(S, x, ps.lstm, st.lstm);
y, st_end = l.d_end(S, ps.d_end, st.d_end);

st_new = (
d_start= st_start,
lstm= st_lstm,
d_end= st_end
)
return y, st_new;

end

"""
`dgm(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1, activation2, out_activation= Lux.identity)`:
returns the architecture defined for Deep Galerkin method
```math
\\begin{align}
S^1 &= \\sigma_1(W^1 x + b^1); \\
Z^l &= \\sigma_1(U^{z,l} x + W^{z,l} S^l + b^{z,l}); \\quad l = 1, \\ldots, L; \\
G^l &= \\sigma_1(U^{g,l} x + W^{g,l} S_l + b^{g,l}); \\quad l = 1, \\ldots, L; \\
R^l &= \\sigma_1(U^{r,l} x + W^{r,l} S^l + b^{r,l}); \\quad l = 1, \\ldots, L; \\
H^l &= \\sigma_2(U^{h,l} x + W^{h,l}(S^l \\cdot R^l) + b^{h,l}); \\quad l = 1, \\ldots, L; \\
S^{l+1} &= (1 - G^l) \\cdot H^l + Z^l \\cdot S^{l}; \\quad l = 1, \\ldots, L; \\
f(t, x, \\theta) &= \\sigma_{out}(W S^{L+1} + b).
\\end{align}
```
## Positional Arguments:
`in_dims`: number of input dimensions= (spatial dimension+ 1)
`out_dims`: number of output dimensions
`modes`: Width of the LSTM type layer (output of the first Dense layer)
`layers`: number of LSTM type layers
`activation1`: activation function used in LSTM type layers
`activation2`: activation function used for the output of LSTM type layers
`out_activation`: activation fn used for the output of the network
`kwargs`: additional arguments to be splatted into `PhysicsInformedNN`
"""
function dgm(in_dims::Int, out_dims::Int, modes::Int, layers::Int, activation1, activation2, out_activation)
dgm(
Lux.Dense(in_dims, modes, activation1),
dgm_lstm_block([dgm_lstm_layer(in_dims, modes, activation1, activation2) for i in 1:layers]),
Lux.Dense(modes, out_dims, out_activation)
)
end

"""
`DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function,
strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...)`:
returns a `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a `PDESystem` into an
`OptimizationProblem` using the Deep Galerkin method.
## Arguments:
`in_dims`: number of input dimensions= (spatial dimension+ 1)
`out_dims`: number of output dimensions
`modes`: Width of the LSTM type layer
`L`: number of LSTM type layers
`activation1`: activation fn used in LSTM type layers
`activation2`: activation fn used for the output of LSTM type layers
`out_activation`: activation fn used for the output of the network
`kwargs`: additional arguments to be splatted into `PhysicsInformedNN`
## Examples
```julia
discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, QuasiRandomTraining(4_000));
```
## References
Sirignano, Justin and Spiliopoulos, Konstantinos, "DGM: A deep learning algorithm for solving partial differential equations",
Journal of Computational Physics, Volume 375, 2018, Pages 1339-1364, doi: https://doi.org/10.1016/j.jcp.2018.08.029
"""
function DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function, strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...)
PhysicsInformedNN(
dgm(in_dims, out_dims, modes, L, activation1, activation2, out_activation),
strategy; kwargs...
)
end
Loading

0 comments on commit a20efb6

Please sign in to comment.