-
-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Physics Informed Point Net #875
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,7 +1,7 @@ | ||||||
""" | ||||||
$(DocStringExtensions.README) | ||||||
""" | ||||||
module NeuralPDE | ||||||
module NeuralPDE | ||||||
|
||||||
using DocStringExtensions | ||||||
using Reexport, Statistics | ||||||
|
@@ -55,7 +55,7 @@ include("BPINN_ode.jl") | |||||
include("PDE_BPINN.jl") | ||||||
include("dgm.jl") | ||||||
|
||||||
export NNODE, NNDAE, | ||||||
export NNODE, PIPN, init_pipn_params, NNDAE, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
PhysicsInformedNN, discretize, | ||||||
GridTraining, StochasticTraining, QuadratureTraining, QuasiRandomTraining, | ||||||
WeightedIntervalTraining, | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -1,7 +1,7 @@ | ||||||||
""" | ||||||||
??? | ||||||||
""" | ||||||||
struct LogOptions | ||||||||
struct LogOptions | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
log_frequency::Int64 | ||||||||
# TODO: add in an option for saving plots in the log. this is currently not done because the type of plot is dependent on the PDESystem | ||||||||
# possible solution: pass in a plot function? | ||||||||
|
@@ -557,3 +557,160 @@ function numeric_derivative(phi, u, x, εs, order, θ) | |||||||
error("This shouldn't happen!") | ||||||||
end | ||||||||
end | ||||||||
|
||||||||
|
||||||||
Comment on lines
+560
to
+561
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
||||||||
""" | ||||||||
PIPN(chain, | ||||||||
strategy = GridTraining(0.1); | ||||||||
init_params = nothing, | ||||||||
param_estim = false, | ||||||||
additional_loss = nothing, | ||||||||
adaptive_loss = nothing, | ||||||||
logger = nothing, | ||||||||
log_options = LogOptions(), | ||||||||
iteration = nothing, | ||||||||
kwargs...) | ||||||||
|
||||||||
A `discretize` algorithm for the ModelingToolkit PDESystem interface, which transforms a | ||||||||
`PDESystem` into an `OptimizationProblem` using the Physics-Informed Point Net (PIPN) methodology, | ||||||||
an extension of the Physics-Informed Neural Networks (PINN) approach. | ||||||||
|
||||||||
## Positional Arguments | ||||||||
|
||||||||
* `chain`: a Lux chain specifying the overall network architecture. The input and output dimensions | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this even needed as shared mlps are constructed inside? |
||||||||
of this chain are used to determine the dimensions of the PIPN components. | ||||||||
* `strategy`: determines which training strategy will be used. Defaults to `GridTraining(0.1)`. | ||||||||
See the Training Strategy documentation for more details. | ||||||||
|
||||||||
## Keyword Arguments | ||||||||
|
||||||||
* `init_params`: the initial parameters of the neural networks. If not provided, default | ||||||||
initialization is used. | ||||||||
* `param_estim`: whether the parameters of the differential equation should be included in | ||||||||
the optimization. Defaults to `false`. | ||||||||
* `additional_loss`: a function `additional_loss(phi, θ, p_)` where `phi` are the neural | ||||||||
network trial solutions, `θ` are the weights of the neural network(s), | ||||||||
and `p_` are the hyperparameters of the `OptimizationProblem`. | ||||||||
* `adaptive_loss`: the choice for the adaptive loss function. See the adaptive loss documentation | ||||||||
for more details. Defaults to no adaptivity. | ||||||||
* `logger`: a logging mechanism for tracking the training process. | ||||||||
* `log_options`: options for controlling the logging behavior. | ||||||||
* `iteration`: used to control the iteration counter. If not provided, starts at 1 and | ||||||||
self-increments. | ||||||||
* `kwargs`: Extra keyword arguments which are passed to the `OptimizationProblem` on `solve`. | ||||||||
|
||||||||
## Fields | ||||||||
|
||||||||
* `shared_mlp1`: First shared multilayer perceptron in the PIPN architecture. | ||||||||
* `shared_mlp2`: Second shared multilayer perceptron in the PIPN architecture. | ||||||||
* `after_pool_mlp`: Multilayer perceptron applied after the pooling operation. | ||||||||
* `final_layer`: Final layer producing the output of the network. | ||||||||
* `strategy`: The training strategy used. | ||||||||
* `init_params`: Initial parameters of the neural networks. | ||||||||
* `param_estim`: Boolean indicating whether parameter estimation is enabled. | ||||||||
* `additional_loss`: Additional loss function, if specified. | ||||||||
* `adaptive_loss`: Adaptive loss function, if specified. | ||||||||
* `logger`: Logging mechanism for the training process. | ||||||||
* `log_options`: Options for controlling logging behavior. | ||||||||
* `iteration`: Vector containing the current iteration count. | ||||||||
* `self_increment`: Boolean indicating whether the iteration count should self-increment. | ||||||||
* `kwargs`: Additional keyword arguments passed to the optimization problem. | ||||||||
|
||||||||
The PIPN structure implements a Physics-Informed Point Net, which is designed to handle | ||||||||
point cloud data in the context of physics-informed neural networks. It uses a series of | ||||||||
shared MLPs, a global feature aggregation step, and additional processing to produce the final output. | ||||||||
""" | ||||||||
|
||||||||
struct PIPN{C1,C2,C3,F,ST,P,PE,AL,ADA,LOG,K} <: AbstractPINN | ||||||||
shared_mlp1::C1 | ||||||||
shared_mlp2::C2 | ||||||||
after_pool_mlp::C3 | ||||||||
final_layer::F | ||||||||
strategy::ST | ||||||||
init_params::P | ||||||||
param_estim::PE | ||||||||
additional_loss::AL | ||||||||
adaptive_loss::ADA | ||||||||
logger::LOG | ||||||||
log_options::LogOptions | ||||||||
iteration::Vector{Int64} | ||||||||
self_increment::Bool | ||||||||
kwargs::K | ||||||||
end | ||||||||
|
||||||||
function PIPN(chain, strategy = GridTraining(0.1); | ||||||||
init_params = nothing, | ||||||||
param_estim = false, | ||||||||
additional_loss = nothing, | ||||||||
adaptive_loss = nothing, | ||||||||
logger = nothing, | ||||||||
log_options = LogOptions(), | ||||||||
iteration = nothing, | ||||||||
shared_mlp1_sizes = [64, 64], | ||||||||
shared_mlp2_sizes = [128, 1024], | ||||||||
after_pool_mlp_sizes = [512, 256, 128], | ||||||||
Comment on lines
+650
to
+652
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't these be exposed to the user? |
||||||||
kwargs...) | ||||||||
|
||||||||
input_dim = chain[1].in_dims[1] | ||||||||
output_dim = chain[end].out_dims[1] | ||||||||
|
||||||||
# Create shared_mlp1 | ||||||||
shared_mlp1_layers = [Lux.Dense(i == 1 ? input_dim : shared_mlp1_sizes[i-1] => shared_mlp1_sizes[i], tanh) for i in 1:length(shared_mlp1_sizes)] | ||||||||
shared_mlp1 = Lux.Chain(shared_mlp1_layers...) | ||||||||
|
||||||||
# Create shared_mlp2 | ||||||||
shared_mlp2_layers = [Lux.Dense(i == 1 ? shared_mlp1_sizes[end] : shared_mlp2_sizes[i-1] => shared_mlp2_sizes[i], tanh) for i in 1:length(shared_mlp2_sizes)] | ||||||||
shared_mlp2 = Lux.Chain(shared_mlp2_layers...) | ||||||||
|
||||||||
# Create after_pool_mlp | ||||||||
after_pool_input_size = 2 * shared_mlp2_sizes[end] # Doubled due to concatenation | ||||||||
after_pool_mlp_layers = [Lux.Dense(i == 1 ? after_pool_input_size : after_pool_mlp_sizes[i-1] => after_pool_mlp_sizes[i], tanh) for i in 1:length(after_pool_mlp_sizes)] | ||||||||
after_pool_mlp = Lux.Chain(after_pool_mlp_layers...) | ||||||||
|
||||||||
final_layer = Lux.Dense(after_pool_mlp_sizes[end] => output_dim) | ||||||||
|
||||||||
if iteration isa Vector{Int64} | ||||||||
self_increment = false | ||||||||
else | ||||||||
iteration = [1] | ||||||||
self_increment = true | ||||||||
end | ||||||||
|
||||||||
PIPN(shared_mlp1, shared_mlp2, after_pool_mlp, final_layer, | ||||||||
strategy, init_params, param_estim, additional_loss, adaptive_loss, | ||||||||
logger, log_options, iteration, self_increment, kwargs) | ||||||||
end | ||||||||
|
||||||||
function (model::PIPN)(x, ps, st::NamedTuple) | ||||||||
point_features, st1 = Lux.apply(model.shared_mlp1, x, ps.shared_mlp1, st.shared_mlp1) | ||||||||
point_features, st2 = Lux.apply(model.shared_mlp2, point_features, ps.shared_mlp2, st.shared_mlp2) | ||||||||
global_feature = aggregate_global_feature(point_features) | ||||||||
global_feature_repeated = repeat(global_feature, 1, size(point_features, 2)) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this is repeated? |
||||||||
combined_features = vcat(point_features, global_feature_repeated) | ||||||||
combined_features, st3 = Lux.apply(model.after_pool_mlp, combined_features, ps.after_pool_mlp, st.after_pool_mlp) | ||||||||
output, st4 = Lux.apply(model.final_layer, combined_features, ps.final_layer, st.final_layer) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't the original architecture also have another block of shared mlps? |
||||||||
return output, (shared_mlp1=st1, shared_mlp2=st2, after_pool_mlp=st3, final_layer=st4) | ||||||||
end | ||||||||
|
||||||||
function aggregate_global_feature(points) | ||||||||
return maximum(points, dims=2) | ||||||||
end | ||||||||
|
||||||||
function init_pipn_params(model::PIPN) | ||||||||
rng = Random.default_rng() | ||||||||
ps1, st1 = Lux.setup(rng, model.shared_mlp1) | ||||||||
ps2, st2 = Lux.setup(rng, model.shared_mlp2) | ||||||||
ps3, st3 = Lux.setup(rng, model.after_pool_mlp) | ||||||||
ps4, st4 = Lux.setup(rng, model.final_layer) | ||||||||
ps = (shared_mlp1=ps1, shared_mlp2=ps2, after_pool_mlp=ps3, final_layer=ps4) | ||||||||
st = (shared_mlp1=st1, shared_mlp2=st2, after_pool_mlp=st3, final_layer=st4) | ||||||||
return ps, st | ||||||||
end | ||||||||
|
||||||||
function vector_to_parameters(θ::AbstractVector, model::PIPN) | ||||||||
ps, _ = init_pipn_params(model) | ||||||||
flat_ps = ComponentArray(ps) | ||||||||
new_flat_ps = typeof(flat_ps)(θ) | ||||||||
return ComponentArray(new_flat_ps) | ||||||||
end | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,4 +1,4 @@ | ||||||
using Pkg | ||||||
using Pkg | ||||||
using SafeTestsets | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
const GROUP = get(ENV, "GROUP", "All") | ||||||
|
@@ -100,4 +100,10 @@ end | |||||
include("dgm_test.jl") | ||||||
end | ||||||
end | ||||||
|
||||||
if GROUP == "All" || GROUP == "PIPN" | ||||||
@time @safetestset "Physics Informed Point Network" begin | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The .github/workflows/ci.yml needs to add a PIPN group as well for this to run. |
||||||
include("test_pipn.jl") | ||||||
end | ||||||
end | ||||||
end |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,81 @@ | ||||||||
using Test | ||||||||
using .NeuralPDE | ||||||||
using Lux | ||||||||
using Random | ||||||||
using ComponentArrays | ||||||||
using ModelingToolkit | ||||||||
using OptimizationProblems | ||||||||
import ModelingToolkit: Interval | ||||||||
|
||||||||
|
||||||||
@testset "PIPN Tests" begin | ||||||||
@testset "PIPN Construction" begin | ||||||||
chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) | ||||||||
pipn = PIPN(chain) | ||||||||
println("we have passed this point") | ||||||||
@test pipn isa PIPN | ||||||||
@test pipn.shared_mlp1 isa Lux.Chain | ||||||||
@test pipn.shared_mlp2 isa Lux.Chain | ||||||||
@test pipn.after_pool_mlp isa Lux.Chain | ||||||||
@test pipn.final_layer isa Lux.Dense | ||||||||
end | ||||||||
|
||||||||
@testset "PIPN Forward Pass" begin | ||||||||
chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) | ||||||||
pipn = PIPN(chain) | ||||||||
x = rand(Float32, 2, 100) | ||||||||
println("Test input size: ", size(x)) | ||||||||
ps, st = init_pipn_params(pipn) | ||||||||
y, _ = pipn(x, ps, st) | ||||||||
@test size(y) == (1, 100) | ||||||||
end | ||||||||
|
||||||||
@testset "PIPN Parameter Initialization" begin | ||||||||
chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) | ||||||||
pipn = PIPN(chain) | ||||||||
ps, st = init_pipn_params(pipn) | ||||||||
@test ps isa NamedTuple | ||||||||
@test st isa NamedTuple | ||||||||
end | ||||||||
|
||||||||
@testset "PIPN Parameter Conversion" begin | ||||||||
chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) | ||||||||
pipn = PIPN(chain) | ||||||||
ps, _ = init_pipn_params(pipn) | ||||||||
flat_ps = ComponentArray(ps) | ||||||||
converted_ps = vector_to_parameters(flat_ps, pipn) | ||||||||
@test converted_ps isa ComponentArray | ||||||||
end | ||||||||
|
||||||||
@testset "PIPN with PDESystem" begin | ||||||||
@parameters x t | ||||||||
@variables u(..) | ||||||||
Dt = Differential(t) | ||||||||
Dxx = Differential(x)^2 | ||||||||
eq = Dt(u(x,t)) ~ Dxx(u(x,t)) | ||||||||
|
||||||||
# Define domain | ||||||||
x_min = 0.0 | ||||||||
x_max = 1.0 | ||||||||
t_min = 0.0 | ||||||||
t_max = 1.0 | ||||||||
|
||||||||
# Use DomainSets for domain definition | ||||||||
domains = [x ∈ Interval(x_min, x_max), | ||||||||
t ∈ Interval(t_min, t_max)] | ||||||||
|
||||||||
bcs = [u(x,0) ~ sin(π*x), | ||||||||
u(0,t) ~ 0.0, | ||||||||
u(1,t) ~ 0.0] | ||||||||
|
||||||||
@named pde_system = PDESystem(eq, bcs, domains, [x,t], [u(x,t)]) | ||||||||
|
||||||||
chain = Lux.Chain(Lux.Dense(2 => 16, tanh), Lux.Dense(16 => 1)) | ||||||||
strategy = GridTraining(0.1) | ||||||||
discretization = PhysicsInformedNN(chain, strategy) | ||||||||
|
||||||||
prob = discretize(pde_system, discretization) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it working with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it should work There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, cool. Can you add it to the test and also address other comments? |
||||||||
|
||||||||
@test prob isa OptimizationProblem | ||||||||
end | ||||||||
end | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.