Skip to content

Commit

Permalink
test: batch = true for NNODE tstops tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Mar 22, 2024
1 parent 6c94adf commit b22ccd7
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/NNODE_tstops_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ dx = 1.0
@testset "Without added points" begin
println("Without added points")
# (difference between solutions should be high)
alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx))
alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx), batch = true)
sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat)
@test abs(mean(sol) - mean(true_sol)) > threshold
end
@testset "With added points" begin
println("With added points")
# (difference between solutions should be low)
alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx))
alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx), batch = true)
sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints)
@test abs(mean(sol) - mean(true_sol)) < threshold
end
Expand All @@ -53,14 +53,14 @@ end
@testset "Without added points" begin
println("Without added points")
# (difference between solutions should be high)
alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points))
alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points), batch = true)
sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat)
@test abs(mean(sol) - mean(true_sol)) > threshold
end
@testset "With added points" begin
println("With added points")
# (difference between solutions should be low)
alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points))
alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points), batch = true)
sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints)
@test abs(mean(sol) - mean(true_sol)) < threshold
end
Expand All @@ -71,14 +71,14 @@ end
@testset "Without added points" begin
println("Without added points")
# (difference between solutions should be high)
alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points))
alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points), batch = true)
sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat)
@test abs(mean(sol) - mean(true_sol)) > threshold
end
@testset "With added points" begin
println("With added points")
# (difference between solutions should be low)
alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points))
alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points), batch = true)
sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints)
@test abs(mean(sol) - mean(true_sol)) < threshold
end
Expand Down

0 comments on commit b22ccd7

Please sign in to comment.