Skip to content

Commit

Permalink
draft
Browse files Browse the repository at this point in the history
  • Loading branch information
KirillZubov committed Jul 11, 2024
1 parent 4b23584 commit 0df4564
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 7 deletions.
20 changes: 18 additions & 2 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,16 @@ which is later optimized upon to give Solution or the Solution Distribution of t
For more information, see `discretize` and `PINNRepresentation`.
"""
function SciMLBase.symbolic_discretize(pde_system::PDESystem,
discretization::AbstractPINN)
#TODO?
function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::PhysicsInformedNO)

end
#TODO?
function SciMLBase.symbolic_discretize(
pde_system::PDESystem, discretization::PhysicsInformedNN)
end

function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN)
eqs = pde_system.eqs
bcs = pde_system.bcs
chain = discretization.chain
Expand Down Expand Up @@ -718,3 +726,11 @@ function SciMLBase.discretize(pde_system::PDESystem, discretization::PhysicsInfo
Optimization.AutoZygote())
Optimization.OptimizationProblem(f, pinnrep.flat_init_params)
end

#TODO?
function SciMLBase.discretize(pde_system::PDESystem, discretization::AbstractPINN)
pinnrep = symbolic_discretize(pde_system, discretization)
f = OptimizationFunction(pinnrep.loss_functions.full_loss_function,
Optimization.AutoZygote())
Optimization.OptimizationProblem(f, pinnrep.flat_init_params)
end
87 changes: 85 additions & 2 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,93 @@ end
"""
PhysicsInformedNO
"""
struct PhysicsInformedNO{} <: AbstractPINN
#TODO
struct PhysicsInformedNO{T, PS, P, PH, DER, AL, ADA, LOG, K} <: AbstractPINN
chain::Any
parameters::PS
strategy::T
init_params::P
phi::PH
derivative::DER
additional_loss::AL
adaptive_loss::ADA #?
logger::LOG #?
log_options::LogOptions #?
iteration::Vector{Int64} #?
self_increment::Bool#?
multioutput::Bool
kwargs::K

@add_kwonly function PhysicsInformedNN(chain,
strategy,
parameters;
init_params = nothing,
phi = nothing,
derivative = nothing,
additional_loss = nothing,
adaptive_loss = nothing,
logger = nothing,
log_options = LogOptions(),
iteration = nothing,
kwargs...)
multioutput = chain isa AbstractArray
if multioutput
!all(i -> i isa Lux.AbstractExplicitLayer, chain) &&
(chain = Lux.transform.(chain))
else
!(chain isa Lux.AbstractExplicitLayer) &&
(chain = adapt(FromFluxAdaptor(false, false), chain))
end
if phi === nothing
if multioutput
_phi = Phi.(chain)
else
_phi = Phi(chain)
end
else
if multioutput
all([phi.f[i] isa Lux.AbstractExplicitLayer for i in eachindex(phi.f)]) ||
throw(ArgumentError("Only Lux Chains are supported"))
else
(phi.f isa Lux.AbstractExplicitLayer) ||
throw(ArgumentError("Only Lux Chains are supported"))
end
_phi = phi
end

if derivative === nothing
_derivative = numeric_derivative
else
_derivative = derivative
end

if iteration isa Vector{Int64}
self_increment = false
else
iteration = [1]
self_increment = true
end

new{typeof(parameters), typeof(strategy), typeof(init_params),
typeof(_phi), typeof(_derivative),
typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(
chain,
parameters,
strategy,
init_params,
_phi,
_derivative,
additional_loss,
adaptive_loss,
logger,
log_options,
iteration,
self_increment,
multioutput,
kwargs)
end
end


"""
`PINNRepresentation``
Expand Down
26 changes: 26 additions & 0 deletions src/pino_pde.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

#TODO?
function PhysicsInformedNO(
neural_operator,
parameters,#bounds
strategy;
kwargs...
)
PhysicsInformedNN(neural_operator,
strategy; kwargs...
)
end
#TODO?
function SciMLBase.discretize(pde_system::PDESystem, neural_operator::PhysicsInformedNO)
pinnrep = symbolic_discretize(pde_system, neural_operator)
f = OptimizationFunction(pinnrep.loss_functions.full_loss_function,
Optimization.AutoZygote())
Optimization.OptimizationProblem(f, pinnrep.flat_init_params)
end
#TODO?
function SciMLBase.discretize(pde_system::PDESystem, neural_operator::PhysicsInformed)
pinnrep = symbolic_discretize(pde_system, neural_operator)
f = OptimizationFunction(pinnrep.loss_functions.full_loss_function,
Optimization.AutoZygote())
Optimization.OptimizationProblem(f, pinnrep.flat_init_params)
end
7 changes: 4 additions & 3 deletions test/PINO_PDE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ using ModelingToolkit
import ModelingToolkit: Interval, infimum, supremum
using DomainSets
using NeuralPDE
using LuxNeuralOperators


##example ODE
@parameters t
@variables u(..)
# @parameters p #[bounds = (0.1f0, pi)]
@parameters p #[bounds = (0.1f0, pi)]
Dt = Differential(t)
eq = [Dt(u(t)) ~ cos(t)]
bc = [u(0) ~ 1.0f0]

dom = [x Interval(0.0, 1.0)]
dom = [t Interval(0.0, 1.0)]
# neural_operator = SomeNeuralOperator(some_args)
neural_operator = Lux.Chain(
Lux.Dense(1, 10, Lux.tanh),
Expand All @@ -26,7 +27,7 @@ neural_operator = Lux.Chain(
# pino = PhysicsInformedNO(neural_operator, sometrainig)
pino = NeuralPDE.PhysicsInformedNN(neural_operator, NeuralPDE.GridTraining(0.1))

@named pde_system = PDESystem(eq, bc, dom, [t], [u(t)]) #[p]; defaults = Dict([p => 1.0 for p in [p]]))
@named pde_system = PDESystem(eq, bc, dom, [t], [u(t)],[p]) #[p]; defaults = Dict([p => 1.0 for p in [p]]))

# hasbounds(pde_system.ps[1])
# getbounds(pde_system.ps[1])
Expand Down

0 comments on commit 0df4564

Please sign in to comment.