Skip to content

Commit f490921

Browse files
authored
Apply runic formatting (#43)
1 parent bd295ac commit f490921

9 files changed

Lines changed: 341 additions & 211 deletions

File tree

benchmarks/benchtests.jl

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function benchmark_sum(sizes)
2222
return times
2323
end
2424

25-
function benchmark_permute(sizes, p=(4, 3, 2, 1))
25+
function benchmark_permute(sizes, p = (4, 3, 2, 1))
2626
times = zeros(length(sizes), 4)
2727
for (i, s) in enumerate(sizes)
2828
A = randn(Float64, s .* one.(p))
@@ -41,7 +41,7 @@ permute_times1 = benchmark_permute(sizes, (4, 3, 2, 1))
4141
permute_times2 = benchmark_permute(sizes, (2, 3, 4, 1))
4242
permute_times3 = benchmark_permute(sizes, (3, 4, 1, 2))
4343

44-
function benchmark_mul(sizesm, sizesk=sizesm, sizesn=sizesm)
44+
function benchmark_mul(sizesm, sizesk = sizesm, sizesn = sizesm)
4545
N = Threads.nthreads()
4646
@assert length(sizesm) == length(sizesk) == length(sizesn)
4747
times = zeros(length(sizesm), 4)
@@ -62,23 +62,23 @@ function benchmark_mul(sizesm, sizesk=sizesm, sizesn=sizesm)
6262
BLAS.set_num_threads(1) # single-threaded blas with strided multithreading
6363
Strided.enable_threaded_mul()
6464
times[i, 4] = @belapsed @strided mul!($C, $A, $B)
65-
println("step $i: sizes $((m,k,n)) => times = $(times[i, :])")
65+
println("step $i: sizes $((m, k, n)) => times = $(times[i, :])")
6666
end
6767
return times
6868
end
6969

7070
function tensorcontraction!(wEnv, hamAB, hamBA, rhoBA, rhoAB, w, v, u)
7171
@tensor wEnv[-1, -2, -3] = hamAB[7, 8, -1, 9] * rhoBA[4, 3, -3, 2] * conj(w[7, 5, 4]) *
72-
u[9, 10, -2, 11] * conj(u[8, 10, 5, 6]) * v[1, 11, 2] *
73-
conj(v[1, 6, 3]) +
74-
hamBA[1, 2, 3, 4] * rhoBA[10, 7, -3, 6] *
75-
conj(w[-1, 11, 10]) * u[3, 4, -2, 8] * conj(u[1, 2, 11, 9]) *
76-
v[5, 8, 6] * conj(v[5, 9, 7]) +
77-
hamAB[5, 7, 3, 1] * rhoBA[10, 9, -3, 8] *
78-
conj(w[-1, 11, 10]) * u[4, 3, -2, 2] * conj(u[4, 5, 11, 6]) *
79-
v[1, 2, 8] * conj(v[7, 6, 9]) +
80-
hamBA[3, 7, 2, -1] * rhoAB[5, 6, 4, -3] * v[2, 1, 4] *
81-
conj(v[3, 1, 5]) * conj(w[7, -2, 6])
72+
u[9, 10, -2, 11] * conj(u[8, 10, 5, 6]) * v[1, 11, 2] *
73+
conj(v[1, 6, 3]) +
74+
hamBA[1, 2, 3, 4] * rhoBA[10, 7, -3, 6] *
75+
conj(w[-1, 11, 10]) * u[3, 4, -2, 8] * conj(u[1, 2, 11, 9]) *
76+
v[5, 8, 6] * conj(v[5, 9, 7]) +
77+
hamAB[5, 7, 3, 1] * rhoBA[10, 9, -3, 8] *
78+
conj(w[-1, 11, 10]) * u[4, 3, -2, 2] * conj(u[4, 5, 11, 6]) *
79+
v[1, 2, 8] * conj(v[7, 6, 9]) +
80+
hamBA[3, 7, 2, -1] * rhoAB[5, 6, 4, -3] * v[2, 1, 4] *
81+
conj(v[3, 1, 5]) * conj(w[7, -2, 6])
8282
return wEnv
8383
end
8484

@@ -100,32 +100,42 @@ function benchmark_tensorcontraction(sizes)
100100
BLAS.set_num_threads(1)
101101
Strided.disable_threads()
102102
Strided.disable_threaded_mul()
103-
times[i, 1] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
104-
$w, $v, $u)
103+
times[i, 1] = @belapsed tensorcontraction!(
104+
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
105+
$w, $v, $u
106+
)
105107

106108
BLAS.set_num_threads(1)
107109
Strided.enable_threads()
108110
Strided.disable_threaded_mul()
109-
times[i, 2] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
110-
$w, $v, $u)
111+
times[i, 2] = @belapsed tensorcontraction!(
112+
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
113+
$w, $v, $u
114+
)
111115

112116
BLAS.set_num_threads(N)
113117
Strided.disable_threads()
114118
Strided.disable_threaded_mul()
115-
times[i, 3] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
116-
$w, $v, $u)
119+
times[i, 3] = @belapsed tensorcontraction!(
120+
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
121+
$w, $v, $u
122+
)
117123

118124
BLAS.set_num_threads(N)
119125
Strided.enable_threads()
120126
Strided.disable_threaded_mul()
121-
times[i, 4] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
122-
$w, $v, $u)
127+
times[i, 4] = @belapsed tensorcontraction!(
128+
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
129+
$w, $v, $u
130+
)
123131

124132
BLAS.set_num_threads(1)
125133
Strided.enable_threads()
126134
Strided.enable_threaded_mul()
127-
times[i, 5] = @belapsed tensorcontraction!($wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
128-
$w, $v, $u)
135+
times[i, 5] = @belapsed tensorcontraction!(
136+
$wEnv, $hamAB, $hamBA, $rhoBA, $rhoAB,
137+
$w, $v, $u
138+
)
129139

130140
println("step $i: size $s => times = $(times[i, :])")
131141
end

src/Strided.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module Strided
22

33
import Base: parent, size, strides, tail, setindex
44
using Base: @propagate_inbounds, RangeIndex, Dims
5-
const SliceIndex = Union{RangeIndex,Colon}
5+
const SliceIndex = Union{RangeIndex, Colon}
66

77
using LinearAlgebra
88

@@ -27,7 +27,7 @@ function set_num_threads(n::Int)
2727
return _NTHREADS[] = n
2828
end
2929
@noinline function _set_num_threads_warn(n)
30-
@warn "Maximal number of threads limited by number of Julia threads,
30+
return @warn "Maximal number of threads limited by number of Julia threads,
3131
setting number of threads equal to Threads.nthreads() = $n"
3232
end
3333

src/broadcast.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,33 @@ using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Bro
33
struct StridedArrayStyle{N} <: AbstractArrayStyle{N}
44
end
55

6-
Broadcast.BroadcastStyle(::Type{<:StridedView{<:Any,N}}) where {N} = StridedArrayStyle{N}()
6+
Broadcast.BroadcastStyle(::Type{<:StridedView{<:Any, N}}) where {N} = StridedArrayStyle{N}()
77

88
StridedArrayStyle(::Val{N}) where {N} = StridedArrayStyle{N}()
9-
StridedArrayStyle{M}(::Val{N}) where {M,N} = StridedArrayStyle{N}()
9+
StridedArrayStyle{M}(::Val{N}) where {M, N} = StridedArrayStyle{N}()
1010

1111
Broadcast.BroadcastStyle(a::StridedArrayStyle, ::DefaultArrayStyle{0}) = a
1212
function Broadcast.BroadcastStyle(::StridedArrayStyle{N}, a::DefaultArrayStyle) where {N}
1313
return BroadcastStyle(DefaultArrayStyle{N}(), a)
1414
end
15-
function Broadcast.BroadcastStyle(::StridedArrayStyle{N},
16-
::Broadcast.Style{Tuple}) where {N}
15+
function Broadcast.BroadcastStyle(
16+
::StridedArrayStyle{N},
17+
::Broadcast.Style{Tuple}
18+
) where {N}
1719
return DefaultArrayStyle{N}()
1820
end
1921

20-
function Base.similar(bc::Broadcasted{<:StridedArrayStyle{N}}, ::Type{T}) where {N,T}
22+
function Base.similar(bc::Broadcasted{<:StridedArrayStyle{N}}, ::Type{T}) where {N, T}
2123
return StridedView(similar(convert(Broadcasted{DefaultArrayStyle{N}}, bc), T))
2224
end
2325

24-
Base.dotview(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getindex(a, I...)
26+
Base.dotview(a::StridedView{<:Any, N}, I::Vararg{SliceIndex, N}) where {N} = getindex(a, I...)
2527

2628
# Broadcasting implementation
27-
@inline function Base.copyto!(dest::StridedView{<:Any,N},
28-
bc::Broadcasted{StridedArrayStyle{N}}) where {N}
29+
@inline function Base.copyto!(
30+
dest::StridedView{<:Any, N},
31+
bc::Broadcasted{StridedArrayStyle{N}}
32+
) where {N}
2933
# convert to map
3034

3135
# flatten and only keep the StridedView arguments
@@ -36,7 +40,7 @@ Base.dotview(a::StridedView{<:Any,N}, I::Vararg{SliceIndex,N}) where {N} = getin
3640
return dest
3741
end
3842

39-
const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}
43+
const WrappedScalarArgs = Union{AbstractArray{<:Any, 0}, Ref{<:Any}}
4044

4145
@inline function capturestridedargs(t::Broadcasted, rest...)
4246
return (capturestridedargs(t.args...)..., capturestridedargs(rest...)...)
@@ -64,7 +68,7 @@ function promoteshape1(sz::Dims{N}, a::StridedView) where {N}
6468
return StridedView(a.parent, sz, newstrides, a.offset, a.op)
6569
end
6670

67-
struct CaptureArgs{F,Args<:Tuple}
71+
struct CaptureArgs{F, Args <: Tuple}
6872
f::F
6973
args::Args
7074
end
@@ -84,7 +88,7 @@ end
8488

8589
# Evaluate CaptureArgs
8690
(c::CaptureArgs)(vals...) = consume(c, vals)[1]
87-
@inline function consume(c::CaptureArgs{F,Args}, vals) where {F,Args}
91+
@inline function consume(c::CaptureArgs{F, Args}, vals) where {F, Args}
8892
args, newvals = t_consume(c.args, vals)
8993
return c.f(args...), newvals
9094
end

src/convert.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ function Base.Array(a::StridedView)
44
return b
55
end
66

7-
function (Base.Array{T})(a::StridedView{S,N}) where {T,S,N}
7+
function (Base.Array{T})(a::StridedView{S, N}) where {T, S, N}
88
b = Array{T}(undef, size(a))
99
copy!(StridedView(b), a)
1010
return b
1111
end
1212

13-
function (Base.Array{T,N})(a::StridedView{S,N}) where {T,S,N}
13+
function (Base.Array{T, N})(a::StridedView{S, N}) where {T, S, N}
1414
b = Array{T}(undef, size(a))
1515
copy!(StridedView(b), a)
1616
return b

src/linalg.jl

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,43 @@
22
LinearAlgebra.rmul!(dst::StridedView, α::Number) = mul!(dst, dst, α)
33
LinearAlgebra.lmul!::Number, dst::StridedView) = mul!(dst, α, dst)
44

5-
function LinearAlgebra.mul!(dst::StridedView{<:Number,N}, α::Number,
6-
src::StridedView{<:Number,N}) where {N}
5+
function LinearAlgebra.mul!(
6+
dst::StridedView{<:Number, N}, α::Number,
7+
src::StridedView{<:Number, N}
8+
) where {N}
79
if α == 1
810
copy!(dst, src)
911
else
1012
dst .= α .* src
1113
end
1214
return dst
1315
end
14-
function LinearAlgebra.mul!(dst::StridedView{<:Number,N}, src::StridedView{<:Number,N},
15-
α::Number) where {N}
16+
function LinearAlgebra.mul!(
17+
dst::StridedView{<:Number, N}, src::StridedView{<:Number, N},
18+
α::Number
19+
) where {N}
1620
if α == 1
1721
copy!(dst, src)
1822
else
1923
dst .= src .* α
2024
end
2125
return dst
2226
end
23-
function LinearAlgebra.axpy!(a::Number, X::StridedView{<:Number,N},
24-
Y::StridedView{<:Number,N}) where {N}
27+
function LinearAlgebra.axpy!(
28+
a::Number, X::StridedView{<:Number, N},
29+
Y::StridedView{<:Number, N}
30+
) where {N}
2531
if a == 1
2632
Y .= X .+ Y
2733
else
2834
Y .= a .* X .+ Y
2935
end
3036
return Y
3137
end
32-
function LinearAlgebra.axpby!(a::Number, X::StridedView{<:Number,N},
33-
b::Number, Y::StridedView{<:Number,N}) where {N}
38+
function LinearAlgebra.axpby!(
39+
a::Number, X::StridedView{<:Number, N},
40+
b::Number, Y::StridedView{<:Number, N}
41+
) where {N}
3442
if b == 1
3543
axpy!(a, X, Y)
3644
elseif b == 0
@@ -41,9 +49,11 @@ function LinearAlgebra.axpby!(a::Number, X::StridedView{<:Number,N},
4149
return Y
4250
end
4351

44-
function LinearAlgebra.mul!(C::StridedView{T,2},
45-
A::StridedView{<:Any,2}, B::StridedView{<:Any,2},
46-
α::Number=true, β::Number=false) where {T}
52+
function LinearAlgebra.mul!(
53+
C::StridedView{T, 2},
54+
A::StridedView{<:Any, 2}, B::StridedView{<:Any, 2},
55+
α::Number = true, β::Number = false
56+
) where {T}
4757
if !(eltype(C) <: LinearAlgebra.BlasFloat && eltype(A) == eltype(B) == eltype(C))
4858
return __mul!(C, A, B, α, β)
4959
end
@@ -62,7 +72,7 @@ function LinearAlgebra.mul!(C::StridedView{T,2},
6272
return C
6373
end
6474

65-
function isblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
75+
function isblasmatrix(A::StridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat}
6676
if A.op == identity
6777
return stride(A, 1) == 1 || stride(A, 2) == 1
6878
elseif A.op == conj
@@ -71,7 +81,7 @@ function isblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
7181
return false
7282
end
7383
end
74-
function getblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
84+
function getblasmatrix(A::StridedView{T, 2}) where {T <: LinearAlgebra.BlasFloat}
7585
if A.op == identity
7686
if stride(A, 1) == 1
7787
return A, 'N'
@@ -84,8 +94,10 @@ function getblasmatrix(A::StridedView{T,2}) where {T<:LinearAlgebra.BlasFloat}
8494
end
8595

8696
# here we will have C.op == :identity && stride(C,1) < stride(C,2)
87-
function _mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2},
88-
α::Number, β::Number) where {T<:LinearAlgebra.BlasFloat}
97+
function _mul!(
98+
C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2},
99+
α::Number, β::Number
100+
) where {T <: LinearAlgebra.BlasFloat}
89101
if stride(C, 1) == 1 && isblasmatrix(A) && isblasmatrix(B)
90102
nthreads = use_threaded_mul() ? get_num_threads() : 1
91103
_threaded_blas_mul!(C, A, B, α, β, nthreads)
@@ -94,41 +106,53 @@ function _mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2},
94106
end
95107
end
96108

97-
function _threaded_blas_mul!(C::StridedView{T,2}, A::StridedView{T,2}, B::StridedView{T,2},
98-
α::Number, β::Number,
99-
nthreads) where {T<:LinearAlgebra.BlasFloat}
109+
function _threaded_blas_mul!(
110+
C::StridedView{T, 2}, A::StridedView{T, 2}, B::StridedView{T, 2},
111+
α::Number, β::Number,
112+
nthreads
113+
) where {T <: LinearAlgebra.BlasFloat}
100114
m, n = size(C)
101115
m == size(A, 1) && n == size(B, 2) || throw(DimensionMismatch())
102-
if nthreads == 1 || m * n < 1024
116+
return if nthreads == 1 || m * n < 1024
103117
A2, CA = getblasmatrix(A)
104118
B2, CB = getblasmatrix(B)
105119
LinearAlgebra.BLAS.gemm!(CA, CB, convert(T, α), A2, B2, convert(T, β), C)
106120
else
107121
if m > n
108122
m2 = round(Int, m / 16) * 8
109123
nthreads2 = nthreads >> 1
110-
t = Threads.@spawn _threaded_blas_mul!(C[1:($m2), :], A[1:($m2), :], B, α, β,
111-
$nthreads2)
112-
_threaded_blas_mul!(C[(m2 + 1):m, :], A[(m2 + 1):m, :], B, α, β,
113-
nthreads - nthreads2)
124+
t = Threads.@spawn _threaded_blas_mul!(
125+
C[1:($m2), :], A[1:($m2), :], B, α, β,
126+
$nthreads2
127+
)
128+
_threaded_blas_mul!(
129+
C[(m2 + 1):m, :], A[(m2 + 1):m, :], B, α, β,
130+
nthreads - nthreads2
131+
)
114132
wait(t)
115133
return C
116134
else
117135
n2 = round(Int, n / 16) * 8
118136
nthreads2 = nthreads >> 1
119-
t = Threads.@spawn _threaded_blas_mul!(C[:, 1:($n2)], A, B[:, 1:($n2)], α, β,
120-
$nthreads2)
121-
_threaded_blas_mul!(C[:, (n2 + 1):n], A, B[:, (n2 + 1):n], α, β,
122-
nthreads - nthreads2)
137+
t = Threads.@spawn _threaded_blas_mul!(
138+
C[:, 1:($n2)], A, B[:, 1:($n2)], α, β,
139+
$nthreads2
140+
)
141+
_threaded_blas_mul!(
142+
C[:, (n2 + 1):n], A, B[:, (n2 + 1):n], α, β,
143+
nthreads - nthreads2
144+
)
123145
wait(t)
124146
return C
125147
end
126148
end
127149
end
128150

129151
# This implementation is faster than LinearAlgebra.generic_matmatmul
130-
function __mul!(C::StridedView{<:Any,2}, A::StridedView{<:Any,2}, B::StridedView{<:Any,2},
131-
α::Number, β::Number)
152+
function __mul!(
153+
C::StridedView{<:Any, 2}, A::StridedView{<:Any, 2}, B::StridedView{<:Any, 2},
154+
α::Number, β::Number
155+
)
132156
(size(C, 1) == size(A, 1) && size(C, 2) == size(B, 2) && size(A, 2) == size(B, 1)) ||
133157
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B)), C has size $(size(C))"))
134158

0 commit comments

Comments
 (0)