Skip to content

Commit b378ece

Browse files
authored
Merge pull request #17861 from JuliaLang/teh/circshift
Faster, indices-aware circshift (and non-allocating circshift!)
2 parents 59f158c + 60660b5 commit b378ece

File tree

6 files changed

+86
-14
lines changed

6 files changed

+86
-14
lines changed

base/abstractarraymath.jl

+10-12
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,10 @@ function flipdim(A::AbstractArray, d::Integer)
158158
return B
159159
end
160160

161-
circshift(a::AbstractArray, shiftamt::Real) = circshift(a, [Integer(shiftamt)])
162-
161+
function circshift(a::AbstractArray, shiftamt::Real)
162+
circshift!(similar(a), a, (Integer(shiftamt),))
163+
end
164+
circshift(a::AbstractArray, shiftamt::DimsInteger) = circshift!(similar(a), a, shiftamt)
163165
"""
164166
circshift(A, shifts)
165167
@@ -174,29 +176,25 @@ julia> b = reshape(collect(1:16), (4,4))
174176
3 7 11 15
175177
4 8 12 16
176178
177-
julia> circshift(b, [0,2])
179+
julia> circshift(b, (0,2))
178180
4×4 Array{Int64,2}:
179181
9 13 1 5
180182
10 14 2 6
181183
11 15 3 7
182184
12 16 4 8
183185
184-
julia> circshift(b, [-1,0])
186+
julia> circshift(b, (-1,0))
185187
4×4 Array{Int64,2}:
186188
2 6 10 14
187189
3 7 11 15
188190
4 8 12 16
189191
1 5 9 13
190192
```
193+
194+
See also `circshift!`.
191195
"""
192-
function circshift{T,N}(a::AbstractArray{T,N}, shiftamts)
193-
I = ()
194-
for i=1:N
195-
s = size(a,i)
196-
d = i<=length(shiftamts) ? shiftamts[i] : 0
197-
I = tuple(I..., d==0 ? [1:s;] : mod([-d:s-1-d;], s).+1)
198-
end
199-
a[(I::NTuple{N,Vector{Int}})...]
196+
function circshift(a::AbstractArray, shiftamt)
197+
circshift!(similar(a), a, map(Integer, (shiftamt...,)))
200198
end
201199

202200
# Uses K-B-N summation

base/exports.jl

+1
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ export
490490
checkbounds,
491491
checkindex,
492492
circshift,
493+
circshift!,
493494
clamp!,
494495
colon,
495496
conj!,

base/multidimensional.jl

+51
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,57 @@ function copy!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N})
610610
dest
611611
end
612612

613+
function copy!(dest::AbstractArray, Rdest::CartesianRange, src::AbstractArray, Rsrc::CartesianRange)
614+
isempty(Rdest) && return dest
615+
size(Rdest) == size(Rsrc) || throw(ArgumentError("source and destination must have same size (got $(size(Rsrc)) and $(size(Rdest)))"))
616+
@boundscheck checkbounds(dest, Rdest.start)
617+
@boundscheck checkbounds(dest, Rdest.stop)
618+
@boundscheck checkbounds(src, Rsrc.start)
619+
@boundscheck checkbounds(src, Rsrc.stop)
620+
deltaI = Rdest.start - Rsrc.start
621+
for I in Rsrc
622+
@inbounds dest[I+deltaI] = src[I]
623+
end
624+
dest
625+
end
626+
627+
# circshift!
628+
circshift!(dest::AbstractArray, src, ::Tuple{}) = copy!(dest, src)
629+
"""
630+
circshift!(dest, src, shifts)
631+
632+
Circularly shift the data in `src`, storing the result in
633+
`dest`. `shifts` specifies the amount to shift in each dimension.
634+
635+
The `dest` array must be distinct from the `src` array (they cannot
636+
alias each other).
637+
638+
See also `circshift`.
639+
"""
640+
@noinline function circshift!{T,N}(dest::AbstractArray{T,N}, src, shiftamt::DimsInteger)
641+
dest === src && throw(ArgumentError("dest and src must be separate arrays"))
642+
inds = indices(src)
643+
indices(dest) == inds || throw(ArgumentError("indices of src and dest must match (got $inds and $(indices(dest)))"))
644+
_circshift!(dest, (), src, (), inds, fill_to_length(shiftamt, 0, Val{N}))
645+
end
646+
circshift!(dest::AbstractArray, src, shiftamt) = circshift!(dest, src, (shiftamt...,))
647+
648+
@inline function _circshift!(dest, rdest, src, rsrc,
649+
inds::Tuple{AbstractUnitRange,Vararg{Any}},
650+
shiftamt::Tuple{Integer,Vararg{Any}})
651+
ind1, d = inds[1], shiftamt[1]
652+
s = mod(d, length(ind1))
653+
sf, sl = first(ind1)+s, last(ind1)-s
654+
r1, r2 = first(ind1):sf-1, sf:last(ind1)
655+
r3, r4 = first(ind1):sl, sl+1:last(ind1)
656+
tinds, tshiftamt = tail(inds), tail(shiftamt)
657+
_circshift!(dest, (rdest..., r1), src, (rsrc..., r4), tinds, tshiftamt)
658+
_circshift!(dest, (rdest..., r2), src, (rsrc..., r3), tinds, tshiftamt)
659+
end
660+
# At least one of inds, shiftamt is empty
661+
function _circshift!(dest, rdest, src, rsrc, inds, shiftamt)
662+
copy!(dest, CartesianRange(rdest), src, CartesianRange(rsrc))
663+
end
613664

614665
### BitArrays
615666

doc/stdlib/arrays.rst

+14-2
Original file line numberDiff line numberDiff line change
@@ -590,20 +590,32 @@ Indexing, Assignment, and Concatenation
590590
3 7 11 15
591591
4 8 12 16
592592

593-
julia> circshift(b, [0,2])
593+
julia> circshift(b, (0,2))
594594
4×4 Array{Int64,2}:
595595
9 13 1 5
596596
10 14 2 6
597597
11 15 3 7
598598
12 16 4 8
599599

600-
julia> circshift(b, [-1,0])
600+
julia> circshift(b, (-1,0))
601601
4×4 Array{Int64,2}:
602602
2 6 10 14
603603
3 7 11 15
604604
4 8 12 16
605605
1 5 9 13
606606

607+
See also ``circshift!``\ .
608+
609+
.. function:: circshift!(dest, src, shifts)
610+
611+
.. Docstring generated from Julia source
612+
613+
Circularly shift the data in ``src``\ , storing the result in ``dest``\ . ``shifts`` specifies the amount to shift in each dimension.
614+
615+
The ``dest`` array must be distinct from the ``src`` array (they cannot alias each other).
616+
617+
See also ``circshift``\ .
618+
607619
.. function:: find(A)
608620

609621
.. Docstring generated from Julia source

test/arrayops.jl

+9
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,15 @@ for i = tensors
462462
@test isequal(i,permutedims(ipermutedims(i,perm),perm))
463463
end
464464

465+
## circshift
466+
467+
@test circshift(1:5, -1) == circshift(1:5, 4) == circshift(1:5, -6) == [2,3,4,5,1]
468+
@test circshift(1:5, 1) == circshift(1:5, -4) == circshift(1:5, 6) == [5,1,2,3,4]
469+
a = [1:5;]
470+
@test_throws ArgumentError Base.circshift!(a, a, 1)
471+
b = copy(a)
472+
@test Base.circshift!(b, a, 1) == [5,1,2,3,4]
473+
465474
## unique across dim ##
466475

467476
# All rows and columns unique

test/offsetarray.jl

+1
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ v = OffsetArray(rand(8), (-2,))
411411
@test rotr90(A) == OffsetArray(rotr90(parent(A)), A.offsets[[2,1]])
412412
@test flipdim(A, 1) == OffsetArray(flipdim(parent(A), 1), A.offsets)
413413
@test flipdim(A, 2) == OffsetArray(flipdim(parent(A), 2), A.offsets)
414+
@test circshift(A, (-1,2)) == OffsetArray(circshift(parent(A), (-1,2)), A.offsets)
414415

415416
@test A+1 == OffsetArray(parent(A)+1, A.offsets)
416417
@test 2*A == OffsetArray(2*parent(A), A.offsets)

0 commit comments

Comments
 (0)