diff --git a/src/pino_ode_solve.jl b/src/pino_ode_solve.jl index 8ed273888d..abdd6f852f 100644 --- a/src/pino_ode_solve.jl +++ b/src/pino_ode_solve.jl @@ -1,3 +1,13 @@ +struct TRAINSET{} + input_data::Vector{ODEProblem} + output_data::Array + isu0::Bool +end + +function TRAINSET(input_data, output_data; isu0 = false) + TRAINSET(input_data, output_data, isu0) +end + """ PINOODE(chain, OptimizationOptimisers.Adam(0.1), @@ -49,16 +59,6 @@ function PINOODE(chain, PINOODE(chain, opt, train_set, is_data_loss, is_physics_loss, init_params, kwargs) end -struct TRAINSET{} - input_data::Vector{ODEProblem} - output_data::Array - isu0::Bool -end - -function TRAINSET(input_data, output_data; isu0 = false) - TRAINSET(input_data, output_data, isu0) -end - mutable struct PINOPhi{C, T, U, S} chain::C t0::T