Skip to content

Commit

Permalink
test: add tests for complex equations
Browse files Browse the repository at this point in the history
  • Loading branch information
sathvikbhagavan committed Mar 28, 2024
1 parent 24a65de commit 88c4116
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions test/NNODE_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import Lux, OptimizationOptimisers, OptimizationOptimJL
using Flux
using LineSearches

rng = Random.default_rng()
Random.seed!(100)

@testset "Scalar" begin
Expand Down Expand Up @@ -250,6 +251,45 @@ end
@test reduce(hcat, sol.u)u_ atol=1e-2
end

@testset "Complex Numbers" begin
function bloch_equations(u, p, t)
Ω, Δ, Γ = p
γ = Γ / 2
ρ₁₁, ρ₂₂, ρ₁₂, ρ₂₁ = u
d̢ρ = [im * Ω * (ρ₁₂ - ρ₂₁) + Γ * ρ₂₂;
-im * Ω * (ρ₁₂ - ρ₂₁) - Γ * ρ₂₂;
-+ im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁);
conj(-+ im * Δ) * ρ₁₂ - im * Ω * (ρ₂₂ - ρ₁₁))]
return d̢ρ
end

u0 = zeros(ComplexF64, 4)
u0[1] = 1
time_span = (0.0, 2.0)
parameters = [2.0, 0.0, 1.0]

problem = ODEProblem(bloch_equations, u0, time_span, parameters)

chain = Lux.Chain(
Lux.Dense(1, 16, tanh; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...)) ,
Lux.Dense(16, 4; init_weight = (rng, a...) -> Lux.kaiming_normal(rng, ComplexF64, a...))
)
ps, st = Lux.setup(rng, chain)

opt = OptimizationOptimisers.Adam(0.01)
ground_truth = solve(problem, Tsit5(), saveat = 0.01)
strategies = [StochasticTraining(500), GridTraining(0.01), WeightedIntervalTraining([0.1, 0.4, 0.4, 0.1], 500)]

@testset "$(nameof(typeof(strategy)))" for strategy in strategies
alg = NNODE(chain, opt, ps; strategy)
sol = solve(problem, alg, verbose = false, maxiters = 5000, saveat = 0.01)
@test sol.u ground_truth.u rtol=1e-1
end

alg = NNODE(chain, opt, ps; strategy = QuadratureTraining())
@test_throws ErrorException solve(problem, alg, verbose = false, maxiters = 5000, saveat = 0.01)
end

@testset "Translating from Flux" begin
println("Translating from Flux")
linear = (u, p, t) -> cos(2pi * t)
Expand Down

0 comments on commit 88c4116

Please sign in to comment.