Skip to content

Commit

Permalink
introduce view in LTS and LMS
Browse files Browse the repository at this point in the history
  • Loading branch information
jbytecode committed Oct 27, 2024
1 parent c76ce2c commit 4a2bcf8
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# v0.11.5 (Upcoming Release)

- Initial implementation of the robust hat matrix regression estimator
- Introduce `view` in LTS and LMS.

# v0.11.4

Expand Down
2 changes: 1 addition & 1 deletion src/lms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function lms(X::AbstractMatrix{Float64}, y::AbstractVector{Float64}; iters = not
try
k = rand(kindices)
sampledindices = sample(indices, k, replace = false)
betas = X[sampledindices, :] \ y[sampledindices]
betas = view(X, sampledindices, :) \ view(y, sampledindices)
res = sort!((y .- X * betas) .^ 2.0)
m2 = res[h]
if m2 < bestobjective
Expand Down
8 changes: 4 additions & 4 deletions src/lts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ function iterateCSteps(
iter::Int = 0
sortedresindices = Array{Int}(undef, n)
while iter < maxiter
tempols = ols(X[subsetindices, :], y[subsetindices])
tempols = ols(view(X, subsetindices, :), view(y, subsetindices))
res = y - X * coef(tempols)
sortperm!(sortedresindices, abs.(res))
subsetindices = sortedresindices[1:h]
objective = sum(sort!(res .^ 2.0)[1:h])
subsetindices = view(sortedresindices, 1:h)
objective = sum(view(sort!(res .^ 2.0), 1:h))
if isapprox(oldobjective, objective, atol=eps)
break
end
Expand Down Expand Up @@ -172,7 +172,7 @@ function lts(X::AbstractMatrix{Float64}, y::AbstractVector{Float64}; iters=nothi
end
end

ltsbetas = X[besthsubset, :] \ y[besthsubset]
ltsbetas = view(X, besthsubset, :) \ view(y, besthsubset)
ltsres = y - X * ltsbetas
ltsS = sqrt(sum((ltsres .^ 2.0)[1:h]) / (h - p))
ltsresmean = mean(ltsres[besthsubset])
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ using LinearAlgebra
using LinRegOutliers
import Plots: RGBX

include("testdiagnostics.jl")
include("testbasis.jl")
include("testols.jl")
include("testdiagnostics.jl")
include("tesths93.jl")
include("testks89.jl")
include("testsmr98.jl")
Expand Down

0 comments on commit 4a2bcf8

Please sign in to comment.