Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Deep Galerkin method #802

Merged
merged 26 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d2214a1
adding Deep Galerkin
ayushinav Feb 8, 2024
dc91cd4
Deep Galerkin for Lux
ayushinav Feb 8, 2024
617e2cd
adding tests for Deep Galerkin
ayushinav Feb 8, 2024
58c236e
SciML style
ayushinav Feb 8, 2024
9f9ef66
Function type stability, docs in Deep Galerkin
ayushinav Feb 9, 2024
1435984
Merge branch 'SciML:master' into master
ayushinav Feb 12, 2024
4640829
docs, test updates in DGM
ayushinav Feb 12, 2024
b0c5c20
fixes from review, DGM
ayushinav Feb 14, 2024
f1ed130
Merge branch 'SciML:master' into master
ayushinav Feb 14, 2024
21e34b3
Merge branch 'SciML:master' into master
ayushinav Feb 18, 2024
d1bade4
minor fix in docs, DGM
ayushinav Feb 18, 2024
095ce56
adding European Call Option in test
ayushinav Feb 23, 2024
82d7830
test fixes
ayushinav Feb 23, 2024
0370d3b
adding tutorials for DGM
ayushinav Feb 24, 2024
5d6d13e
Merge branch 'SciML:master' into master
ayushinav Feb 25, 2024
babd3a9
fixes in DGM tutorial
ayushinav Feb 25, 2024
ef66153
Merge branch 'master' of https://github.com/ayushinav/NeuralPDE.jl
ayushinav Feb 25, 2024
1f0a37d
adding Burger's eqn test for DGM
ayushinav Feb 28, 2024
0440a99
edits from review
ayushinav Feb 29, 2024
24471c0
tolerance check for DGM
ayushinav Feb 29, 2024
4d3d4da
tolerance check fix for DGM
ayushinav Feb 29, 2024
15f8d88
doc updates for DGM
ayushinav Mar 4, 2024
eced4ef
doc updates for DGM
ayushinav Mar 4, 2024
a867862
Merge branch 'dgm'
ayushinav Mar 4, 2024
6189427
Merge branch 'master' of https://github.com/ayushinav/NeuralPDE.jl
ayushinav Mar 4, 2024
c380453
Merge branch 'master' into master
ChrisRackauckas Mar 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ include("neural_adapter.jl")
include("advancedHMC_MCMC.jl")
include("BPINN_ode.jl")
include("PDE_BPINN.jl")
include("deep_galerkin.jl")

export NNODE, NNDAE,
PhysicsInformedNN, discretize,
Expand All @@ -62,6 +63,7 @@ export NNODE, NNDAE,
AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss,
MiniMaxAdaptiveLoss, LogOptions,
ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters,
BPINNsolution, BayesianPINN
BPINNsolution, BayesianPINN,
DeepGalerkin

end # module
179 changes: 179 additions & 0 deletions src/deep_galerkin.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
struct dgm_lstm_layer{F1, F2} <:Lux.AbstractExplicitLayer
activation1::Function
activation2::Function
in_dims::Int
out_dims::Int
init_weight::F1
init_bias::F2
end

function dgm_lstm_layer(in_dims::Int, out_dims::Int, activation1, activation2;
ayushinav marked this conversation as resolved.
Show resolved Hide resolved
init_weight = Lux.glorot_uniform, init_bias = Lux.zeros32)
return dgm_lstm_layer{typeof(init_weight), typeof(init_bias)}(activation1, activation2, in_dims, out_dims, init_weight, init_bias);
end

import Lux:initialparameters, initialstates, parameterlength, statelength

function Lux.initialparameters(rng::AbstractRNG, l::dgm_lstm_layer)
return (
Uz = l.init_weight(rng, l.out_dims, l.in_dims),
Ug = l.init_weight(rng, l.out_dims, l.in_dims),
Ur = l.init_weight(rng, l.out_dims, l.in_dims),
Uh = l.init_weight(rng, l.out_dims, l.in_dims),
Wz = l.init_weight(rng, l.out_dims, l.out_dims),
Wg = l.init_weight(rng, l.out_dims, l.out_dims),
Wr = l.init_weight(rng, l.out_dims, l.out_dims),
Wh = l.init_weight(rng, l.out_dims, l.out_dims),
bz = l.init_bias(rng, l.out_dims) ,
bg = l.init_bias(rng, l.out_dims) ,
br = l.init_bias(rng, l.out_dims) ,
bh = l.init_bias(rng, l.out_dims)
)
end

Lux.initialstates(::AbstractRNG, ::dgm_lstm_layer) = NamedTuple()
Lux.parameterlength(l::dgm_lstm_layer) = 4* (l.out_dims * l.in_dims + l.out_dims * l.out_dims + l.out_dims)
Lux.statelength(l::dgm_lstm_layer) = 0

function (layer::dgm_lstm_layer)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T
@unpack Uz, Ug, Ur, Uh, Wz, Wg, Wr, Wh, bz, bg, br, bh = ps
Z = layer.activation1.(Uz*x+ Wz*S .+ bz);
G = layer.activation1.(Ug*x+ Wg*S .+ bg);
R = layer.activation1.(Ur*x+ Wr*S .+ br);
H = layer.activation2.(Uh*x+ Wh*(S.*R) .+ bh);
S_new = (1. .- G) .* H .+ Z .* S;
return S_new, st;
ayushinav marked this conversation as resolved.
Show resolved Hide resolved
end

struct dgm_lstm_block{L <:NamedTuple} <: Lux.AbstractExplicitContainerLayer{(:layers,)}
layers::L
end

function dgm_lstm_block(l...)
names = ntuple(i-> Symbol("dgm_lstm_$i"), length(l));
layers = NamedTuple{names}(l);
return dgm_lstm_block(layers);
end

dgm_lstm_block(xs::AbstractVector) = dgm_lstm_block(xs...)

@generated function apply_dgm_lstm_block(layers::NamedTuple{fields}, S::AbstractVecOrMat, x::AbstractVecOrMat, ps, st::NamedTuple) where fields
N = length(fields);
S_symbols = vcat([:S], [gensym() for _ in 1:N])
x_symbol = :x;
st_symbols = [gensym() for _ in 1:N]
calls = [:(($(S_symbols[i + 1]), $(st_symbols[i])) = layers.$(fields[i])(
$(S_symbols[i]), $(x_symbol), ps.$(fields[i]), st.$(fields[i]))) for i in 1:N]
push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),)))))
push!(calls, :(return $(S_symbols[N + 1]), st))
return Expr(:block, calls...)
end

function (L::dgm_lstm_block)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T
return apply_dgm_lstm_block(L.layers, S, x, ps, st)
end

struct dgm{S, L, E} <: Lux.AbstractExplicitContainerLayer{(:d_start, :lstm, :d_end)}
d_start::S
lstm:: L
d_end:: E
end

function (l::dgm)(x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T

S, st_start = l.d_start(x, ps.d_start, st.d_start);
S, st_lstm = l.lstm(S, x, ps.lstm, st.lstm);
y, st_end = l.d_end(S, ps.d_end, st.d_end);

st_new = (
d_start= st_start,
lstm= st_lstm,
d_end= st_end
)
return y, st_new;

end
"""
`dgm(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1, activation2, out_activation= Lux.identity)`:
sathvikbhagavan marked this conversation as resolved.
Show resolved Hide resolved
returns the architecture defined for Deep Galerkin method

```math

\\begin{align}
S^1 &= \\sigma_1(W^1 x + b^1); \\
Z^l &= \\sigma_1(U^{z,l} x + W^{z,l} S^l + b^{z,l}); \\quad l = 1, \\ldots, L; \\
G^l &= \\sigma_1(U^{g,l} x + W^{g,l} S_l + b^{g,l}); \\quad l = 1, \\ldots, L; \\
R^l &= \\sigma_1(U^{r,l} x + W^{r,l} S^l + b^{r,l}); \\quad l = 1, \\ldots, L; \\
H^l &= \\sigma_2(U^{h,l} x + W^{h,l}(S^l \\cdot R^l) + b^{h,l}); \\quad l = 1, \\ldots, L; \\
S^{l+1} &= (1 - G^l) \\cdot H^l + Z^l \\cdot S^{l}; \\quad l = 1, \\ldots, L; \\
f(t, x, \\theta) &= \\sigma_{out}(W S^{L+1} + b).
\\end{align}
```

### Arguments:
ayushinav marked this conversation as resolved.
Show resolved Hide resolved

`in_dims`: number of input dimensions= (spatial dimension+ 1)

ayushinav marked this conversation as resolved.
Show resolved Hide resolved
`out_dims`: number of output dimensions

`modes`: Width of the LSTM type layer (output of the first Dense layer)

`layers`: number of LSTM type layers

`activation1`: activation function used in LSTM type layers

`activation2`: activation function used for the output of LSTM type layers

`out_activation`: activation fn used for the output of the network

`kwargs`: additional arguments to be splatted into `PhysicsInformedNN`

### Examples
ayushinav marked this conversation as resolved.
Show resolved Hide resolved

```julia
discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, QuasiRandomTraining(4_000));
```

### References
ayushinav marked this conversation as resolved.
Show resolved Hide resolved

Sirignano, Justin and Spiliopoulos, Konstantinos, "DGM: A deep learning algorithm for solving partial differential equations",
Journal of Computational Physics, Volume 375, 2018, Pages 1339-1364, doi: https://doi.org/10.1016/j.jcp.2018.08.029
"""
function dgm(in_dims::Int, out_dims::Int, modes::Int, layers::Int, activation1, activation2, out_activation)
dgm(
Lux.Dense(in_dims, modes, activation1),
dgm_lstm_block([dgm_lstm_layer(in_dims, modes, activation1, activation2) for i in 1:layers]),
Lux.Dense(modes, out_dims, out_activation)
)
end

"""
`DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function,
strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...)`:
returns a `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a `PDESystem` into an
`OptimizationProblem` using the Deep Galerkin method.

### Arguments:

`in_dims`: number of input dimensions= (spatial dimension+ 1)

`out_dims`: number of output dimensions

`modes`: Width of the LSTM type layer
ayushinav marked this conversation as resolved.
Show resolved Hide resolved

`L`: number of LSTM type layers
ayushinav marked this conversation as resolved.
Show resolved Hide resolved

`activation1`: activation fn used in LSTM type layers
ayushinav marked this conversation as resolved.
Show resolved Hide resolved

`activation2`: activation fn used for the output of LSTM type layers

`out_activation`: activation fn used for the output of the network

`kwargs`: additional arguments to be splatted into `PhysicsInformedNN`
"""
function DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function, strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...)
PhysicsInformedNN(
dgm(in_dims, out_dims, modes, L, activation1, activation2, out_activation),
strategy; kwargs...
)
end
48 changes: 48 additions & 0 deletions test/other_algs_test.jl
ayushinav marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
using NeuralPDE

using ModelingToolkit, Optimization, OptimizationOptimisers
import ModelingToolkit: Interval, infimum, supremum
import Lux: tanh, identity

@testset begin
sathvikbhagavan marked this conversation as resolved.
Show resolved Hide resolved
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Dyy = Differential(y)^2

# 2D PDE
eq = Dxx(u(x, y)) + Dyy(u(x, y)) ~ -sin(pi * x) * sin(pi * y)

# Initial and boundary conditions
bcs = [u(0, y) ~ 0.0, u(1, y) ~ -sin(pi * 1) * sin(pi * y),
u(x, 0) ~ 0.0, u(x, 1) ~ -sin(pi * x) * sin(pi * 1)]
# Space and time domains
domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)]

strategy = QuasiRandomTraining(4_000, minibatch= 500);
discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, strategy);

@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
prob = discretize(pde_system, discretization)

global iter = 0;
callback = function (p, l)
global iter += 1;
if iter%10 == 0
println("$iter => $l")
end
return false
end

res = Optimization.solve(prob, ADAM(0.01); callback = callback, maxiters = 500)
phi = discretization.phi

xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
analytic_sol_func(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2)

u_predict = reshape([first(phi([x, y], res.minimizer)) for x in xs for y in ys],
(length(xs), length(ys)))
u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys],
(length(xs), length(ys)))
@test u_predict≈u_real atol=0.1
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,8 @@ end
if !is_APPVEYOR && GROUP == "GPU"
@safetestset "NNPDE_gpu_Lux" begin include("NNPDE_tests_gpu_Lux.jl") end
end

if GROUP == "All" || GROUP == "Other algos"
ayushinav marked this conversation as resolved.
Show resolved Hide resolved
@time @safetestset "Deep Galerkin solver" begin include("other_algs_test.jl") end
end
end
Loading