Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/linalg/factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using MatrixAlgebraKit
using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm
using MatrixAlgebraKit: AbstractAlgorithm, YALAPACK.BlasMat, Algorithm, diagview
import MatrixAlgebraKit as MAK

# Type piracy for defining the MAK rules on BlockArrays!
Expand All @@ -9,8 +9,11 @@ const BlockBlasMat{T <: MAK.BlasFloat} = BlockMatrix{T}

function MatrixAlgebraKit.one!(A::BlockBlasMat)
_one, _zero = one(eltype(A)), zero(eltype(A))
@inbounds for j in axes(A, 2), i in axes(A, 1)
A[i, j] = ifelse(i == j, _one, _zero)
A .= _zero
n_blocks = blocksize(A)[1]
# awful workaround to BlockArrays indexing interface
for bi in 1:n_blocks
A[Block(bi), Block(bi)] .= diagm(fill(_one, blocksizes(A)[bi, bi][1]))
Comment thread
kshyatt marked this conversation as resolved.
Outdated
end
return A
end
Expand Down
12 changes: 6 additions & 6 deletions src/tensors/blocktensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,10 @@ end
# ------------------------
for (fname, felt) in ((:zeros, :zero), (:ones, :one))
@eval begin
function Base.$fname(::Type{T}, V::TensorMapSumSpace) where {T}
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T)
function Base.$fname(::Type{TorA}, V::TensorMapSumSpace) where {TorA}
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA)
t = TT(undef, V)
fill!(t, $felt(T))
fill!(t, $felt(scalartype(t)))
return t
end
end
Expand All @@ -136,9 +136,9 @@ for randfun in (:rand, :randn, :randexp)
randfun! = Symbol(randfun, :!)
@eval begin
function Random.$randfun(
rng::Random.AbstractRNG, ::Type{T}, V::TensorMapSumSpace
) where {T}
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), T)
rng::Random.AbstractRNG, ::Type{TorA}, V::TensorMapSumSpace
) where {TorA}
TT = blocktensormaptype(spacetype(V), numout(V), numin(V), TorA)
t = TT(undef, V)
Random.$randfun!(rng, t)
return t
Expand Down