Skip to content

Commit

Permalink
fixup! test: make verbose = false for NNODE tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Mar 21, 2024
1 parent 6e26018 commit f0c5a34
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Random.seed!(100)

@testset "Scalar" begin
# Run a solve on scalars
println("Scalar")
linear = (u, p, t) -> cos(2pi * t)
tspan = (0.0f0, 1.0f0)
u0 = 0.0f0
Expand Down Expand Up @@ -36,6 +37,7 @@ end

@testset "Vector" begin
# Run a solve on vectors
println("Vector")
linear = (u, p, t) -> [cos(2pi * t)]
tspan = (0.0f0, 1.0f0)
u0 = [0.0f0]
Expand All @@ -59,6 +61,7 @@ end
end

@testset "Example 1" begin
println("Example 1")
linear = (u, p, t) -> @. t^3 + 2 * t + (t^2) * ((1 + 3 * (t^2)) / (1 + t + (t^3))) -
u * (t + ((1 + 3 * (t^2)) / (1 + t + t^3)))
linear_analytic = (u0, p, t) -> [exp(-(t^2) / 2) / (1 + t + t^3) + t^2]
Expand Down Expand Up @@ -92,6 +95,7 @@ end
end

@testset "Example 2" begin
println("Example 2")
linear = (u, p, t) -> -u / 5 + exp(-t / 5) .* cos(t)
linear_analytic = (u0, p, t) -> exp(-t / 5) * (u0 + sin(t))
prob = ODEProblem(ODEFunction(linear, analytic = linear_analytic), 0.0f0, (0.0f0, 1.0f0))
Expand Down Expand Up @@ -128,6 +132,7 @@ end
end

@testset "Example 3" begin
println("Example 3")
linear = (u, p, t) -> [cos(2pi * t), sin(2pi * t)]
tspan = (0.0f0, 1.0f0)
u0 = [0.0f0, -1.0f0 / 2pi]
Expand All @@ -146,6 +151,7 @@ end

@testset "Training Strategies" begin
@testset "WeightedIntervalTraining" begin
println("WeightedIntervalTraining")
function f(u, p, t)
[p[1] * u[1] - p[2] * u[1] * u[2], -p[3] * u[2] + p[4] * u[1] * u[2]]
end
Expand All @@ -162,7 +168,7 @@ end
points = 200
alg = NNODE(chain, opt, autodiff = false,
strategy = NeuralPDE.WeightedIntervalTraining(weights, points))
sol = solve(prob_oop, alg, verbose = false, maxiters = 100000, saveat = 0.01)
sol = solve(prob_oop, alg, verbose = true, maxiters = 5000, saveat = 0.01)
@test abs(mean(sol) - mean(true_sol)) < 0.2
end

Expand All @@ -176,6 +182,7 @@ end
u_analytical(x) = (1 / (2pi)) .* sin.(2pi .* x)

@testset "GridTraining" begin
println("GridTraining")
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
Expand All @@ -188,6 +195,7 @@ end
end

@testset "QuadratureTraining" begin
println("QuadratureTraining")
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
Expand All @@ -199,6 +207,7 @@ end
end

@testset "StochasticTraining" begin
println("StochasticTraining")
luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1))
(u_, t_) = (u_analytical(ts), ts)
function additional_loss(phi, θ)
Expand All @@ -212,6 +221,7 @@ end
end

@testset "Parameter Estimation" begin
println("Parameter Estimation")
function lorenz(u, p, t)
return [p[1]*(u[2]-u[1]),
u[1]*(p[2]-u[3])-u[2],
Expand Down Expand Up @@ -241,6 +251,7 @@ end
end

@testset "Translating from Flux" begin
println("Translating from Flux")
linear = (u, p, t) -> cos(2pi * t)
linear_analytic = (u, p, t) -> (1 / (2pi)) * sin(2pi * t)
tspan = (0.0, 1.0)
Expand Down

0 comments on commit f0c5a34

Please sign in to comment.