diff --git a/.travis.yml b/.travis.yml index d90ad059d3..38eeb116c4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -4,7 +4,7 @@ os: - linux - osx julia: - - 1.0 + - 1.1 - nightly matrix: allow_failures: diff --git a/REQUIRE b/REQUIRE deleted file mode 100644 index 78339d8950..0000000000 --- a/REQUIRE +++ /dev/null @@ -1,8 +0,0 @@ -julia 0.5 -DiffEqBase -Plots -ParameterizedFunctions -DiffEqProblemLibrary -Flux -Compat 0.17.0 -Reexport diff --git a/src/NeuralNetDiffEq.jl b/src/NeuralNetDiffEq.jl index 1b78b1376d..83c37922fa 100644 --- a/src/NeuralNetDiffEq.jl +++ b/src/NeuralNetDiffEq.jl @@ -5,13 +5,13 @@ using Reexport using Flux abstract type NeuralNetDiffEqAlgorithm <: DiffEqBase.AbstractODEAlgorithm end -struct nnode <: NeuralNetDiffEqAlgorithm - hl_width::Int +struct nnode{C,O} <: NeuralNetDiffEqAlgorithm + chain::C + opt::O end -nnode(;hl_width=10) = nnode(hl_width) +nnode(chain;opt=Flux.ADAM(0.1)) = nnode(chain,opt) export nnode include("solve.jl") -include("training_utils.jl") end # module diff --git a/src/solve.jl b/src/solve.jl index 7c10738b79..850518de5d 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -2,66 +2,60 @@ function DiffEqBase.solve( prob::DiffEqBase.AbstractODEProblem, alg::NeuralNetDiffEqAlgorithm, args...; - dt = error("dt must be set."), + dt, timeseries_errors = true, save_everystep=true, adaptive=false, + abstol = 1f-6, + verbose = false, maxiters = 100) + DiffEqBase.isinplace(prob) && error("Only out-of-place methods are allowed!") + u0 = prob.u0 tspan = prob.tspan f = prob.f p = prob.p t0 = tspan[1] - #types and dimensions - # uElType = eltype(u0) - # tType = typeof(tspan[1]) - # outdim = length(u0) - #hidden layer - hl_width = alg.hl_width - - #initialization of weights and bias - P = init_params(hl_width) - - #The phi trial solution - phi(P,x) = u0 .+ x.*predict(P,x) + chain = alg.chain + opt = alg.opt + ps = Flux.params(chain) + data = Iterators.repeated((), maxiters) #train points generation - x = generate_data(tspan[1],tspan[2],dt) - y = [f(phi(P, i)[1].data, p, i) for i in x] - px =Flux.param(x) - data = [(px, y)] + ts = tspan[1]:dt:tspan[2] - #initialization of optimization parameters (ADAM by default for now) - η = 0.1 - β1 = 0.9 - β2 = 0.95 - opt = Flux.ADAM(η, (β1, β2)) - - ps = Flux.Params(P) - - #derivatives of a function f - dfdx(i) = Tracker.gradient(() -> sum(phi(P,i)), Flux.params(i); nest = true) - #loss function for training - loss(x, y) = sum(abs2, [dfdx(i)[i] for i in x] .- y) + #The phi trial solution + phi(t) = u0 .+ (t .- tspan[1]).*chain(Tracker.collect([t])) + + if u0 isa Number + dfdx = t -> Tracker.gradient(t -> sum(phi(t)), t; nest = true)[1] + loss = () -> sum(abs2,sum(abs2,dfdx(t) .- f(phi(t)[1],p,t)[1]) for t in ts) + else + dfdx = t -> (phi(t+sqrt(eps(typeof(dt)))) - phi(t)) / sqrt(eps(typeof(dt))) + #dfdx(t) = Flux.Tracker.forwarddiff(phi,t) + #dfdx(t) = Tracker.collect([Flux.Tracker.gradient(t->phi(t)[i],t, nest=true) for i in 1:length(u0)]) + #loss function for training + loss = () -> sum(abs2,sum(abs2,dfdx(t) - f(phi(t),p,t)) for t in ts) + end - @time for iters=1:maxiters - Flux.train!(loss, ps, data, opt) - if mod(iters,50) == 0 - loss_value = loss(px,y).data - println((:iteration,iters,:loss,loss_value)) - if loss_value < 10^(-6.0) - break - end - end + cb = function () + l = loss() + verbose && println("Current loss is: $l") + l < abstol && Flux.stop() end + Flux.train!(loss, ps, data, opt; cb = cb) #solutions at timepoints - u = [phi(P,i)[1].data for i in x] + if u0 isa Number + u = [phi(t)[1].data for t in ts] + else + u = [phi(t).data for t in ts] + end - sol = DiffEqBase.build_solution(prob,alg,x,u,calculate_error = false) + sol = DiffEqBase.build_solution(prob,alg,ts,u,calculate_error = false) DiffEqBase.has_analytic(prob.f) && DiffEqBase.calculate_solution_errors!(sol;timeseries_errors=true,dense_errors=false) sol end #solve diff --git a/src/training_utils.jl b/src/training_utils.jl deleted file mode 100644 index 7861d2cf74..0000000000 --- a/src/training_utils.jl +++ /dev/null @@ -1,16 +0,0 @@ -sigm(x) = 1 ./ (1 .+ exp.(.-x)) - -function predict(P, x) - w, b, v = P - h = sigm(w * x .+ b) - return v * h -end - -function init_params(hl_width) - w = Flux.param(rand(hl_width,1)) - b = Flux.param(zeros(hl_width,1)) - v = Flux.param(randn(1, hl_width)) - return [w ,b, v] -end - -generate_data(low, high, dt) = collect(low:dt:high) diff --git a/test/REQUIRE b/test/REQUIRE deleted file mode 100644 index df87e0d523..0000000000 --- a/test/REQUIRE +++ /dev/null @@ -1,2 +0,0 @@ -DiffEqDevTools -DiffEqProblemLibrary diff --git a/test/runtests.jl b/test/runtests.jl index 343b65d1de..300ce8bb52 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,29 +1,58 @@ -using NeuralNetDiffEq, Test +using Test, Flux, NeuralNetDiffEq using DiffEqDevTools -# Run a solve +# Run a solve on scalars linear = (u,p,t) -> cos(2pi*t) -tspan = (0.0,1.0) -u0 = 0.0 +tspan = (0.0f0, 1.0f0) +u0 = 0.0f0 prob = ODEProblem(linear, u0 ,tspan) -sol = solve(prob, NeuralNetDiffEq.nnode(5), dt=1/20, maxiters=300) -# println(sol) -#plot(sol) -#plot!(sol.t, t -> sin(2pi*t) / (2*pi), lw=3,ls=:dash,label="True Solution!") +chain = Flux.Chain(Dense(1,5,σ),Dense(5,1)) +opt = Flux.ADAM(0.1, (0.9, 0.95)) +sol = solve(prob, NeuralNetDiffEq.nnode(chain,opt), dt=1/20f0, verbose = true, + abstol=1e-10, maxiters = 200) + +# Run a solve on vectors +linear = (u,p,t) -> [cos(2pi*t)] +tspan = (0.0f0, 1.0f0) +u0 = [0.0f0] +prob = ODEProblem(linear, u0 ,tspan) +chain = Flux.Chain(Dense(1,5,σ),Dense(5,1)) +opt = Flux.ADAM(0.1, (0.9, 0.95)) +sol = solve(prob, NeuralNetDiffEq.nnode(chain,opt), dt=1/20f0, abstol=1e-10, + verbose = true, maxiters=200) #Example 1 -linear = (u,p,t) -> t^3 + 2*t + (t^2)*((1+3*(t^2))/(1+t+(t^3))) - u*(t + ((1+3*(t^2))/(1+t+t^3))) -linear_analytic = (u0,p,t) -> exp(-(t^2)/2)/(1+t+t^3) + t^2 -prob = ODEProblem(ODEFunction(linear,analytic=linear_analytic),1/2,(0.0,1.0)) -dts = 1 ./ 2 .^ (10:-1:7) -sim = test_convergence(dts, prob, nnode()) -@test abs(sim.𝒪est[:l2]) < 0.3 +linear = (u,p,t) -> @. t^3 + 2*t + (t^2)*((1+3*(t^2))/(1+t+(t^3))) - u*(t + ((1+3*(t^2))/(1+t+t^3))) +linear_analytic = (u0,p,t) -> [exp(-(t^2)/2)/(1+t+t^3) + t^2] +prob = ODEProblem(ODEFunction(linear,analytic=linear_analytic),[1f0],(0.0f0,1.0f0)) +chain = Flux.Chain(Dense(1,5,σ),Dense(5,1)) +opt = Flux.ADAM(0.1, (0.9, 0.95)) +sol = solve(prob,NeuralNetDiffEq.nnode(chain,opt),verbose = true, dt=1/5f0) +err = sol.errors[:l2] +sol = solve(prob,NeuralNetDiffEq.nnode(chain,opt),verbose = true, dt=1/20f0) +sol.errors[:l2]/err < 0.5 + +#= +dts = 1f0 ./ 2f0 .^ (6:-1:2) +sim = test_convergence(dts, prob, NeuralNetDiffEq.nnode(chain, opt)) +@test abs(sim.𝒪est[:l2]) < 0.1 @test minimum(sim.errors[:l2]) < 0.5 +=# #Example 2 linear = (u,p,t) -> -u/5 + exp(-t/5).*cos(t) linear_analytic = (u0,p,t) -> exp(-t/5)*(u0 + sin(t)) -prob = ODEProblem(ODEFunction(linear,analytic=linear_analytic),0.0,(0.0,1.0)) -sim = test_convergence(dts, prob, nnode()) +prob = ODEProblem(ODEFunction(linear,analytic=linear_analytic),0.0f0,(0.0f0,1.0f0)) +chain = Flux.Chain(Dense(1,5,σ),Dense(5,1)) +opt = Flux.ADAM(0.1, (0.9, 0.95)) +sol = solve(prob,NeuralNetDiffEq.nnode(chain,opt),verbose = true, dt=1/5f0) +err = sol.errors[:l2] +sol = solve(prob,NeuralNetDiffEq.nnode(chain,opt),verbose = true, dt=1/20f0) +sol.errors[:l2]/err < 0.5 + +#= +dts = 1f0 ./ 2f0 .^ (6:-1:2) +sim = test_convergence(dts, prob, NeuralNetDiffEq.nnode(chain, opt)) @test abs(sim.𝒪est[:l2]) < 0.5 -@test minimum(sim.errors[:l2]) < 0.3 +@test minimum(sim.errors[:l2]) < 0.1 +=#