Skip to content

Commit 0c16631

Browse files
committed
Small fixes to let BlockTensorMaps work with GPU arrays
1 parent 3392c90 commit 0c16631

2 files changed

Lines changed: 9 additions & 10 deletions

File tree

src/linalg/factorizations.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using MatrixAlgebraKit
2-
using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm
2+
using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm, diagview
33
import MatrixAlgebraKit as MAK
44

55
# Type piracy for defining the MAK rules on BlockArrays!
@@ -9,9 +9,8 @@ const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T}
99

1010
function MatrixAlgebraKit.one!(A::BlockBlasMat)
1111
_one, _zero = one(eltype(A)), zero(eltype(A))
12-
@inbounds for j in axes(A, 2), i in axes(A, 1)
13-
A[i, j] = ifelse(i == j, _one, _zero)
14-
end
12+
A .= _zero
13+
diagview(A) .= _one
1514
return A
1615
end
1716

src/tensors/blocktensor.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ end
123123
# ------------------------
124124
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
125125
@eval begin
126-
function Base.$fname(::Type{T}, V::TensorMapSumSpace) where {T}
127-
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T)
126+
function Base.$fname(::Type{TorA}, V::TensorMapSumSpace) where {TorA}
127+
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA)
128128
t = TT(undef, V)
129-
fill!(t, $felt(T))
129+
fill!(t, $felt(scalartype(t)))
130130
return t
131131
end
132132
end
@@ -136,9 +136,9 @@ for randfun in (:rand, :randn, :randexp)
136136
randfun! = Symbol(randfun, :!)
137137
@eval begin
138138
function Random.$randfun(
139-
rng::Random.AbstractRNG, ::Type{T}, V::TensorMapSumSpace
140-
) where {T}
141-
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T)
139+
rng::Random.AbstractRNG, ::Type{TorA}, V::TensorMapSumSpace
140+
) where {TorA}
141+
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA)
142142
t = TT(undef, V)
143143
Random.$randfun!(rng, t)
144144
return t

0 commit comments

Comments
 (0)