diff --git a/src/curve_fit.jl b/src/curve_fit.jl index 23794d8..b795864 100755 --- a/src/curve_fit.jl +++ b/src/curve_fit.jl @@ -15,16 +15,7 @@ StatsBase.weights(lfr::LsqFitResult) = lfr.wt StatsBase.residuals(lfr::LsqFitResult) = lfr.resid mse(lfr::LsqFitResult) = rss(lfr)/dof(lfr) -function check_data_health(xdata, ydata) - if any(ismissing, xdata) || any(ismissing, ydata) - error("Data contains `missing` values and a fit cannot be performed") - end - if any(isinf, xdata) || any(isinf, ydata) || any(isnan, xdata) || any(isnan, ydata) - error("Data contains `Inf` or `NaN` values and a fit cannot be performed") - end -end - -function check_data_health(xdata, ydata, wt) +function check_data_health(xdata, ydata, wt = []) if any(ismissing, xdata) || any(ismissing, ydata) || any(ismissing, wt) error("Data contains `missing` values and a fit cannot be performed") end @@ -146,7 +137,7 @@ function curve_fit(model, jacobian_model, end function curve_fit(model, xdata::AbstractArray, ydata::AbstractArray, wt::AbstractArray, p0::AbstractArray; inplace = false, kwargs...) - check_data_health(xdata, ydata) + check_data_health(xdata, ydata, wt) # construct a weighted cost function, with a vector weight for each ydata # for example, this might be wt = 1/sigma where sigma is some error term u = sqrt.(wt) # to be consistant with the matrix form @@ -162,7 +153,7 @@ end function curve_fit(model, jacobian_model, xdata::AbstractArray, ydata::AbstractArray, wt::AbstractArray, p0::AbstractArray; inplace = false, kwargs...) - check_data_health(xdata, ydata) + check_data_health(xdata, ydata, wt) u = sqrt.(wt) # to be consistant with the matrix form if inplace @@ -177,7 +168,7 @@ function curve_fit(model, jacobian_model, end function curve_fit(model, xdata::AbstractArray, ydata::AbstractArray, wt::AbstractMatrix, p0::AbstractArray; kwargs...) - check_data_health(xdata, ydata) + check_data_health(xdata, ydata, wt) # as before, construct a weighted cost function with where this # method uses a matrix weight. @@ -194,7 +185,7 @@ end function curve_fit(model, jacobian_model, xdata::AbstractArray, ydata::AbstractArray, wt::AbstractMatrix, p0::AbstractArray; kwargs...) - check_data_health(xdata, ydata) + check_data_health(xdata, ydata, wt) u = cholesky(wt).U