Skip to content

Commit

Permalink
refactor: use stateless apply
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Jul 19, 2024
1 parent f9e45a0 commit 1a74b0a
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 29 deletions.
1 change: 1 addition & 0 deletions src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ using Lux: FromFluxAdaptor, recursive_eltype
using ChainRulesCore: @non_differentiable, @ignore_derivatives
using PDEBase: AbstractVarEqMapping, VariableMap, cardinalize_eqs!, get_depvars,
get_indvars, differential_order
using LuxCore: stateless_apply

RuntimeGeneratedFunctions.init(@__MODULE__)

Expand Down
57 changes: 46 additions & 11 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ function get_numeric_integral(pinnrep::PINNRepresentation)
end
end

function lazyconvert(T, x::Symbolics.Arr)
Symbolics.array_term(convert, T, x, size = size(x))
end

"""
prob = symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN)
Expand Down Expand Up @@ -215,7 +219,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem,
end
# chain_names = ntuple(i -> depvars(eqs[i].lhs, eqdata), length(chain))
# @show chain_names
chain_names = Tuple(Symbol.(pdesys.dvs))
chain_names = Tuple(Symbol.(operation.(unwrap.(pdesys.dvs))))
init_params = ComponentArrays.ComponentArray(NamedTuple{chain_names}(i
for i in x))
else
Expand All @@ -227,6 +231,37 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem,
init_params = init_params
end

if phi isa AbstractVector
chain_params_symbols = map(chain_names) do chain_name
_params = getproperty(init_params, chain_name)
[
first(@parameters Symbol("pss_" * string(chain_name))[1:length(_params)]),
first(@parameters Symbol("T_" * string(chain_name))::typeof(typeof(_params))=typeof(_params) [tunable = false])
]
end
outs = []
for i in eachindex(phi)
out = x -> stateless_apply(phi[i].f, x,
lazyconvert(chain_params_symbols[i][2], chain_params_symbols[i][1]))[1]
push!(outs, out)
end
else
chain_params_symbols = [
first(@parameters pss[1:length(init_params)]),
first(@parameters T::typeof(typeof(init_params))=typeof(init_params) [tunable = false])
]
outs = []
for i in eachindex(pdesys.dvs)
out = x -> stateless_apply(
phi.f, x, lazyconvert(chain_params_symbols[2], chain_params_symbols[1]))[i]
push!(outs, out)
end
end

depvars_outs_map = Dict(
operation.(unwrap.(pdesys.dvs)) .=> outs
)

flat_init_params = if init_params isa ComponentArrays.ComponentArray
init_params
# elseif multioutput
Expand All @@ -253,15 +288,15 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem,
phi.st)
end

if multioutput
# acum = [0; accumulate(+, map(length, init_params))]
phi = map(enumerate(pdesys.dvs)) do (i, dv)
(coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv))
end
else
# phimap = nothing
phi = (coord, expr_θ) -> phi(coord, expr_θ.depvar)
end
# if multioutput
# # acum = [0; accumulate(+, map(length, init_params))]
# phi = map(enumerate(pdesys.dvs)) do (i, dv)
# (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv))
# end
# else
# # phimap = nothing
# phi = (coord, expr_θ) -> phi(coord, expr_θ.depvar)
# end

eltypeθ = eltype(flat_init_params)

Expand All @@ -275,7 +310,7 @@ function SciMLBase.symbolic_discretize(pdesys::PDESystem,
pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p,
param_estim, additional_loss, adaloss, varmap, logger,
multioutput, iteration, init_params, flat_init_params, phi,
derivative,
derivative, depvars_outs_map,
strategy, eqdata, nothing, nothing, nothing, nothing)

#integral = get_numeric_integral(pinnrep)
Expand Down
42 changes: 24 additions & 18 deletions src/loss_function_generation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,14 @@ function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq;
return full_loss_func
end

@register_array_symbolic (f::Phi{<:Lux.AbstractExplicitLayer})(
x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin
size = LuxCore.outputsize(f.f, x, LuxCore._default_rng())
eltype = Real
end

function build_loss_function(pinnrep, eq)
@unpack eq_params, param_estim, default_p, phi, multioutput, derivative, integral = pinnrep

_loss_function = build_symbolic_loss_function(pinnrep, eq,
eq_params = eq_params,
param_estim = param_estim)
Expand All @@ -89,20 +94,21 @@ end
function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = false,
dict_transformation_vars = nothing,
transformation_vars = nothing)
@unpack varmap, eqdata, derivative, integral, flat_init_params, multioutput = pinnrep
@unpack varmap, eqdata, derivative, integral, flat_init_params, phi, depvars_outs_map, = pinnrep
eltypeθ = eltype(flat_init_params)

ex_vars = get_depvars(term, varmap.depvar_ops)

if multioutput
dummyvars = @variables (phi(..))[1:length(varmap.ū)], θ_SYMBOL, switch
else
dummyvars = @variables phi(..), θ_SYMBOL, switch
end
# if multioutput
# dummyvars = @variables switch
# else
# dummyvars = @variables switch
# end
dummyvars = @variables switch

dummyvars = unwrap.(dummyvars)
deriv_rules = generate_derivative_rules(
term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput)
term, eqdata, eltypeθ, dummyvars, derivative, varmap, depvars_outs_map)
ch = Prewalk(Chain(deriv_rules))

expr = ch(term)
Expand All @@ -111,7 +117,7 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa
sym_coords = DestructuredArgs(ivs)
ps = DestructuredArgs(varmap.ps)

args = [sym_coords, θ_SYMBOL, phi, ps]
args = [sym_coords, ps]

ex = Func(args, [], expr) |> toexpr |> _dot_

Expand All @@ -121,11 +127,11 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa
end

function generate_derivative_rules(
term, eqdata, eltypeθ, dummyvars, derivative, varmap, multioutput)
phi, θ, switch = dummyvars
if symtype(phi) isa AbstractArray
phi = collect(phi)
end
term, eqdata, eltypeθ, dummyvars, derivative, varmap, depvars_outs_map)
switch = dummyvars
# if symtype(phi) isa AbstractArray
# phi = collect(phi)
# end

dvs = get_depvars(term, varmap.depvar_ops)

Expand All @@ -134,7 +140,7 @@ function generate_derivative_rules(
rs = reduce(vcat,
[reduce(vcat,
[[@rule $((Differential(x)^d)(w)) => derivative(
ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ),
depvars_outs_map[operation(w)], arguments(w),
get_ε(n(w), j, eltypeθ, d),
d, θ)
for d in differential_order(term, x)]
Expand All @@ -154,16 +160,16 @@ function generate_derivative_rules(
ε2 = get_ε(n(w), k, eltypeθ, 1)
[@rule $((Differential(x))((Differential(y))(w))) => derivative(
(coord_, θ_) -> derivative(
ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ),
depvars_outs_map[operation(w)], arguments(w),
ε2, 1, θ_),
reducevcat(arguments(w), eltypeθ, switch), ε1, 1, θ)]
arguments(w), ε1, 1, θ)]
end
end
end
end

vr = mapreduce(vcat, dvs, init = []) do w
@rule w => ufunc(w, phi, varmap)(reducevcat(arguments(w), eltypeθ), θ)
@rule w => depvars_outs_map[operation(w)](arguments(w))
end

return [mx; rs; vr]
Expand Down
4 changes: 4 additions & 0 deletions src/pinn_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,10 @@ mutable struct PINNRepresentation
"""
derivative::Any
"""
Symbols of parameters of neural networks.
"""
depvars_outs_map::Any
"""
The training strategy as provided by the user
"""
strategy::AbstractTrainingStrategy
Expand Down

0 comments on commit 1a74b0a

Please sign in to comment.