From 12995659ee9351c953bebd50a9fcc4a592e6adc9 Mon Sep 17 00:00:00 2001 From: Alex Jones Date: Thu, 4 Jan 2024 16:12:03 +0000 Subject: [PATCH] fix switch --- src/loss_function_generation.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl index 25bb72c417..eeb7bf5362 100644 --- a/src/loss_function_generation.jl +++ b/src/loss_function_generation.jl @@ -120,7 +120,6 @@ function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = fa sym_coords = DestructuredArgs(ivs) ps = DestructuredArgs(varmap.ps) - args = [sym_coords, θ_SYMBOL, phi, ps] ex = Func(args, [], expr) |> toexpr |> _dot_ @@ -137,12 +136,13 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative end dvs = get_depvars(term, varmap.depvar_ops) + ivs = get_indvars(term, v) @show eltypeθ @show methods(derivative) # Orthodox derivatives n(w) = length(arguments(w)) rs = reduce(vcat, [reduce(vcat, [[@rule $((Differential(x)^d)(w)) => - derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), + derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ, switch), [get_ε(n(w), j, eltypeθ, i) for i in 1:d], d, θ) @@ -159,7 +159,7 @@ function generate_derivative_rules(term, eqdata, eltypeθ, dummyvars, derivative ε1 = [get_ε(n(w), j, eltypeθ, i) for i in 1:2] ε2 = [get_ε(n(w), k, eltypeθ, i) for i in 1:2] [@rule $((Differential(x))((Differential(y))(w))) => - derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ), + derivative((coord_, θ_) -> derivative(ufunc(w, phi, varmap), reducevcat(arguments(w), eltypeθ, switch), ε2, 1, θ_), reducevcat(arguments(w), eltypeθ), ε1, 1, θ)] end