Skip to content

Commit dc6696a

Browse files
kshyattlkdvos
andauthored
Small fixes to let BlockTensorMaps work with GPU arrays (#44)
* Small fixes to let BlockTensorMaps work with GPU arrays * Indexing horror * Update src/linalg/factorizations.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Use views * Add a check for blocksquarity * Update src/linalg/factorizations.jl Co-authored-by: Lukas Devos <ldevos98@gmail.com> * No checks full gas --------- Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 3392c90 commit dc6696a

2 files changed

Lines changed: 17 additions & 9 deletions

File tree

src/linalg/factorizations.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,18 @@ import MatrixAlgebraKit as MAK
77

88
const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T}
99

10+
function MatrixAlgebraKit.zero!(A::BlockBlasMat)
11+
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
12+
a = view(A, bi, bj)
13+
MAK.zero!(a)
14+
end
15+
return A
16+
end
17+
1018
function MatrixAlgebraKit.one!(A::BlockBlasMat)
11-
_one, _zero = one(eltype(A)), zero(eltype(A))
12-
@inbounds for j in axes(A, 2), i in axes(A, 1)
13-
A[i, j] = ifelse(i == j, _one, _zero)
19+
for bj in blockaxes(A, 2), bi in blockaxes(A, 1)
20+
a = view(A, bi, bj)
21+
bi == bj ? MAK.one!(a) : MAK.zero!(a)
1422
end
1523
return A
1624
end

src/tensors/blocktensor.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ end
123123
# ------------------------
124124
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
125125
@eval begin
126-
function Base.$fname(::Type{T}, V::TensorMapSumSpace) where {T}
127-
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T)
126+
function Base.$fname(::Type{TorA}, V::TensorMapSumSpace) where {TorA}
127+
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA)
128128
t = TT(undef, V)
129-
fill!(t, $felt(T))
129+
fill!(t, $felt(scalartype(t)))
130130
return t
131131
end
132132
end
@@ -136,9 +136,9 @@ for randfun in (:rand, :randn, :randexp)
136136
randfun! = Symbol(randfun, :!)
137137
@eval begin
138138
function Random.$randfun(
139-
rng::Random.AbstractRNG, ::Type{T}, V::TensorMapSumSpace
140-
) where {T}
141-
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T)
139+
rng::Random.AbstractRNG, ::Type{TorA}, V::TensorMapSumSpace
140+
) where {TorA}
141+
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA)
142142
t = TT(undef, V)
143143
Random.$randfun!(rng, t)
144144
return t

0 commit comments

Comments
 (0)