Skip to content

Commit

Permalink
Simplified check_data_health fun
Browse files Browse the repository at this point in the history
  • Loading branch information
briederer committed Nov 2, 2021
1 parent 6fd5d24 commit 94a15f0
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions src/curve_fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,7 @@ StatsBase.weights(lfr::LsqFitResult) = lfr.wt
StatsBase.residuals(lfr::LsqFitResult) = lfr.resid
mse(lfr::LsqFitResult) = rss(lfr)/dof(lfr)

function check_data_health(xdata, ydata)
if any(ismissing, xdata) || any(ismissing, ydata)
error("Data contains `missing` values and a fit cannot be performed")
end
if any(isinf, xdata) || any(isinf, ydata) || any(isnan, xdata) || any(isnan, ydata)
error("Data contains `Inf` or `NaN` values and a fit cannot be performed")
end
end

function check_data_health(xdata, ydata, wt)
function check_data_health(xdata, ydata, wt = [])
if any(ismissing, xdata) || any(ismissing, ydata) || any(ismissing, wt)
error("Data contains `missing` values and a fit cannot be performed")
end
Expand Down Expand Up @@ -146,7 +137,7 @@ function curve_fit(model, jacobian_model,
end

function curve_fit(model, xdata::AbstractArray, ydata::AbstractArray, wt::AbstractArray, p0::AbstractArray; inplace = false, kwargs...)
check_data_health(xdata, ydata)
check_data_health(xdata, ydata, wt)
# construct a weighted cost function, with a vector weight for each ydata
# for example, this might be wt = 1/sigma where sigma is some error term
u = sqrt.(wt) # to be consistant with the matrix form
Expand All @@ -162,7 +153,7 @@ end

function curve_fit(model, jacobian_model,
xdata::AbstractArray, ydata::AbstractArray, wt::AbstractArray, p0::AbstractArray; inplace = false, kwargs...)
check_data_health(xdata, ydata)
check_data_health(xdata, ydata, wt)
u = sqrt.(wt) # to be consistant with the matrix form

if inplace
Expand All @@ -177,7 +168,7 @@ function curve_fit(model, jacobian_model,
end

function curve_fit(model, xdata::AbstractArray, ydata::AbstractArray, wt::AbstractMatrix, p0::AbstractArray; kwargs...)
check_data_health(xdata, ydata)
check_data_health(xdata, ydata, wt)

# as before, construct a weighted cost function with where this
# method uses a matrix weight.
Expand All @@ -194,7 +185,7 @@ end

function curve_fit(model, jacobian_model,
xdata::AbstractArray, ydata::AbstractArray, wt::AbstractMatrix, p0::AbstractArray; kwargs...)
check_data_health(xdata, ydata)
check_data_health(xdata, ydata, wt)

u = cholesky(wt).U

Expand Down

0 comments on commit 94a15f0

Please sign in to comment.