Skip to content

Commit

Permalink
add StochasticTraining
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Jun 26, 2024
1 parent 30a5134 commit 2818f34
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 32 deletions.
56 changes: 38 additions & 18 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#TODO rewrite doc strings
"""
PINOODE(chain,
OptimizationOptimisers.Adam(0.1),
Expand Down Expand Up @@ -74,14 +75,14 @@ function generate_pino_phi_θ(chain::Lux.AbstractExplicitLayer, init_params)
PINOPhi(chain, st), init_params
end

function (f::PINOPhi{C, T})(x, θ) where {C, T} #C <: NeuralOperator
function (f::PINOPhi{C, T})(
x, θ) where {C <: Lux.AbstractExplicitLayer, T}
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
y
end

#TODO migrate to LuxNeuralOperators.DeepONet
function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ) where {C, T} #C <: DeepONet
function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ) where {C <: CompactLuxLayer{:DeepONet,}, T}
p, t = x
branch_left, branch_right = p, p
trunk_left, trunk_right = t .+ sqrt(eps(eltype(t))), t
Expand All @@ -91,14 +92,13 @@ function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ) where {C, T} #C <: DeepONet
end

function physics_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where {C, T} #C <: DeepONet
phi::PINOPhi{C, T}, prob::ODEProblem, x::Tuple, θ) where {
C <: CompactLuxLayer{:DeepONet,}, T}
p, t = x
f = prob.f
# x = (p, t)
out = phi(x, θ)
if size(p)[1] == 1
fs = f.(out, p, vec(t))
# fs = f.(0, p, vec(t))
f_vec = vec(fs)
else
f_vec = reduce(
Expand All @@ -110,10 +110,10 @@ function physics_loss(
end

function initial_condition_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C, T} #C <: DeepONet #TODO migrate to LuxNeuralOperators.DeepONet
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {
C <: CompactLuxLayer{:DeepONet,}, T}
p, t = x
t0 = t[:, [1], :]
# pfs0 = pfs[:, :, [1]]
t0 = reshape([prob.tspan[1]], (1, 1, 1)) # t[:, [1], :]
x0 = (p, t0)
out = phi(x0, θ)
u = vec(out)
Expand Down Expand Up @@ -147,14 +147,35 @@ function generate_loss(
end
end

function get_trainset(strategy::StochasticTraining, bounds, number_of_parameters, tspan)
if size(bounds)[1] == 1
bound = bounds[1]
p = (bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1]
else
p = reduce(vcat,
[(bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1]
for bound in bounds])
end
t = (tspan[2] .- tspan[1]) .* rand(1, strategy.points,1) .+ tspan[1]
(p, t)
end

function generate_loss(
strategy::QuasiRandomTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan)
#TODO
strategy::StochasticTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan)
function loss(θ, _)
x = get_trainset(strategy, bounds, number_of_parameters, tspan)
initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ)
end
end

struct PINOODEInterpolation{T <: PINOPhi, T2}
phi::T
θ::T2
end

#TODO
# (f::NNODEInterpolation)(t, ...) = f.phi(t, f.θ)

function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
alg::PINOODE,
args...;
Expand All @@ -166,10 +187,9 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
@unpack tspan, u0, f = prob
@unpack chain, opt, bounds, number_of_parameters, init_params, strategy, additional_loss = alg

#TODO migrate to LuxNeuralOperators.DeepONet
# if !isa(chain, DeepONet)
# error("Only DeepONet neural networks are supported")
# end
if !isa(chain, CompactLuxLayer{:DeepONet,})
error("Only DeepONet neural networks are supported with PINO ODE")
end

!(chain isa Lux.AbstractExplicitLayer) &&
error("Only Lux.AbstractExplicitLayer neural networks are supported")
Expand All @@ -181,7 +201,6 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,

try
in_dim = chain.layers.branch.layers.layer_1.in_dims
# x = (branch = rand(in_dim, 10, 10), trunk = rand(1, 1, 10))
u = rand(in_dim, number_of_parameters)
v = rand(1, 10, 1)
x = (u, v)
Expand All @@ -197,8 +216,8 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
if strategy === nothing
dt = (tspan[2] - tspan[1]) / 50
strategy = GridTraining(dt)
elseif !(strategy isa GridTraining || strategy isa QuasiRandomTraining)
throw(ArgumentError("Only GridTraining and QuasiRandomTraining strategy is supported"))
elseif !(strategy isa GridTraining || strategy isa StochasticTraining)
throw(ArgumentError("Only GridTraining and StochasticTraining strategy is supported"))
end

inner_f = generate_loss(strategy, prob, phi, bounds, number_of_parameters, tspan)
Expand Down Expand Up @@ -233,6 +252,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,

sol = SciMLBase.build_solution(prob, alg, x, u;
k = res, dense = true,
interp = PINOODEInterpolation(phi, res.u),
calculate_error = false,
retcode = ReturnCode.Success,
original = res,
Expand Down
26 changes: 12 additions & 14 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,26 +31,24 @@ using NeuralPDE

bounds = [(pi, 2pi)]
number_of_parameters = 50
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining(dt)
# strategy = QuasiRandomTraining(50)
opt = OptimizationOptimisers.Adam(0.03)
# dt = (tspan[2] - tspan[1]) / 40
# strategy = GridTraining(dt)
strategy = StochasticTraining(40)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 3000)

sol = solve(prob, alg, verbose = true, maxiters = 5000)
sol.original.objective
# TODO intrepretation output another mesh
# x = (branch = p, trunk = t)
# phi(sol.original.u)
# sol.
# TODO intrepretation output with few mesh

Check warning on line 41 in test/PINO_ode_tests.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"intrepretation" should be "interpretation".
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
#TDOD another number_of_parameters
p_ = range(start = bounds[1][1], length = number_of_parameters, stop = bounds[1][2])
p = collect(reshape(p_, 1, size(p_)[1]))
ground_solution = ground_analytic.(u0, p, vec(sol.t[2]))
t_ = collect(tspan[1]:dt:tspan[2])
t = collect(reshape(t_, 1, size(t_)[1], 1))
ground_solution = ground_analytic.(u0, p, t_)
predict_sol = sol.interp.phi((p,t), sol.interp.θ)

@test ground_solutionsol.u rtol=0.1
@test ground_solutionsol.u rtol=0.01
@test ground_solutionpredict_sol rtol=0.1
@test ground_solutionpredict_sol rtol=0.01
end

@testset "Example du = cos(p * t) + u" begin
Expand Down

0 comments on commit 2818f34

Please sign in to comment.