From f0c5a347a7e81429b4c4e28303a1a049927bf810 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Thu, 21 Mar 2024 14:38:02 +0000 Subject: [PATCH] fixup! test: make verbose = false for NNODE tests --- test/NNODE_tests.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 1d219ab05..8d2717161 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -9,6 +9,7 @@ Random.seed!(100) @testset "Scalar" begin # Run a solve on scalars + println("Scalar") linear = (u, p, t) -> cos(2pi * t) tspan = (0.0f0, 1.0f0) u0 = 0.0f0 @@ -36,6 +37,7 @@ end @testset "Vector" begin # Run a solve on vectors + println("Vector") linear = (u, p, t) -> [cos(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0] @@ -59,6 +61,7 @@ end end @testset "Example 1" begin + println("Example 1") linear = (u, p, t) -> @. t^3 + 2 * t + (t^2) * ((1 + 3 * (t^2)) / (1 + t + (t^3))) - u * (t + ((1 + 3 * (t^2)) / (1 + t + t^3))) linear_analytic = (u0, p, t) -> [exp(-(t^2) / 2) / (1 + t + t^3) + t^2] @@ -92,6 +95,7 @@ end end @testset "Example 2" begin + println("Example 2") linear = (u, p, t) -> -u / 5 + exp(-t / 5) .* cos(t) linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t)) prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0)) @@ -128,6 +132,7 @@ end end @testset "Example 3" begin + println("Example 3") linear = (u, p, t) -> [cos(2pi * t), sin(2pi * t)] tspan = (0.0f0, 1.0f0) u0 = [0.0f0, -1.0f0 / 2pi] @@ -146,6 +151,7 @@ end @testset "Training Strategies" begin @testset "WeightedIntervalTraining" begin + println("WeightedIntervalTraining") function f(u, p, t) [p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]] end @@ -162,7 +168,7 @@ end points = 200 alg = NNODE(chain, opt, autodiff = false, strategy = NeuralPDE.WeightedIntervalTraining(weights, points)) - sol = solve(prob_oop, alg, verbose = false, maxiters = 100000, saveat = 0.01) + sol = solve(prob_oop, alg, verbose = true, maxiters = 5000, saveat = 0.01) @test abs(mean(sol) - mean(true_sol)) < 0.2 end @@ -176,6 +182,7 @@ end u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x) @testset "GridTraining" begin + println("GridTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) @@ -188,6 +195,7 @@ end end @testset "QuadratureTraining" begin + println("QuadratureTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) @@ -199,6 +207,7 @@ end end @testset "StochasticTraining" begin + println("StochasticTraining") luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) @@ -212,6 +221,7 @@ end end @testset "Parameter Estimation" begin + println("Parameter Estimation") function lorenz(u, p, t) return [p[1]*(u[2]-u[1]), u[1]*(p[2]-u[3])-u[2], @@ -241,6 +251,7 @@ end end @testset "Translating from Flux" begin + println("Translating from Flux") linear = (u, p, t) -> cos(2pi * t) linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t) tspan = (0.0, 1.0)