Skip to content

Commit 05197a5

Browse files
timholytkelman
authored andcommitted
Support non-1 indices and fix type problems in DFT (fixes #17896)
(cherry picked from commit 996e275) ref #17919
1 parent a36da57 commit 05197a5

File tree

4 files changed

+147
-11
lines changed

4 files changed

+147
-11
lines changed

base/dft.jl

+29-10
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,41 @@ export fft, ifft, bfft, fft!, ifft!, bfft!,
2020
plan_fft, plan_ifft, plan_bfft, plan_fft!, plan_ifft!, plan_bfft!,
2121
rfft, irfft, brfft, plan_rfft, plan_irfft, plan_brfft
2222

23-
complexfloat{T<:AbstractFloat}(x::AbstractArray{Complex{T}}) = x
23+
typealias FFTWFloat Union{Float32,Float64}
24+
fftwfloat(x) = _fftwfloat(float(x))
25+
_fftwfloat{T<:FFTWFloat}(::Type{T}) = T
26+
_fftwfloat(::Type{Float16}) = Float32
27+
_fftwfloat{T}(::Type{T}) = error("type $T not supported")
28+
_fftwfloat{T}(x::T) = _fftwfloat(T)(x)
29+
30+
complexfloat{T<:FFTWFloat}(x::StridedArray{Complex{T}}) = x
31+
realfloat{T<:FFTWFloat}(x::StridedArray{T}) = x
2432

2533
# return an Array, rather than similar(x), to avoid an extra copy for FFTW
2634
# (which only works on StridedArray types).
27-
complexfloat{T<:Complex}(x::AbstractArray{T}) = copy!(Array{typeof(float(one(T)))}(size(x)), x)
28-
complexfloat{T<:AbstractFloat}(x::AbstractArray{T}) = copy!(Array{typeof(complex(one(T)))}(size(x)), x)
29-
complexfloat{T<:Real}(x::AbstractArray{T}) = copy!(Array{typeof(complex(float(one(T))))}(size(x)), x)
35+
complexfloat{T<:Complex}(x::AbstractArray{T}) = copy1(typeof(fftwfloat(one(T))), x)
36+
complexfloat{T<:Real}(x::AbstractArray{T}) = copy1(typeof(complex(fftwfloat(one(T)))), x)
37+
38+
realfloat{T<:Real}(x::AbstractArray{T}) = copy1(typeof(fftwfloat(one(T))), x)
39+
40+
# copy to a 1-based array, using circular permutation
41+
function copy1{T}(::Type{T}, x)
42+
y = Array{T}(map(length, indices(x)))
43+
Base.circcopy!(y, x)
44+
end
45+
46+
to1(x::AbstractArray) = _to1(indices(x), x)
47+
_to1(::Tuple{Base.OneTo,Vararg{Base.OneTo}}, x) = x
48+
_to1(::Tuple, x) = copy1(eltype(x), x)
3049

3150
# implementations only need to provide plan_X(x, region)
3251
# for X in (:fft, :bfft, ...):
3352
for f in (:fft, :bfft, :ifft, :fft!, :bfft!, :ifft!, :rfft)
3453
pf = Symbol("plan_", f)
3554
@eval begin
36-
$f(x::AbstractArray) = $pf(x) * x
37-
$f(x::AbstractArray, region) = $pf(x, region) * x
38-
$pf(x::AbstractArray; kws...) = $pf(x, 1:ndims(x); kws...)
55+
$f(x::AbstractArray) = (y = to1(x); $pf(y) * y)
56+
$f(x::AbstractArray, region) = (y = to1(x); $pf(y, region) * y)
57+
$pf(x::AbstractArray; kws...) = (y = to1(x); $pf(y, 1:ndims(y); kws...))
3958
end
4059
end
4160

@@ -187,11 +206,11 @@ for f in (:fft, :bfft, :ifft)
187206
$pf{T<:Union{Integer,Rational}}(x::AbstractArray{Complex{T}}, region; kws...) = $pf(complexfloat(x), region; kws...)
188207
end
189208
end
190-
rfft{T<:Union{Integer,Rational}}(x::AbstractArray{T}, region=1:ndims(x)) = rfft(float(x), region)
191-
plan_rfft{T<:Union{Integer,Rational}}(x::AbstractArray{T}, region; kws...) = plan_rfft(float(x), region; kws...)
209+
rfft{T<:Union{Integer,Rational}}(x::AbstractArray{T}, region=1:ndims(x)) = rfft(realfloat(x), region)
210+
plan_rfft(x::AbstractArray, region; kws...) = plan_rfft(realfloat(x), region; kws...)
192211

193212
# only require implementation to provide *(::Plan{T}, ::Array{T})
194-
*{T}(p::Plan{T}, x::AbstractArray) = p * copy!(Array{T}(size(x)), x)
213+
*{T}(p::Plan{T}, x::AbstractArray) = p * copy1(T, x)
195214

196215
# Implementations should also implement A_mul_B!(Y, plan, X) so as to support
197216
# pre-allocated output arrays. We don't define * in terms of A_mul_B!

base/multidimensional.jl

+78
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,22 @@ See also `circshift`.
645645
end
646646
circshift!(dest::AbstractArray, src, shiftamt) = circshift!(dest, src, (shiftamt...,))
647647

648+
# For each dimension, we copy the first half of src to the second half
649+
# of dest, and the second half of src to the first half of dest. This
650+
# uses a recursive bifurcation strategy so that these splits can be
651+
# encoded by ranges, which means that we need only one call to `mod`
652+
# per dimension rather than one call per index.
653+
# `rdest` and `rsrc` are tuples-of-ranges that grow one dimension at a
654+
# time; when all the dimensions have been filled in, you call `copy!`
655+
# for that block. In other words, in two dimensions schematically we
656+
# have the following call sequence (--> means a call):
657+
# circshift!(dest, src, shiftamt) -->
658+
# _circshift!(dest, src, ("first half of dim1",)) -->
659+
# _circshift!(dest, src, ("first half of dim1", "first half of dim2")) --> copy!
660+
# _circshift!(dest, src, ("first half of dim1", "second half of dim2")) --> copy!
661+
# _circshift!(dest, src, ("second half of dim1",)) -->
662+
# _circshift!(dest, src, ("second half of dim1", "first half of dim2")) --> copy!
663+
# _circshift!(dest, src, ("second half of dim1", "second half of dim2")) --> copy!
648664
@inline function _circshift!(dest, rdest, src, rsrc,
649665
inds::Tuple{AbstractUnitRange,Vararg{Any}},
650666
shiftamt::Tuple{Integer,Vararg{Any}})
@@ -662,6 +678,68 @@ function _circshift!(dest, rdest, src, rsrc, inds, shiftamt)
662678
copy!(dest, CartesianRange(rdest), src, CartesianRange(rsrc))
663679
end
664680

681+
# circcopy!
682+
"""
683+
circcopy!(dest, src)
684+
685+
Copy `src` to `dest`, indexing each dimension modulo its length.
686+
`src` and `dest` must have the same size, but can be offset in
687+
their indices; any offset results in a (circular) wraparound. If the
688+
arrays have overlapping indices, then on the domain of the overlap
689+
`dest` agrees with `src`.
690+
691+
```julia
692+
julia> src = reshape(collect(1:16), (4,4))
693+
4×4 Array{Int64,2}:
694+
1 5 9 13
695+
2 6 10 14
696+
3 7 11 15
697+
4 8 12 16
698+
699+
julia> dest = OffsetArray{Int}((0:3,2:5))
700+
701+
julia> circcopy!(dest, src)
702+
OffsetArrays.OffsetArray{Int64,2,Array{Int64,2}} with indices 0:3×2:5:
703+
8 12 16 4
704+
5 9 13 1
705+
6 10 14 2
706+
7 11 15 3
707+
708+
julia> dest[1:3,2:4] == src[1:3,2:4]
709+
true
710+
```
711+
"""
712+
function circcopy!(dest, src)
713+
dest === src && throw(ArgumentError("dest and src must be separate arrays"))
714+
indssrc, indsdest = indices(src), indices(dest)
715+
if (szsrc = map(length, indssrc)) != (szdest = map(length, indsdest))
716+
throw(DimensionMismatch("src and dest must have the same sizes (got $szsrc and $szdest)"))
717+
end
718+
shift = map((isrc, idest)->first(isrc)-first(idest), indssrc, indsdest)
719+
all(x->x==0, shift) && return copy!(dest, src)
720+
_circcopy!(dest, (), indsdest, src, (), indssrc)
721+
end
722+
723+
# This uses the same strategy described above for _circshift!
724+
@inline function _circcopy!(dest, rdest, indsdest::Tuple{AbstractUnitRange,Vararg{Any}},
725+
src, rsrc, indssrc::Tuple{AbstractUnitRange,Vararg{Any}})
726+
indd1, inds1 = indsdest[1], indssrc[1]
727+
l = length(indd1)
728+
s = mod(first(inds1)-first(indd1), l)
729+
sdf = first(indd1)+s
730+
rd1, rd2 = first(indd1):sdf-1, sdf:last(indd1)
731+
ssf = last(inds1)-s
732+
rs1, rs2 = first(inds1):ssf, ssf+1:last(inds1)
733+
tindsd, tindss = tail(indsdest), tail(indssrc)
734+
_circcopy!(dest, (rdest..., rd1), tindsd, src, (rsrc..., rs2), tindss)
735+
_circcopy!(dest, (rdest..., rd2), tindsd, src, (rsrc..., rs1), tindss)
736+
end
737+
738+
# At least one of indsdest, indssrc are empty (and both should be, since we've checked)
739+
function _circcopy!(dest, rdest, indsdest, src, rsrc, indssrc)
740+
copy!(dest, CartesianRange(rdest), src, CartesianRange(rsrc))
741+
end
742+
665743
### BitArrays
666744

667745
## getindex

test/fft.jl

+8
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,11 @@ for x in (randn(10),randn(10,12))
326326
# note: inference doesn't work for plan_fft_ since the
327327
# algorithm steps are included in the CTPlan type
328328
end
329+
330+
# issue #17896
331+
a = rand(5)
332+
@test fft(a) == fft(view(a,:)) == fft(view(a, 1:5)) == fft(view(a, [1:5;]))
333+
@test rfft(a) == rfft(view(a,:)) == rfft(view(a, 1:5)) == rfft(view(a, [1:5;]))
334+
a16 = convert(Vector{Float16}, a)
335+
@test fft(a16) == fft(view(a16,:)) == fft(view(a16, 1:5)) == fft(view(a16, [1:5;]))
336+
@test rfft(a16) == rfft(view(a16,:)) == rfft(view(a16, 1:5)) == rfft(view(a16, [1:5;]))

test/offsetarray.jl

+32-1
Original file line numberDiff line numberDiff line change
@@ -411,10 +411,41 @@ v = OffsetArray(rand(8), (-2,))
411411
@test rotr90(A) == OffsetArray(rotr90(parent(A)), A.offsets[[2,1]])
412412
@test flipdim(A, 1) == OffsetArray(flipdim(parent(A), 1), A.offsets)
413413
@test flipdim(A, 2) == OffsetArray(flipdim(parent(A), 2), A.offsets)
414-
@test circshift(A, (-1,2)) == OffsetArray(circshift(parent(A), (-1,2)), A.offsets)
415414

416415
@test A+1 == OffsetArray(parent(A)+1, A.offsets)
417416
@test 2*A == OffsetArray(2*parent(A), A.offsets)
418417
@test A+A == OffsetArray(parent(A)+parent(A), A.offsets)
419418
@test A.*A == OffsetArray(parent(A).*parent(A), A.offsets)
419+
420+
@test circshift(A, (-1,2)) == OffsetArray(circshift(parent(A), (-1,2)), A.offsets)
421+
422+
src = reshape(collect(1:16), (4,4))
423+
dest = OffsetArray(Array{Int}(4,4), (-1,1))
424+
circcopy!(dest, src)
425+
@test parent(dest) == [8 12 16 4; 5 9 13 1; 6 10 14 2; 7 11 15 3]
426+
@test dest[1:3,2:4] == src[1:3,2:4]
427+
428+
e = eye(5)
429+
a = [e[:,1], e[:,2], e[:,3], e[:,4], e[:,5]]
430+
a1 = zeros(5)
431+
c = [ones(Complex{Float64}, 5),
432+
exp(-2*pi*im*(0:4)/5),
433+
exp(-4*pi*im*(0:4)/5),
434+
exp(-6*pi*im*(0:4)/5),
435+
exp(-8*pi*im*(0:4)/5)]
436+
for s = -5:5
437+
for i = 1:5
438+
thisa = OffsetArray(a[i], (s,))
439+
thisc = c[mod1(i+s+5,5)]
440+
@test_approx_eq fft(thisa) thisc
441+
@test_approx_eq fft(thisa, 1) thisc
442+
@test_approx_eq ifft(fft(thisa)) circcopy!(a1, thisa)
443+
@test_approx_eq ifft(fft(thisa, 1), 1) circcopy!(a1, thisa)
444+
@test_approx_eq rfft(thisa) thisc[1:3]
445+
@test_approx_eq rfft(thisa, 1) thisc[1:3]
446+
@test_approx_eq irfft(rfft(thisa, 1), 5, 1) a1
447+
@test_approx_eq irfft(rfft(thisa, 1), 5, 1) a1
448+
end
420449
end
450+
451+
end # let

0 commit comments

Comments
 (0)