Skip to content

Commit 9ee15ff

Browse files
committed
replace sort(), sortperm() with sort!() and sortperm!() for lts, lta, and lms
1 parent adb400c commit 9ee15ff

4 files changed

Lines changed: 14 additions & 45 deletions

File tree

src/lms.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,24 +70,22 @@ function lms(X::AbstractMatrix{Float64}, y::AbstractVector{Float64}; iters = not
7070
bestobjective = Inf
7171
bestparamaters = Array{Float64}(undef, p)
7272
bestres = Array{Float64}(undef, n)
73-
origres = Array{Float64}(undef, n)
7473
indices = collect(1:n)
7574
kindices = collect(p:n)
7675
betas = Array{Float64}(undef, p)
7776
res = Array{Float64}(undef, n)
7877

7978
for _ = 1:iters
8079
try
81-
k = rand(kindices, 1)[1]
80+
k = rand(kindices)
8281
sampledindices = sample(indices, k, replace = false)
8382
betas = X[sampledindices, :] \ y[sampledindices]
84-
origres = y .- X * betas
85-
res = sort(origres .^ 2.0)
83+
res = sort!((y .- X * betas) .^ 2.0)
8684
m2 = res[h]
8785
if m2 < bestobjective
88-
bestparamaters .= betas
86+
bestparamaters = betas
8987
bestobjective = m2
90-
bestres .= origres
88+
bestres = y .- X * betas
9189
end
9290
catch e
9391
@warn e

src/lta.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,14 @@ function lta(X::AbstractMatrix{Float64}, y::AbstractVector{Float64}; exact = fal
8585
psubsets = collect(combinations(1:n, p))
8686
else
8787
iters = p * 3000
88-
psubsets = [sample(1:n, p, replace = false) for i = 1:iters]
88+
psubsets = [sample(1:n, p, replace = false) for _ = 1:iters]
8989
end
9090

9191
function lta_cost(subsetindices::Array{Int,1})::Tuple{Float64,Vector{Float64}}
9292
try
9393
betas = X[subsetindices, :] \ y[subsetindices]
94-
res_abs = abs.(y .- X * betas)
95-
ordered_res = sort(res_abs)
96-
cost = sum(ordered_res[1:h])
94+
sorted_res_abs = sort!(abs.(y .- X * betas))
95+
cost = sum(sorted_res_abs[1:h])
9796
return (cost, betas)
9897
catch
9998
return (Inf64, [])

src/lts.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,17 @@ function iterateCSteps(
4545
subsetindices::Array{Int,1},
4646
h::Int; eps::Float64 = 0.01, maxiter::Int = 10000
4747
)
48+
n = length(y)
4849
oldobjective::Float64 = Inf64
4950
objective::Float64 = Inf64
5051
iter::Int = 0
52+
sortedresindices = Array{Int}(undef, n)
5153
while iter < maxiter
5254
tempols = ols(X[subsetindices, :], y[subsetindices])
5355
res = y - X * coef(tempols)
54-
sortedresindices = sortperm(abs.(res))
56+
sortperm!(sortedresindices, abs.(res))
5557
subsetindices = sortedresindices[1:h]
56-
objective = sum(sort(res .^ 2.0)[1:h])
58+
objective = sum(sort!(res .^ 2.0)[1:h])
5759
if isapprox(oldobjective, objective, atol=eps)
5860
break
5961
end

test.jl

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,9 @@
11
using LinRegOutliers
2+
using BenchmarkTools
23

3-
import LinRegOutliers.OrdinaryLeastSquares: ols, wls, coef, residuals
4-
import Distributions: median
5-
6-
#sett = createRegressionSetting(@formula(calls ~ year), phones)
7-
#sett = createRegressionSetting(@formula(y ~ x1 + x2 + x3), hbk)
8-
9-
function generatedata(
10-
n::Int,
11-
p::Int,
12-
contamination::Float64,
13-
direction::Symbol,
14-
)::Tuple{Array{Float64,2},Array{Float64,1}}
15-
if !(direction in [:y, :x])
16-
error("- Direction must be :x or :y")
17-
end
18-
on = ones(Float64, n)
19-
pxvars = p - 1
20-
xvars = randn(n, pxvars)
21-
b = [5.0 for i = 1:pxvars]
22-
y = 5.0 .+ xvars * b + randn(n)
23-
totalcontamination = round(Int, contamination * n)
24-
if direction == :y
25-
y[1:totalcontamination] .= maximum(y) .+ abs.(rand(totalcontamination) * 5.0)
26-
elseif direction == :x
27-
xvars[1:totalcontamination, :] .=
28-
maximum(xvars) .+ abs.(rand(totalcontamination) * 5.0)
29-
end
30-
return (hcat(on, xvars), y)
31-
end
324

335

346

357
sett = createRegressionSetting(@formula(calls ~ year), phones)
36-
x = designMatrix(sett)
37-
y = responseVector(sett)
38-
x, y = generatedata(1000, 25, 0.30, :x)
39-
result = gwcga(x, y)
8+
9+
@btime lta($sett)

0 commit comments

Comments
 (0)