diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 1122afc838..a2ffc2370a 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -31,7 +31,7 @@ using DomainSets: Domain, ClosedInterval, AbstractInterval, leftendpoint, righte using SciMLBase: @add_kwonly, parameterless_type using UnPack: @unpack import ChainRulesCore, Lux, ComponentArrays -using Lux: FromFluxAdaptor +using Lux: FromFluxAdaptor, recursive_eltype using ChainRulesCore: @non_differentiable RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 64d7b3ac6c..bcf9c68ebe 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -370,8 +370,7 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, !(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported") phi, init_params = generate_phi_θ(chain, t0, u0, init_params) - ((eltype(eltype(init_params).types[1]) <: Complex || - eltype(eltype(init_params).types[2]) <: Complex) && + (recursive_eltype(init_params) <: Complex && alg.strategy isa QuadratureTraining) && error("QuadratureTraining cannot be used with complex parameters. Use other strategies.")