diff --git a/.github/workflows/Tests.yml b/.github/workflows/Tests.yml index a290993f27..36980161e8 100644 --- a/.github/workflows/Tests.yml +++ b/.github/workflows/Tests.yml @@ -36,6 +36,7 @@ jobs: - "NNODE" - "NeuralAdapter" - "IntegroDiff" + - "PIPN" uses: "SciML/.github/.github/workflows/tests.yml@v1" with: group: "${{ matrix.group }}" diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index a2ffc2370a..b791f617fc 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -1,7 +1,7 @@ """ $(DocStringExtensions.README) """ -module NeuralPDE +module NeuralPDE using DocStringExtensions using Reexport, Statistics @@ -55,7 +55,7 @@ include("BPINN_ode.jl") include("PDE_BPINN.jl") include("dgm.jl") -export NNODE, NNDAE, +export NNODE, PIPN, init_pipn_params, NNDAE, PhysicsInformedNN, discretize, GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining, WeightedIntervalTraining, diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 59480d8a60..27845c13d9 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -1,7 +1,7 @@ """ ??? """ -struct LogOptions +struct LogOptions log_frequency::Int64 # TODO: add in an option for saving plots in the log. this is currently not done because the type of plot is dependent on the PDESystem # possible solution: pass in a plot function? @@ -557,3 +557,160 @@ function numeric_derivative(phi, u, x, εs, order, θ) error("This shouldn't happen!") end end + + + +""" + PIPN(chain, + strategy = GridTraining(0.1); + init_params = nothing, + param_estim = false, + additional_loss = nothing, + adaptive_loss = nothing, + logger = nothing, + log_options = LogOptions(), + iteration = nothing, + kwargs...) + +A `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a +`PDESystem` into an `OptimizationProblem` using the Physics-Informed Point Net (PIPN) methodology, +an extension of the Physics-Informed Neural Networks (PINN) approach. + +## Positional Arguments + +* `chain`: a Lux chain specifying the overall network architecture. The input and output dimensions + of this chain are used to determine the dimensions of the PIPN components. +* `strategy`: determines which training strategy will be used. Defaults to `GridTraining(0.1)`. + See the Training Strategy documentation for more details. + +## Keyword Arguments + +* `init_params`: the initial parameters of the neural networks. If not provided, default + initialization is used. +* `param_estim`: whether the parameters of the differential equation should be included in + the optimization. Defaults to `false`. +* `additional_loss`: a function `additional_loss(phi, θ, p_)` where `phi` are the neural + network trial solutions, `θ` are the weights of the neural network(s), + and `p_` are the hyperparameters of the `OptimizationProblem`. +* `adaptive_loss`: the choice for the adaptive loss function. See the adaptive loss documentation + for more details. Defaults to no adaptivity. +* `logger`: a logging mechanism for tracking the training process. +* `log_options`: options for controlling the logging behavior. +* `iteration`: used to control the iteration counter. If not provided, starts at 1 and + self-increments. +* `kwargs`: Extra keyword arguments which are passed to the `OptimizationProblem` on `solve`. + +## Fields + +* `shared_mlp1`: First shared multilayer perceptron in the PIPN architecture. +* `shared_mlp2`: Second shared multilayer perceptron in the PIPN architecture. +* `after_pool_mlp`: Multilayer perceptron applied after the pooling operation. +* `final_layer`: Final layer producing the output of the network. +* `strategy`: The training strategy used. +* `init_params`: Initial parameters of the neural networks. +* `param_estim`: Boolean indicating whether parameter estimation is enabled. +* `additional_loss`: Additional loss function, if specified. +* `adaptive_loss`: Adaptive loss function, if specified. +* `logger`: Logging mechanism for the training process. +* `log_options`: Options for controlling logging behavior. +* `iteration`: Vector containing the current iteration count. +* `self_increment`: Boolean indicating whether the iteration count should self-increment. +* `kwargs`: Additional keyword arguments passed to the optimization problem. + +The PIPN structure implements a Physics-Informed Point Net, which is designed to handle +point cloud data in the context of physics-informed neural networks. It uses a series of +shared MLPs, a global feature aggregation step, and additional processing to produce the final output. +""" + +struct PIPN{C1,C2,C3,F,ST,P,PE,AL,ADA,LOG,K} <: AbstractPINN + shared_mlp1::C1 + shared_mlp2::C2 + after_pool_mlp::C3 + final_layer::F + strategy::ST + init_params::P + param_estim::PE + additional_loss::AL + adaptive_loss::ADA + logger::LOG + log_options::LogOptions + iteration::Vector{Int64} + self_increment::Bool + kwargs::K +end + +function PIPN(chain, strategy = GridTraining(0.1); + init_params = nothing, + param_estim = false, + additional_loss = nothing, + adaptive_loss = nothing, + logger = nothing, + log_options = LogOptions(), + iteration = nothing, + shared_mlp1_sizes = [64, 64], + shared_mlp2_sizes = [128, 1024], + after_pool_mlp_sizes = [512, 256, 128], + kwargs...) + + input_dim = chain[1].in_dims[1] + output_dim = chain[end].out_dims[1] + + # Create shared_mlp1 + shared_mlp1_layers = [Lux.Dense(i == 1 ? input_dim : shared_mlp1_sizes[i-1] => shared_mlp1_sizes[i], tanh) for i in 1:length(shared_mlp1_sizes)] + shared_mlp1 = Lux.Chain(shared_mlp1_layers...) + + # Create shared_mlp2 + shared_mlp2_layers = [Lux.Dense(i == 1 ? shared_mlp1_sizes[end] : shared_mlp2_sizes[i-1] => shared_mlp2_sizes[i], tanh) for i in 1:length(shared_mlp2_sizes)] + shared_mlp2 = Lux.Chain(shared_mlp2_layers...) + + # Create after_pool_mlp + after_pool_input_size = 2 * shared_mlp2_sizes[end] # Doubled due to concatenation + after_pool_mlp_layers = [Lux.Dense(i == 1 ? after_pool_input_size : after_pool_mlp_sizes[i-1] => after_pool_mlp_sizes[i], tanh) for i in 1:length(after_pool_mlp_sizes)] + after_pool_mlp = Lux.Chain(after_pool_mlp_layers...) + + final_layer = Lux.Dense(after_pool_mlp_sizes[end] => output_dim) + + if iteration isa Vector{Int64} + self_increment = false + else + iteration = [1] + self_increment = true + end + + PIPN(shared_mlp1, shared_mlp2, after_pool_mlp, final_layer, + strategy, init_params, param_estim, additional_loss, adaptive_loss, + logger, log_options, iteration, self_increment, kwargs) +end + +function (model::PIPN)(x, ps, st::NamedTuple) + point_features, st1 = Lux.apply(model.shared_mlp1, x, ps.shared_mlp1, st.shared_mlp1) + point_features, st2 = Lux.apply(model.shared_mlp2, point_features, ps.shared_mlp2, st.shared_mlp2) + global_feature = aggregate_global_feature(point_features) + global_feature_repeated = repeat(global_feature, 1, size(point_features, 2)) + combined_features = vcat(point_features, global_feature_repeated) + combined_features, st3 = Lux.apply(model.after_pool_mlp, combined_features, ps.after_pool_mlp, st.after_pool_mlp) + output, st4 = Lux.apply(model.final_layer, combined_features, ps.final_layer, st.final_layer) + return output, (shared_mlp1=st1, shared_mlp2=st2, after_pool_mlp=st3, final_layer=st4) +end + +function aggregate_global_feature(points) + return maximum(points, dims=2) +end + +function init_pipn_params(model::PIPN) + rng = Random.default_rng() + ps1, st1 = Lux.setup(rng, model.shared_mlp1) + ps2, st2 = Lux.setup(rng, model.shared_mlp2) + ps3, st3 = Lux.setup(rng, model.after_pool_mlp) + ps4, st4 = Lux.setup(rng, model.final_layer) + ps = (shared_mlp1=ps1, shared_mlp2=ps2, after_pool_mlp=ps3, final_layer=ps4) + st = (shared_mlp1=st1, shared_mlp2=st2, after_pool_mlp=st3, final_layer=st4) + return ps, st +end + +function vector_to_parameters(θ::AbstractVector, model::PIPN) + ps, _ = init_pipn_params(model) + flat_ps = ComponentArray(ps) + new_flat_ps = typeof(flat_ps)(θ) + return ComponentArray(new_flat_ps) +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index e6248eae60..24cdb4574f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using Pkg +using Pkg using SafeTestsets const GROUP = get(ENV, "GROUP", "All") @@ -100,4 +100,10 @@ end include("dgm_test.jl") end end + + if GROUP == "All" || GROUP == "PIPN" + @time @safetestset "Physics Informed Point Network" begin + include("test_pipn.jl") + end + end end diff --git a/test/test_pipn.jl b/test/test_pipn.jl new file mode 100644 index 0000000000..452e17b867 --- /dev/null +++ b/test/test_pipn.jl @@ -0,0 +1,81 @@ +using Test +using .NeuralPDE +using Lux +using Random +using ComponentArrays +using ModelingToolkit +using OptimizationProblems +import ModelingToolkit: Interval + + +@testset "PIPN Tests" begin + @testset "PIPN Construction" begin + chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) + pipn = PIPN(chain) + println("we have passed this point") + @test pipn isa PIPN + @test pipn.shared_mlp1 isa Lux.Chain + @test pipn.shared_mlp2 isa Lux.Chain + @test pipn.after_pool_mlp isa Lux.Chain + @test pipn.final_layer isa Lux.Dense + end + + @testset "PIPN Forward Pass" begin + chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) + pipn = PIPN(chain) + x = rand(Float32, 2, 100) + println("Test input size: ", size(x)) + ps, st = init_pipn_params(pipn) + y, _ = pipn(x, ps, st) + @test size(y) == (1, 100) + end + + @testset "PIPN Parameter Initialization" begin + chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) + pipn = PIPN(chain) + ps, st = init_pipn_params(pipn) + @test ps isa NamedTuple + @test st isa NamedTuple + end + + @testset "PIPN Parameter Conversion" begin + chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) + pipn = PIPN(chain) + ps, _ = init_pipn_params(pipn) + flat_ps = ComponentArray(ps) + converted_ps = vector_to_parameters(flat_ps, pipn) + @test converted_ps isa ComponentArray + end + + @testset "PIPN with PDESystem" begin + @parameters x t + @variables u(..) + Dt = Differential(t) + Dxx = Differential(x)^2 + eq = Dt(u(x,t)) ~ Dxx(u(x,t)) + + # Define domain + x_min = 0.0 + x_max = 1.0 + t_min = 0.0 + t_max = 1.0 + + # Use DomainSets for domain definition + domains = [x ∈ Interval(x_min, x_max), + t ∈ Interval(t_min, t_max)] + + bcs = [u(x,0) ~ sin(π*x), + u(0,t) ~ 0.0, + u(1,t) ~ 0.0] + + @named pde_system = PDESystem(eq, bcs, domains, [x,t], [u(x,t)]) + + chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) + strategy = GridTraining(0.1) + discretization = PhysicsInformedNN(chain, strategy) + + prob = discretize(pde_system, discretization) + + @test prob isa OptimizationProblem + end +end \ No newline at end of file