Skip to content

Commit

Permalink
support chain with StochasticTraining
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Oct 1, 2024
1 parent 000d8b5 commit 5e9a025
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 58 deletions.
37 changes: 25 additions & 12 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ end

function (f::PINOPhi{C, T})(x::Tuple, θ) where {C <: DeepONet, T}
eltypeθ, typeθ = eltype(θ), parameterless_type(ComponentArrays.getdata(θ))
x = (convert.(eltypeθ, adapt(typeθ, x[1])),convert.(eltypeθ, adapt(typeθ, x[2])))
x = (convert.(eltypeθ, adapt(typeθ, x[1])), convert.(eltypeθ, adapt(typeθ, x[2])))
y, st = f.chain(x, θ, f.st)
y
end
Expand Down Expand Up @@ -172,16 +172,6 @@ function get_trainset(
(p, t)
end

function get_trainset(
strategy::StochasticTraining, chain::DeepONet, bounds, number_of_parameters, tspan, eltypeθ)
p = reduce(vcat,
[(bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1]
for bound in bounds])
t = (tspan[2] .- tspan[1]) .* rand(1, strategy.points, 1) .+ tspan[1]
p, t = convert.(eltypeθ, p), convert.(eltypeθ, t)
(p, t)
end

function get_trainset(
strategy::GridTraining, chain::Lux.Chain, bounds, number_of_parameters, tspan, eltypeθ)
dt = strategy.dx
Expand All @@ -197,14 +187,37 @@ function get_trainset(
(p, t)
end

# function get_trainset(
# strategy::StochasticTraining, chain::DeepONet, bounds, number_of_parameters, tspan, eltypeθ)
# p = reduce(vcat,
# [(bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1]
# for bound in bounds])
# t = (tspan[2] .- tspan[1]) .* rand(1, strategy.points, 1) .+ tspan[1]
# p, t = convert.(eltypeθ, p), convert.(eltypeθ, t)
# (p, t)
# end


function get_trainset(
strategy::StochasticTraining, chain::Union{DeepONet, Lux.Chain},
bounds, number_of_parameters, tspan, eltypeθ)
number_of_parameters != strategy.points &&
throw(error("number_of_parameters should be the same strategy.points for StochasticTraining"))
p = reduce(vcat,
[(bound[2] .- bound[1]) .* rand(1, number_of_parameters) .+ bound[1]
for bound in bounds])
t = (tspan[2] .- tspan[1]) .* rand(1, strategy.points, 1) .+ tspan[1]
p, t = convert.(eltypeθ, p), convert.(eltypeθ, t)
(p, t)
end

function generate_loss(
strategy::GridTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan, eltypeθ)
x = get_trainset(strategy, phi.chain, bounds, number_of_parameters, tspan, eltypeθ)
function loss(θ, _)
initial_condition_loss(phi, prob, x, θ) + physics_loss(phi, prob, x, θ)
end
end
# Zygote.gradient(θ -> initial_condition_loss(phi, prob, x, θ), θ)

function generate_loss(
strategy::StochasticTraining, prob::ODEProblem, phi, bounds, number_of_parameters, tspan, eltypeθ)
Expand Down
66 changes: 20 additions & 46 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@ function get_trainset(chain::Lux.Chain, bounds, number_of_parameters, tspan, dt)
for b in bounds]
x_ = hcat(vec(map(
points -> collect(points), Iterators.product([pspan..., tspan_]...)))...)
# x = reshape(x_, size(bounds, 1) + 1, size.(pspan, 1)..., size(tspan_, 1))
x = reshape(x_, size(bounds, 1) + 1, prod(size.(pspan, 1)), size(tspan_, 1))
p, t = x[1:(end - 1), :, :], x[[end], :, :]
(p, t)
end

#Test with Chain
#Test with Chain with Float64 accuracy
@testset "Example du = cos(p * t)" begin
equation = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 1.0f0)
Expand All @@ -38,9 +37,8 @@ end
b = chain(x, θ, st)[1]

bounds = [(pi, 2pi)]
number_of_parameters = 50
strategy = GridTraining(0.1f0)
# strategy = StochasticTraining(70) #TODO chain +StochasticTraining
number_of_parameters = 300
strategy = StochasticTraining(300)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(
chain, opt, bounds, number_of_parameters; strategy = strategy, init_params = θ |>
Expand All @@ -51,12 +49,12 @@ end
p, t = get_trainset(chain, bounds, number_of_parameters, tspan, dt)
ground_solution = ground_analytic.(u0, p, t)
predict_sol = sol.interp(reduce(vcat, (p, t)))
@test eltype(sol.k) == eltype(predict_sol)
@test ground_solutionpredict_sol rtol=0.07
p, t = get_trainset(chain, bounds, 100, tspan, 0.01)
ground_solution = ground_analytic.(u0, p, t)
predict_sol = sol.interp(reduce(vcat, (p, t)))
@test ground_solutionpredict_sol rtol=0.07
@test eltype(sol.k) == eltype(predict_sol)
end

#Test with DeepONet
Expand Down Expand Up @@ -87,7 +85,7 @@ end
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters;
strategy = strategy, init_params = θ |> f64)
sol = solve(prob, alg, verbose = true, maxiters = 2000)
sol = solve(prob, alg, verbose = false, maxiters = 2000)
ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
dt = 0.025f0
p, t = get_trainset(deeponet, bounds, number_of_parameters, tspan, dt)
Expand All @@ -114,7 +112,7 @@ end
bounds = [(0.1f0, 2.0f0)]
number_of_parameters = 40
dt = (tspan[2] - tspan[1]) / 40
strategy = StochasticTraining(50)
strategy = GridTraining(0.1f0)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 4000)
Expand Down Expand Up @@ -177,30 +175,29 @@ end

#multiple parameters chain
@testset "Example du = cos(p * t)" begin
equation = (u, p, t) -> p[1] * cos(p[2] * t) #+ p[3]
equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3]
tspan = (0.0, 1.0)
u0 = 1.0
prob = ODEProblem(equation, u0, tspan)

input_branch_size = 2
input_branch_size = 3
chain = Chain(
Dense(input_branch_size + 1 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast),
Dense(10 => 10, Lux.tanh_fast), Dense(10 => 1))

x = rand(Float32, 3, 1000, 10)
x = rand(Float32, 4, 1000, 10)
θ, st = Lux.setup(Random.default_rng(), chain)
c = chain(x, θ, st)[1]

bounds = [(1.0, pi), (1.0, 2.0)]#, (2.0, 3.0)]
number_of_parameters = 10
strategy = GridTraining(0.1f0)
# strategy = StochasticTraining(20)
opt = OptimizationOptimisers.Adam(0.03)
bounds = [(1.0, pi), (1.0, 2.0), (2.0, 3.0)]
number_of_parameters = 200
strategy = StochasticTraining(200)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(chain, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 3000)
sol = solve(prob, alg, verbose = false, maxiters = 4000)

ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) #+ p[3] * t
ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) + p[3] * t

function ground_solution_f(p, t)
reduce(hcat,
Expand All @@ -210,29 +207,15 @@ end
(p, t) = get_trainset(chain, bounds, 50, tspan, 0.025f0)
ground_solution_ = ground_solution_f(p, t)
predict = sol.interp(reduce(vcat, (p, t)))[1, :, :]
@test ground_solution_predict rtol=0.4 #TODO rtol=0.05
@test ground_solution_predict rtol=0.05

p, t = get_trainset(chain, bounds, 100, tspan, 0.01f0)
p, t = get_trainset(chain, bounds, 60, tspan, 0.01f0)
ground_solution_ = ground_solution_f(p, t)
predict_sol = sol.interp(reduce(vcat, (p, t)))[1, :, :]
@test ground_solution_predict_sol rtol=0.05
@test eltype(sol.k) == eltype(predict_sol)
end

function plot_()
# Animate
anim = @animate for (i) in 1:100
plot(ground_solution_[:, i], label = "Ground")
plot!(predict[:, i], label = "Predicted")
end
gif(anim, "pino.gif", fps = 10)
end

plot_()

plot(predict[:, :], linetype = :contourf)
plot(ground_solution_, linetype = :contourf)

#multiple parameters DeepOnet
@testset "Example du = cos(p * t)" begin
equation = (u, p, t) -> p[1] * cos(p[2] * t) + p[3]
Expand All @@ -255,10 +238,9 @@ plot(ground_solution_, linetype = :contourf)
bounds = [(1.0, pi), (1.0, 2.0), (2.0, 3.0)]
number_of_parameters = 50
strategy = StochasticTraining(20)
# strategy = GridTraining(0.1f0)
opt = OptimizationOptimisers.Adam(0.03)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = false, maxiters = 3000)
sol = solve(prob, alg, verbose = false, maxiters = 4000)
ground_solution = (u0, p, t) -> u0 + p[1] / p[2] * sin(p[2] * t) + p[3] * t
function ground_solution_f(p, t)
reduce(hcat,
Expand All @@ -274,7 +256,7 @@ plot(ground_solution_, linetype = :contourf)
ground_solution_ = ground_solution_f(p, t)
predict = sol.interp((p, t))
@test ground_solution_predict rtol=0.05
@test eltype(sol.k.u) == eltype(predict_sol)
@test eltype(sol.k.u) == eltype(predict)
end

#TODO vector output TODO
Expand All @@ -296,18 +278,10 @@ end
strategy = StochasticTraining(40)
opt = OptimizationOptimisers.Adam(0.01)
alg = PINOODE(deeponet, opt, bounds, number_of_parameters; strategy = strategy)
sol = solve(prob, alg, verbose = true, maxiters = 2000)
sol = solve(prob, alg, verbose = false, maxiters = 2000)

ground_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
dt = 0.025f0
# function get_trainset(bounds, tspan, number_of_parameters, dt)
# p_ = [range(start = b[1], length = number_of_parameters, stop = b[2])
# for b in bounds]
# p = vcat([collect(reshape(p_i, 1, size(p_i, 1))) for p_i in p_]...)
# t_ = collect(tspan[1]:dt:tspan[2])
# t = collect(reshape(t_, 1, size(t_, 1), 1))
# (p, t)
# end
p, t = get_trainset(chain, bounds, tspan, number_of_parameters, dt)

ground_solution = (u0, p, t) -> [sin(2pi * t) / 2pi, -cos(2pi * t) / 2pi]
Expand Down

0 comments on commit 5e9a025

Please sign in to comment.