Skip to content

Commit

Permalink
clear up code
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Mar 6, 2024
1 parent cbf01af commit db50090
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 192 deletions.
186 changes: 28 additions & 158 deletions src/pino_ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,18 @@
which thus uses the random initialization provided by the neural network library.
## Keyword Arguments
* `minibatch`: TODO
* `minibatch`:
## Examples
TODO
```julia
```
## References
Zongyi Li "Physics-Informed Neural Operator for Learning Partial Differential Equations"
"""
#TODO
struct TRAINSET{} #T
struct TRAINSET{} #TODO #T <: Number
input_data::Vector{ODEProblem}
output_data::Array
isu0::Bool
Expand All @@ -53,7 +51,6 @@ function PINOODE(chain,
minibatch = 0,
kwargs...)
!(chain isa Lux.AbstractExplicitLayer) && (chain = Lux.transform(chain))
#TODO transform convert complex numbers to zero
PINOODE(chain, opt, train_set, init_params, minibatch, kwargs)
end

Expand All @@ -67,7 +64,7 @@ mutable struct PINOPhi{C, T, U, S}
u0::U
st::S
function PINOPhi(chain::Lux.AbstractExplicitLayer, t0, u0, st)
new{typeof(chain), typeof(t0), typeof(u0), typeof(st)}(chain, t0,u0, st)
new{typeof(chain), typeof(t0), typeof(u0), typeof(st)}(chain, t0, u0, st)
end
end

Expand All @@ -84,128 +81,45 @@ function generate_pino_phi_θ(chain::Lux.AbstractExplicitLayer,
PINOPhi(chain, t0, u0, st), init_params
end

# function (f::PINOPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U}
# y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), [t]), θ, f.st)
# ChainRulesCore.@ignore_derivatives f.st = st
# first(y)
# end

function (f::PINOPhi{C, T, U})(t::AbstractArray,
θ) where {C <: Lux.AbstractExplicitLayer, T, U}
# Batch via data as row vectors
# y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), t), θ, f.st)
y, st = f.chain(t, θ, f.st)
y, st = f.chain(adapt(parameterless_type(ComponentArrays.getdata(θ)), t), θ, f.st)
ChainRulesCore.@ignore_derivatives f.st = st
# y
f.u0 .+ (t[[1],:,:] .- f.t0) .* y
f.u0 .+ (t[[1], :, :] .- f.t0) .* y
end

# feature_dims = 2:(ndims(t) - 1)
# loss = sum( t, dims = feature_dims)
# loss = sum(.√(sum(abs2, 𝐲̂ - 𝐲, dims = feature_dims)))
# y_norm = sum(.√(sum(abs2, 𝐲, dims = feature_dims)))

# return loss / y_norm
# function dfdx(phi::PINOPhi, t::AbstractArray, θ)
# ε = sqrt(eps(eltype(t)))
# εs = [ε, zero(eltype(t))]
# # ε = [sqrt(eps(eltype(t))), zeros(eltype(t), phi.chain.layers.layer_1.in_dims - 1)...]
# # ChainRulesCore.@ignore_derivatives tl = t .+ ε
# tl = t .+ ε
# tr = t
# (phi(tl, θ) - phi(tr, θ)) ./ ε
# end
function dfdx_rand_matrix(phi::PINOPhi, t::AbstractArray, θ)
ε_ = sqrt(eps(eltype(t)))
d = Normal{eltype(t)}(0.0f0, ε_)
size_ = size(t) .- (1, 0, 0)
eps_ = ε_ .+ rand(d, size_) .* ε_
zeros_ = zeros(eltype(t), size_)
ε = cat(eps_, zeros_, dims = 1)
(phi(t .+ ε, θ) - phi(t, θ)) ./ sqrt(eps(eltype(t)))
end

function dfdx(phi::PINOPhi, t::AbstractArray, θ)
ε = [sqrt(eps(eltype(t))), zero(eltype(t))]
#TODO ε is size of t ?
# ε = [sqrt(eps(eltype(t))), zeros(eltype(t), phi.chain.layers.layer_1.in_dims - 1)...]
(phi(t .+ ε, θ) - phi(t, θ)) ./ sqrt(eps(eltype(t)))
end

function inner_physics_loss(phi::PINOPhi{C, T, U},
θ,
ts::AbstractArray,
prob::ODEProblem,
isu0::Bool,
in_) where {C, T, U}
u0 = prob.u0
p = prob.p
f = prob.f
# if isu0 == true
# #TODO data should be generate before train ?
# in_ = reduce(vcat, [ts, fill(u0, 1, size(ts)[2], 1)])
# #TODO for all case p and u0
# # u0 isa Vector
# # in_ = reduce(vcat, [ts, reduce(hcat, fill(u0, 1, size(ts)[2], 1))])
# else
# if p isa Number
# in_ = reduce(vcat, [ts, fill(p, 1, size(ts)[2], 1)])
# elseif p isa Vector
# #TODO nno for Vector
# inner = reduce(vcat, [ts, reduce(hcat, fill(p, 1, size(ts)[2], 1))])
# in_ = reshape(inner, size(inner)..., 1)
# else
# error("p should be a number or a vector")
# end
# end
out_ = phi(in_, θ)
# fs = f.f.(out_, p, ts)
if p isa Number
fs = f.f.(out_, p, ts)
elseif p isa Vector
fs = reduce(hcat, [f.f(out_[:, i], p, ts[i]) for i in 1:size(out_, 2)])
else
error("p should be a number or a vector")
end
NeuralOperators.l₂loss(dfdx(phi, in_, θ), fs)
end


function physics_loss(phi::PINOPhi{C, T, U},
θ,
ts::AbstractArray,
train_set::TRAINSET,
input_data_set) where {C, T, U}
prob_set, output_data = train_set.input_data, train_set.output_data
f = prob_set[1].f
# norm = prod(size(output_data))
# norm = size(output_data)[1] * size(output_data[1])[2] * size(output_data[1])[1]
# loss = reduce(vcat,
# [inner_physics_loss(phi, θ, ts, prob, train_set.isu0, in_)
# for (in_, prob) in zip(inputdata, prob_set)])
# sum(abs2, loss) / norm
ps = [prob.p for prob in prob_set]'
fs = f.f.(output_data, ps, ts)
loss = NeuralOperators.l₂loss(dfdx(phi, input_data_set, θ), fs)
end

function inner_data_loss(phi::PINOPhi{C, T, U},
θ,
ts::AbstractArray,
prob::ODEProblem,
out_::AbstractArray,
isu0::Bool,
in_) where {C, T, U}
u0 = prob.u0
p = prob.p
f = prob.f
if isu0 == true
in_ = reduce(vcat, [ts, fill(u0, 1, size(ts)[2], 1)])
#TODO for all case p and u0
# u0 isa Vector
# in_ = reduce(vcat, [ts, reduce(hcat, fill(u0, 1, size(ts)[2], 1))])
prob_set, output_data = train_set.input_data, train_set.output_data #TODO
f = prob_set[1].f #TODO one f for all
out_ = phi(input_data_set, θ)
if train_set.isu0 === false
ps = [prob.p for prob in prob_set] #TODO do it within generator for data
else
if p isa Number
in_ = reduce(vcat, [ts, fill(p, 1, size(ts)[2],1)])
elseif p isa Vector
inner = reduce(vcat, [ts, reduce(hcat, fill(p, 1, size(ts)[2], 1))])
in_ = reshape(inner, size(inner)..., 1)
else
error("p should be a number or a vector")
end
error("WIP")
end
NeuralOperators.l₂loss(phi(in_, θ), out_)
fs = cat([f.f.(out_[:, :, [i]], p, ts) for (i, p) in enumerate(ps)]..., dims = 3)
NeuralOperators.l₂loss(dfdx(phi, input_data_set, θ), fs)
end

function data_loss(phi::PINOPhi{C, T, U},
Expand All @@ -214,64 +128,20 @@ function data_loss(phi::PINOPhi{C, T, U},
train_set::TRAINSET,
input_data_set) where {C, T, U}
prob_set, output_data = train_set.input_data, train_set.output_data
# norm = prod(size(output_data))
# norm = size(output_data)[1] * size(output_data[1])[2] * size(output_data[1])[1]
# loss = reduce(vcat,
# [inner_data_loss(phi, θ, ts, prob, out_, train_set.isu0, in_)
# for (prob, out_, in_) in zip(prob_set, output_data, input_data_set)])
# sum(abs2, loss) / norm
loss = NeuralOperators.l₂loss(phi(input_data_set, θ), output_data)
NeuralOperators.l₂loss(phi(input_data_set, θ), output_data)
end

function generate_data(ts, prob_set, isu0)
input_data_set = []
input_data_set_right = []
for prob in prob_set
u0 = prob.u0
p = prob.p
f = prob.f
ε = sqrt(eps(eltype(ts)))
tsr = ts .+ ε
if isu0 == true
#TODO data should be generate before train ?
in_ = reduce(vcat, [ts, fill(u0, 1, size(ts)[2], 1)])

#TODO for all case p and u0
# u0 isa Vector
# in_ = reduce(vcat, [ts, reduce(hcat, fill(u0, 1, size(ts)[2], 1))])
else
if p isa Number
in_ = reduce(vcat, [ts, fill(p, 1, size(ts)[2], 1)])
in_r = reduce(vcat, [tsr, fill(p, 1, size(ts)[2], 1)])

elseif p isa Vector
#TODO nno for Vector
inner = reduce(vcat, [ts, reduce(hcat, fill(p, 1, size(ts)[2], 1))])
in_ = reshape(inner, size(inner)..., 1)
else
error("p should be a number or a vector")
end
end
push!(input_data_set, in_)
push!(input_data_set_right, in_r)
end
input_data_set, input_data_set_right
end

function generate_data_matrix(ts, prob_set, isu0)

batch_size = size(prob_set)[1]
instances_size = size(ts)[2]
dims = 2
input_data_set = Array{Float32, 3}(undef, dims, instances_size, batch_size)
for (i,prob) in enumerate(prob_set)
for (i, prob) in enumerate(prob_set)
u0 = prob.u0
p = prob.p
f = prob.f
if isu0 == true
#TODO data should be generate before train ?
in_ = reduce(vcat, [ts, fill(u0, 1, size(ts)[2], 1)])

#TODO for all case p and u0
# u0 isa Vector
# in_ = reduce(vcat, [ts, reduce(hcat, fill(u0, 1, size(ts)[2], 1))])
Expand All @@ -295,11 +165,11 @@ function generate_loss(phi::PINOPhi{C, T, U}, train_set::TRAINSET, tspan) where
t0 = tspan[1]
t_end = tspan[2]
instances_size = size(train_set.output_data)[2]
# instances_size = size(train_set.output_data[1])[2]
range_ = range(t0, stop = t_end, length = instances_size)
ts = reshape(collect(range_), 1, instances_size)
prob_set, output_data = train_set.input_data, train_set.output_data
input_data_set = generate_data_matrix(ts, prob_set, train_set.isu0)

prob_set, output_data = train_set.input_data, train_set.output_data #TODO one format data
input_data_set = generate_data(ts, prob_set, train_set.isu0)
function loss(θ, _)
data_loss(phi, θ, ts, train_set, input_data_set) +
physics_loss(phi, θ, ts, train_set, input_data_set)
Expand Down Expand Up @@ -328,7 +198,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem,
init_params = alg.init_params

# mapping between functional space of some vararible 'a' of equation (for example initial
# condition {u(t0 x)} or parameter p) join and solution of equation u(t)
# condition {u(t0 x)} or parameter p) and solution of equation u(t)
train_set = alg.train_set

!(chain isa Lux.AbstractExplicitLayer) &&
Expand Down
67 changes: 33 additions & 34 deletions test/PINO_ode_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Statistics, Random
using NeuralOperators
using NeuralPDE

@testset "Example 1" begin
@testset "Example p" begin
linear_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
linear = (u, p, t) -> cos(p * t)
tspan = (0.0f0, 2.0f0)
Expand All @@ -18,7 +18,7 @@ using NeuralPDE
batch_size = 50
as = [Float32(i) for i in range(0.1, stop = pi / 2, length = batch_size)]

u_output_ = Array{Float32, 3}(undef, 1, instances_size, batch_size)
u_output_ = zeros(Float32, 1, instances_size, batch_size)
prob_set = []
for (i, a_i) in enumerate(as)
prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0, tspan, a_i)
Expand All @@ -27,15 +27,6 @@ using NeuralPDE
push!(prob_set, prob)
u_output_[:, :, i] = reshape_sol
end
# u_output_ = Array{Float32, 3}[]
# prob_set = []
# for a_i in as
# prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0, tspan, a_i)
# sol1 = solve(prob, Tsit5(); saveat = 0.0204)
# reshape_sol = Float32.(reshape(sol1(range_).u', 1, instances_size, 1))
# push!(prob_set, prob)
# push!(u_output_, reshape_sol)
# end

"""
Set of training data:
Expand All @@ -45,48 +36,46 @@ using NeuralPDE
train_set = NeuralPDE.TRAINSET(prob_set, u_output_);
#TODO u0 ?
prob = ODEProblem(linear, u0, tspan, 0)
chain = Lux.Chain(Lux.Dense(2, 20, Lux.σ), Lux.Dense(20, 20, Lux.σ), Lux.Dense(20, 1))
chain = Lux.Chain(Lux.Dense(2, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 16, Lux.σ),
Lux.Dense(16, 32, Lux.σ),
Lux.Dense(32, 1))
flat_no = FourierNeuralOperator(ch = (2, 16, 16, 16, 16, 16, 32, 1), modes = (16,),
σ = gelu)
# flat_no(rand(2, 100, 1))
# Random.default_rng()
# luxm = Lux.transform(flat_no)
# θ, st = Lux.setup(Random.default_rng(), luxm)
# luxm(rand(Float32, 2, 40, 1), θ, st)[1]
# pk(c, θ) = luxm(rand(2, 40, 1), θ, st)[1]
# Zygote.gradient(θ -> sum(abs2, pk(rand(2, 100, 1), θ)), θ)
# NeuralOperators.l₂loss(pk(rand(2, 100, 1), θ), rand(1,100,1))

η₀ = 1.0f-2
opt = OptimizationOptimisers.Adam(0.03)
alg = NeuralPDE.PINOODE(chain, opt, train_set)
alg = NeuralPDE.PINOODE(flat_no, opt, train_set)

res, phi = solve(prob,
alg, verbose = true,
maxiters = 400, abstol = 1.0f-10)
maxiters = 500)

input_data_set_2 = Array{Float32, 3}(undef, 2, instances_size, batch_size)
for (i, prob) in enumerate(prob_set)
in_ = reduce(vcat, [ts, fill(p, 1, size(ts)[2], 1)])
input_data_set_2[:, :, i] = in_
end
predict_2 = phi(input_data_set_2, res.u)
predict = phi(input_data_set, res.u)
ground = output_data

predict = reduce(vcat,
[phi(
reshape(reduce(vcat, [ts, fill(train_set.input_data[i].p, 1, size(ts)[2])]),
2, instances_size, 1),
res.u)
for i in 1:batch_size])
ground = reduce(vcat, [train_set.output_data[i] for i in 1:batch_size])
@test groundpredict atol=1
end


function plot_()
# Animate
anim = @animate for (i) in 1:batch_size
plot(predict[i, :], label = "Predicted")
plot!(ground[i, :], label = "Ground truth")
plot(predict[1, :, i], label = "Predicted")
plot!(ground[1, :,i], label = "Ground truth")
end
gif(anim, "pino.gif", fps = 10)
end

plot_()

"Example 2" begin
"Example u0" begin
linear_analytic = (u0, p, t) -> u0 + sin(p * t) / (p)
linear = (u, p, t) -> cos(p * t)
tspan = (0.0, 2.0)
Expand All @@ -108,6 +97,16 @@ plot_()
push!(prob_set, prob)
end

u_output_ = zeros(Float32, 1, instances_size, batch_size)
prob_set = []
for (i, u0_i) in enumerate(u0s)
prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), u0_i, tspan, p)
sol1 = solve(prob, Tsit5(); saveat = 0.0204)
reshape_sol = Float32.(reshape(sol1(range_).u', 1, instances_size, 1))
push!(prob_set, prob)
u_output_[:, :, i] = reshape_sol
end

"""
Set of training data:
* input data: set of initial conditions 'a':
Expand Down

0 comments on commit db50090

Please sign in to comment.