Skip to content

Commit

Permalink
vector outputs and multiple parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Jun 20, 2024
1 parent d754e30 commit 214b178
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 77 deletions.
40 changes: 22 additions & 18 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct ParametricFunction{}
function_ ::Union{Nothing, Function}
function_::Union{Nothing, Function}
bounds::Any
end

Expand Down Expand Up @@ -83,12 +83,14 @@ function (f::PINOPhi{C, T})(x::NamedTuple, θ) where {C <: NeuralOperator, T}
end

function dfdx(phi::PINOPhi{C, T}, x::Tuple, θ, prob::ODEProblem) where {C <: DeepONet, T}
# @unpack function_, bounds = parametric_function
# branch_left, branch_right = function_.(p, t), function_.(p, t .+ sqrt(eps(eltype(t))))
pfs, p, t = x
# branch_left, branch_right = pfs, pfs
branch_left, branch_right = p, p
trunk_left, trunk_right = t .+ sqrt(eps(eltype(t))), t
x_left = (branch = pfs, trunk = trunk_left)
x_right = (branch = pfs, trunk = trunk_right)
(phi(x_left, θ) .- phi(x_right, θ)) / sqrt(eps(eltype(t)))
x_left = (branch = branch_left, trunk = trunk_left)
x_right = (branch = branch_right, trunk = trunk_right)
(phi(x_left, θ) .- phi(x_right, θ)) ./ sqrt(eps(eltype(t)))
end

# function physics_loss(
Expand Down Expand Up @@ -122,24 +124,25 @@ function physics_loss(
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
pfs, p, t = x
f = prob.f
du = vec(dfdx(phi, x, θ, prob))
tuple = (branch = pfs, trunk = t)
tuple = (branch = p, trunk = t)
out = phi(tuple, θ)
# if size(p)[1] == 1
if size(p)[1] == 1
fs = f.(out, p, t)
f_ = vec(fs)
# else
# f_ = reduce(vcat,[reduce(vcat, [f(out[i], p[i], t[j]) for i in axes(p, 2)]) for j in axes(t, 3)])
# end
norm = prod(size(out))
sum(abs2, du .- f_) / norm
f_vec= vec(fs)
# out_ = vec(out)
else
f_vec = reduce(vcat,[[f(out[i], p[:, i, 1], t[j]) for i in axes(p, 2)] for j in axes(t, 3)])
end
du = vec(dfdx(phi, x, θ, prob))
norm = prod(size(du))
sum(abs2, du .- f_vec) / norm
end

function initial_condition_loss(phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
pfs, p, t = x
t0 = t[:, :, [1]]
pfs0 = pfs[:, :, [1]]
tuple = (branch = pfs0, trunk = t0)
# pfs0 = pfs[:, :, [1]]
tuple = (branch = p, trunk = t0)
out = phi(tuple, θ)
u = vec(out)
u0 = vec(fill(prob.u0, size(out)))
Expand Down Expand Up @@ -210,7 +213,8 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
throw(error("The PINOODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t)."))

try
x = (branch = rand(1, 10, 10), trunk = rand(1, 1, 10))
in_dim = chain.branch.layers.layer_1.in_dims
x = (branch = rand(in_dim, 10, 10), trunk = rand(1, 1, 10))
phi(x, init_params)
catch err
if isa(err, DimensionMismatch)
Expand Down Expand Up @@ -258,7 +262,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem,
res = solve(optprob, opt; callback, maxiters, alg.kwargs...)

pfs, p, t = get_trainset(strategy, parametric_function, tspan)
tuple = (branch = pfs, trunk = t)
tuple = (branch = p, trunk = t)
u = phi(tuple, res.u)

sol = SciMLBase.build_solution(prob, alg, tuple, u;
Expand Down
120 changes: 61 additions & 59 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using NeuralPDE
# dG(u(t, p), t) = f(G,u(t, p))
@testset "Example du = cos(p * t)" begin
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
tspan = (0.0f0, 2.0f0)
u0 = 1.0f0
prob = ODEProblem(equation, u0, tspan)

Expand All @@ -27,15 +27,17 @@ using NeuralPDE
θ, st = Lux.setup(Random.default_rng(), deeponet)

c = deeponet(x, θ, st)[1]
function_(p, t) = cos(p * t)
bounds = (0.1f0, pi)
function_(p, t) = cos(p*t)
bounds = (pi, 2pi)
parametric_function = ParametricFunction(function_, bounds)
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining(dt)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, parametric_function; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 5000)
sol.original
sol = solve(prob, alg, verbose = true, maxiters = 3000)

phi(tuple, sol.original.u)
sol.original.objective
# TODO intrepretation output another mesh

Check warning on line 41 in test/PINO_ode_tests.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"intrepretation" should be "interpretation".
# x = (branch = p, trunk = t)
# phi(sol.original.u)
Expand All @@ -46,24 +48,9 @@ using NeuralPDE
p = collect(reshape(p_, 1, size(p_)[1], 1))
ground_solution = ground_analytic.(u0, p, sol.t.trunk)

@test ground_solutionsol.u rtol=0.05
end

plot(sol.u[1, :, :], linetype = :contourf)
plot!(ground_solution[1, :, :], linetype = :contourf)

function plot_()
# Animate
anim = @animate for (i) in 1:41
plot(ground_solution[1, i, :], label = "Ground")
# plot(equation_[1, i, :], label = "equation")
plot!(sol.u[1, i, :], label = "Predicted")
end
gif(anim, "pino.gif", fps = 15)
@test ground_solutionsol.u rtol=0.01
end

plot_()

@testset "Example du = cos(p * t) + u" begin
eq_(u, p, t) = cos(p * t) + u
tspan = (0.0f0, 1.0f0)
Expand All @@ -79,34 +66,29 @@ plot_()
Lux.Dense(10, 10, Lux.tanh_fast))

deeponet = DeepONet(branch, trunk)

parametric_function = Parametric_Function(nothing, bounds)

function_(p, t) = cos(p * t)
bounds = (0.1f0, 2.f0)
parametric_function = ParametricFunction(function_, bounds)
sol.original.objective
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining(dt)

opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, bounds; strategy = strategy)

sol = solve(prob, alg, verbose = false, maxiters = 10000)
alg = PINOODE(deeponet, opt, parametric_function; strategy = strategy)

sol = solve(prob, alg, verbose = false, maxiters = 5000)
sol.original.objective
#if u0 == 1
ground_analytic_(u0, p, t) = (p * sin(p * t) - cos(p * t) + (p^2 + 2) * exp(t)) /
(p^2 + 1)

ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
size_of_p = 50
p_ = [range(start = b[1], length = size_of_p, stop = b[2]) for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i)[1], 1)) for p_i in p_]...)
ground_solution = ground_analytic.(u0, p, sol.t.trunk)
p_ = range(start = bounds[1], length = size_of_p, stop = bounds[2])
p = collect(reshape(p_, 1, size(p_)[1], 1))
ground_solution = ground_analytic_.(u0, p, sol.t.trunk)

@test ground_solutionsol.u rtol=0.05
@test ground_solutionsol.u rtol=0.01
end





@testset "Example with data du = p*t^2" begin
equation = (u, p, t) -> p * t^2
tspan = (0.0f0, 1.0f0)
Expand All @@ -122,20 +104,22 @@ end
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast))
linear = Lux.Chain(Lux.Dense(10, 1))
deeponet = NeuralPDE.DeepONet(branch, trunk; linear = linear)
deeponet = DeepONet(branch, trunk; linear = linear)

bounds = (p = [0.0f0, 10.0f0],)
function_(p, t) = cos(p * t)
bounds = (0.0f0, 10.0f0)
parametric_function = ParametricFunction(function_, bounds)

# db = (bounds.p[2] - bounds.p[1]) / 50
dt = (tspan[2] - tspan[1]) / 40
strategy = NeuralPDE.GridTraining([db, dt])
strategy = GridTraining(dt)

opt = OptimizationOptimisers.Adam(0.03)

#generate data
ground_analytic = (u0, p, t) -> u0 + p * t^3 / 3
function get_trainset(branch_size, trunk_size, bounds, tspan)
p_ = range(bounds.p[1], stop = bounds.p[2], length = branch_size)
p_ = range(bounds[1], stop = bounds[1], length = branch_size)
p = reshape(p_, 1, branch_size, 1)
t_ = collect(range(tspan[1], stop = tspan[2], length = trunk_size))
t = reshape(t_, 1, 1, trunk_size)
Expand All @@ -157,11 +141,12 @@ end
norm = prod(size(u))
sum(abs2, u .- data) / norm
end
alg = NeuralPDE.PINOODE(
deeponet, opt, bounds; strategy = strategy, additional_loss = additional_loss_)
sol = solve(prob, alg, verbose = false, maxiters = 2000)
alg = PINOODE(
deeponet, opt, parametric_function; strategy = strategy, additional_loss = additional_loss_)
sol = solve(prob, alg, verbose = true, maxiters = 2000)

p_ = bounds.p[1]:strategy.dx[1]:bounds.p[2]
size_of_p = 50
p_ = range(start = bounds[1], length = size_of_p, stop = bounds[2])
p = reshape(p_, 1, size(p_)[1], 1)
ground_solution = ground_analytic.(u0, p, sol.t.trunk)

Expand All @@ -175,36 +160,53 @@ end
cos(p1 * t) + p2
end

equation = (u, p, t) -> cos(p * t)
equation = (u, p, t) -> p[1]*cos(p[2] * t)
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
prob = ODEProblem(equation, u0, tspan)

branch = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(2, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10))
trunk = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast))

deeponet = DeepONet(branch, trunk; linear = nothing)
# p1 = [0.1f0, pi]; p2 = [0.1f0, 2.0f0]
# bounds = (p = [p1, p2],)
deeponet = DeepONet(branch, trunk)

#TODO add size_of_p = 50
bounds = [[0.1f0, pi], [0.1f0, 2.0f0]]
# db = 0.025f0
function_(p, t) = cos(p * t)
bounds = [(0.1f0, pi), (1.0f0, 2.0f0)]
parametric_function = ParametricFunction(function_, bounds)
dt = (tspan[2] - tspan[1]) / 40
strategy = GridTraining(dt)
opt = OptimizationOptimisers.Adam(0.03)
alg = PINOODE(deeponet, opt, bounds; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 2000)

ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
p_ = bounds.p[1]:strategy.dx[1]:bounds.p[2]
p = reshape(p_, 1, size(p_)[1], 1)
ground_solution = ground_analytic.(u0, p, sol.t.trunk)
alg = PINOODE(deeponet, opt, parametric_function; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 3000)

ga = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t)
p_ = [range(start = b[1], length = size_of_p, stop = b[2]) for b in bounds]
p = vcat([collect(reshape(p_i, 1, size(p_i)[1], 1)) for p_i in p_]...)
ground_solution_ = f_vec = reduce(
hcat, [reduce(
vcat, [ga(u0, p[:, i, 1], t[j]) for i in axes(p, 2)]) for j in axes(t, 3)])
ground_solution = reshape(ground_solution_, 1, size(ground_solution_)...)
@test ground_solutionsol.u rtol=0.01
end

# plot(sol.u[1, :, :], linetype = :contourf)
# plot!(ground_solution[1, :, :], linetype = :contourf)

# function plot_()
# # Animate
# anim = @animate for (i) in 1:41
# plot(ground_solution[1, i, :], label = "Ground")
# # plot(equation_[1, i, :], label = "equation")
# plot!(sol.u[1, i, :], label = "Predicted")
# end
# gif(anim, "pino.gif", fps = 10)
# end

# plot_()

0 comments on commit 214b178

Please sign in to comment.