diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 8662940570..59f046cc9f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -25,6 +25,8 @@ jobs: - AdaptiveLoss - Logging - Forward + - NeuralAdapter + - DGM version: - "1" steps: diff --git a/Project.toml b/Project.toml index 91aa4524f3..836b78eec4 100644 --- a/Project.toml +++ b/Project.toml @@ -82,6 +82,7 @@ Symbolics = "5" Test = "1" UnPack = "1" Zygote = "0.6" +MethodOfLines = "0.10.7" julia = "1.6" [extras] @@ -95,6 +96,7 @@ OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4" [targets] -test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux"] +test = ["Aqua", "Test", "CUDA", "SafeTestsets", "OptimizationOptimJL", "Pkg", "OrdinaryDiffEq", "LineSearches", "LuxCUDA", "Flux", "MethodOfLines"] diff --git a/docs/pages.jl b/docs/pages.jl index c1016b4c23..054b01f9ef 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -3,6 +3,7 @@ pages = ["index.md", "Bayesian PINNs for Coupled ODEs" => "tutorials/Lotka_Volterra_BPINNs.md", "PINNs DAEs" => "tutorials/dae.md", "Parameter Estimation with PINNs for ODEs" => "tutorials/ode_parameter_estimation.md", + "Deep Galerkin Method" => "tutorials/dgm.md" #"examples/nnrode_example.md", # currently incorrect ], "PDE PINN Tutorials" => Any["Introduction to NeuralPDE for PDEs" => "tutorials/pdesystem.md", diff --git a/docs/src/tutorials/dgm.md b/docs/src/tutorials/dgm.md new file mode 100644 index 0000000000..2d5fac5b7e --- /dev/null +++ b/docs/src/tutorials/dgm.md @@ -0,0 +1,112 @@ +## Solving PDEs using Deep Galerkin Method + +### Overview + +Deep Galerkin Method is a meshless deep learning algorithm to solve high dimensional PDEs. The algorithm does so by approximating the solution of a PDE with a neural network. The loss function of the network is defined in the similar spirit as PINNs, composed of PDE loss and boundary condition loss. + +In the following example, we demonstrate computing the loss function using Quasi-Random Sampling, a sampling technique that uses quasi-Monte Carlo sampling to generate low discrepancy random sequences in high dimensional spaces. + +### Algorithm +The authors of DGM suggest a network composed of LSTM-type layers that works well for most of the parabolic and quasi-parabolic PDEs. + +```math +\begin{align*} +S^1 &= \sigma_1(W^1 \vec{x} + b^1); \\ +Z^l &= \sigma_1(U^{z,l} \vec{x} + W^{z,l} S^l + b^{z,l}); \quad l = 1, \ldots, L; \\ +G^l &= \sigma_1(U^{g,l} \vec{x} + W^{g,l} S_l + b^{g,l}); \quad l = 1, \ldots, L; \\ +R^l &= \sigma_1(U^{r,l} \vec{x} + W^{r,l} S^l + b^{r,l}); \quad l = 1, \ldots, L; \\ +H^l &= \sigma_2(U^{h,l} \vec{x} + W^{h,l}(S^l \cdot R^l) + b^{h,l}); \quad l = 1, \ldots, L; \\ +S^{l+1} &= (1 - G^l) \cdot H^l + Z^l \cdot S^{l}; \quad l = 1, \ldots, L; \\ +f(t, x; \theta) &= \sigma_{out}(W S^{L+1} + b). +\end{align*} +``` + +where $\vec{x}$ is the concatenated vector of $(t, x)$ and $L$ is the number of LSTM type layers in the network. + +### Example + +Let's try to solve the following Burger's equation using Deep Galerkin Method for $\alpha = 0.05$ and compare our solution with the finite difference method: + +$$ +\partial_t u(t, x) + u(t, x) \partial_x u(t, x) - \alpha \partial_{xx} u(t, x) = 0 +$$ + +defined over +$$ t \in [0, 1], x \in [-1, 1] $$ + +with boundary conditions +```math +\begin{align*} +u(t, x) & = - sin(πx), \\ +u(t, -1) & = 0, \\ +u(t, 1) & = 0 +\end{align*} +``` + +### Copy- Pasteable code +```julia +using NeuralPDE +using ModelingToolkit, Optimization, OptimizationOptimisers +import Lux: tanh, identity +using Distributions +import ModelingToolkit: Interval, infimum, supremum +using MethodOfLines, OrdinaryDiffEq + +@parameters x t +@variables u(..) + +Dt= Differential(t) +Dx= Differential(x) +Dxx= Dx^2 +α = 0.05; +# Burger's equation +eq= Dt(u(t,x)) + u(t,x) * Dx(u(t,x)) - α * Dxx(u(t,x)) ~ 0 + +# boundary conditions +bcs= [ + u(0.0, x) ~ - sin(π*x), + u(t, -1.0) ~ 0.0, + u(t, 1.0) ~ 0.0 +] + +domains = [t ∈ Interval(0.0, 1.0), x ∈ Interval(-1.0, 1.0)] + +# MethodOfLines, for FD solution +dx= 0.01 +order = 2 +discretization = MOLFiniteDifference([x => dx], t, saveat = 0.01) +@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]) +prob = discretize(pde_system, discretization) +sol= solve(prob, Tsit5()) +ts = sol[t] +xs = sol[x] + +u_MOL = sol[u(t,x)] + +# NeuralPDE, using Deep Galerkin Method +strategy = QuasiRandomTraining(4_000, minibatch= 500); +discretization= DeepGalerkin(2, 1, 50, 5, tanh, tanh, identity, strategy); +@named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]); +prob = discretize(pde_system, discretization); +global iter = 0; +callback = function (p, l) + global iter += 1; + if iter%20 == 0 + println("$iter => $l") + end + return false +end + +res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 300); +phi = discretization.phi; + +u_predict= [first(phi([t, x], res.minimizer)) for t in ts, x in xs] + +diff_u = abs.(u_predict .- u_MOL); + +using Plots +p1 = plot(tgrid, xgrid, u_MOL', linetype = :contourf, title = "FD"); +p2 = plot(tgrid, xgrid, u_predict', linetype = :contourf, title = "predict"); +p3 = plot(tgrid, xgrid, diff_u', linetype = :contourf, title = "error"); +plot(p1, p2, p3) +``` diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index ea4f563b67..d367bf8b6c 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -51,6 +51,7 @@ include("neural_adapter.jl") include("advancedHMC_MCMC.jl") include("BPINN_ode.jl") include("PDE_BPINN.jl") +include("dgm.jl") export NNODE, NNDAE, PhysicsInformedNN, discretize, @@ -62,6 +63,7 @@ export NNODE, NNDAE, AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss, MiniMaxAdaptiveLoss, LogOptions, ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters, - BPINNsolution, BayesianPINN + BPINNsolution, BayesianPINN, + DeepGalerkin end # module diff --git a/src/dgm.jl b/src/dgm.jl new file mode 100644 index 0000000000..ac273b2cbe --- /dev/null +++ b/src/dgm.jl @@ -0,0 +1,174 @@ +struct dgm_lstm_layer{F1, F2} <:Lux.AbstractExplicitLayer + activation1::Function + activation2::Function + in_dims::Int + out_dims::Int + init_weight::F1 + init_bias::F2 +end + +function dgm_lstm_layer(in_dims::Int, out_dims::Int, activation1, activation2; + init_weight = Lux.glorot_uniform, init_bias = Lux.zeros32) + return dgm_lstm_layer{typeof(init_weight), typeof(init_bias)}(activation1, activation2, in_dims, out_dims, init_weight, init_bias); +end + +import Lux:initialparameters, initialstates, parameterlength, statelength + +function Lux.initialparameters(rng::AbstractRNG, l::dgm_lstm_layer) + return ( + Uz = l.init_weight(rng, l.out_dims, l.in_dims), + Ug = l.init_weight(rng, l.out_dims, l.in_dims), + Ur = l.init_weight(rng, l.out_dims, l.in_dims), + Uh = l.init_weight(rng, l.out_dims, l.in_dims), + Wz = l.init_weight(rng, l.out_dims, l.out_dims), + Wg = l.init_weight(rng, l.out_dims, l.out_dims), + Wr = l.init_weight(rng, l.out_dims, l.out_dims), + Wh = l.init_weight(rng, l.out_dims, l.out_dims), + bz = l.init_bias(rng, l.out_dims) , + bg = l.init_bias(rng, l.out_dims) , + br = l.init_bias(rng, l.out_dims) , + bh = l.init_bias(rng, l.out_dims) + ) +end + +Lux.initialstates(::AbstractRNG, ::dgm_lstm_layer) = NamedTuple() +Lux.parameterlength(l::dgm_lstm_layer) = 4* (l.out_dims * l.in_dims + l.out_dims * l.out_dims + l.out_dims) +Lux.statelength(l::dgm_lstm_layer) = 0 + +function (layer::dgm_lstm_layer)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T + @unpack Uz, Ug, Ur, Uh, Wz, Wg, Wr, Wh, bz, bg, br, bh = ps + Z = layer.activation1.(Uz*x+ Wz*S .+ bz); + G = layer.activation1.(Ug*x+ Wg*S .+ bg); + R = layer.activation1.(Ur*x+ Wr*S .+ br); + H = layer.activation2.(Uh*x+ Wh*(S.*R) .+ bh); + S_new = (1. .- G) .* H .+ Z .* S; + return S_new, st; +end + +struct dgm_lstm_block{L <:NamedTuple} <: Lux.AbstractExplicitContainerLayer{(:layers,)} + layers::L +end + +function dgm_lstm_block(l...) + names = ntuple(i-> Symbol("dgm_lstm_$i"), length(l)); + layers = NamedTuple{names}(l); + return dgm_lstm_block(layers); +end + +dgm_lstm_block(xs::AbstractVector) = dgm_lstm_block(xs...) + +@generated function apply_dgm_lstm_block(layers::NamedTuple{fields}, S::AbstractVecOrMat, x::AbstractVecOrMat, ps, st::NamedTuple) where fields + N = length(fields); + S_symbols = vcat([:S], [gensym() for _ in 1:N]) + x_symbol = :x; + st_symbols = [gensym() for _ in 1:N] + calls = [:(($(S_symbols[i + 1]), $(st_symbols[i])) = layers.$(fields[i])( + $(S_symbols[i]), $(x_symbol), ps.$(fields[i]), st.$(fields[i]))) for i in 1:N] + push!(calls, :(st = NamedTuple{$fields}((($(Tuple(st_symbols)...),))))) + push!(calls, :(return $(S_symbols[N + 1]), st)) + return Expr(:block, calls...) +end + +function (L::dgm_lstm_block)(S::AbstractVecOrMat{T}, x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T + return apply_dgm_lstm_block(L.layers, S, x, ps, st) +end + +struct dgm{S, L, E} <: Lux.AbstractExplicitContainerLayer{(:d_start, :lstm, :d_end)} + d_start::S + lstm:: L + d_end:: E +end + +function (l::dgm)(x::AbstractVecOrMat{T}, ps, st::NamedTuple) where T + + S, st_start = l.d_start(x, ps.d_start, st.d_start); + S, st_lstm = l.lstm(S, x, ps.lstm, st.lstm); + y, st_end = l.d_end(S, ps.d_end, st.d_end); + + st_new = ( + d_start= st_start, + lstm= st_lstm, + d_end= st_end + ) + return y, st_new; + +end + +""" +`dgm(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1, activation2, out_activation= Lux.identity)`: +returns the architecture defined for Deep Galerkin method + +```math +\\begin{align} +S^1 &= \\sigma_1(W^1 x + b^1); \\ +Z^l &= \\sigma_1(U^{z,l} x + W^{z,l} S^l + b^{z,l}); \\quad l = 1, \\ldots, L; \\ +G^l &= \\sigma_1(U^{g,l} x + W^{g,l} S_l + b^{g,l}); \\quad l = 1, \\ldots, L; \\ +R^l &= \\sigma_1(U^{r,l} x + W^{r,l} S^l + b^{r,l}); \\quad l = 1, \\ldots, L; \\ +H^l &= \\sigma_2(U^{h,l} x + W^{h,l}(S^l \\cdot R^l) + b^{h,l}); \\quad l = 1, \\ldots, L; \\ +S^{l+1} &= (1 - G^l) \\cdot H^l + Z^l \\cdot S^{l}; \\quad l = 1, \\ldots, L; \\ +f(t, x, \\theta) &= \\sigma_{out}(W S^{L+1} + b). +\\end{align} +``` +## Positional Arguments: +`in_dims`: number of input dimensions= (spatial dimension+ 1) + +`out_dims`: number of output dimensions + +`modes`: Width of the LSTM type layer (output of the first Dense layer) + +`layers`: number of LSTM type layers + +`activation1`: activation function used in LSTM type layers + +`activation2`: activation function used for the output of LSTM type layers + +`out_activation`: activation fn used for the output of the network + +`kwargs`: additional arguments to be splatted into `PhysicsInformedNN` +""" +function dgm(in_dims::Int, out_dims::Int, modes::Int, layers::Int, activation1, activation2, out_activation) + dgm( + Lux.Dense(in_dims, modes, activation1), + dgm_lstm_block([dgm_lstm_layer(in_dims, modes, activation1, activation2) for i in 1:layers]), + Lux.Dense(modes, out_dims, out_activation) + ) +end + +""" +`DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function, + strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...)`: + +returns a `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a `PDESystem` into an + `OptimizationProblem` using the Deep Galerkin method. + +## Arguments: +`in_dims`: number of input dimensions= (spatial dimension+ 1) + +`out_dims`: number of output dimensions + +`modes`: Width of the LSTM type layer + +`L`: number of LSTM type layers + +`activation1`: activation fn used in LSTM type layers + +`activation2`: activation fn used for the output of LSTM type layers + +`out_activation`: activation fn used for the output of the network + +`kwargs`: additional arguments to be splatted into `PhysicsInformedNN` + +## Examples +```julia +discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, QuasiRandomTraining(4_000)); +``` +## References +Sirignano, Justin and Spiliopoulos, Konstantinos, "DGM: A deep learning algorithm for solving partial differential equations", +Journal of Computational Physics, Volume 375, 2018, Pages 1339-1364, doi: https://doi.org/10.1016/j.jcp.2018.08.029 +""" +function DeepGalerkin(in_dims::Int, out_dims::Int, modes::Int, L::Int, activation1::Function, activation2::Function, out_activation::Function, strategy::NeuralPDE.AbstractTrainingStrategy; kwargs...) + PhysicsInformedNN( + dgm(in_dims, out_dims, modes, L, activation1, activation2, out_activation), + strategy; kwargs... + ) +end \ No newline at end of file diff --git a/test/dgm_test.jl b/test/dgm_test.jl new file mode 100644 index 0000000000..32c3cfb2a3 --- /dev/null +++ b/test/dgm_test.jl @@ -0,0 +1,161 @@ +using NeuralPDE, Test + +using ModelingToolkit, Optimization, OptimizationOptimisers, Distributions, MethodOfLines, OrdinaryDiffEq +import ModelingToolkit: Interval, infimum, supremum +import Lux: tanh, identity + +@testset "Poisson's equation" begin + @parameters x y + @variables u(..) + Dxx = Differential(x)^2 + Dyy = Differential(y)^2 + + # 2D PDE + eq = Dxx(u(x, y)) + Dyy(u(x, y)) ~ -sin(pi * x) * sin(pi * y) + + # Initial and boundary conditions + bcs = [u(0, y) ~ 0.0, u(1, y) ~ -sin(pi * 1) * sin(pi * y), + u(x, 0) ~ 0.0, u(x, 1) ~ -sin(pi * x) * sin(pi * 1)] + # Space and time domains + domains = [x ∈ Interval(0.0, 1.0), y ∈ Interval(0.0, 1.0)] + + strategy = QuasiRandomTraining(4_000, minibatch= 500); + discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, strategy); + + @named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)]) + prob = discretize(pde_system, discretization) + + global iter = 0; + callback = function (p, l) + global iter += 1; + if iter%50 == 0 + println("$iter => $l") + end + return false + end + + res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 500) + prob = remake(prob, u0 = res.minimizer) + res = Optimization.solve(prob, Adam(0.001); callback = callback, maxiters = 200) + phi = discretization.phi + + xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains] + analytic_sol_func(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2) + + u_predict = reshape([first(phi([x, y], res.minimizer)) for x in xs for y in ys], + (length(xs), length(ys))) + u_real = reshape([analytic_sol_func(x, y) for x in xs for y in ys], + (length(xs), length(ys))) + @test u_predict≈u_real atol=0.1 +end + +@testset "Black-Scholes PDE: European Call Option" begin + K = 50.0; + T = 1.0; + r = 0.05; + σ = 0.25; + S = 130.0; + S_multiplier = 1.3; + + @parameters x t + @variables g(..) + G(x)= max(x - K , 0.0) + + Dt= Differential(t) + Dx= Differential(x) + Dxx= Dx^2 + + eq= Dt(g(t,x)) + r * x * Dx(g(t,x)) + 0.5 * σ^2 * Dxx(g(t,x)) ~ r * g(t,x) + + bcs= [g(T,x) ~ G(x)] # terminal condition + + domains = [t ∈ Interval(0.0, T), x ∈ Interval(0.0, S * S_multiplier)] + + strategy = QuasiRandomTraining(4_000, minibatch= 500); + discretization= DeepGalerkin(2, 1, 30, 3, tanh, tanh, identity, strategy); + + @named pde_system = PDESystem(eq, bcs, domains, [t, x], [g(t,x)]) + prob = discretize(pde_system, discretization) + + global iter = 0; + callback = function (p, l) + global iter += 1; + if iter%50 == 0 + println("$iter => $l") + end + return false + end + + res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 300) + prob = remake(prob, u0 = res.minimizer) + res = Optimization.solve(prob, Adam(0.001); callback = callback, maxiters = 300) + phi = discretization.phi + + function analytical_soln(t, x, K, σ, T) + d₊ = (log(x/K) + (r + 0.5 * σ^2) * (T - t)) / (σ * sqrt(T - t)) + d₋ = d₊ - (σ * sqrt(T - t)) + return x * cdf(Normal(0,1), d₊) .- K*exp(-r * (T - t))*cdf(Normal(0,1), d₋) + end + analytic_sol_func(t, x) = analytical_soln(t, x, K, σ, T) + + domains2 = [t ∈ Interval(0.0, T - 0.001), x ∈ Interval(0.0, S)] + ts = collect(infimum(domains2[1].domain):0.01:supremum(domains2[1].domain)) + xs = collect(infimum(domains2[2].domain):1.0:supremum(domains2[2].domain)) + + u_real= [analytic_sol_func(t,x) for t in ts, x in xs] + u_predict= [first(phi([t, x], res.minimizer)) for t in ts, x in xs] + @test u_predict ≈ u_real rtol= 0.05 +end + +@testset "Burger's equation" begin + @parameters x t + @variables u(..) + + Dt= Differential(t) + Dx= Differential(x) + Dxx= Dx^2 + α = 0.05; + eq= Dt(u(t,x)) + u(t,x) * Dx(u(t,x)) - α * Dxx(u(t,x)) ~ 0 # Burger's equation + + bcs= [ + u(0.0, x) ~ - sin(π*x), + u(t, -1.0) ~ 0.0, + u(t, 1.0) ~ 0.0 + ] + + domains = [t ∈ Interval(0.0, 1.0), x ∈ Interval(-1.0, 1.0)] + + # MethodOfLines + dx= 0.01 + order = 2 + discretization = MOLFiniteDifference([x => dx], t, saveat = 0.01) + @named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]) + prob = discretize(pde_system, discretization) + sol= solve(prob, Tsit5()) + ts = sol[t] + xs = sol[x] + + u_MOL = sol[u(t,x)] + + # NeuralPDE + strategy = QuasiRandomTraining(4_000, minibatch= 500); + discretization= DeepGalerkin(2, 1, 50, 5, tanh, tanh, identity, strategy); + @named pde_system = PDESystem(eq, bcs, domains, [t, x], [u(t,x)]); + prob = discretize(pde_system, discretization); + global iter = 0; + callback = function (p, l) + global iter += 1; + if iter%20 == 0 + println("$iter => $l") + end + return false + end + + res = Optimization.solve(prob, Adam(0.01); callback = callback, maxiters = 300); + phi = discretization.phi; + + u_predict= [first(phi([t, x], res.minimizer)) for t in ts, x in xs] + + @test u_predict ≈ u_MOL rtol= 0.025 + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e21f554fb4..34bd50ae81 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -64,4 +64,8 @@ end if !is_APPVEYOR && GROUP == "GPU" @safetestset "NNPDE_gpu_Lux" begin include("NNPDE_tests_gpu_Lux.jl") end end + + if GROUP == "All" || GROUP == "DGM" + @time @safetestset "Deep Galerkin solver" begin include("dgm_test.jl") end + end end