Skip to content

Commit

Permalink
fixup! test: debug CI slowdown using loss callback for the case it is…
Browse files Browse the repository at this point in the history
… stuck
  • Loading branch information
sathvikbhagavan committed Feb 29, 2024
1 parent 6d39d02 commit fa90f4a
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions test/neural_adapter_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Test, NeuralPDE
using Optimization, OptimizationOptimJL
using Optimization
import ModelingToolkit: Interval, infimum, supremum
import Lux, OptimizationOptimisers
using Statistics
Expand All @@ -8,12 +8,19 @@ using ComponentArrays
using Random
Random.seed!(100)

global iter = 0

callback = function (p, l)
println("Current loss is: $l")
global iter
iter += 1
if iter % 100 == 0
println("Current loss at iteration $iter is: $l")
end
return false
end

@testset "Example, 2D Poisson equation with Neural adapter" begin
# @testset "Example, 2D Poisson equation with Neural adapter" begin
begin
@parameters x y
@variables u(..)
Dxx = Differential(x)^2
Expand Down Expand Up @@ -43,7 +50,7 @@ end
@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
prob = NeuralPDE.discretize(pde_system, discretization)
println("Poisson equation, strategy: $(nameof(typeof(quadrature_strategy)))")
@time res = solve(prob, OptimizationOptimisers.Adam(5e-3); maxiters = 10000)
@time res = solve(prob, OptimizationOptimisers.Adam(5e-3); maxiters = 10000, callback)
phi = discretization.phi

inner_ = 8
Expand All @@ -56,6 +63,7 @@ end
init_params2 = Float64.(ComponentArrays.ComponentArray(initp))

function loss(cord, θ)
global st
ch2, st = chain2(cord, θ, st)
ch2 .- phi(cord, res.minimizer)
end
Expand All @@ -69,18 +77,16 @@ end
reses_1 = map(strategies1) do strategy_
println("Neural adapter Poisson equation, strategy: $(nameof(typeof(strategy_)))")
prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy_)
if strategy_ isa QuadratureTraining
@time res_ = solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000, callback)
else
@time res_ = solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000)
end
global iter = 0
@time res_ = solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000, callback)
end

strategies2 = [stochastic_strategy, quasirandom_strategy]
reses_2 = map(strategies2) do strategy_
println("Neural adapter Poisson equation, strategy: $(nameof(typeof(strategy_)))")
prob_ = NeuralPDE.neural_adapter(loss, init_params2, pde_system, strategy_)
@time res_ = solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000)
global iter = 0
@time res_ = solve(prob_, OptimizationOptimisers.Adam(5e-3); maxiters = 10000, callback)
end

reses_ = [reses_1; reses_2]
Expand Down

0 comments on commit fa90f4a

Please sign in to comment.