Skip to content

Commit f1e5ee9

Browse files
MaxenceGollieramontoison
authored andcommitted
lbfgs: optimize memory accesses
1 parent 7ef0d25 commit f1e5ee9

1 file changed

Lines changed: 8 additions & 14 deletions

File tree

src/lbfgs.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -124,27 +124,23 @@ function InverseLBFGSOperator(T::Type, n::I; kwargs...) where {I <: Integer}
124124
q = data.Ax # tmp vector
125125
q .= x
126126

127-
for i = 1:(data.mem)
127+
@inbounds for i = 1:(data.mem)
128128
k = mod(data.insert - i - 1, data.mem) + 1
129129
if data.ys[k] != 0
130130
αk = dot(data.s[k], q) / data.ys[k]
131131
data.α[k] = αk
132-
for j eachindex(q)
133-
q[j] -= αk * data.y[k][j]
134-
end
132+
@. q -= αk * data.y[k]
135133
end
136134
end
137135

138136
data.scaling && (q .*= data.scaling_factor)
139137

140-
for i = 1:(data.mem)
138+
@inbounds for i = 1:(data.mem)
141139
k = mod(data.insert + i - 2, data.mem) + 1
142140
if data.ys[k] != 0
143141
αk = data.α[k]
144142
β = αk - dot(data.y[k], q) / data.ys[k]
145-
for j eachindex(q)
146-
q[j] += β * data.s[k][j]
147-
end
143+
@. q += β * data.s[k]
148144
end
149145
end
150146
if βm == zero(T2)
@@ -227,12 +223,12 @@ function push_common!(
227223
if !op.inverse
228224
@. data.b[insert] = y / sqrt(ys)
229225

230-
for i = 1:(data.mem)
226+
@inbounds for i = 1:(data.mem)
231227
k = mod(insert + i - 1, data.mem) + 1
232228
if data.ys[k] != 0
233229
@. data.a[k] = data.s[k] / data.scaling_factor # B₀ = I / γ.
234230

235-
for j = 1:(i - 1)
231+
@inbounds for j = 1:(i - 1)
236232
l = mod(insert + j - 1, data.mem) + 1
237233
if data.ys[l] != 0
238234
data.a[k] .+= dot(data.b[l], data.s[k]) .* data.b[l]
@@ -379,12 +375,10 @@ function diag!(op::LBFGSOperator{T}, d) where {T}
379375
fill!(d, 1)
380376
data.scaling && (d ./= data.scaling_factor)
381377

382-
for i = 1:(data.mem)
378+
@inbounds for i = 1:(data.mem)
383379
k = mod(data.insert + i - 2, data.mem) + 1
384380
if data.ys[k] != 0
385-
for j = 1:(op.nrow)
386-
d[j] = d[j] + data.b[k][j]^2 - data.a[k][j]^2
387-
end
381+
@. d += data.b[k].^2 - data.a[k].^2
388382
end
389383
end
390384
return d

0 commit comments

Comments
 (0)