Skip to content

Commit b51e04c

Browse files
committed
add size(A,d) methods
1 parent 457aebf commit b51e04c

2 files changed

Lines changed: 8 additions & 0 deletions

File tree

src/TensorStoreWrapper.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ function Base.setindex!(w::TensorStoreWrapper, v, indices...; kwargs...)
150150
end
151151

152152
Base.size(w::TensorStoreWrapper) = pyconvert(Tuple, parent(w).shape)
153+
Base.size(w::TensorStoreWrapper, d::Integer) = d <= ndims(w) ? size(w)[d] : 1
153154
Base.ndims(w::TensorStoreWrapper) = pyconvert(Int, parent(w).rank)
154155

155156
const TS_TYPE_MAP = Dict(
@@ -228,6 +229,7 @@ end
228229

229230
# IndexDomainWrapper methods
230231
Base.size(w::IndexDomainWrapper) = pyconvert(Tuple, parent(w).shape)
232+
Base.size(w::IndexDomainWrapper, d::Integer) = d <= ndims(w) ? size(w)[d] : 1
231233
Base.ndims(w::IndexDomainWrapper) = pyconvert(Int, parent(w).rank)
232234
function Base.axes(w::IndexDomainWrapper)
233235
rank = ndims(w)

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ using PythonCall
2323
@test eltype(w) == Int32
2424
@test ndims(w) == 2
2525
@test size(w) == (10, 20)
26+
@test size(w, 1) == 10
27+
@test size(w, 2) == 20
28+
@test size(w, 3) == 1
2629
@test axes(w) == (1:10, 1:20)
2730
end
2831

@@ -44,6 +47,9 @@ using PythonCall
4447
# Labeled indexing
4548
sub_w = w[x=1:5, y=11:15]
4649
@test size(sub_w) == (5, 5)
50+
@test size(domain, 1) == 10
51+
@test size(domain, 2) == 20
52+
@test size(domain, 3) == 1
4753
@test axes(sub_w) == (1:5, 11:15)
4854

4955
# translate_by

0 commit comments

Comments
 (0)