Skip to content

Commit

Permalink
pure PINO with DeepOnet
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed May 2, 2024
1 parent 2cc1d1f commit f873541
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 445 deletions.
3 changes: 2 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ pages = ["index.md",
"manual/training_strategies.md",
"manual/adaptive_losses.md",
"manual/logging.md",
"manual/neural_adapters.md"],
"manual/neural_adapters.md",
"manual/pino_ode.md"],
"Developer Documentation" => Any["developer/debugging.md"]
]
12 changes: 1 addition & 11 deletions docs/src/manual/pino_ode.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,7 @@ PINOODE
```

```@docs
TRAINSET
DeepONet
```

```@docs
PINOsolution
```

```@docs
OperatorLearning
```

```@docs
EquationSolving
```
127 changes: 46 additions & 81 deletions docs/src/tutorials/pino_ode.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# Physics informed Neural Operator ODEs Solvers
# Physics Informed Neural Operator for ODEs Solvers

This tutorial is an introduction to using physics-informed neural operator (PINOs) for solving family of parametric ordinary diferential equations (ODEs).

#TODO two phase
This tutorial provides an example of how using Physics Informed Neural Operator (PINO) for solving family of parametric ordinary differential equations (ODEs).

## Operator Learning for a family of parametric ODE.

Expand All @@ -11,95 +9,62 @@ using Test
using OrdinaryDiffEq, OptimizationOptimisers
using Lux
using Statistics, Random
# using NeuralOperators
using NeuralPDE
linear_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
linear = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 2.0f0)
u0 = 0.0f0
p = pi / 2f0
prob = ODEProblem(linear, u0, tspan, p)
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
prob = ODEProblem(equation, u0, tspan)
# initilize DeepONet operator

Check warning on line 19 in docs/src/tutorials/pino_ode.md

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"initilize" should be "initialize".
branch = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10))
trunk = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast))
deeponet = NeuralPDE.DeepONet(branch, trunk; linear = nothing)
bounds = (p = [0.1f0, pi],)
#TODO add truct

Check warning on line 32 in docs/src/tutorials/pino_ode.md

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"truct" should be "struct".
strategy = (branch_size = 50, trunk_size = 40)
# strategy = (branch_size = 50, dt = 0.1)?
opt = OptimizationOptimisers.Adam(0.03)
alg = NeuralPDE.PINOODE(deeponet, opt, bounds; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 2000)
```

Generate a dataset for learning a given family of ODEs where the parameter 'a' is varied. The dataset is generated by solving the ODE for different values of 'a' and storing the solutions. The dataset is then used to train the PINO model:
* input data: set of parameters 'a',
* output data: set of solutions u(t){a} corresponding parameter 'a'.
Now let's compare the prediction from the learned operator with the ground truth solution which is obtained by analytic solution the parametric ODE. Where
Compare prediction with ground truth.

```@example pino
t0, t_end = tspan
instances_size = 50
range_ = range(t0, stop = t_end, length = instances_size)
ts = reshape(collect(range_), 1, instances_size)
batch_size = 50
as = [Float32(i) for i in range(0.1, stop = pi / 2, length = batch_size)]
u_output_ = zeros(Float32, 1, instances_size, batch_size)
prob_set = []
for (i, a_i) in enumerate(as)
prob_ = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0, tspan, a_i)
sol1 = solve(prob_, Tsit5(); saveat = 0.0204)
reshape_sol = Float32.(reshape(sol1(range_).u', 1, instances_size, 1))
push!(prob_set, prob_)
u_output_[:, :, i] = reshape_sol
end
train_set = TRAINSET(prob_set, u_output_)
using Plots
# Compute the ground truth solution for each parameter value and time in the solution
# The '.' operator is used to apply the functd ion element-wise
ground_analytic = (u0, p, t) -> begin u0 + sin(p * t) / (p)
p_ = range(bounds.p[1], stop = bounds.p[2], length = strategy.branch_size)
p = reshape(p_, 1, branch_size, 1)
ground_solution = ground_analytic.(u0, p, sol.t.trunk)
# Plot the predicted solution and the ground truth solution as a filled contour plot
# sol.u[1, :, :], represents the predicted solution for each parameter value and time
plot(sol.u[1, :, :], linetype = :contourf)
plot!(ground_solution[1, :, :], linetype = :contourf)
```

Here it used the PINO method to learning operator of the given family of parametric ODEs.

```@example pino
chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 32, Lux.σ),
Lux.Dense(32, 32, Lux.σ),
Lux.Dense(32, 1))
# flat_no = FourierNeuralOperator(ch = (2, 16, 16, 16, 16, 16, 32, 1), modes = (16,),
# σ = gelu)
opt = OptimizationOptimisers.Adam(0.01)
pino_phase = OperatorLearning(train_set, is_data_loss = true, is_physics_loss = true)
alg = PINOODE(chain, opt, pino_phase)
pino_solution = solve(
prob, alg, verbose = false, maxiters = 3000)
predict = pino_solution.predict
ground = u_output_
```
using Plots
Now let's compare the predictions from the learned operator with the ground truth solution which is obtained early by numerically solving the parametric ODE. Where 'i' is the index of the parameter 'a' in the dataset.
# 'i' is the index of the parameter 'a' in the dataset
i = 45
```@example pino
using Plots
i=45
# 'predict' is the predicted solution from the PINO model
# 'ground' is the ground truth solution
plot(predict[1, :, i], label = "Predicted")
plot!(ground[1, :, i], label = "Ground truth")
```

Now to move on the stage of solving a certain equation using a trained operator and physics

## Solve ODE using learned operator family of parametric ODE for fine tuning.
```@example pino
dt = (t_end - t0) / instances_size
pino_phase = EquationSolving(dt, pino_solution)
chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 32, Lux.σ),
Lux.Dense(32, 1))
alg = PINOODE(chain, opt, pino_phase)
fine_tune_solution = solve( prob, alg, verbose = false, maxiters = 2000)
fine_tune_predict = fine_tune_solution.predict
operator_predict = pino_solution.phi(
fine_tune_solution.input_data_set, pino_solution.res.u)
ground_fine_tune = linear_analytic.(u0, p, fine_tune_solution.input_data_set[[1], :, :])
```

Compare prediction with ground truth.

```@example pino
plot(operator_predict[1, :, 1], label = "operator_predict")
plot!(fine_tune_predict[1, :, 1], label = "fine_tune_predict")
plot!(ground_fine_tune[1, :, 1], label = "Ground truth")
```
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ include("PDE_BPINN.jl")
include("dgm.jl")


export NNODE, NNDAE, PINOODE, DeepONet
export NNODE, NNDAE, PINOODE, DeepONet, SomeStrategy #TODO remove SomeStrategy
PhysicsInformedNN, discretize,
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining,
WeightedIntervalTraining,
Expand Down
105 changes: 80 additions & 25 deletions src/neural_operators.jl
Original file line number Diff line number Diff line change
@@ -1,47 +1,102 @@
#TODO: Add docstrings
abstract type NeuralOperator <: Lux.AbstractExplicitLayer end

"""
DeepONet(branch,trunk)
"""
struct DeepONet{} <: Lux.AbstractExplicitLayer

"""
DeepONet(branch,trunk,linear=nothing)
`DeepONet` is differential neural operator focused for solving physic-informed parametric ODEs.
DeepONet uses two neural networks, referred to as the "branch" and "trunk", to approximate
the solution of a differential equation. The branch network takes the spatial variables as
input and the trunk network takes the temporal variables as input. The final output is
the dot product of the outputs of the branch and trunk networks.
DeepONet is composed of two separate neural networks referred to as the "branch" and "trunk",
respectively. The branch net takes on input represents a function evaluated at a collection
of fixed locations in some boundsand returns a features embedding. The trunk net takes the
continuous coordinates as inputs, and outputs a features embedding. The final output of the
DeepONet, the outputs of the branch and trunk networks are merged together via a dot product.
## Positional Arguments
* `branch`: A branch neural network.
* `trunk`: A trunk neural network.
## Keyword Arguments
* `linear`: A linear layer to apply to the output of the branch and trunk networks.
## Example
```julia
branch = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10))
trunk = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast))
linear = Lux.Chain(Lux.Dense(10, 1))
deeponet = DeepONet(branch, trunk; linear= linear)
a = rand(1, 50, 40)
b = rand(1, 1, 40)
x = (branch = a, trunk = b)
θ, st = Lux.setup(Random.default_rng(), deeponet)
y, st = deeponet(x, θ, st)
```
## References
* Lu Lu, Pengzhan Jin, George Em Karniadakis "DeepONet: Learning nonlinear operators for identifying differential equations based on the universal approximation theorem of operators"
* Sifan Wang "Learning the solution operator of parametric partial differential equations with physics-informed DeepOnets"
"""

struct DeepONet{L <: Union{Nothing, Lux.AbstractExplicitLayer }} <: NeuralOperator
branch::Lux.AbstractExplicitLayer
trunk::Lux.AbstractExplicitLayer
linear::L
end

function DeepONet(branch, trunk; linear=nothing)
DeepONet(branch, trunk, linear)
end

function Lux.setup(rng::AbstractRNG, l::DeepONet)
branch, trunk = l.branch, l.trunk
branch, trunk, linear = l.branch, l.trunk, l.linear
θ_branch, st_branch = Lux.setup(rng, branch)
θ_trunk, st_trunk = Lux.setup(rng, trunk)
θ = (branch = θ_branch, trunk = θ_trunk)
st = (branch = st_branch, trunk = st_trunk)
if linear !== nothing
θ_liner, st_liner = Lux.setup(rng, linear)
θ =..., liner = θ_liner)
st = (st..., liner = st_liner)
end
θ, st
end

# function Lux.initialparameters(rng::AbstractRNG, e::DeepONet)
# code
# end

Lux.initialstates(::AbstractRNG, ::DeepONet) = NamedTuple()

"""
example:
branch = Lux.Chain(Lux.Dense(1, 32, Lux.σ), Lux.Dense(32, 1))
trunk = Lux.Chain(Lux.Dense(1, 32, Lux.σ), Lux.Dense(32, 1))
a = rand(1, 100, 10)
t = rand(1, 1, 10)
x = (branch = a, trunk = t)
deeponet = DeepONet(branch, trunk)
θ, st = Lux.setup(Random.default_rng(), deeponet)
y = deeponet(x, θ, st)
"""
@inline function (f::DeepONet)(x::NamedTuple, θ, st::NamedTuple)
parameters, cord = x.branch, x.trunk
x_branch, x_trunk = x.branch, x.trunk
branch, trunk = f.branch, f.trunk
st_branch, st_trunk = st.branch, st.trunk
θ_branch, θ_trunk = θ.branch, θ.trunk
out_b, st_b = branch(parameters, θ_branch, st_branch)
out_t, st_t = trunk(cord, θ_trunk, st_trunk)
out = out_b' * out_t
return out, (branch = st_b, trunk = st_t)
out_b, st_b = branch(x_branch, θ_branch, st_branch)
out_t, st_t = trunk(x_trunk, θ_trunk, st_trunk)
if f.linear !== nothing
linear = f.linear
θ_liner, st_liner = θ.liner, st.liner
# out = sum(out_b .* out_t, dims = 1)
out_ = out_b .* out_t
out, st_liner = linear(out_, θ_liner, st_liner)
out = sum(out, dims = 1)
return out, (branch = st_b, trunk = st_t, liner = st_liner)
else
out = sum(out_b .* out_t, dims = 1)
return out, (branch = st_b, trunk = st_t)
end
end
Loading

0 comments on commit f873541

Please sign in to comment.