-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsymbolicarray.jl
More file actions
51 lines (51 loc) · 1.98 KB
/
symbolicarray.jl
File metadata and controls
51 lines (51 loc) · 1.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# TODO: Allow dynamic/unknown number of dimensions by supporting vector axes.
struct SymbolicArray{T, N, Name, Axes <: NTuple{N, AbstractUnitRange{<:Integer}}} <: AbstractArray{T, N}
name::Name
axes::Axes
function SymbolicArray{T}(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}}) where {T}
N = length(ax)
return new{T, N, typeof(name), typeof(ax)}(name, ax)
end
end
function SymbolicArray(name, ax::Tuple{Vararg{AbstractUnitRange{<:Integer}}})
return SymbolicArray{Any}(name, ax)
end
symname(a::SymbolicArray) = getfield(a, :name)
Base.axes(a::SymbolicArray) = getfield(a, :axes)
Base.size(a::SymbolicArray) = length.(axes(a))
function Base.:(==)(a::SymbolicArray, b::SymbolicArray)
return symname(a) == symname(b) && axes(a) == axes(b)
end
Base.isequal(a::SymbolicArray, b::SymbolicArray) = a == b
function Base.hash(a::SymbolicArray, h::UInt64)
h = hash(:SymbolicArray, h)
h = hash(symname(a), h)
return hash(size(a), h)
end
function Base.getindex(a::SymbolicArray{<:Any, N}, I::Vararg{Int, N}) where {N}
return error("Indexing into SymbolicArray not supported.")
end
function Base.setindex!(a::SymbolicArray{<:Any, N}, value, I::Vararg{Int, N}) where {N}
return error("Indexing into SymbolicArray not supported.")
end
using DerivableInterfaces: DerivableInterfaces
DerivableInterfaces.permuteddims(a::SymbolicArray, p) = permutedims(a, p)
function Base.permutedims(a::SymbolicArray, p)
@assert ndims(a) == length(p) && isperm(p)
return SymbolicArray(symname(a), ntuple(i -> axes(a)[p[i]], ndims(a)))
end
function Base.show(io::IO, mime::MIME"text/plain", a::SymbolicArray)
Base.summary(io, a)
println(io, ":")
print(io, repr(symname(a)))
return nothing
end
function Base.show(io::IO, a::SymbolicArray)
print(io, "SymbolicArray(", symname(a), ", ", size(a), ")")
return nothing
end
using AbstractTrees: AbstractTrees
function AbstractTrees.printnode(io::IO, a::SymbolicArray)
print(io, repr(symname(a)))
return nothing
end