diff --git a/.gitignore b/.gitignore index 90debbd..bc8c91c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ benchmarks/graphs/* *~ *.kate-swp +Manifest.toml +test/*.png +docs/build/ diff --git a/Project.toml b/Project.toml index aaa2ef0..585c560 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c" OptimBase = "87e2bd06-a317-5318-96d9-3ecbac512eee" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] @@ -16,11 +17,13 @@ Distributions = "0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24" ForwardDiff = "0.10" NLSolversBase = "7.5" OptimBase = "2.0" +RecipesBase = "1" StatsBase = "0.32, 0.33" julia = "1.1" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" [targets] -test = ["Test"] +test = ["Test", "Plots"] diff --git a/docs/src/img/plots.png b/docs/src/img/plots.png new file mode 100644 index 0000000..59ff55a Binary files /dev/null and b/docs/src/img/plots.png differ diff --git a/docs/src/tutorial.md b/docs/src/tutorial.md index 171a4f8..1b3233a 100644 --- a/docs/src/tutorial.md +++ b/docs/src/tutorial.md @@ -391,6 +391,33 @@ julia> fit_WLS = curve_fit(m, tdata, ydata, wt, p0) julia> cov = estimate_covar(fit_WLS) ``` +## Visualization +A `Plots.jl` plot recipe is provided for visualizing a fit. The simplest case +is plotting just the fit curve by itself, which can be done by supplying the +model and the fit object to a `plot` call: +```julia +plot(m, fit) +``` + +The plot recipe can also show the confidence and prediction intervals by +supplying the `purpose` keyword (default `:neither`): +```julia +plot( + plot(m, fit; purpose=:neither), + plot(m, fit; purpose=:prediction), + plot(m, fit; purpose=:confidence), + plot(m, fit; purpose=:both), + ; + layout=(2,2), +) +``` + +![Plots with different intervals marked.](./img/plots.png) + +Changing the keyword `significance` (default `0.05`) changes which confidence +is being illustrated. + + ## References Hansen, P. C., Pereyra, V. and Scherer, G. (2013) Least squares data fitting with applications. Baltimore, Md: Johns Hopkins University Press, p. 147-155. diff --git a/src/LsqFit.jl b/src/LsqFit.jl index e1ecc91..0d139d8 100644 --- a/src/LsqFit.jl +++ b/src/LsqFit.jl @@ -15,6 +15,7 @@ module LsqFit using OptimBase using LinearAlgebra using ForwardDiff + using RecipesBase import NLSolversBase: value, jacobian import StatsBase import StatsBase: coef, dof, nobs, rss, stderror, weights, residuals @@ -24,5 +25,6 @@ module LsqFit include("geodesic.jl") include("levenberg_marquardt.jl") include("curve_fit.jl") + include("plot.jl") end diff --git a/src/plot.jl b/src/plot.jl new file mode 100644 index 0000000..f4d76b4 --- /dev/null +++ b/src/plot.jl @@ -0,0 +1,51 @@ +@recipe function f(model::Function, fit::LsqFitResult; significance=0.05, purpose=:neither) + @series begin + seriestype --> :line + label --> "Fit" + x->model(x, fit.param) + end + if purpose in (:confidence, :both) + @series begin + seriestype := :line + seriescolor := :black + linestyle := :dash + label := ["Confidence interval" nothing] + + [ + x -> model(x, fit.param) - margin_error(model, x, fit, significance), + x -> model(x, fit.param) + margin_error(model, x, fit, significance) + ] + end + end + if purpose in (:prediction, :both) + @series begin + seriestype := :line + seriescolor := :black + linestyle := :dot + label := ["Prediction interval" nothing] + + [ + x -> model(x, fit.param) - margin_error(model, x, fit, significance; purpose=:prediction) + x -> model(x, fit.param) + margin_error(model, x, fit, significance; purpose=:prediction) + ] + end + end +end + + +""" +```julia +margin_error(model, x, fit, significance; purpose) +``` +Find the width at `x` of the confidence or prediction interval. +""" +function margin_error(model::Function, x, fit::LsqFitResult, alpha=0.05; purpose=:confidence) + g = p -> ForwardDiff.gradient(p -> model(x, p), fit.param) + c = g(fit.param)' * estimate_covar(fit) * g(fit.param) + if purpose === :prediction + c = c + 1 + end + dist = TDist(dof(fit)) + critical_values = quantile(dist, 1 - alpha/2) + return sqrt(c*rss(fit)/dof(fit))*critical_values +end diff --git a/test/plotting.jl b/test/plotting.jl new file mode 100644 index 0000000..23d8f1b --- /dev/null +++ b/test/plotting.jl @@ -0,0 +1,35 @@ +let + @testset "Plotting" begin + model(x, p) = @. p[1] * x^2 + p[2] * x + p[3] + p[4] * exp(-(x - p[5])^2 / p[6]^2) + p_true = [0.5, 0.7, 0.0, 4.5, 0.3, 0.5] + xdata = -4.0:0.5:4.0 + ydata = model(xdata, p_true) + 0.7 * randn(size(xdata)) + + f = curve_fit(model, xdata, ydata, fill(1.0, 6)) + + p = plot( + map((:neither, :prediction, :confidence, :both)) do purpose + plot( + xdata, + ydata; + seriestype=:scatter, + label="Data", + legend_foreground_color=nothing, + legend_background_color=nothing, + legend=:topleft, + ) + plot!( + x -> model(x, p_true); + seriestype=:line, + label="Ground truth", + linestyle=:dot, + ) + plot!(model, f; purpose=purpose, title="$purpose") + end...; + layout=(2, 2), + ) + @test p isa Plots.Plot + savefig(p, "plots.png") + println("Plots saved to `test/plots.png`") + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ce4b553..11638a6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,9 +3,10 @@ # using LsqFit, Test, LinearAlgebra, Random using OptimBase +using Plots import NLSolversBase: OnceDifferentiable -my_tests = ["curve_fit.jl", "levenberg_marquardt.jl", "curve_fit_inplace.jl", "geodesic.jl"] +my_tests = ["curve_fit.jl", "levenberg_marquardt.jl", "curve_fit_inplace.jl", "geodesic.jl", "plotting.jl"] println("Running tests:")