Skip to content

Commit 36e9bf1

Browse files
authored
Merge pull request #3 from bjarthur/bja/size
add size(A,d) methods
2 parents d3ee910 + 911282d commit 36e9bf1

2 files changed

Lines changed: 20 additions & 0 deletions

File tree

src/TensorStoreWrapper.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,12 @@ function Base.setindex!(w::TensorStoreWrapper, v, indices...; kwargs...)
150150
end
151151

152152
Base.size(w::TensorStoreWrapper) = pyconvert(Tuple, parent(w).shape)
153+
function Base.size(w::TensorStoreWrapper, d::Integer)
154+
d < 1 && throw(ArgumentError("dimension must be ≥ 1"))
155+
d > ndims(w) && return 1
156+
di = Base.to_index(d)
157+
return pyconvert(Int, parent(w).shape[di-1])
158+
end
153159
Base.ndims(w::TensorStoreWrapper) = pyconvert(Int, parent(w).rank)
154160

155161
const TS_TYPE_MAP = Dict(
@@ -228,6 +234,12 @@ end
228234

229235
# IndexDomainWrapper methods
230236
Base.size(w::IndexDomainWrapper) = pyconvert(Tuple, parent(w).shape)
237+
function Base.size(w::IndexDomainWrapper, d::Integer)
238+
d < 1 && throw(ArgumentError("dimension must be ≥ 1"))
239+
d > ndims(w) && return 1
240+
di = Base.to_index(d)
241+
return pyconvert(Int, parent(w).shape[di-1])
242+
end
231243
Base.ndims(w::IndexDomainWrapper) = pyconvert(Int, parent(w).rank)
232244
function Base.axes(w::IndexDomainWrapper)
233245
rank = ndims(w)

test/runtests.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ using PythonCall
2323
@test eltype(w) == Int32
2424
@test ndims(w) == 2
2525
@test size(w) == (10, 20)
26+
@test size(w, 1) == 10
27+
@test size(w, 2) == 20
28+
@test size(w, 3) == 1
29+
@test_throws ArgumentError size(w, 0)
2630
@test axes(w) == (1:10, 1:20)
2731
end
2832

@@ -44,6 +48,10 @@ using PythonCall
4448
# Labeled indexing
4549
sub_w = w[x=1:5, y=11:15]
4650
@test size(sub_w) == (5, 5)
51+
@test size(domain, 1) == 10
52+
@test size(domain, 2) == 20
53+
@test size(domain, 3) == 1
54+
@test_throws ArgumentError size(domain, 0)
4755
@test axes(sub_w) == (1:5, 11:15)
4856

4957
# translate_by

0 commit comments

Comments
 (0)