@@ -58,21 +58,30 @@ function calculate_separables_nokw(::Type{AT}, fct, sz::NTuple{N, Int},
5858 # args_1d = ntuple((d) -> arg_n(d, args), Val(N))
5959 # in_place_assing!.(all_axes, 1, fct, idcs, sz, args_1d)
6060 for (res, sz1d, d) in zip (all_axes, sz, 1 : N)
61- off = get_vec_dim (offset, d, sz)
62- sca = get_vec_dim (scale, d, sz)
61+ # off = get_vec_dim(offset, d, sz) # not needed any more since in get_1d_ids
62+ # sca = get_vec_dim(scale, d, sz)
6363 idc = get_1d_ids (d, sz, offset, scale)
6464 args_d = arg_n (d, args, RT, sz) #
6565 # in_place_assing!(res, 1, fct, idc, sz1d, args_d)
6666 # @show size(res)
67- # @show size(idc)
6867 res .= fct .(idc, sz1d, args_d... ) # 5 allocs, 160 bytes
6968 end
7069 return all_axes
7170 # return res
7271end
7372
73+ """
74+ get_1d_ids(d, sz::NTuple{N, Int}, offset, scale) where {N}
75+
76+ returns one-dimensional indices for a given dimension `d` of an N-dimensional array.
77+ The indices are shifted by `offset` and scaled by `scale`, which can also be vectors
78+ """
79+ # for Numbers, the reorient comes last, to have it CUDA-compatible
80+ get_1d_ids (d, sz:: NTuple{N, Int} , offset:: Number , scale:: Number ) where {N} = (reorient (get_vec_dim (scale, d, sz) .* ((1 : sz[d]) .- get_vec_dim (offset, d, sz)), d, Val (N)))
81+ # for abstract arrays, we first have to reorient.
7482get_1d_ids (d, sz:: NTuple{N, Int} , offset, scale) where {N} = get_vec_dim (scale, d, sz) .* (reorient ((1 : sz[d]), d, Val (N)) .- get_vec_dim (offset, d, sz))
75- get_1d_ids (d, sz:: NTuple{N, Int} , offset) where {N} = (reorient ((1 : sz[d]), d, Val (N)) .- get_vec_dim (offset, d, sz))
83+ get_1d_ids (d, sz:: NTuple{N, Int} , offset:: Number ) where {N} = (reorient ((1 : sz[d]) .- get_vec_dim (offset, d, sz), d, Val (N)))
84+ get_1d_ids (d, sz:: NTuple{N, Int} , offset) where {N} = reorient (1 : sz[d], d, Val (N)) .- get_vec_dim (offset, d, sz)
7685# get_1d_ids(d, sz, offset, scale) = pick_n(d, scale) .* ((1:sz[d]) .- pick_n(d, offset))
7786# get_1d_ids(d, sz, offset::NTuple, scale::NTuple) = scale[d] .* ((1:sz[d]) .- offset[d])
7887
@@ -340,6 +349,7 @@ function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(cal
340349 sca = isnothing (args[2 ]) ? RAT ([one (RT)]) : RT .(args[2 ])
341350
342351 ids = ntuple ((d) -> get_1d_ids (d, sz, off, sca), Val (N)) # offset==args[1] and scale==args[2]
352+ # ids_offset_only = ntuple((d) -> get_1d_ids(d, sz, off), Val(N)) # offset==args[1] and scale==args[2]
343353 ids_offset_only = get_1d_ids .(1 : N, Ref (sz), Ref (off)) # , one(eltype(AT))
344354
345355 extra_sz = get_arg_sz (sz, args... )
0 commit comments