diff --git a/src/dspbase.jl b/src/dspbase.jl index fbcb5a7d7..a487d62f5 100644 --- a/src/dspbase.jl +++ b/src/dspbase.jl @@ -679,52 +679,77 @@ end # May switch argument order """ - conv(u,v) + conv(u,v; mode = :full) Convolution of two arrays. Uses either FFT convolution or overlap-save, depending on the size of the input. `u` and `v` can be N-dimensional arrays, with arbitrary indexing offsets, but their axes must be a `UnitRange`. + +:full — Return the full convolution. + +:same — Return the central part of the convolution, which is the same size as u. + +:valid — Return only parts of the convolution that are computed without zero-padded edges. + +Warning: `:same` and `:valid` will result in an extra copy of the result. """ function conv(u::AbstractArray{T, N}, - v::AbstractArray{T, N}) where {T<:BLAS.BlasFloat, N} + v::AbstractArray{T, N}; + mode::Symbol = :full) where {T<:BLAS.BlasFloat, N} su = size(u) sv = size(v) - if prod(su) >= prod(sv) - _conv(u, v, su, sv) + if mode == :full + if prod(su) >= prod(sv) + _conv(u, v, su, sv) + else + _conv(v, u, sv, su) + end + elseif mode == :same + conv_res = conv(u, v) + outsize = CartesianIndex(Int.(floor.(sv ./2 .+ 1))...):CartesianIndex(Int.(floor.(sv ./ 2) .+ su)...) + conv_res[outsize] + elseif mode == :valid + conv_res = conv(u, v) + outsize = CartesianIndex(sv...):CartesianIndex(su...) + conv_res[outsize] else - _conv(v, u, sv, su) + throw(ArgumentError("mode keyword argument must be :full or :same or :valid")) end end function conv(u::AbstractArray{<:BLAS.BlasFloat, N}, - v::AbstractArray{<:BLAS.BlasFloat, N}) where N + v::AbstractArray{<:BLAS.BlasFloat, N}; + mode::Symbol = :full) where N fu, fv = promote(u, v) - conv(fu, fv) + conv(fu, fv, mode = mode) end -conv(u::AbstractArray{<:Integer, N}, v::AbstractArray{<:Integer, N}) where {N} = - round.(Int, conv(float(u), float(v))) +conv(u::AbstractArray{<:Integer, N}, v::AbstractArray{<:Integer, N}; mode::Symbol = :full) where {N} = + round.(Int, conv(float(u), float(v), mode = mode)) -conv(u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}) where {N} = - conv(float(u), float(v)) +conv(u::AbstractArray{<:Number, N}, v::AbstractArray{<:Number, N}; mode::Symbol = :full) where {N} = + conv(float(u), float(v), mode = mode) function conv(u::AbstractArray{<:Number, N}, - v::AbstractArray{<:BLAS.BlasFloat, N}) where N - conv(float(u), v) + v::AbstractArray{<:BLAS.BlasFloat, N}; + mode::Symbol = :full) where N + conv(float(u), v, mode = mode) end function conv(u::AbstractArray{<:BLAS.BlasFloat, N}, - v::AbstractArray{<:Number, N}) where N - conv(u, float(v)) + v::AbstractArray{<:Number, N}; + mode::Symbol = :full) where N + conv(u, float(v), mode = mode) end function conv(A::AbstractArray{<:Number, M}, - B::AbstractArray{<:Number, N}) where {M, N} + B::AbstractArray{<:Number, N}; + mode::Symbol = :full) where {M, N} if (M < N) - conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B) + conv(cat(A, dims=N)::AbstractArray{eltype(A), N}, B, mode = mode) else @assert M > N - conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M}) + conv(A, cat(B, dims=M)::AbstractArray{eltype(B), M}, mode = mode) end end diff --git a/test/dsp.jl b/test/dsp.jl index 91dc942d7..6b62a136e 100644 --- a/test/dsp.jl +++ b/test/dsp.jl @@ -46,6 +46,8 @@ end im_expectation = [1, 3, 6, 6, 5, 3] a32 = convert(Array{Int32}, a) @test conv(a, b) == expectation + @test conv(a, b, mode=:same) == expectation[2:5] + @test conv(a, b, mode=:valid) == expectation[3:4] @test conv(a32, b) == expectation fa = convert(Array{Float64}, a) f32a = convert(Array{Float32}, a) @@ -90,6 +92,8 @@ end # Real Integers a32 = convert(Array{Int32}, a) @test conv(a, b) == expectation + @test conv(a, b, mode=:same) == expectation[2:4, 2:4] + @test conv(a, b, mode=:valid) == expectation[2:3, 2:3] @test conv(a32, b) == expectation # Floats fa = convert(Array{Float64}, a) @@ -161,6 +165,8 @@ end 47, 96, 100, 51, 25, 51, 53, 27], (4, 4, 4))) @test conv(a, b) == exp + @test conv(a, b, mode=:same) == exp[2:4, 2:4, 2:4] + @test conv(a, b, mode=:valid) == exp[2:3, 2:3, 2:3] #6D, trivial, just to see if it works