diff --git a/.gitignore b/.gitignore index 770a1818c9..18672428f2 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,4 @@ Manifest.toml docs/build/* scratch scratch/* -.DS_store \ No newline at end of file +.DS_store diff --git a/Project.toml b/Project.toml index a8e43f9b08..d5076a3951 100644 --- a/Project.toml +++ b/Project.toml @@ -27,6 +27,7 @@ MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Optim = "429524aa-4258-5aef-a3af-852621145aeb" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" +PDEBase = "a7812802-0625-4b9e-961c-d332478797e5" QuasiMonteCarlo = "8a4e6c94-4038-4cdc-81c3-7e6ffdb2a71b" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -77,6 +78,7 @@ Reexport = "1.0" RuntimeGeneratedFunctions = "0.5" SafeTestsets = "0.1" SciMLBase = "2" +PDEBase = "0.1.7" Statistics = "1" StochasticDiffEq = "6.13" SymbolicUtils = "1" diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 0f217b07d4..7d2d0ea922 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -17,28 +17,36 @@ using QuasiMonteCarlo: LatinHypercubeSample import QuasiMonteCarlo using RuntimeGeneratedFunctions using SciMLBase +using PDEBase +using PDEBase: cardinalize_eqs!, get_depvars, get_indvars, differential_order using Statistics using ArrayInterface import Optim -using Symbolics: wrap, unwrap, arguments, operation +using DomainSets +using Symbolics +using Symbolics: wrap, unwrap, arguments, operation, symtype, @arrayop, Arr using SymbolicUtils using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains -using MonteCarloMeasurements: Particles -using ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives, Interval, infimum, supremum -import DomainSets -using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, rightendpoint, ProductDomain -using SciMLBase: @add_kwonly, parameterless_type -using UnPack: @unpack +using MonteCarloMeasurements +using SymbolicUtils.Code +using SymbolicUtils: Prewalk, Postwalk, Chain +import ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives +import DomainSets: Domain, ClosedInterval +import ModelingToolkit: Interval, infimum, supremum #,Ball +import SciMLBase: @add_kwonly, parameterless_type +import UnPack: @unpack import ChainRulesCore, Lux, ComponentArrays -using ChainRulesCore: @non_differentiable +import ChainRulesCore: @non_differentiable, @ignore_derivatives RuntimeGeneratedFunctions.init(@__MODULE__) -abstract type AbstractPINN end +abstract type AbstractPINN <: SciMLBase.AbstractDiscretization end abstract type AbstractTrainingStrategy end +abstract type AbstractGridfreeStrategy <: AbstractTrainingStrategy end include("pinn_types.jl") +include("eq_data.jl") include("symbolic_utilities.jl") include("training_strategies.jl") include("adaptive_losses.jl") @@ -46,6 +54,7 @@ include("ode_solve.jl") # include("rode_solve.jl") include("dae_solve.jl") include("transform_inf_integral.jl") +include("loss_function_generation.jl") include("discretize.jl") include("neural_adapter.jl") include("advancedHMC_MCMC.jl") diff --git a/src/discretize.jl b/src/discretize.jl index af035980b3..be71cfb531 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -1,199 +1,25 @@ """ -Build a loss function for a PDE or a boundary condition. - -# Examples: System of PDEs: - -Take expressions in the form: - -[Dx(u1(x,y)) + 4*Dy(u2(x,y)) ~ 0, - Dx(u2(x,y)) + 9*Dy(u1(x,y)) ~ 0] - -to - -:((cord, θ, phi, derivative, u)->begin - #= ... =# - #= ... =# - begin - (u1, u2) = (θ.depvar.u1, θ.depvar.u2) - (phi1, phi2) = (phi[1], phi[2]) - let (x, y) = (cord[1], cord[2]) - [(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, u1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, u1))) - 0, - (+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, u2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, u2))) - 0] - end - end - end) - -for Lux.AbstractExplicitLayer. -""" -function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; - eq_params = SciMLBase.NullParameters(), - param_estim = false, - default_p = nothing, - bc_indvars = pinnrep.indvars, - integrand = nothing, - dict_transformation_vars = nothing, - transformation_vars = nothing, - integrating_depvars = pinnrep.depvars) - @unpack indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input, - phi, derivative, integral, - multioutput, init_params, strategy, eq_params, - param_estim, default_p = pinnrep - - eltypeθ = eltype(pinnrep.flat_init_params) - - if integrand isa Nothing - loss_function = parse_equation(pinnrep, eqs) - this_eq_pair = pair(eqs, depvars, dict_depvars, dict_depvar_input) - this_eq_indvars = unique(vcat(values(this_eq_pair)...)) - else - this_eq_pair = Dict(map(intvars -> dict_depvars[intvars] => dict_depvar_input[intvars], - integrating_depvars)) - this_eq_indvars = transformation_vars isa Nothing ? - unique(vcat(values(this_eq_pair)...)) : transformation_vars - loss_function = integrand - end - - vars = :(cord, $θ, phi, derivative, integral, u, p) - ex = Expr(:block) - if multioutput - θ_nums = Symbol[] - phi_nums = Symbol[] - for v in depvars - num = dict_depvars[v] - push!(θ_nums, :($(Symbol(:($θ), num)))) - push!(phi_nums, :($(Symbol(:phi, num)))) - end - - expr_θ = Expr[] - expr_phi = Expr[] - - acum = [0; accumulate(+, map(length, init_params))] - sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] - - for i in eachindex(depvars) - push!(expr_θ, :($θ.depvar.$(depvars[i]))) - push!(expr_phi, :(phi[$i])) - end - - vars_θ = Expr(:(=), build_expr(:tuple, θ_nums), build_expr(:tuple, expr_θ)) - push!(ex.args, vars_θ) - - vars_phi = Expr(:(=), build_expr(:tuple, phi_nums), build_expr(:tuple, expr_phi)) - push!(ex.args, vars_phi) - end - - #Add an expression for parameter symbols - if param_estim == true && eq_params != SciMLBase.NullParameters() - params_symbols = Symbol[] - expr_params = Expr[] - for (i, eq_param) in enumerate(eq_params) - push!(expr_params, :($θ.p[$((i):(i))])) - push!(params_symbols, Symbol(:($eq_param))) - end - params_eq = Expr(:(=), build_expr(:tuple, params_symbols), - build_expr(:tuple, expr_params)) - push!(ex.args, params_eq) - end - - if eq_params != SciMLBase.NullParameters() && param_estim == false - params_symbols = Symbol[] - expr_params = Expr[] - for (i, eq_param) in enumerate(eq_params) - push!(expr_params, :(ArrayInterface.allowed_getindex(p, ($i):($i)))) - push!(params_symbols, Symbol(:($eq_param))) - end - params_eq = Expr(:(=), build_expr(:tuple, params_symbols), - build_expr(:tuple, expr_params)) - push!(ex.args, params_eq) - end - - eq_pair_expr = Expr[] - for i in keys(this_eq_pair) - push!(eq_pair_expr, :($(Symbol(:cord, :($i))) = vcat($(this_eq_pair[i]...)))) - end - vcat_expr = Expr(:block, :($(eq_pair_expr...))) - vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) # TODO rename - - if strategy isa QuadratureTraining - indvars_ex = get_indvars_ex(bc_indvars) - left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex - vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), - build_expr(:tuple, right_arg_pairs)) - else - indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] - left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex - vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), - build_expr(:tuple, right_arg_pairs)) - end - - if !(dict_transformation_vars isa Nothing) - transformation_expr_ = Expr[] - for (i, u) in dict_transformation_vars - push!(transformation_expr_, :($i = $u)) - end - transformation_expr = Expr(:block, :($(transformation_expr_...))) - vcat_expr_loss_functions = Expr(:block, transformation_expr, vcat_expr, - loss_function) - end - let_ex = Expr(:let, vars_eq, vcat_expr_loss_functions) - push!(ex.args, let_ex) - expr_loss_function = :(($vars) -> begin $ex end) -end - -""" - build_loss_function(eqs, indvars, depvars, phi, derivative, init_params; bc_indvars=nothing) - -Returns the body of loss function, which is the executable Julia function, for the main -equation or boundary condition. -""" -function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars) - @unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep - - bc_indvars = bc_indvars === nothing ? pinnrep.indvars : bc_indvars - - expr_loss_function = build_symbolic_loss_function(pinnrep, eqs; - bc_indvars = bc_indvars, - eq_params = eq_params, - param_estim = param_estim, - default_p = default_p) - u = get_u() - _loss_function = @RuntimeGeneratedFunction(expr_loss_function) - loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, derivative, integral, u, - default_p) end - return loss_function -end - -""" - generate_training_sets(domains,dx,bcs,_indvars::Array,_depvars::Array) - +```julia +generate_training_sets(domains,dx,bcs,_indvars::Array,_depvars::Array) +``` Returns training sets for equations and boundary condition, that is used for GridTraining strategy. """ function generate_training_sets end -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, _indvars::Array, - _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars, - dict_depvars) -end - # Generate training set in the domain and on the boundary -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::Dict, - dict_depvars::Dict) +function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, varmap) if dx isa Array dxs = dx else dxs = fill(dx, length(domains)) end - spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)] - dict_var_span = Dict([Symbol(d.variables) => infimum(d.domain):dx:supremum(d.domain) + dict_var_span = Dict([d.variables => infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]) - bound_args = get_argument(bcs, dict_indvars, dict_depvars) - bound_vars = get_variables(bcs, dict_indvars, dict_depvars) + bound_args = get_argument(bcs, varmap) + bound_vars = get_variables(bcs, varmap) dif = [eltypeθ[] for i in 1:size(domains)[1]] for _args in bound_vars @@ -208,7 +34,7 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D setdiff(c, d) end - dict_var_span_ = Dict([Symbol(d.variables) => bc for (d, bc) in zip(domains, bc_data)]) + dict_var_span_ = Dict([d.variables => bc for (d, bc) in zip(domains, bc_data)]) bcs_train_sets = map(bound_args) do bt span = map(b -> get(dict_var_span, b, b), bt) @@ -216,8 +42,8 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) end - pde_vars = get_variables(eqs, dict_indvars, dict_depvars) - pde_args = get_argument(eqs, dict_indvars, dict_depvars) + pde_vars = get_variables(eqs, varmap) + pde_args = get_argument(eqs, varmap) pde_train_set = adapt(eltypeθ, hcat(vec(map(points -> collect(points), @@ -239,25 +65,10 @@ training strategy: StochasticTraining, QuasiRandomTraining, QuadratureTraining. """ function get_bounds end -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) -end - -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, - strategy::QuadratureTraining) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) -end - -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, - strategy::QuadratureTraining) - dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains]) - dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains]) - - pde_args = get_argument(eqs, dict_indvars, dict_depvars) +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::QuadratureTraining) + dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) + dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) + pde_args = get_argument(eqs, v) pde_lower_bounds = map(pde_args) do pd span = map(p -> get(dict_lower_bound, p, p), pd) @@ -269,7 +80,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, end pde_bounds = [pde_lower_bounds, pde_upper_bounds] - bound_vars = get_variables(bcs, dict_indvars, dict_depvars) + bound_vars = get_variables(bcs, v) bcs_lower_bounds = map(bound_vars) do bt map(b -> dict_lower_bound[b], bt) @@ -278,26 +89,25 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, map(b -> dict_upper_bound[b], bt) end bcs_bounds = [bcs_lower_bounds, bcs_upper_bounds] - [pde_bounds, bcs_bounds] end -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy) dx = 1 / strategy.points - dict_span = Dict([Symbol(d.variables) => [ + dict_span = Dict([d.variables => [ infimum(d.domain) + dx, supremum(d.domain) - dx, ] for d in domains]) # pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains] - pde_args = get_argument(eqs, dict_indvars, dict_depvars) + pde_args = get_argument(eqs, v) pde_bounds = map(pde_args) do pde_arg bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg) bds = eltypeθ.(bds) bds[1, :], bds[2, :] end - bound_args = get_argument(bcs, dict_indvars, dict_depvars) + bound_args = get_argument(bcs, v) bcs_bounds = map(bound_args) do bound_arg bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg) bds = eltypeθ.(bds) @@ -305,12 +115,11 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str end return pde_bounds, bcs_bounds end - +# TODO: Get this to work with varmap function get_numeric_integral(pinnrep::PINNRepresentation) - @unpack strategy, indvars, depvars, multioutput, derivative, - depvars, indvars, dict_indvars, dict_depvars = pinnrep + @unpack strategy, multioutput, derivative, varmap = pinnrep - integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, indvars = indvars, depvars = depvars, dict_indvars = dict_indvars, dict_depvars = dict_depvars) -> begin + integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, varmap=varmap) -> begin function integration_(cord, lb, ub, θ) cord_ = cord function integrand_(x, p) @@ -353,7 +162,10 @@ function get_numeric_integral(pinnrep::PINNRepresentation) end """ - prob = symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN) +```julia + +prob = symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN) +``` `symbolic_discretize` is the lower level interface to `discretize` for inspecting internals. It transforms a symbolic description of a ModelingToolkit-defined `PDESystem` into a @@ -364,15 +176,17 @@ which is later optimized upon to give Solution or the Solution Distribution of t For more information, see `discretize` and `PINNRepresentation`. """ -function SciMLBase.symbolic_discretize(pde_system::PDESystem, - discretization::AbstractPINN) - eqs = pde_system.eqs - bcs = pde_system.bcs +function SciMLBase.symbolic_discretize(pdesys::PDESystem, + discretization::PhysicsInformedNN) + cardinalize_eqs!(pdesys) + eqs = pdesys.eqs + bcs = pdesys.bcs + chain = discretization.chain - domains = pde_system.domain - eq_params = pde_system.ps - defaults = pde_system.defaults + domains = pdesys.domain + eq_params = pdesys.ps + defaults = pdesys.defaults default_p = eq_params == SciMLBase.NullParameters() ? nothing : [defaults[ep] for ep in eq_params] @@ -380,11 +194,23 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, additional_loss = discretization.additional_loss adaloss = discretization.adaptive_loss - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(pde_system.indvars, - pde_system.depvars) multioutput = discretization.multioutput init_params = discretization.init_params + phi = discretization.phi + + derivative = discretization.derivative + strategy = discretization.strategy + + logger = discretization.logger + log_frequency = discretization.log_options.log_frequency + iteration = discretization.iteration + self_increment = discretization.self_increment + + v = VariableMap(pdesys, discretization) + + eqdata = EquationData(pdesys, v, strategy) + if init_params === nothing # Use the initialization of the neural network framework @@ -425,84 +251,61 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, ComponentArrays.ComponentArray(; depvar = flat_init_params, p = default_p) end - eltypeθ = eltype(flat_init_params) - - if adaloss === nothing - adaloss = NonAdaptiveLoss{eltypeθ}() - end - - phi = discretization.phi - if (phi isa Vector && phi[1].f isa Lux.AbstractExplicitLayer) for ϕ in phi ϕ.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), - ϕ.st) + ϕ.st) end elseif (!(phi isa Vector) && phi.f isa Lux.AbstractExplicitLayer) phi.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), - phi.st) + phi.st) end - derivative = discretization.derivative - strategy = discretization.strategy - - logger = discretization.logger - log_frequency = discretization.log_options.log_frequency - iteration = discretization.iteration - self_increment = discretization.self_increment - - if !(eqs isa Array) - eqs = [eqs] - end - - pde_indvars = if strategy isa QuadratureTraining - get_argument(eqs, dict_indvars, dict_depvars) + if multioutput + dvs = v.ū + acum = [0; accumulate(+, map(length, init_params))] + sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] + phi = map(enumerate(dvs)) do (i, dv) + if (phi isa Vector && phi[1].f isa Optimisers.Restructure) || + (!(phi isa Vector) && phi.f isa Optimisers.Restructure) + # Flux.Chain + (coord, expr_θ) -> phi[i](coord, expr_θ[sep[i]]) + else # Lux.AbstractExplicitLayer + (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv)) + end + end else - get_variables(eqs, dict_indvars, dict_depvars) + phimap = nothing end - bc_indvars = if strategy isa QuadratureTraining - get_argument(bcs, dict_indvars, dict_depvars) - else - get_variables(bcs, dict_indvars, dict_depvars) + eltypeθ = eltype(flat_init_params) + + if adaloss === nothing + adaloss = NonAdaptiveLoss{eltypeθ}() end - pde_integration_vars = get_integration_variables(eqs, dict_indvars, dict_depvars) - bc_integration_vars = get_integration_variables(bcs, dict_indvars, dict_depvars) + eqs = map(eq -> eq.lhs, eqs) + bcs = map(bc -> bc.lhs, bcs) pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, - param_estim, additional_loss, adaloss, depvars, indvars, - dict_indvars, dict_depvars, dict_depvar_input, logger, + param_estim, additional_loss, adaloss, v, logger, multioutput, iteration, init_params, flat_init_params, phi, derivative, - strategy, pde_indvars, bc_indvars, pde_integration_vars, - bc_integration_vars, nothing, nothing, nothing, nothing) + strategy, eqdata, nothing, nothing, nothing, nothing) - integral = get_numeric_integral(pinnrep) + #integral = get_numeric_integral(pinnrep) - symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq; - bc_indvars = pde_indvar) - for (eq, pde_indvar) in zip(eqs, pde_indvars, - pde_integration_vars)] + #symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq) for eq in eqs] - symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc; - bc_indvars = bc_indvar) - for (bc, bc_indvar) in zip(bcs, bc_indvars, - bc_integration_vars)] + #symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc) |> toexpr for bc in bcs] - pinnrep.integral = integral - pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions - pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions + #pinnrep.integral = integral + #pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions + #pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions - datafree_pde_loss_functions = [build_loss_function(pinnrep, eq, pde_indvar) - for (eq, pde_indvar, integration_indvar) in zip(eqs, - pde_indvars, - pde_integration_vars)] + datafree_pde_loss_functions = [build_loss_function(pinnrep, eq) for eq in eqs] - datafree_bc_loss_functions = [build_loss_function(pinnrep, bc, bc_indvar) - for (bc, bc_indvar, integration_indvar) in zip(bcs, - bc_indvars, - bc_integration_vars)] + datafree_bc_loss_functions = [build_loss_function(pinnrep, bc) for bc in bcs] pde_loss_functions, bc_loss_functions = merge_strategy_with_loss_function(pinnrep, strategy, @@ -698,14 +501,16 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, end """ - prob = discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) +```julia +prob = discretize(pdesys::PDESystem, discretization::PhysicsInformedNN) +``` Transforms a symbolic description of a ModelingToolkit-defined `PDESystem` and generates an `OptimizationProblem` for [Optimization.jl](https://docs.sciml.ai/Optimization/stable/) whose solution is the solution to the PDE. """ -function SciMLBase.discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) - pinnrep = symbolic_discretize(pde_system, discretization) +function SciMLBase.discretize(pdesys::PDESystem, discretization::PhysicsInformedNN) + pinnrep = symbolic_discretize(pdesys, discretization) f = OptimizationFunction(pinnrep.loss_functions.full_loss_function, Optimization.AutoZygote()) Optimization.OptimizationProblem(f, pinnrep.flat_init_params) diff --git a/src/eq_data.jl b/src/eq_data.jl new file mode 100644 index 0000000000..d94e4345d5 --- /dev/null +++ b/src/eq_data.jl @@ -0,0 +1,95 @@ +struct EquationData <: PDEBase.AbstractVarEqMapping + depvarmap + indvarmap + args + ivargs + argmap +end + +function EquationData(pdesys, v, strategy) + eqs = map(eq -> eq.lhs, pdesys.eqs) + bcs = map(eq -> eq.lhs, pdesys.bcs) + alleqs = vcat(eqs, bcs) + + argmap = map(alleqs) do eq + eq => get_argument([eq], v)[1] + end |> Dict + depvarmap = map(alleqs) do eq + eq => get_depvars(eq, v.depvar_ops) + end |> Dict + indvarmap = map(alleqs) do eq + eq => get_indvars(eq, v) + end |> Dict + + if strategy isa QuadratureTraining + _args = get_argument(alleqs, v) + else + _args = get_variables(alleqs, v) + end + + args = map(zip(alleqs, _args)) do (eq, args) + eq => args + end |> Dict + + ivargs = get_iv_argument(alleqs, v) + + ivargs = map(zip(alleqs, ivargs)) do (eq, args) + eq => args + end |> Dict + + EquationData(depvarmap, indvarmap, args, ivargs, argmap) +end + +function depvars(eq, eqdata::EquationData) + eqdata.depvarmap[eq] +end + +function indvars(eq, eqdata::EquationData) + eqdata.indvarmap[eq] +end + +function eq_args(eq, eqdata::EquationData) + eqdata.args[eq] +end + +function eq_iv_args(eq, eqdata::EquationData) + eqdata.ivargs[eq] +end + +argument(eq, eqdata) = eqdata.argmap[eq] + +function get_iv_argument(eqs, v::VariableMap) + vars = map(eqs) do eq + _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) + f_vars = filter(x -> !isempty(x), _vars) + mapreduce(vars -> mapreduce(op -> v.args[op], vcat, operation.(vars), init = []), vcat, f_vars, init = []) + end + args_ = map(vars) do _vars + seen = [] + filter(_vars) do x + if x isa Number + error("Unreachable") + else + if any(isequal(x), seen) + false + else + push!(seen, x) + true + end + end + end + end + return args_ +end + +""" +``julia +get_variables(eqs,_indvars,_depvars) +``` + +Returns all variables that are used in each equations or boundary condition. +""" +function get_iv_variables(eqs, v::VariableMap) + args = get_iv_argument(eqs, v) + return map(arg -> filter(x -> !(x isa Number), arg), args) +end diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl new file mode 100644 index 0000000000..831800a6ba --- /dev/null +++ b/src/loss_function_generation.jl @@ -0,0 +1,178 @@ +# TODO: add multioutput +# TODO: add integrals + +function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; + eq_params = SciMLBase.NullParameters(), + param_estim = false, + default_p = [], + integrand = nothing, + transformation_vars = nothing) + @unpack varmap, eqdata, + phi, derivative, integral, + multioutput, init_params, strategy, eq_params, + param_estim, default_p = pinnrep + + eltypeθ = eltype(pinnrep.flat_init_params) + + eq = eq isa Equation ? eq.lhs : eq + + eq_args = get(eqdata.ivargs, eq, varmap.x̄) + + if integrand isa Nothing + this_eq_indvars = indvars(eq, eqdata) + this_eq_depvars = depvars(eq, eqdata) + loss_function = parse_equation(pinnrep, eq, eq_iv_args(eq, eqdata)) + else + this_eq_indvars = transformation_vars isa Nothing ? + unique(indvars(eq, eqmap)) : transformation_vars + loss_function = integrand + end + + n = length(this_eq_indvars) + + if param_estim == true && eq_params != SciMLBase.NullParameters() + param_len = length(eq_params) + # check parameter format to use correct indexing + psform = (phi isa Vector && phi[1].f isa Optimisers.Restructure) || + (!(phi isa Vector) && phi.f isa Optimisers.Restructure) + + if psform + last_indx = [0; accumulate(+, map(length, init_params))][end] + ps_range = 1:param_len .+ last_indx + get_ps = (θ) -> θ[ps_range] + else + ps_range = 1:param_len + get_ps = (θ) -> θ.p[ps_range] + end + else + get_ps = (θ) -> default_p + end + + function get_coords(cord) + num_numbers = 0 + out = map(enumerate(eq_args)) do (i, x) + if x isa Number + fill(convert(eltypeθ, x), size(cord[[1], :])) + else + cord[[i], :] + end + end + if out === nothing + return [] + else + return out + end + end + + full_loss_func = (cord, θ, phi, p) -> begin + coords = [[nothing]] + @ignore_derivatives coords = get_coords(cord) + @show coords + loss_function(coords, θ, phi, get_ps(θ)) + end + return full_loss_func +end + +function build_loss_function(pinnrep, eqs) + @unpack eq_params, param_estim, default_p, phi, multioutput, derivative, integral = pinnrep + + _loss_function = build_symbolic_loss_function(pinnrep, eqs, + eq_params = eq_params, + param_estim = param_estim) + loss_function = (cord, θ) -> begin _loss_function(cord, θ, phi, + default_p) end + return loss_function +end + +function operations(ex) + if istree(ex) + op = operation(ex) + return vcat(operations.(arguments(ex))..., op) + end + return [] +end + +############################################################################################ +# Parse equation +############################################################################################ + +function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = false, + dict_transformation_vars = nothing, + transformation_vars = nothing) + @unpack varmap, eqdata, derivative, integral, flat_init_params, multioutput = pinnrep + eltypeθ = eltype(flat_init_params) + + ex_vars = get_depvars(term, varmap.depvar_ops) + + if multioutput + dummyvars = @variables phi[1:length(varmap.ū)](..), θ_SYMBOL, switch + else + dummyvars = @variables phi(..), θ_SYMBOL, switch + end + + dummyvars = unwrap.(dummyvars) + deriv_rules = generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) + + ch = Prewalk(Chain(deriv_rules)) + + expr = ch(term) + #expr = swch(expr) + + sym_coords = DestructuredArgs(ivs) + ps = DestructuredArgs(varmap.ps) + + args = [sym_coords, θ_SYMBOL, phi, ps] + + ex = Func(args, [], expr) |> toexpr |> _dot_ + + @show ex + f = @RuntimeGeneratedFunction ex + return f +end + +function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput) + phi, θ, switch = dummyvars + if symtype(phi) isa AbstractArray + phi = collect(phi) + end + + dvs = get_depvars(term, varmap.depvar_ops) + # Orthodox derivatives + n(w) = length(arguments(w)) + rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => + derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), + get_ε(n(w),j, eltypeθ, d), + d, θ) + for d in differential_order(term, x)] + for (j, x) in enumerate(varmap.args[operation(w)])], init = []) + for w in dvs], init = []) + # Mixed derivatives + mx = mapreduce(vcat, dvs, init = []) do w + mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (j, x) + mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (k, y) + if isequal(x, y) + [(_) -> nothing] + else + ε1 = get_ε(n(w), j, eltypeθ, 1) + ε2 = get_ε(n(w), k, eltypeθ, 1) + [@rule $((Differential(x))((Differential(y))(w))) => + derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), + ε2, 1, θ_), + reducevcat(arguments(w), eltypeθ, switch), ε1, 1, θ)] + end + end + end + end + vr = mapreduce(vcat, dvs, init = []) do w + @rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ), θ) + end + + return [mx; rs; vr] +end + +function generate_integral_rules(eq, eqdata, dummyvars) + phi, u, θ = dummyvars + #! all that should be needed is to solve an integral problem, the trick is doing this + #! with rules without putting symbols through the solve + +end \ No newline at end of file diff --git a/src/neural_adapter.jl b/src/neural_adapter.jl index 8b8ae68c97..0ef682c5d2 100644 --- a/src/neural_adapter.jl +++ b/src/neural_adapter.jl @@ -15,18 +15,17 @@ function get_loss_function_(loss, init_params, pde_system, strategy::GridTrainin eqs = [eqs] end domains = pde_system.domain - depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars, - pde_system.depvars) + eltypeθ = eltype(init_params) dx = strategy.dx train_set = generate_training_sets(domains, dx, eqs, eltypeθ) get_loss_function(loss, train_set, eltypeθ, strategy) end -function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy) +function get_bounds_(domains, eqs, eltypeθ, varmap, strategy) dict_span = Dict([Symbol(d.variables) => [infimum(d.domain), supremum(d.domain)] for d in domains]) - args = get_argument(eqs, dict_indvars, dict_depvars) + args = get_argument(eqs, varmap) bounds = map(args) do pd span = map(p -> get(dict_span, p, p), pd) @@ -35,44 +34,38 @@ function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strateg bounds end -function get_loss_function_(loss, init_params, pde_system, strategy::StochasticTraining) +function get_loss_function_(loss, init_params, pde_system, varmap, strategy::StochasticTraining) eqs = pde_system.eqs if !(eqs isa Array) eqs = [eqs] end domains = pde_system.domain - depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars, - pde_system.depvars) - eltypeθ = eltype(init_params) - bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)[1] + bound = get_bounds_(domains, eqs, eltypeθ, varmap, strategy)[1] get_loss_function(loss, bound, eltypeθ, strategy) end -function get_loss_function_(loss, init_params, pde_system, strategy::QuasiRandomTraining) +function get_loss_function_(loss, init_params, pde_system, varmap, strategy::QuasiRandomTraining) eqs = pde_system.eqs if !(eqs isa Array) eqs = [eqs] end domains = pde_system.domain - depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars, - pde_system.depvars) - eltypeθ = eltype(init_params) - bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy)[1] + bound = get_bounds_(domains, eqs, eltypeθ, varmap, strategy)[1] get_loss_function(loss, bound, eltypeθ, strategy) end -function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, +function get_bounds_(domains, eqs, eltypeθ, varmap, strategy::QuadratureTraining) - dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains]) - dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains]) + dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) + dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) - args = get_argument(eqs, dict_indvars, dict_depvars) + args = get_argument(eqs, varmap) lower_bounds = map(args) do pd span = map(p -> get(dict_lower_bound, p, p), pd) @@ -85,18 +78,15 @@ function get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, bound = lower_bounds, upper_bounds end -function get_loss_function_(loss, init_params, pde_system, strategy::QuadratureTraining) +function get_loss_function_(loss, init_params, pde_system, varmap, strategy::QuadratureTraining) eqs = pde_system.eqs if !(eqs isa Array) eqs = [eqs] end domains = pde_system.domain - depvars, indvars, dict_indvars, dict_depvars = get_vars(pde_system.indvars, - pde_system.depvars) - eltypeθ = eltype(init_params) - bound = get_bounds_(domains, eqs, eltypeθ, dict_indvars, dict_depvars, strategy) + bound = get_bounds_(domains, eqs, eltypeθ, varmap, strategy) lb, ub = bound get_loss_function(loss, lb[1], ub[1], eltypeθ, strategy) end @@ -116,7 +106,8 @@ Trains a neural network using the results from one already obtained prediction. function neural_adapter end function neural_adapter(loss, init_params, pde_system, strategy) - loss_function__ = get_loss_function_(loss, init_params, pde_system, strategy) + varmap = VariableMap(pde_system) + loss_function__ = get_loss_function_(loss, init_params, pde_system, varmap, strategy) function loss_function_(θ, p) loss_function__(θ) @@ -126,8 +117,9 @@ function neural_adapter(loss, init_params, pde_system, strategy) end function neural_adapter(losses::Array, init_params, pde_systems::Array, strategy) - loss_functions_ = map(zip(losses, pde_systems)) do (l, p) - get_loss_function_(l, init_params, p, strategy) + varmaps = VariableMap.(pde_systems) + loss_functions_ = map(zip(losses, pde_systems, varmaps)) do (l, p, v) + get_loss_function_(l, init_params, p, v, strategy) end loss_function__ = θ -> sum(map(l -> l(θ), loss_functions_)) function loss_function_(θ, p) diff --git a/src/pinn_types.jl b/src/pinn_types.jl index e78c0da089..68352e243f 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -2,30 +2,31 @@ ??? """ struct LogOptions - log_frequency::Int64 - # TODO: add in an option for saving plots in the log. this is currently not done because the type of plot is dependent on the PDESystem - # possible solution: pass in a plot function? - # this is somewhat important because we want to support plotting adaptive weights that depend on pde independent variables - # and not just one weight for each loss function, i.e. pde_loss_weights(i, t, x) and since this would be function-internal, - # we'd want the plot & log to happen internally as well - # plots of the learned function can happen in the outer callback, but we might want to offer that here too - - SciMLBase.@add_kwonly function LogOptions(; log_frequency = 50) - new(convert(Int64, log_frequency)) - end + log_frequency::Int64 + # TODO: add in an option for saving plots in the log. this is currently not done because the type of plot is dependent on the PDESystem + # possible solution: pass in a plot function? + # this is somewhat important because we want to support plotting adaptive weights that depend on pde independent variables + # and not just one weight for each loss function, i.e. pde_loss_weights(i, t, x) and since this would be function-internal, + # we'd want the plot & log to happen internally as well + # plots of the learned function can happen in the outer callback, but we might want to offer that here too + + SciMLBase.@add_kwonly function LogOptions(; log_frequency = 50) + new(convert(Int64, log_frequency)) + end end """This function is defined here as stubs to be overridden by the subpackage NeuralPDELogging if imported""" function logvector(logger, v::AbstractVector{R}, name::AbstractString, - step::Integer) where {R <: Real} - nothing + step::Integer) where {R <: Real} + nothing end """This function is defined here as stubs to be overridden by the subpackage NeuralPDELogging if imported""" function logscalar(logger, s::R, name::AbstractString, step::Integer) where {R <: Real} - nothing + nothing end + """ PhysicsInformedNN(chain, strategy; @@ -75,21 +76,21 @@ methodology. * `iteration`: used to control the iteration counter??? * `kwargs`: Extra keyword arguments which are splatted to the `OptimizationProblem` on `solve`. """ -struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN - chain::Any - strategy::T - init_params::P - phi::PH - derivative::DER - param_estim::PE - additional_loss::AL - adaptive_loss::ADA - logger::LOG - log_options::LogOptions - iteration::Vector{Int64} - self_increment::Bool - multioutput::Bool - kwargs::K +struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: SciMLBase.AbstractDiscretization + chain::Any + strategy::T + init_params::P + phi::PH + derivative::DER + param_estim::PE + additional_loss::AL + adaptive_loss::ADA + logger::LOG + log_options::LogOptions + iteration::Vector{Int64} + self_increment::Bool + multioutput::Bool + kwargs::K @add_kwonly function PhysicsInformedNN(chain, strategy; @@ -137,23 +138,24 @@ struct PhysicsInformedNN{T, P, PH, DER, PE, AL, ADA, LOG, K} <: AbstractPINN self_increment = true end - new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), - typeof(param_estim), - typeof(additional_loss), typeof(adaptive_loss), typeof(logger), typeof(kwargs)}(chain, - strategy, - init_params, - _phi, - _derivative, - param_estim, - additional_loss, - adaptive_loss, - logger, - log_options, - iteration, - self_increment, - multioutput, - kwargs) - end + new{typeof(strategy), typeof(init_params), typeof(_phi), typeof(_derivative), + typeof(param_estim), + typeof(additional_loss), typeof(adaptive_loss), + typeof(logger), typeof(kwargs)}(chain, + strategy, + init_params, + _phi, + _derivative, + param_estim, + additional_loss, + adaptive_loss, + logger, + log_options, + iteration, + self_increment, + multioutput, + kwargs) + end end """ @@ -348,23 +350,7 @@ mutable struct PINNRepresentation """ The dependent variables of the system """ - depvars::Any - """ - The independent variables of the system - """ - indvars::Any - """ - A dictionary form of the independent variables. Define the structure ??? - """ - dict_indvars::Any - """ - A dictionary form of the dependent variables. Define the structure ??? - """ - dict_depvars::Any - """ - ??? - """ - dict_depvar_input::Any + varmap::Any """ The logger as provided by the user """ @@ -409,19 +395,7 @@ mutable struct PINNRepresentation """ ??? """ - pde_indvars::Any - """ - ??? - """ - bc_indvars::Any - """ - ??? - """ - pde_integration_vars::Any - """ - ??? - """ - bc_integration_vars::Any + eqdata::Any """ ??? """ @@ -450,31 +424,31 @@ The generated functions from the PINNRepresentation $(FIELDS) """ struct PINNLossFunctions - """ - The boundary condition loss functions - """ - bc_loss_functions::Any - """ - The PDE loss functions - """ - pde_loss_functions::Any - """ - The full loss function, combining the PDE and boundary condition loss functions. - This is the loss function that is used by the optimizer. - """ - full_loss_function::Any - """ - The wrapped `additional_loss`, as pieced together for the optimizer. - """ - additional_loss_function::Any - """ - The pre-data version of the PDE loss function - """ - datafree_pde_loss_functions::Any - """ - The pre-data version of the BC loss function - """ - datafree_bc_loss_functions::Any + """ + The boundary condition loss functions + """ + bc_loss_functions::Any + """ + The PDE loss functions + """ + pde_loss_functions::Any + """ + The full loss function, combining the PDE and boundary condition loss functions. + This is the loss function that is used by the optimizer. + """ + full_loss_function::Any + """ + The wrapped `additional_loss`, as pieced together for the optimizer. + """ + additional_loss_function::Any + """ + The pre-data version of the PDE loss function + """ + datafree_pde_loss_functions::Any + """ + The pre-data version of the BC loss function + """ + datafree_bc_loss_functions::Any end """ @@ -496,59 +470,97 @@ mutable struct Phi{C, S} end function (f::Phi{<:Lux.AbstractExplicitLayer})(x::Number, θ) - y, st = f.f(adapt(parameterless_type(ComponentArrays.getdata(θ)), [x]), θ, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - y + y, st = f.f(adapt(parameterless_type(ComponentArrays.getdata(θ)), [x]), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + y end function (f::Phi{<:Lux.AbstractExplicitLayer})(x::AbstractArray, θ) - y, st = f.f(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - y + @show x, typeof(x) + y, st = f.f(adapt(parameterless_type(ComponentArrays.getdata(θ)), x), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + y end function (f::Phi{<:Optimisers.Restructure})(x, θ) - f.f(θ)(adapt(parameterless_type(θ), x)) + f.f(θ)(adapt(parameterless_type(θ), x)) end -function get_u() - u = (cord, θ, phi) -> phi(cord, θ) +# the method to calculate the derivative +function numeric_derivative(phi, x, ε, order, θ) + _type = parameterless_type(ComponentArrays.getdata(θ)) + + _epsilon = inv(first(ε[ε.!=zero(eltype(ε))])) + ε = adapt(_type, ε) + x = adapt(_type, x) + + if order == 4 + return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) + .+ + 6 .* phi(x, θ) + .- + 4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* _epsilon^4 + elseif order == 3 + return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) + - + phi(x .- 2 .* ε, θ)) .* _epsilon^3 ./ 2 + elseif order == 2 + return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* _epsilon^2 + elseif order == 1 + return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* _epsilon ./ 2 + else + error("This shouldn't happen! Got an order of $(order).") + end end -# the method to calculate the derivative -function numeric_derivative(phi, u, x, εs, order, θ) - _type = parameterless_type(ComponentArrays.getdata(θ)) - - ε = εs[order] - _epsilon = inv(first(ε[ε .!= zero(ε)])) - - ε = adapt(_type, ε) - x = adapt(_type, x) - - # any(x->x!=εs[1],εs) - # εs is the epsilon for each order, if they are all the same then we use a fancy formula - # if order 1, this is trivially true - - if order > 4 || any(x -> x != εs[1], εs) - return (numeric_derivative(phi, u, x .+ ε, @view(εs[1:(end - 1)]), order - 1, θ) - .- - numeric_derivative(phi, u, x .- ε, @view(εs[1:(end - 1)]), order - 1, θ)) .* - _epsilon ./ 2 - elseif order == 4 - return (u(x .+ 2 .* ε, θ, phi) .- 4 .* u(x .+ ε, θ, phi) - .+ - 6 .* u(x, θ, phi) - .- - 4 .* u(x .- ε, θ, phi) .+ u(x .- 2 .* ε, θ, phi)) .* _epsilon^4 - elseif order == 3 - return (u(x .+ 2 .* ε, θ, phi) .- 2 .* u(x .+ ε, θ, phi) .+ 2 .* u(x .- ε, θ, phi) - - - u(x .- 2 .* ε, θ, phi)) .* _epsilon^3 ./ 2 - elseif order == 2 - return (u(x .+ ε, θ, phi) .+ u(x .- ε, θ, phi) .- 2 .* u(x, θ, phi)) .* _epsilon^2 - elseif order == 1 - return (u(x .+ ε, θ, phi) .- u(x .- ε, θ, phi)) .* _epsilon ./ 2 - else - error("This shouldn't happen!") - end +#@register_symbolic(numeric_derivative(phi, x, ε, order, θ)) + +function ufunc(u, phi, v) + if symtype(phi) isa AbstractArray + return phi[findfirst(w -> isequal(operation(w), operation(u)), v.ū)] + else + return phi + end end + +#= +_vcat(x::Number...) = vcat(x...) +_vcat(x::AbstractArray{<:Number}...) = vcat(x...) +function _vcat(x::Union{Number, AbstractArray{<:Number}}...) + example = first(Iterators.filter(e -> !(e isa Number), x)) + dims = (1, size(example)[2:end]...) + x = map(el -> el isa Number ? (typeof(example))(fill(el, dims)) : el, x) + _vcat(x...) +end +_vcat(x...) = vcat(x...) +https://github.com/SciML/NeuralPDE.jl/pull/627/files +=# + + + +function reducevcat(vector::Vector, eltypeθ) + isnothing(vector) && return [[nothing]] + if all(x -> x isa Number, vector) + return vector + else + z = findfirst(x -> !(x isa Number), vector) + return rvcat(vector, vector[z], eltypeθ) + end +end + +function rvcat(example, sym, eltypeθ) + out = map(example) do x + if x isa Number + out = convert(eltypeθ, x) + out + else + out = x + out + end + end + #out = @arrayop (i,) out[i] i in 1:length(out) + + return out +end + +#@register_symbolic(rvcat(vector, example, eltypeθ, switch)) \ No newline at end of file diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index b4d4c97f3a..3d988a8a45 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -18,11 +18,15 @@ julia> _dot_(e) """ dottable_(x) = Broadcast.dottable(x) dottable_(x::Function) = true +dottable_(x::typeof(numeric_derivative)) = false +dottable_(x::Phi) = false + _dot_(x) = x function _dot_(x::Expr) dotargs = Base.mapany(_dot_, x.args) - if x.head === :call && dottable_(x.args[1]) + nodot = [:phi, Symbol("NeuralPDE.numeric_derivative"), NeuralPDE.rvcat] + if x.head === :call && dottable_(x.args[1]) && all(s -> x.args[1] != s, nodot) Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...)) elseif x.head === :comparison Expr(:comparison, @@ -34,7 +38,9 @@ function _dot_(x::Expr) Expr(:let, undot(dotargs[1]), dotargs[2]) elseif x.head === :for # don't add dots to for x=... assignments Expr(:for, undot(dotargs[1]), dotargs[2]) - elseif (x.head === :(=) || x.head === :function || x.head === :macro) && + elseif x.head === :(=) # don't add dots to x=... assignments + Expr(:(=), dotargs[1], dotargs[2]) + elseif (x.head === :function || x.head === :macro) && Meta.isexpr(x.args[1], :call) # function or macro definition Expr(x.head, x.args[1], dotargs[2]) elseif x.head === :(<:) || x.head === :(>:) @@ -49,7 +55,6 @@ function _dot_(x::Expr) end end end - """ Create dictionary: variable => unique number for variable @@ -114,167 +119,6 @@ where - order - order of derivative. - θ - weights in neural network. """ -function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = false, - dict_transformation_vars = nothing, - transformation_vars = nothing) - @unpack indvars, depvars, dict_indvars, dict_depvars, - dict_depvar_input, multioutput, strategy, phi, - derivative, integral, flat_init_params, init_params = pinnrep - eltypeθ = eltype(flat_init_params) - - _args = ex.args - for (i, e) in enumerate(_args) - if !(e isa Expr) - if e in keys(dict_depvars) - depvar = _args[1] - num_depvar = dict_depvars[depvar] - indvars = _args[2:end] - var_ = is_integral ? :(u) : :($(Expr(:$, :u))) - ex.args = if !multioutput - [var_, Symbol(:cord, num_depvar), :($θ), :phi] - else - [ - var_, - Symbol(:cord, num_depvar), - Symbol(:($θ), num_depvar), - Symbol(:phi, num_depvar), - ] - end - break - elseif e isa ModelingToolkit.Differential - derivative_variables = Symbol[] - order = 0 - while (_args[1] isa ModelingToolkit.Differential) - order += 1 - push!(derivative_variables, toexpr(_args[1].x)) - _args = _args[2].args - end - depvar = _args[1] - num_depvar = dict_depvars[depvar] - indvars = _args[2:end] - dict_interior_indvars = Dict([indvar .=> j - for (j, indvar) in enumerate(dict_depvar_input[depvar])]) - dim_l = length(dict_interior_indvars) - - var_ = is_integral ? :(derivative) : :($(Expr(:$, :derivative))) - εs = [get_ε(dim_l, d, eltypeθ, order) for d in 1:dim_l] - undv = [dict_interior_indvars[d_p] for d_p in derivative_variables] - εs_dnv = [εs[d] for d in undv] - - ex.args = if !multioutput - [var_, :phi, :u, Symbol(:cord, num_depvar), εs_dnv, order, :($θ)] - else - [ - var_, - Symbol(:phi, num_depvar), - :u, - Symbol(:cord, num_depvar), - εs_dnv, - order, - Symbol(:($θ), num_depvar), - ] - end - break - elseif e isa Symbolics.Integral - if _args[1].domain.variables isa Tuple - integrating_variable_ = collect(_args[1].domain.variables) - integrating_variable = toexpr.(integrating_variable_) - integrating_var_id = [dict_indvars[i] for i in integrating_variable] - else - integrating_variable = toexpr(_args[1].domain.variables) - integrating_var_id = [dict_indvars[integrating_variable]] - end - - integrating_depvars = [] - integrand_expr = _args[2] - for d in depvars - d_ex = find_thing_in_expr(integrand_expr, d) - if !isempty(d_ex) - push!(integrating_depvars, d_ex[1].args[1]) - end - end - - lb, ub = get_limits(_args[1].domain.domain) - lb, ub, _args[2], dict_transformation_vars, transformation_vars = transform_inf_integral(lb, - ub, - _args[2], - integrating_depvars, - dict_depvar_input, - dict_depvars, - integrating_variable, - eltypeθ) - - num_depvar = map(int_depvar -> dict_depvars[int_depvar], - integrating_depvars) - integrand_ = transform_expression(pinnrep, _args[2]; - is_integral = false, - dict_transformation_vars = dict_transformation_vars, - transformation_vars = transformation_vars) - integrand__ = _dot_(integrand_) - - integrand = build_symbolic_loss_function(pinnrep, nothing; - integrand = integrand__, - integrating_depvars = integrating_depvars, - eq_params = SciMLBase.NullParameters(), - dict_transformation_vars = dict_transformation_vars, - transformation_vars = transformation_vars, - param_estim = false, - default_p = nothing) - # integrand = repr(integrand) - lb = toexpr.(lb) - ub = toexpr.(ub) - ub_ = [] - lb_ = [] - for l in lb - if l isa Number - push!(lb_, l) - else - l_expr = NeuralPDE.build_symbolic_loss_function(pinnrep, nothing; - integrand = _dot_(l), - integrating_depvars = integrating_depvars, - param_estim = false, - default_p = nothing) - l_f = @RuntimeGeneratedFunction(l_expr) - push!(lb_, l_f) - end - end - for u_ in ub - if u_ isa Number - push!(ub_, u_) - else - u_expr = NeuralPDE.build_symbolic_loss_function(pinnrep, nothing; - integrand = _dot_(u_), - integrating_depvars = integrating_depvars, - param_estim = false, - default_p = nothing) - u_f = @RuntimeGeneratedFunction(u_expr) - push!(ub_, u_f) - end - end - - integrand_func = @RuntimeGeneratedFunction(integrand) - ex.args = [ - :($(Expr(:$, :integral))), - :u, - Symbol(:cord, num_depvar[1]), - :phi, - integrating_var_id, - integrand_func, - lb_, - ub_, - :($θ), - ] - break - end - else - ex.args[i] = _transform_expression(pinnrep, ex.args[i]; - is_integral = is_integral, - dict_transformation_vars = dict_transformation_vars, - transformation_vars = transformation_vars) - end - end - return ex -end """ Parse ModelingToolkit equation form to the inner representation. @@ -342,79 +186,9 @@ function pair(eq, depvars, dict_depvars, dict_depvar_input) Dict(filter(p -> p !== nothing, pair_)) end -function get_vars(indvars_, depvars_) - indvars = ModelingToolkit.getname.(indvars_) - depvars = Symbol[] - dict_depvar_input = Dict{Symbol, Vector{Symbol}}() - for d in depvars_ - if unwrap(d) isa SymbolicUtils.BasicSymbolic - dname = ModelingToolkit.getname(d) - push!(depvars, dname) - push!(dict_depvar_input, - dname => [nameof(unwrap(argument)) - for argument in arguments(unwrap(d))]) - else - dname = ModelingToolkit.getname(d) - push!(depvars, dname) - push!(dict_depvar_input, dname => indvars) # default to all inputs if not given - end - end - - dict_indvars = get_dict_vars(indvars) - dict_depvars = get_dict_vars(depvars) - return depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input -end - -function get_integration_variables(eqs, _indvars::Array, _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - get_integration_variables(eqs, dict_indvars, dict_depvars) -end - -function get_integration_variables(eqs, dict_indvars, dict_depvars) - exprs = toexpr.(eqs) - vars = map(exprs) do expr - _vars = Symbol.(filter(indvar -> length(find_thing_in_expr(expr, indvar)) > 0, - sort(collect(keys(dict_indvars))))) - end -end - -""" - get_variables(eqs,_indvars,_depvars) - -Returns all variables that are used in each equations or boundary condition. -""" -function get_variables end - -function get_variables(eqs, _indvars::Array, _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return get_variables(eqs, dict_indvars, dict_depvars) -end - -function get_variables(eqs, dict_indvars, dict_depvars) - bc_args = get_argument(eqs, dict_indvars, dict_depvars) - return map(barg -> filter(x -> x isa Symbol, barg), bc_args) -end - -function get_number(eqs, dict_indvars, dict_depvars) - bc_args = get_argument(eqs, dict_indvars, dict_depvars) - return map(barg -> filter(x -> x isa Number, barg), bc_args) -end - -function find_thing_in_expr(ex::Expr, thing; ans = []) - if thing in ex.args - push!(ans, ex) - end - for e in ex.args - if e isa Expr - if thing in e.args - push!(ans, e) - end - find_thing_in_expr(e, thing; ans = ans) - end - end - return collect(Set(ans)) +function get_integration_variables(eqs, v::VariableMap) + ivs = all_ivs(v) + return map(eq -> get_indvars(eq, ivs), eqs) end """ @@ -424,34 +198,45 @@ Returns all arguments that are used in each equations or boundary condition. """ function get_argument end -# Get arguments from boundary condition functions -function get_argument(eqs, _indvars::Array, _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - get_argument(eqs, dict_indvars, dict_depvars) -end -function get_argument(eqs, dict_indvars, dict_depvars) - exprs = toexpr.(eqs) - vars = map(exprs) do expr - _vars = map(depvar -> find_thing_in_expr(expr, depvar), collect(keys(dict_depvars))) +function get_argument(eqs, v::VariableMap) + vars = map(eqs) do eq + _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) f_vars = filter(x -> !isempty(x), _vars) - map(x -> first(x), f_vars) + map(first, f_vars) end args_ = map(vars) do _vars - ind_args_ = map(var -> var.args[2:end], _vars) - syms = Set{Symbol}() - filter(vcat(ind_args_...)) do ind_arg - if ind_arg isa Symbol - if ind_arg ∈ syms + seen = [] + filter(reduce(vcat, arguments.(_vars), init = [])) do x + if x isa Number + true + else + if any(isequal(x), seen) false else - push!(syms, ind_arg) + push!(seen, x) true end - else - true end end end return args_ # TODO for all arguments end + +""" +``julia +get_variables(eqs,_indvars,_depvars) +``` + +Returns all variables that are used in each equations or boundary condition. +""" +function get_variables(eqs, v::VariableMap) + args = get_argument(eqs, v) + return map(arg -> filter(x -> !(x isa Number), arg), args) +end + +function get_number(eqs, v::VariableMap) + args = get_argument(eqs, v) + return map(arg -> filter(x -> x isa Number, arg), args) +end + +sym_op(u) = Symbol(operation(u)) \ No newline at end of file diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 5739ac4797..f19794da29 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -52,12 +52,12 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::GridTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep dx = strategy.dx eltypeθ = eltype(pinnrep.flat_init_params) train_sets = generate_training_sets(domains, dx, eqs, bcs, eltypeθ, - dict_indvars, dict_depvars) + varmap) # the points in the domain and on the boundary pde_train_sets, bcs_train_sets = train_sets @@ -92,7 +92,7 @@ end * `bcs_points`: number of points in random select training set for boundary conditions (by default, it equals `points`). """ -struct StochasticTraining <: AbstractTrainingStrategy +struct StochasticTraining <: AbstractGridfreeStrategy points::Int64 bcs_points::Int64 end @@ -110,11 +110,11 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::StochasticTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, varmap, strategy) pde_bounds, bcs_bounds = bounds @@ -165,7 +165,7 @@ that accelerate the convergence in high dimensional spaces over pure random sequ For more information, see [QuasiMonteCarlo.jl](https://docs.sciml.ai/QuasiMonteCarlo/stable/). """ -struct QuasiRandomTraining <: AbstractTrainingStrategy +struct QuasiRandomTraining <: AbstractGridfreeStrategy points::Int64 bcs_points::Int64 sampling_alg::QuasiMonteCarlo.SamplingAlgorithm @@ -191,11 +191,11 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuasiRandomTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, varmap, strategy) pde_bounds, bcs_bounds = bounds @@ -265,7 +265,7 @@ For more information on the argument values and algorithm choices, see [Integrals.jl](https://docs.sciml.ai/Integrals/stable/). """ struct QuadratureTraining{Q <: SciMLBase.AbstractIntegralAlgorithm, T} <: - AbstractTrainingStrategy + AbstractGridfreeStrategy quadrature_alg::Q reltol::T abstol::T @@ -282,10 +282,10 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuadratureTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, varmap, strategy) pde_bounds, bcs_bounds = bounds @@ -330,8 +330,8 @@ end """ WeightedIntervalTraining(weights, samples) -A training strategy that generates points for training based on the given inputs. -We split the timespan into equal segments based on the number of weights, +A training strategy that generates points for training based on the given inputs. +We split the timespan into equal segments based on the number of weights, then sample points in each segment based on that segments corresponding weight, such that the total number of sampled points is equivalent to the given samples @@ -344,7 +344,7 @@ such that the total number of sampled points is equivalent to the given samples This training strategy can only be used with ODEs (`NNODE`). """ -struct WeightedIntervalTraining{T} <: AbstractTrainingStrategy +struct WeightedIntervalTraining{T} <: AbstractGridfreeStrategy weights::Vector{T} points::Int end diff --git a/test/NNPDE_tests.jl b/test/NNPDE_tests.jl index e47ab36d16..751847dce2 100644 --- a/test/NNPDE_tests.jl +++ b/test/NNPDE_tests.jl @@ -36,8 +36,8 @@ function test_ode(strategy_) chain = Lux.Chain(Lux.Dense(1, 12, Lux.σ), Lux.Dense(12, 1)) discretization = PhysicsInformedNN(chain, strategy_) - @named pde_system = PDESystem(eq, bcs, domains, [θ], [u]) - prob = discretize(pde_system, discretization) + @named pde_system = PDESystem(eq, bcs, domains, [θ], [u(θ)]) + prob = NeuralPDE.discretize(pde_system, discretization) res = Optimization.solve(prob, OptimizationOptimisers.Adam(0.1); maxiters = 1000) prob = remake(prob, u0 = res.minimizer) @@ -220,7 +220,7 @@ end eq = Dx(Dxxu(x)) ~ cos(pi * x) # Initial and boundary conditions - bcs_ = [u(0.0) ~ 0.0, + bcs = [u(0.0) ~ 0.0, u(1.0) ~ cos(pi), Dxu(1.0) ~ 1.0] ep = (cbrt(eps(eltype(Float64))))^2 / 6 @@ -228,7 +228,7 @@ end der = [Dxu(x) ~ Dx(u(x)) + ep * O1(x), Dxxu(x) ~ Dx(Dxu(x)) + ep * O2(x)] - bcs = [bcs_; der] + eqs = [eq; der] # Space and time domains domains = [x ∈ Interval(0.0, 1.0)] @@ -240,7 +240,7 @@ end discretization = PhysicsInformedNN(chain, quasirandom_strategy) - @named pde_system = PDESystem(eq, bcs, domains, [x], + @named pde_system = PDESystem(eqs, bcs, domains, [x], [u(x), Dxu(x), Dxxu(x), O1(x), O2(x)]) prob = discretize(pde_system, discretization)