Skip to content

Commit aed6887

Browse files
committed
Add support for Adapt
1 parent c31fc91 commit aed6887

4 files changed

Lines changed: 39 additions & 16 deletions

File tree

Project.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
11
name = "SerializedArrays"
22
uuid = "621c0da3-e96e-4f80-bd06-5ae31cdfcb39"
33
authors = ["ITensor developers <support@itensor.org> and contributors"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
88
DiskArrays = "3c3547ce-8d99-4f5e-a174-61eb10b00ae3"
9-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
109
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
1110

11+
[weakdeps]
12+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
13+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
14+
15+
[extensions]
16+
SerializedArraysAdaptExt = "Adapt"
17+
SerializedArraysLinearAlgebraExt = "LinearAlgebra"
18+
1219
[compat]
20+
Adapt = "4.3.0"
1321
ConstructionBase = "1.5.8"
1422
DiskArrays = "0.4.12"
1523
LinearAlgebra = "1.10"
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
module SerializedArraysAdaptExt
2+
3+
using Adapt: Adapt
4+
using SerializedArrays: SerializedArray
5+
6+
function Adapt.adapt_storage(arrayt::Type{<:SerializedArray}, a::AbstractArray)
7+
return convert(arrayt, a)
8+
end
9+
10+
end
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module SerializedArraysLinearAlgebraExt
2+
3+
using LinearAlgebra: LinearAlgebra, mul!
4+
using SerializedArrays: SerializedArray
5+
6+
function LinearAlgebra.mul!(
7+
a_dest::AbstractMatrix, a1::SerializedArray, a2::SerializedArray, α::Number, β::Number
8+
)
9+
mul!(a_dest, copy(a1), copy(a2), α, β)
10+
return a_dest
11+
end
12+
13+
end

src/SerializedArrays.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ module SerializedArrays
22

33
using ConstructionBase: constructorof
44
using DiskArrays: DiskArrays, AbstractDiskArray, Unchunked
5-
using LinearAlgebra: LinearAlgebra, mul!
65
using Serialization: deserialize, serialize
76

87
struct SerializedArray{T,N,A<:AbstractArray{T,N},Axes} <: AbstractDiskArray{T,N}
@@ -14,21 +13,25 @@ Base.axes(a::SerializedArray) = getfield(a, :axes)
1413
arraytype(a::SerializedArray{<:Any,<:Any,A}) where {A} = A
1514

1615
function SerializedArray(file::String, a::AbstractArray)
17-
serialize(file, a)
16+
serialize(file, vec(a))
1817
ax = axes(a)
1918
return SerializedArray{eltype(a),ndims(a),typeof(a),typeof(ax)}(file, ax)
2019
end
2120
function SerializedArray(a::AbstractArray)
2221
return SerializedArray(tempname(), a)
2322
end
2423

24+
function Base.convert(arrayt::Type{<:SerializedArray}, a::AbstractArray)
25+
return arrayt(a)
26+
end
27+
2528
function Base.similar(a::SerializedArray, elt::Type, dims::Tuple{Vararg{Int}})
2629
return constructorof(arraytype(a)){elt}(undef, dims...)
2730
end
2831

2932
function Base.copy(a::SerializedArray)
3033
arrayt = arraytype(a)
31-
return convert(arrayt, deserialize(file(a)))::arrayt
34+
return convert(arrayt, reshape(deserialize(file(a)), axes(a)))::arrayt
3235
end
3336

3437
Base.size(a::SerializedArray) = length.(axes(a))
@@ -106,15 +109,4 @@ function Base.copy(broadcasted::Broadcasted{SerializedArrayStyle{N}}) where {N}
106109
return BroadcastSerializedArray(flatten(broadcasted))
107110
end
108111

109-
#
110-
# LinearAlgebra
111-
#
112-
113-
function LinearAlgebra.mul!(
114-
a_dest::AbstractMatrix, a1::SerializedArray, a2::SerializedArray, α::Number, β::Number
115-
)
116-
mul!(a_dest, copy(a1), copy(a2), α, β)
117-
return a_dest
118-
end
119-
120112
end

0 commit comments

Comments
 (0)