Skip to content

Commit

Permalink
support u0 is param
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Jun 14, 2024
1 parent ea7c638 commit 09f4891
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 7 deletions.
18 changes: 13 additions & 5 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,23 +103,31 @@ function physics_loss(
end

function inital_condition_loss(

Check warning on line 105 in src/pino_ode_solve.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"inital" should be "initial".
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ) where {C <: DeepONet, T}
phi::PINOPhi{C, T}, prob::ODEProblem, x, θ, bounds) where {C <: DeepONet, T}
p, t = x
f = prob.f
t0 = t[:, :, [1]]
f_0 = f.(0, p, t0)
tuple = (branch = f_0, trunk = t0)
out = phi(tuple, θ)
u = vec(out)
u0_ = fill(prob.u0, size(out))
u0 = vec(u0_)
#TODO
if any(in(keys(bounds)), (:u0,))
u0_ = p
u0 = p
else
u0_ = fill(prob.u0, size(out))
u0 = vec(u0_)
end
norm = prod(size(u0_))
sum(abs2, u .- u0) / norm
end

function get_trainset(strategy::GridTraining, bounds, tspan)
db, dt = strategy.dx
p_ = bounds.p[1]:db:bounds.p[2]
v = values(bounds)[1]
#TODO for all v
p_ = v[1]:db:v[2]
p = reshape(p_, 1, size(p_)[1], 1)
t_ = collect(tspan[1]:dt:tspan[2])
t = reshape(t_, 1, 1, size(t_)[1])
Expand All @@ -129,7 +137,7 @@ end
function generate_loss(strategy::GridTraining, prob::ODEProblem, phi, bounds, tspan)
x = get_trainset(strategy, bounds, tspan)
function loss(θ, _)
inital_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ)
inital_condition_loss(phi, prob, x, θ, bounds) + physics_loss(phi, prob, x, θ)

Check warning on line 140 in src/pino_ode_solve.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"inital" should be "initial".
end
end

Expand Down
87 changes: 85 additions & 2 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ using NeuralPDE
end

@testset "Example du = cos(p * t) + u" begin
eq(u, p, t) = cos(p * t) + u
eq_(u, p, t) = cos(p * t) + u
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
prob = ODEProblem(eq, u0, tspan)
prob = ODEProblem(eq_, u0, tspan)
branch = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Expand Down Expand Up @@ -140,3 +140,86 @@ end

@test ground_solutionsol.u rtol=0.005
end

#u0
@testset "Example du = cos(p * t)" begin
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
prob = ODEProblem(equation, u0, tspan)

branch = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10))
trunk = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast))

deeponet = NeuralPDE.DeepONet(branch, trunk; linear = nothing)

bounds = (u0 = [0.5f0, 2.f0],)
db = (bounds.u0[2] - bounds.u0[1]) / 50
dt = (tspan[2] - tspan[1]) / 40
strategy = NeuralPDE.GridTraining([db, dt])
opt = OptimizationOptimisers.Adam(0.03)
alg = NeuralPDE.PINOODE(deeponet, opt, bounds; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 2000)

ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
p_ = bounds.u0[1]:strategy.dx[1]:bounds.u0[2]
p = reshape(p_, 1, size(p_)[1], 1)
ground_solution = ground_analytic.(u0, p, sol.t.trunk)

@test ground_solutionsol.u rtol=0.01
end

plot(sol.u[1, :, :], linetype = :contourf)
plot(ground_solution[1, :, :], linetype = :contourf)

function plot_()
# Animate
anim = @animate for (i) in 1:51
plot(ground_solution[1, i, :], label = "Ground")
# plot(equation_[1, i, :], label = "equation")
plot!(sol.u[1, i, :], label = "Predicted")
end
gif(anim, "pino.gif", fps = 15)
end

plot_()

#vector outputs and multiple parameters
@testset "Example du = cos(p * t)" begin
equation = (u, p, t) -> cos(p1 * t) + p2
tspan = (0.0f0, 1.0f0)
u0 = 1.0f0
prob = ODEProblem(equation, u0, tspan)

branch = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10))
trunk = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast),
Lux.Dense(10, 10, Lux.tanh_fast))

deeponet = NeuralPDE.DeepONet(branch, trunk; linear = nothing)

bounds = (p1 = [0.1f0, pi], p2 = [0.1f0, 2.f0], u0 = [0.0f0, 2.0f0])
db = (bounds.u0[2] - bounds.u0[1]) / 50
dt = (tspan[2] - tspan[1]) / 40
strategy = NeuralPDE.GridTraining([db, dt])
opt = OptimizationOptimisers.Adam(0.03)
alg = NeuralPDE.PINOODE(deeponet, opt, bounds; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 2000)
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)

p_ = bounds.p[1]:strategy.dx[1]:bounds.p[2]
p = reshape(p_, 1, size(p_)[1], 1)
ground_solution = ground_analytic.(u0, p, sol.t.trunk)

@test ground_solutionsol.u rtol=0.01
end

0 comments on commit 09f4891

Please sign in to comment.