diff --git a/test/NNODE_tstops_test.jl b/test/NNODE_tstops_test.jl index bc4a4b08d6..d58739c1c5 100644 --- a/test/NNODE_tstops_test.jl +++ b/test/NNODE_tstops_test.jl @@ -35,14 +35,14 @@ dx = 1.0 @testset "Without added points" begin println("Without added points") # (difference between solutions should be high) - alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) + alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx), batch = true) sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin println("With added points") # (difference between solutions should be low) - alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx)) + alg = NNODE(chain, opt, autodiff = false, strategy = GridTraining(dx), batch = true) sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end @@ -53,14 +53,14 @@ end @testset "Without added points" begin println("Without added points") # (difference between solutions should be high) - alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points)) + alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points), batch = true) sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin println("With added points") # (difference between solutions should be low) - alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points)) + alg = NNODE(chain, opt, autodiff = false, strategy = WeightedIntervalTraining(weights, points), batch = true) sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end @@ -71,14 +71,14 @@ end @testset "Without added points" begin println("Without added points") # (difference between solutions should be high) - alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) + alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points), batch = true) sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat) @test abs(mean(sol) - mean(true_sol)) > threshold end @testset "With added points" begin println("With added points") # (difference between solutions should be low) - alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points)) + alg = NNODE(chain, opt, autodiff = false, strategy = StochasticTraining(points), batch = true) sol = solve(prob_oop, alg, verbose = false, maxiters = maxiters, saveat = saveat, tstops = addedPoints) @test abs(mean(sol) - mean(true_sol)) < threshold end