-
-
Notifications
You must be signed in to change notification settings - Fork 202
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #802 from ayushinav/master
add Deep Galerkin method
- Loading branch information
Showing
8 changed files
with
460 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,8 @@ jobs: | |
- AdaptiveLoss | ||
- Logging | ||
- Forward | ||
- NeuralAdapter | ||
- DGM | ||
version: | ||
- "1" | ||
steps: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
## Solving PDEs using Deep Galerkin Method | ||
|
||
### Overview | ||
|
||
Deep Galerkin Method is a meshless deep learning algorithm to solve high dimensional PDEs. The algorithm does so by approximating the solution of a PDE with a neural network. The loss function of the network is defined in the similar spirit as PINNs, composed of PDE loss and boundary condition loss. | ||
|
||
In the following example, we demonstrate computing the loss function using Quasi-Random Sampling, a sampling technique that uses quasi-Monte Carlo sampling to generate low discrepancy random sequences in high dimensional spaces. | ||
|
||
### Algorithm | ||
The authors of DGM suggest a network composed of LSTM-type layers that works well for most of the parabolic and quasi-parabolic PDEs. | ||
|
||
```math | ||
\begin{align*} | ||
S^1 &= \sigma_1(W^1 \vec{x} + b^1); \\ | ||
Z^l &= \sigma_1(U^{z,l} \vec{x} + W^{z,l} S^l + b^{z,l}); \quad l = 1, \ldots, L; \\ | ||
G^l &= \sigma_1(U^{g,l} \vec{x} + W^{g,l} S_l + b^{g,l}); \quad l = 1, \ldots, L; \\ | ||
R^l &= \sigma_1(U^{r,l} \vec{x} + W^{r,l} S^l + b^{r,l}); \quad l = 1, \ldots, L; \\ | ||
H^l &= \sigma_2(U^{h,l} \vec{x} + W^{h,l}(S^l \cdot R^l) + b^{h,l}); \quad l = 1, \ldots, L; \\ | ||
S^{l+1} &= (1 - G^l) \cdot H^l + Z^l \cdot S^{l}; \quad l = 1, \ldots, L; \\ | ||
f(t, x; \theta) &= \sigma_{out}(W S^{L+1} + b). | ||
\end{align*} | ||
``` | ||
|
||
where $\vec{x}$ is the concatenated vector of $(t, x)$ and $L$ is the number of LSTM type layers in the network. | ||
|
||
### Example | ||
|
||
Let's try to solve the following Burger's equation using Deep Galerkin Method for $\alpha = 0.05$ and compare our solution with the finite difference method: | ||
|
||
$$ | ||
\partial_t u(t, x) + u(t, x) \partial_x u(t, x) - \alpha \partial_{xx} u(t, x) = 0 | ||
$$ | ||
|
||
defined over | ||
$$ t \in [0, 1], x \in [-1, 1] $$ | ||
|
||
with boundary conditions | ||
```math | ||
\begin{align*} | ||
u(t, x) & = - sin(πx), \\ | ||
u(t, -1) & = 0, \\ | ||
u(t, 1) & = 0 | ||
\end{align*} | ||
``` | ||
|
||
### Copy- Pasteable code | ||
```julia | ||
using NeuralPDE | ||
using ModelingToolkit, Optimization, OptimizationOptimisers | ||
import Lux: tanh, identity | ||
using Distributions | ||
import ModelingToolkit: Interval, infimum, supremum | ||
using MethodOfLines, OrdinaryDiffEq | ||
|
||
@parameters x t | ||
@variables u(..) | ||
|
||
Dt= Differential(t) | ||
Dx= Differential(x) | ||
Dxx= Dx^2 | ||
α = 0.05; | ||
# Burger's equation | ||
eq= Dt(u(t,x)) + u(t,x) * Dx(u(t,x)) - α * Dxx(u(t,x)) ~ 0 | ||
|
||
# boundary conditions | ||
bcs= [ | ||
u(0.0, x) ~ - sin(π*x), | ||
u(t, -1.0) ~ 0.0, | ||
u(t, 1.0) ~ 0.0 | ||
] | ||
|
||
domains = [t ∈ Interval(0.0, 1.0), x ∈ Interval(-1.0, 1.0)] | ||
|
||
# MethodOfLines, for FD solution | ||
dx= 0.01 | ||
order = 2 | ||
discretization = MOLFiniteDifference([x => dx], t, saveat = 0.01) | ||
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]) | ||
prob = discretize(pde_system, discretization) | ||
sol= solve(prob, Tsit5()) | ||
ts = sol[t] | ||
xs = sol[x] | ||
|
||
u_MOL = sol[u(t,x)] | ||
|
||
# NeuralPDE, using Deep Galerkin Method | ||
strategy = QuasiRandomTraining(4_000, minibatch= 500); | ||
discretization= DeepGalerkin(2, 1, 50, 5, tanh, tanh, identity, strategy); | ||
@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]); | ||
prob = discretize(pde_system, discretization); | ||
global iter = 0; | ||
callback = function (p, l) | ||
global iter += 1; | ||
if iter%20 == 0 | ||
println("$iter => $l") | ||
end | ||
return false | ||
end | ||
|
||
res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 300); | ||
phi = discretization.phi; | ||
|
||
u_predict= [first(phi([t, x], res.minimizer)) for t in ts, x in xs] | ||
|
||
diff_u = abs.(u_predict .- u_MOL); | ||
|
||
using Plots | ||
p1 = plot(tgrid, xgrid, u_MOL', linetype = :contourf, title = "FD"); | ||
p2 = plot(tgrid, xgrid, u_predict', linetype = :contourf, title = "predict"); | ||
p3 = plot(tgrid, xgrid, diff_u', linetype = :contourf, title = "error"); | ||
plot(p1, p2, p3) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
struct dgm_lstm_layer{F1, F2} <:Lux.AbstractExplicitLayer | ||
activation1::Function | ||
activation2::Function | ||
in_dims::Int | ||
out_dims::Int | ||
init_weight::F1 | ||
init_bias::F2 | ||
end | ||
|
||
function dgm_lstm_layer(in_dims::Int, out_dims::Int, activation1, activation2; | ||
init_weight = Lux.glorot_uniform, init_bias = Lux.zeros32) | ||
return dgm_lstm_layer{typeof(init_weight), typeof(init_bias)}(activation1, activation2, in_dims, out_dims, init_weight, init_bias); | ||
end | ||
|
||
import Lux:initialparameters, initialstates, parameterlength, statelength | ||
|
||
function Lux.initialparameters(rng::AbstractRNG, l::dgm_lstm_layer) | ||
return ( | ||
Uz = l.init_weight(rng, l.out_dims, l.in_dims), | ||
Ug = l.init_weight(rng, l.out_dims, l.in_dims), | ||
Ur = l.init_weight(rng, l.out_dims, l.in_dims), | ||
Uh = l.init_weight(rng, l.out_dims, l.in_dims), | ||
Wz = l.init_weight(rng, l.out_dims, l.out_dims), | ||
Wg = l.init_weight(rng, l.out_dims, l.out_dims), | ||
Wr = l.init_weight(rng, l.out_dims, l.out_dims), | ||
Wh = l.init_weight(rng, l.out_dims, l.out_dims), | ||
bz = l.init_bias(rng, l.out_dims) , | ||
bg = l.init_bias(rng, l.out_dims) , | ||
br = l.init_bias(rng, l.out_dims) , | ||
bh = l.init_bias(rng, l.out_dims) | ||
) | ||
end | ||
|
||
Lux.initialstates(::AbstractRNG, ::dgm_lstm_layer) = NamedTuple() | ||
Lux.parameterlength(l::dgm_lstm_layer) = 4* (l.out_dims * l.in_dims + l.out_dims * l.out_dims + l.out_dims) | ||
Lux.statelength(l::dgm_lstm_layer) = 0 | ||
|
||
function (layer::dgm_lstm_layer)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T | ||
@unpack Uz, Ug, Ur, Uh, Wz, Wg, Wr, Wh, bz, bg, br, bh = ps | ||
Z = layer.activation1.(Uz*x+ Wz*S .+ bz); | ||
G = layer.activation1.(Ug*x+ Wg*S .+ bg); | ||
R = layer.activation1.(Ur*x+ Wr*S .+ br); | ||
H = layer.activation2.(Uh*x+ Wh*(S.*R) .+ bh); | ||
S_new = (1. .- G) .* H .+ Z .* S; | ||
return S_new, st; | ||
end | ||
|
||
struct dgm_lstm_block{L <:NamedTuple} <: Lux.AbstractExplicitContainerLayer{(:layers,)} | ||
layers::L | ||
end | ||
|
||
function dgm_lstm_block(l...) | ||
names = ntuple(i-> Symbol("dgm_lstm_$i"), length(l)); | ||
layers = NamedTuple{names}(l); | ||
return dgm_lstm_block(layers); | ||
end | ||
|
||
dgm_lstm_block(xs::AbstractVector) = dgm_lstm_block(xs...) | ||
|
||
@generated function apply_dgm_lstm_block(layers::NamedTuple{fields}, S::AbstractVecOrMat, x::AbstractVecOrMat, ps, st::NamedTuple) where fields | ||
N = length(fields); | ||
S_symbols = vcat([:S], [gensym() for _ in 1:N]) | ||
x_symbol = :x; | ||
st_symbols = [gensym() for _ in 1:N] | ||
calls = [:(($(S_symbols[i + 1]), $(st_symbols[i])) = layers.$(fields[i])( | ||
$(S_symbols[i]), $(x_symbol), ps.$(fields[i]), st.$(fields[i]))) for i in 1:N] | ||
push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))) | ||
push!(calls, :(return $(S_symbols[N + 1]), st)) | ||
return Expr(:block, calls...) | ||
end | ||
|
||
function (L::dgm_lstm_block)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T | ||
return apply_dgm_lstm_block(L.layers, S, x, ps, st) | ||
end | ||
|
||
struct dgm{S, L, E} <: Lux.AbstractExplicitContainerLayer{(:d_start, :lstm, :d_end)} | ||
d_start::S | ||
lstm:: L | ||
d_end:: E | ||
end | ||
|
||
function (l::dgm)(x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T | ||
|
||
S, st_start = l.d_start(x, ps.d_start, st.d_start); | ||
S, st_lstm = l.lstm(S, x, ps.lstm, st.lstm); | ||
y, st_end = l.d_end(S, ps.d_end, st.d_end); | ||
|
||
st_new = ( | ||
d_start= st_start, | ||
lstm= st_lstm, | ||
d_end= st_end | ||
) | ||
return y, st_new; | ||
|
||
end | ||
|
||
""" | ||
`dgm(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1, activation2, out_activation= Lux.identity)`: | ||
returns the architecture defined for Deep Galerkin method | ||
```math | ||
\\begin{align} | ||
S^1 &= \\sigma_1(W^1 x + b^1); \\ | ||
Z^l &= \\sigma_1(U^{z,l} x + W^{z,l} S^l + b^{z,l}); \\quad l = 1, \\ldots, L; \\ | ||
G^l &= \\sigma_1(U^{g,l} x + W^{g,l} S_l + b^{g,l}); \\quad l = 1, \\ldots, L; \\ | ||
R^l &= \\sigma_1(U^{r,l} x + W^{r,l} S^l + b^{r,l}); \\quad l = 1, \\ldots, L; \\ | ||
H^l &= \\sigma_2(U^{h,l} x + W^{h,l}(S^l \\cdot R^l) + b^{h,l}); \\quad l = 1, \\ldots, L; \\ | ||
S^{l+1} &= (1 - G^l) \\cdot H^l + Z^l \\cdot S^{l}; \\quad l = 1, \\ldots, L; \\ | ||
f(t, x, \\theta) &= \\sigma_{out}(W S^{L+1} + b). | ||
\\end{align} | ||
``` | ||
## Positional Arguments: | ||
`in_dims`: number of input dimensions= (spatial dimension+ 1) | ||
`out_dims`: number of output dimensions | ||
`modes`: Width of the LSTM type layer (output of the first Dense layer) | ||
`layers`: number of LSTM type layers | ||
`activation1`: activation function used in LSTM type layers | ||
`activation2`: activation function used for the output of LSTM type layers | ||
`out_activation`: activation fn used for the output of the network | ||
`kwargs`: additional arguments to be splatted into `PhysicsInformedNN` | ||
""" | ||
function dgm(in_dims::Int, out_dims::Int, modes::Int, layers::Int, activation1, activation2, out_activation) | ||
dgm( | ||
Lux.Dense(in_dims, modes, activation1), | ||
dgm_lstm_block([dgm_lstm_layer(in_dims, modes, activation1, activation2) for i in 1:layers]), | ||
Lux.Dense(modes, out_dims, out_activation) | ||
) | ||
end | ||
|
||
""" | ||
`DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function, | ||
strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...)`: | ||
returns a `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a `PDESystem` into an | ||
`OptimizationProblem` using the Deep Galerkin method. | ||
## Arguments: | ||
`in_dims`: number of input dimensions= (spatial dimension+ 1) | ||
`out_dims`: number of output dimensions | ||
`modes`: Width of the LSTM type layer | ||
`L`: number of LSTM type layers | ||
`activation1`: activation fn used in LSTM type layers | ||
`activation2`: activation fn used for the output of LSTM type layers | ||
`out_activation`: activation fn used for the output of the network | ||
`kwargs`: additional arguments to be splatted into `PhysicsInformedNN` | ||
## Examples | ||
```julia | ||
discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, QuasiRandomTraining(4_000)); | ||
``` | ||
## References | ||
Sirignano, Justin and Spiliopoulos, Konstantinos, "DGM: A deep learning algorithm for solving partial differential equations", | ||
Journal of Computational Physics, Volume 375, 2018, Pages 1339-1364, doi: https://doi.org/10.1016/j.jcp.2018.08.029 | ||
""" | ||
function DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function, strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...) | ||
PhysicsInformedNN( | ||
dgm(in_dims, out_dims, modes, L, activation1, activation2, out_activation), | ||
strategy; kwargs... | ||
) | ||
end |
Oops, something went wrong.