From dad881531fffe25705c35b3258c4c3f183904636 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Thu, 28 Mar 2024 04:14:26 +0000 Subject: [PATCH] refactor: error out if QuadratureTraining is used with complex parameters for NNODE --- src/ode_solve.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index bb417b6492..eb5ae942e4 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -358,6 +358,8 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractODEProblem, !(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported") phi, init_params = generate_phi_θ(chain, t0, u0, init_params) + ((eltype(eltype(init_params).types[1]) <: Complex || eltype(eltype(init_params).types[2]) <: Complex) && alg.strategy isa QuadratureTraining) && + error("QuadratureTraining cannot be used with complex parameters. Use other strategies.") init_params = if alg.param_estim ComponentArrays.ComponentArray(; depvar = ComponentArrays.ComponentArray(init_params), p = prob.p)