Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plot recipe for fit result #180

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
benchmarks/graphs/*
*~
*.kate-swp
Manifest.toml
test/*.png
docs/build/
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
OptimBase = "87e2bd06-a317-5318-96d9-3ecbac512eee"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[compat]
Distributions = "0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24"
ForwardDiff = "0.10"
NLSolversBase = "7.5"
OptimBase = "2.0"
RecipesBase = "1"
StatsBase = "0.32, 0.33"
julia = "1.1"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[targets]
test = ["Test"]
test = ["Test", "Plots"]
Binary file added docs/src/img/plots.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 27 additions & 0 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,33 @@ julia> fit_WLS = curve_fit(m, tdata, ydata, wt, p0)
julia> cov = estimate_covar(fit_WLS)
```

## Visualization
A `Plots.jl` plot recipe is provided for visualizing a fit. The simplest case
is plotting just the fit curve by itself, which can be done by supplying the
model and the fit object to a `plot` call:
```julia
plot(m, fit)
```

The plot recipe can also show the confidence and prediction intervals by
supplying the `purpose` keyword (default `:neither`):
```julia
plot(
plot(m, fit; purpose=:neither),
plot(m, fit; purpose=:prediction),
plot(m, fit; purpose=:confidence),
plot(m, fit; purpose=:both),
;
layout=(2,2),
)
```

![Plots with different intervals marked.](./img/plots.png)

Changing the keyword `significance` (default `0.05`) changes which confidence
is being illustrated.


## References
Hansen, P. C., Pereyra, V. and Scherer, G. (2013) Least squares data fitting with applications. Baltimore, Md: Johns Hopkins University Press, p. 147-155.

Expand Down
2 changes: 2 additions & 0 deletions src/LsqFit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ module LsqFit
using OptimBase
using LinearAlgebra
using ForwardDiff
using RecipesBase
import NLSolversBase: value, jacobian
import StatsBase
import StatsBase: coef, dof, nobs, rss, stderror, weights, residuals
Expand All @@ -24,5 +25,6 @@ module LsqFit
include("geodesic.jl")
include("levenberg_marquardt.jl")
include("curve_fit.jl")
include("plot.jl")

end
51 changes: 51 additions & 0 deletions src/plot.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
@recipe function f(model::Function, fit::LsqFitResult; significance=0.05, purpose=:neither)
@series begin
seriestype --> :line
label --> "Fit"
x->model(x, fit.param)
end
if purpose in (:confidence, :both)
@series begin
seriestype := :line
seriescolor := :black
linestyle := :dash
label := ["Confidence interval" nothing]

[
x -> model(x, fit.param) - margin_error(model, x, fit, significance),
x -> model(x, fit.param) + margin_error(model, x, fit, significance)
]
end
end
if purpose in (:prediction, :both)
@series begin
seriestype := :line
seriescolor := :black
linestyle := :dot
label := ["Prediction interval" nothing]

[
x -> model(x, fit.param) - margin_error(model, x, fit, significance; purpose=:prediction)
x -> model(x, fit.param) + margin_error(model, x, fit, significance; purpose=:prediction)
]
end
end
end


"""
```julia
margin_error(model, x, fit, significance; purpose)
```
Find the width at `x` of the confidence or prediction interval.
"""
function margin_error(model::Function, x, fit::LsqFitResult, alpha=0.05; purpose=:confidence)
g = p -> ForwardDiff.gradient(p -> model(x, p), fit.param)
c = g(fit.param)' * estimate_covar(fit) * g(fit.param)
if purpose === :prediction
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this was in the original formulation on stackexchange, but where does the 1 come from?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My statistics textbook has a derivation for this, but it's sadly in Swedish. I'll see later today if I can find a good reference. Essentially the difference between prediction and confidence coincidentally comes down to 1 dof.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. You can post the picture of the Swedish text, I'll be able to read it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

Here's the derivation in Seber, G.A.F & Wild, C.J 1989: Statistical inference (My textbook only had it for the linear case). My statistics isn't quite strong enough that I could convert this directly to working code, but I think it contains the source of the 1.

Intuitively, your next value is likely to land in an interval that is the confidence interval plus one "standard deviation" to either side, because that's where it's likely to land given the position of the true mean is inside the confidence interval. I hope this makes sense, otherwise it might be an idea to hold off on this part of the PR.

c = c + 1
end
dist = TDist(dof(fit))
critical_values = quantile(dist, 1 - alpha/2)
return sqrt(c*rss(fit)/dof(fit))*critical_values
end
35 changes: 35 additions & 0 deletions test/plotting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
let
@testset "Plotting" begin
model(x, p) = @. p[1] * x^2 + p[2] * x + p[3] + p[4] * exp(-(x - p[5])^2 / p[6]^2)
p_true = [0.5, 0.7, 0.0, 4.5, 0.3, 0.5]
xdata = -4.0:0.5:4.0
ydata = model(xdata, p_true) + 0.7 * randn(size(xdata))

f = curve_fit(model, xdata, ydata, fill(1.0, 6))

p = plot(
map((:neither, :prediction, :confidence, :both)) do purpose
plot(
xdata,
ydata;
seriestype=:scatter,
label="Data",
legend_foreground_color=nothing,
legend_background_color=nothing,
legend=:topleft,
)
plot!(
x -> model(x, p_true);
seriestype=:line,
label="Ground truth",
linestyle=:dot,
)
plot!(model, f; purpose=purpose, title="$purpose")
end...;
layout=(2, 2),
)
@test p isa Plots.Plot
savefig(p, "plots.png")
println("Plots saved to `test/plots.png`")
end
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
#
using LsqFit, Test, LinearAlgebra, Random
using OptimBase
using Plots
import NLSolversBase: OnceDifferentiable

my_tests = ["curve_fit.jl", "levenberg_marquardt.jl", "curve_fit_inplace.jl", "geodesic.jl"]
my_tests = ["curve_fit.jl", "levenberg_marquardt.jl", "curve_fit_inplace.jl", "geodesic.jl", "plotting.jl"]

println("Running tests:")

Expand Down