Skip to content

Commit 4665b91

Browse files
Bug fix with Cuda for Tuple of Vector offsets
1 parent 0cdf70b commit 4665b91

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SeparableFunctions"
22
uuid = "c8c7ead4-852c-491e-a42d-3d43bc74259e"
33
authors = ["RainerHeintzmann <heintzmann@gmail.com>"]
4-
version = "0.2.1"
4+
version = "0.2.2"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/general.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
6363
idc = get_1d_ids(d, sz, offset, scale)
6464
args_d = arg_n(d, args, RT, sz) #
6565
# in_place_assing!(res, 1, fct, idc, sz1d, args_d)
66-
# @show size(res)
6766
res .= fct.(idc, sz1d, args_d...) # 5 allocs, 160 bytes
6867
end
6968
return all_axes
@@ -76,8 +75,9 @@ end
7675
returns one-dimensional indices for a given dimension `d` of an N-dimensional array.
7776
The indices are shifted by `offset` and scaled by `scale`, which can also be vectors
7877
"""
79-
# for Numbers, the reorient comes last, to have it CUDA-compatible
80-
get_1d_ids(d, sz::NTuple{N, Int}, offset::Number, scale::Number) where {N} = (reorient(get_vec_dim(scale, d, sz) .* ((1:sz[d]) .- get_vec_dim(offset, d, sz)), d, Val(N)))
78+
# for Numbers or Vectors, the reorient comes last, to have it CUDA-compatible
79+
NumVecTup = Union{Number, Vector, NTuple}
80+
get_1d_ids(d, sz::NTuple{N, Int}, offset::NumVecTup, scale::NumVecTup) where {N} = (reorient(get_vec_dim(scale, d, sz) .* ((1:sz[d]) .- get_vec_dim(offset, d, sz)), d, Val(N)))
8181
# for abstract arrays, we first have to reorient.
8282
get_1d_ids(d, sz::NTuple{N, Int}, offset, scale) where {N} = get_vec_dim(scale, d, sz) .* (reorient((1:sz[d]), d, Val(N)) .- get_vec_dim(offset, d, sz))
8383
get_1d_ids(d, sz::NTuple{N, Int}, offset::Number) where {N} = (reorient((1:sz[d]) .- get_vec_dim(offset, d, sz), d, Val(N)))

0 commit comments

Comments
 (0)