Skip to content

Commit 6125381

Browse files
committed
Faster, indices-aware circshift (and non-allocating circshift!)
Fixes #16032, fixes #17581
1 parent 2d30203 commit 6125381

File tree

6 files changed

+81
-14
lines changed

6 files changed

+81
-14
lines changed

base/abstractarraymath.jl

+10-12
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,10 @@ function flipdim(A::AbstractArray, d::Integer)
158158
return B
159159
end
160160

161-
circshift(a::AbstractArray, shiftamt::Real) = circshift(a, [Integer(shiftamt)])
162-
161+
function circshift(a::AbstractArray, shiftamt::Real)
162+
circshift!(similar(a), a, (Integer(shiftamt),))
163+
end
164+
circshift(a::AbstractArray, shiftamt::DimsInteger) = circshift!(similar(a), a, shiftamt)
163165
"""
164166
circshift(A, shifts)
165167
@@ -174,29 +176,25 @@ julia> b = reshape(collect(1:16), (4,4))
174176
3 7 11 15
175177
4 8 12 16
176178
177-
julia> circshift(b, [0,2])
179+
julia> circshift(b, (0,2))
178180
4×4 Array{Int64,2}:
179181
9 13 1 5
180182
10 14 2 6
181183
11 15 3 7
182184
12 16 4 8
183185
184-
julia> circshift(b, [-1,0])
186+
julia> circshift(b, (-1,0))
185187
4×4 Array{Int64,2}:
186188
2 6 10 14
187189
3 7 11 15
188190
4 8 12 16
189191
1 5 9 13
190192
```
193+
194+
See also `circshift!`.
191195
"""
192-
function circshift{T,N}(a::AbstractArray{T,N}, shiftamts)
193-
I = ()
194-
for i=1:N
195-
s = size(a,i)
196-
d = i<=length(shiftamts) ? shiftamts[i] : 0
197-
I = tuple(I..., d==0 ? [1:s;] : mod([-d:s-1-d;], s).+1)
198-
end
199-
a[(I::NTuple{N,Vector{Int}})...]
196+
function circshift(a::AbstractArray, shiftamt)
197+
circshift!(similar(a), a, map(Integer, (shiftamt...,)))
200198
end
201199

202200
# Uses K-B-N summation

base/exports.jl

+1
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ export
490490
checkbounds,
491491
checkindex,
492492
circshift,
493+
circshift!,
493494
clamp!,
494495
colon,
495496
conj!,

base/multidimensional.jl

+48
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,54 @@ function copy!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N})
609609
dest
610610
end
611611

612+
function copy!(dest::AbstractArray, Rdest::CartesianRange, src::AbstractArray, Rsrc::CartesianRange)
613+
isempty(Rdest) && return dest
614+
size(Rdest) == size(Rsrc) || throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))"))
615+
@boundscheck checkbounds(dest, Rdest.start)
616+
@boundscheck checkbounds(dest, Rdest.stop)
617+
@boundscheck checkbounds(src, Rsrc.start)
618+
@boundscheck checkbounds(src, Rsrc.stop)
619+
deltaI = Rdest.start - Rsrc.start
620+
for I in Rsrc
621+
@inbounds dest[I+deltaI] = src[I]
622+
end
623+
dest
624+
end
625+
626+
# circshift!
627+
circshift!(dest::AbstractArray, src, ::Tuple{}) = copy!(dest, src)
628+
"""
629+
circshift!(dest, src, shifts)
630+
631+
Circularly shift the data in `src`, storing the result in
632+
`dest`. `shifts` specifies the amount to shift in each dimension.
633+
634+
See also `circshift`.
635+
"""
636+
@noinline function circshift!{T,N}(dest::AbstractArray{T,N}, src, shiftamt::DimsInteger)
637+
dest === src && throw(ArgumentError("dest and src must be separate arrays"))
638+
inds = indices(src)
639+
indices(dest) == inds || throw(ArgumentError("indices of src and dest must match (got $inds and $(indices(dest)))"))
640+
_circshift!(dest, (), src, (), inds, fill_to_length(shiftamt, 0, Val{N}))
641+
end
642+
circshift!(dest::AbstractArray, src, shiftamt) = circshift!(dest, src, (shiftamt...,))
643+
644+
@inline function _circshift!(dest, rdest, src, rsrc,
645+
inds::Tuple{AbstractUnitRange,Vararg{Any}},
646+
shiftamt::Tuple{Integer,Vararg{Any}})
647+
ind1, d = inds[1], shiftamt[1]
648+
s = mod(d, length(ind1))
649+
sf, sl = first(ind1)+s, last(ind1)-s
650+
r1, r2 = first(ind1):sf-1, sf:last(ind1)
651+
r3, r4 = first(ind1):sl, sl+1:last(ind1)
652+
tinds, tshiftamt = tail(inds), tail(shiftamt)
653+
_circshift!(dest, (rdest..., r1), src, (rsrc..., r4), tinds, tshiftamt)
654+
_circshift!(dest, (rdest..., r2), src, (rsrc..., r3), tinds, tshiftamt)
655+
end
656+
# At least one of inds, shiftamt is empty
657+
function _circshift!(dest, rdest, src, rsrc, inds, shiftamt)
658+
copy!(dest, CartesianRange(rdest), src, CartesianRange(rsrc))
659+
end
612660

613661
### BitArrays
614662

doc/stdlib/arrays.rst

+12-2
Original file line numberDiff line numberDiff line change
@@ -574,20 +574,30 @@ Indexing, Assignment, and Concatenation
574574
3 7 11 15
575575
4 8 12 16
576576

577-
julia> circshift(b, [0,2])
577+
julia> circshift(b, (0,2))
578578
4×4 Array{Int64,2}:
579579
9 13 1 5
580580
10 14 2 6
581581
11 15 3 7
582582
12 16 4 8
583583

584-
julia> circshift(b, [-1,0])
584+
julia> circshift(b, (-1,0))
585585
4×4 Array{Int64,2}:
586586
2 6 10 14
587587
3 7 11 15
588588
4 8 12 16
589589
1 5 9 13
590590

591+
See also ``circshift!``\ .
592+
593+
.. function:: circshift!(dest, src, shifts)
594+
595+
.. Docstring generated from Julia source
596+
597+
Circularly shift the data in ``src``\ , storing the result in ``dest``\ . ``shifts`` specifies the amount to shift in each dimension.
598+
599+
See also ``circshift``\ .
600+
591601
.. function:: find(A)
592602

593603
.. Docstring generated from Julia source

test/arrayops.jl

+9
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,15 @@ for i = tensors
457457
@test isequal(i,permutedims(ipermutedims(i,perm),perm))
458458
end
459459

460+
## circshift
461+
462+
@test circshift(1:5, -1) == circshift(1:5, 4) == circshift(1:5, -6) == [2,3,4,5,1]
463+
@test circshift(1:5, 1) == circshift(1:5, -4) == circshift(1:5, 6) == [5,1,2,3,4]
464+
a = [1:5;]
465+
@test_throws ArgumentError Base.circshift!(a, a, 1)
466+
b = copy(a)
467+
@test Base.circshift!(b, a, 1) == [5,1,2,3,4]
468+
460469
## unique across dim ##
461470

462471
# All rows and columns unique

test/offsetarray.jl

+1
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ v = OffsetArray(rand(8), (-2,))
411411
@test rotr90(A) == OffsetArray(rotr90(parent(A)), A.offsets[[2,1]])
412412
@test flipdim(A, 1) == OffsetArray(flipdim(parent(A), 1), A.offsets)
413413
@test flipdim(A, 2) == OffsetArray(flipdim(parent(A), 2), A.offsets)
414+
@test circshift(A, (-1,2)) == OffsetArray(circshift(parent(A), (-1,2)), A.offsets)
414415

415416
@test A+1 == OffsetArray(parent(A)+1, A.offsets)
416417
@test 2*A == OffsetArray(2*parent(A), A.offsets)

0 commit comments

Comments
 (0)