Skip to content

Commit ac48189

Browse files
committed
base: make diff() use views and broadcasting
1 parent 1717adb commit ac48189

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

base/multidimensional.jl

+10-12
Original file line numberDiff line numberDiff line change
@@ -659,10 +659,7 @@ end
659659
end
660660
end
661661

662-
function diff(a::AbstractVector)
663-
@assert !has_offset_axes(a)
664-
[ a[i+1] - a[i] for i=1:length(a)-1 ]
665-
end
662+
diff(a::AbstractVector) = diff(a, dims=1)
666663

667664
"""
668665
diff(A::AbstractVector)
@@ -690,14 +687,15 @@ julia> diff(vec(a))
690687
12
691688
```
692689
"""
693-
function diff(A::AbstractMatrix; dims::Integer)
694-
if dims == 1
695-
[A[i+1,j] - A[i,j] for i=1:size(A,1)-1, j=1:size(A,2)]
696-
elseif dims == 2
697-
[A[i,j+1] - A[i,j] for i=1:size(A,1), j=1:size(A,2)-1]
698-
else
699-
throw(ArgumentError("dimension must be 1 or 2, got $dims"))
700-
end
690+
function diff(a::AbstractArray{T,N}; dims::Integer) where {T,N}
691+
has_offset_axes(a) && throw(ArgumentError("offset axes unsupported"))
692+
1 <= dims <= N || throw(ArgumentError("dimension $dims out of range (1:$N)"))
693+
694+
r = axes(a)
695+
r0 = ntuple(i -> i == dims ? UnitRange(1, last(r[i]) - 1) : UnitRange(r[i]), N)
696+
r1 = ntuple(i -> i == dims ? UnitRange(2, last(r[i])) : UnitRange(r[i]), N)
697+
698+
return view(a, r1...) .- view(a, r0...)
701699
end
702700

703701
### from abstractarray.jl

test/arrayops.jl

+6
Original file line numberDiff line numberDiff line change
@@ -2283,6 +2283,9 @@ end
22832283

22842284
@testset "diff" begin
22852285
# test diff, throw ArgumentError for invalid dimension argument
2286+
v = [7, 3, 5, 1, 9]
2287+
@test diff(v) == [-4, 2, -4, 8]
2288+
@test diff(v,dims=1) == [-4, 2, -4, 8]
22862289
X = [3 9 5;
22872290
7 4 2;
22882291
2 1 10]
@@ -2292,6 +2295,9 @@ end
22922295
@test diff(view(X, 1:2, 1:2),dims=2) == reshape([6; -3], (2,1))
22932296
@test diff(view(X, 2:3, 2:3),dims=1) == [-3 8]
22942297
@test diff(view(X, 2:3, 2:3),dims=2) == reshape([-2; 9], (2,1))
2298+
Y = cat([1 3; 4 3], [6 5; 1 4], dims=3)
2299+
@test diff(Y, dims=3) == reshape([5 2; -3 1], (2, 2, 1))
2300+
@test_throws UndefKeywordError diff(X)
22952301
@test_throws ArgumentError diff(X,dims=3)
22962302
@test_throws ArgumentError diff(X,dims=-1)
22972303
end

0 commit comments

Comments
 (0)