Skip to content

Commit 373e61b

Browse files
committedFeb 3, 2015
ngenerate/nsplat: multidimensional algorithms on AbstractArrays
1 parent c8c5b6f commit 373e61b

File tree

1 file changed

+126
-104
lines changed

1 file changed

+126
-104
lines changed
 

Diff for: ‎base/multidimensional.jl

+126-104
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,12 @@ using .IteratorsMD
181181

182182
### From array.jl
183183

184-
@ngenerate N Void function checksize(A::AbstractArray, I::NTuple{N, Any}...)
185-
@nexprs N d->(size(A, d) == length(I_d) || throw(DimensionMismatch("index $d has length $(length(I_d)), but size(A, $d) = $(size(A,d))")))
186-
nothing
184+
stagedfunction checksize(A::AbstractArray, I...)
185+
N = length(I)
186+
quote
187+
@nexprs $N d->(size(A, d) == length(I[d]) || throw(DimensionMismatch("index $d has length $(length(I[d])), but size(A, $d) = $(size(A,d))")))
188+
nothing
189+
end
187190
end
188191

189192
@inline unsafe_getindex(v::BitArray, ind::Int) = Base.unsafe_bitgetindex(v.chunks, ind)
@@ -259,17 +262,19 @@ end
259262
end
260263

261264

262-
@ngenerate N NTuple{N,Vector{Int}} function findn{T,N}(A::AbstractArray{T,N})
263-
nnzA = countnz(A)
264-
@nexprs N d->(I_d = Array(Int, nnzA))
265-
k = 1
266-
@nloops N i A begin
267-
@inbounds if (@nref N A i) != zero(T)
268-
@nexprs N d->(I_d[k] = i_d)
269-
k += 1
265+
stagedfunction findn{T,N}(A::AbstractArray{T,N})
266+
quote
267+
nnzA = countnz(A)
268+
@nexprs $N d->(I_d = Array(Int, nnzA))
269+
k = 1
270+
@nloops $N i A begin
271+
@inbounds if (@nref $N A i) != zero(T)
272+
@nexprs $N d->(I_d[k] = i_d)
273+
k += 1
274+
end
270275
end
276+
@ntuple $N I
271277
end
272-
@ntuple N I
273278
end
274279

275280

@@ -386,57 +391,70 @@ end
386391

387392

388393
cumsum(A::AbstractArray, axis::Integer=1) = cumsum!(similar(A, Base._cumsum_type(A)), A, axis)
394+
cumsum!(B, A::AbstractArray) = cumsum!(B, A, 1)
389395
cumprod(A::AbstractArray, axis::Integer=1) = cumprod!(similar(A), A, axis)
396+
cumprod!(B, A) = cumprod!(B, A, 1)
390397

391398
for (f, op) in ((:cumsum!, :+),
392399
(:cumprod!, :*))
393400
@eval begin
394-
@ngenerate N typeof(B) function ($f){T,N}(B, A::AbstractArray{T,N}, axis::Integer=1)
395-
if size(B, axis) < 1
396-
return B
397-
end
398-
size(B) == size(A) || throw(DimensionMismatch("size of B must match A"))
399-
if axis == 1
400-
# We can accumulate to a temporary variable, which allows register usage and will be slightly faster
401-
@inbounds @nloops N i d->(d > 1 ? (1:size(A,d)) : (1:1)) begin
402-
tmp = convert(eltype(B), @nref(N, A, i))
403-
@nref(N, B, i) = tmp
404-
for i_1 = 2:size(A,1)
405-
tmp = ($op)(tmp, @nref(N, A, i))
406-
@nref(N, B, i) = tmp
407-
end
401+
stagedfunction ($f){T,N}(B, A::AbstractArray{T,N}, axis::Integer)
402+
quote
403+
if size(B, axis) < 1
404+
return B
408405
end
409-
else
410-
@nexprs N d->(isaxis_d = axis == d)
411-
# Copy the initial element in each 1d vector along dimension `axis`
412-
@inbounds @nloops N i d->(d == axis ? (1:1) : (1:size(A,d))) @nref(N, B, i) = @nref(N, A, i)
413-
# Accumulate
414-
@inbounds @nloops N i d->((1+isaxis_d):size(A, d)) d->(j_d = i_d - isaxis_d) begin
415-
@nref(N, B, i) = ($op)(@nref(N, B, j), @nref(N, A, i))
406+
size(B) == size(A) || throw(DimensionMismatch("Size of B must match A"))
407+
if axis == 1
408+
# We can accumulate to a temporary variable, which allows register usage and will be slightly faster
409+
@inbounds @nloops $N i d->(d > 1 ? (1:size(A,d)) : (1:1)) begin
410+
tmp = convert(eltype(B), @nref($N, A, i))
411+
@nref($N, B, i) = tmp
412+
for i_1 = 2:size(A,1)
413+
tmp = ($($op))(tmp, @nref($N, A, i))
414+
@nref($N, B, i) = tmp
415+
end
416+
end
417+
else
418+
@nexprs $N d->(isaxis_d = axis == d)
419+
# Copy the initial element in each 1d vector along dimension `axis`
420+
@inbounds @nloops $N i d->(d == axis ? (1:1) : (1:size(A,d))) @nref($N, B, i) = @nref($N, A, i)
421+
# Accumulate
422+
@inbounds @nloops $N i d->((1+isaxis_d):size(A, d)) d->(j_d = i_d - isaxis_d) begin
423+
@nref($N, B, i) = ($($op))(@nref($N, B, j), @nref($N, A, i))
424+
end
416425
end
426+
B
417427
end
418-
B
419428
end
420429
end
421430
end
422431

423432
### from abstractarray.jl
424433

425-
@ngenerate N typeof(A) function fill!{T,N}(A::AbstractArray{T,N}, x)
426-
xT = convert(T, x)
427-
@nloops N i A begin
428-
@inbounds (@nref N A i) = xT
434+
function fill!(A::AbstractArray, x)
435+
for I in eachindex(A)
436+
@inbounds A[I] = x
429437
end
430438
A
431439
end
432440

433-
@ngenerate N typeof(dest) function copy!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N})
434-
if @nall N d->(size(dest,d) == size(src,d))
435-
@nloops N i dest begin
436-
@inbounds (@nref N dest i) = (@nref N src i)
441+
function copy!{T,N}(dest::AbstractArray{T,N}, src::AbstractArray{T,N})
442+
samesize = true
443+
for d = 1:N
444+
if size(dest,d) != size(src,d)
445+
samesize = false
446+
break
447+
end
448+
end
449+
if samesize
450+
for I in eachindex(dest)
451+
@inbounds dest[I] = src[I]
437452
end
438453
else
439-
invoke(copy!, (typeof(dest), Any), dest, src)
454+
length(dest) == length(src) || throw(DimensionMismatch("Inconsistent lengths"))
455+
for (Idest, Isrc) in zip(eachindex(dest),eachindex(src))
456+
@inbounds dest[Idest] = src[Isrc]
457+
end
440458
end
441459
dest
442460
end
@@ -697,19 +715,21 @@ end
697715

698716
## findn
699717

700-
@ngenerate N NTuple{N,Vector{Int}} function findn{N}(B::BitArray{N})
701-
nnzB = countnz(B)
702-
I = ntuple(N, x->Array(Int, nnzB))
703-
if nnzB > 0
704-
count = 1
705-
@nloops N i B begin
706-
if (@nref N B i) # TODO: should avoid bounds checking
707-
@nexprs N d->(I[d][count] = i_d)
708-
count += 1
718+
stagedfunction findn{N}(B::BitArray{N})
719+
quote
720+
nnzB = countnz(B)
721+
I = ntuple($N, x->Array(Int, nnzB))
722+
if nnzB > 0
723+
count = 1
724+
@nloops $N i B begin
725+
if (@nref $N B i) # TODO: should avoid bounds checking
726+
@nexprs $N d->(I[d][count] = i_d)
727+
count += 1
728+
end
709729
end
710730
end
731+
return I
711732
end
712-
return I
713733
end
714734

715735
## isassigned
@@ -774,70 +794,72 @@ immutable Prehashed
774794
end
775795
hash(x::Prehashed) = x.hash
776796

777-
@ngenerate N typeof(A) function unique{T,N}(A::AbstractArray{T,N}, dim::Int)
778-
1 <= dim <= N || return copy(A)
779-
hashes = zeros(UInt, size(A, dim))
797+
stagedfunction unique{T,N}(A::AbstractArray{T,N}, dim::Int)
798+
quote
799+
1 <= dim <= $N || return copy(A)
800+
hashes = zeros(UInt, size(A, dim))
780801

781-
# Compute hash for each row
782-
k = 0
783-
@nloops N i A d->(if d == dim; k = i_d; end) begin
784-
@inbounds hashes[k] = hash(hashes[k], hash((@nref N A i)))
785-
end
802+
# Compute hash for each row
803+
k = 0
804+
@nloops $N i A d->(if d == dim; k = i_d; end) begin
805+
@inbounds hashes[k] = hash(hashes[k], hash((@nref $N A i)))
806+
end
786807

787-
# Collect index of first row for each hash
788-
uniquerow = Array(Int, size(A, dim))
789-
firstrow = Dict{Prehashed,Int}()
790-
for k = 1:size(A, dim)
791-
uniquerow[k] = get!(firstrow, Prehashed(hashes[k]), k)
792-
end
793-
uniquerows = collect(values(firstrow))
808+
# Collect index of first row for each hash
809+
uniquerow = Array(Int, size(A, dim))
810+
firstrow = Dict{Prehashed,Int}()
811+
for k = 1:size(A, dim)
812+
uniquerow[k] = get!(firstrow, Prehashed(hashes[k]), k)
813+
end
814+
uniquerows = collect(values(firstrow))
794815

795-
# Check for collisions
796-
collided = falses(size(A, dim))
797-
@inbounds begin
798-
@nloops N i A d->(if d == dim
816+
# Check for collisions
817+
collided = falses(size(A, dim))
818+
@inbounds begin
819+
@nloops $N i A d->(if d == dim
799820
k = i_d
800821
j_d = uniquerow[k]
801822
else
802823
j_d = i_d
803824
end) begin
804-
if (@nref N A j) != (@nref N A i)
805-
collided[k] = true
806-
end
825+
if (@nref $N A j) != (@nref $N A i)
826+
collided[k] = true
827+
end
828+
end
807829
end
808-
end
809830

810-
if any(collided)
811-
nowcollided = BitArray(size(A, dim))
812-
while any(collided)
813-
# Collect index of first row for each collided hash
814-
empty!(firstrow)
815-
for j = 1:size(A, dim)
816-
collided[j] || continue
817-
uniquerow[j] = get!(firstrow, Prehashed(hashes[j]), j)
818-
end
819-
for v in values(firstrow)
820-
push!(uniquerows, v)
821-
end
831+
if any(collided)
832+
nowcollided = BitArray(size(A, dim))
833+
while any(collided)
834+
# Collect index of first row for each collided hash
835+
empty!(firstrow)
836+
for j = 1:size(A, dim)
837+
collided[j] || continue
838+
uniquerow[j] = get!(firstrow, Prehashed(hashes[j]), j)
839+
end
840+
for v in values(firstrow)
841+
push!(uniquerows, v)
842+
end
822843

823-
# Check for collisions
824-
fill!(nowcollided, false)
825-
@nloops N i A d->begin
826-
if d == dim
827-
k = i_d
828-
j_d = uniquerow[k]
829-
(!collided[k] || j_d == k) && continue
830-
else
831-
j_d = i_d
832-
end
833-
end begin
834-
if (@nref N A j) != (@nref N A i)
835-
nowcollided[k] = true
844+
# Check for collisions
845+
fill!(nowcollided, false)
846+
@nloops $N i A d->begin
847+
if d == dim
848+
k = i_d
849+
j_d = uniquerow[k]
850+
(!collided[k] || j_d == k) && continue
851+
else
852+
j_d = i_d
853+
end
854+
end begin
855+
if (@nref $N A j) != (@nref $N A i)
856+
nowcollided[k] = true
857+
end
836858
end
859+
(collided, nowcollided) = (nowcollided, collided)
837860
end
838-
(collided, nowcollided) = (nowcollided, collided)
839861
end
840-
end
841862

842-
@nref N A d->d == dim ? sort!(uniquerows) : (1:size(A, d))
863+
@nref $N A d->d == dim ? sort!(uniquerows) : (1:size(A, d))
864+
end
843865
end

0 commit comments

Comments
 (0)
Please sign in to comment.