Skip to content

Commit 4f6d1db

Browse files
committed
Merge pull request #7131 from mauro3/sparse_indexing2
Improved sparse getindex
2 parents 1ef93a0 + 4fd274e commit 4f6d1db

File tree

2 files changed

+112
-103
lines changed

2 files changed

+112
-103
lines changed

base/sparse/sparsematrix.jl

+61-103
Original file line numberDiff line numberDiff line change
@@ -712,31 +712,40 @@ prod{T}(A::SparseMatrixCSC{T}, region) = reducedim(*,A,region,one(T))
712712
#sum(A::SparseMatrixCSC{Bool}) = countnz(A)
713713

714714
## getindex
715-
getindex(A::SparseMatrixCSC, i::Integer) = getindex(A, ind2sub(size(A),i)...)
716-
717-
function getindex{T}(A::SparseMatrixCSC{T}, i0::Integer, i1::Integer)
718-
if !(1 <= i0 <= A.m && 1 <= i1 <= A.n); throw(BoundsError()); end
719-
first = A.colptr[i1]
720-
last = A.colptr[i1+1]-1
721-
while first <= last
722-
mid = (first + last) >> 1
723-
t = A.rowval[mid]
724-
if t == i0
725-
return A.nzval[mid]
726-
elseif t > i0
727-
last = mid - 1
715+
function binarysearch(haystack::AbstractVector, needle, lo::Int, hi::Int)
716+
# Finds the first occurrence of needle in haystack[lo:hi]
717+
lo = lo-1
718+
hi2 = hi
719+
hi = hi+1
720+
@inbounds while lo < hi-1
721+
m = (lo+hi)>>>1
722+
if haystack[m] < needle
723+
lo = m
728724
else
729-
first = mid + 1
725+
hi = m
730726
end
731727
end
732-
return zero(T)
728+
(hi==hi2+1 || haystack[hi]!=needle) ? -1 : hi
729+
end
730+
function rangesearch(haystack::Range, needle)
731+
(i,rem) = divrem(needle - first(haystack), step(haystack))
732+
(rem==0 && 1<=i+1<=length(haystack)) ? i+1 : -1
733+
end
734+
735+
getindex(A::SparseMatrixCSC, i::Integer) = getindex(A, ind2sub(size(A),i))
736+
getindex(A::SparseMatrixCSC, I::(Integer,Integer)) = getindex(A, I[1], I[2])
737+
738+
function getindex{T}(A::SparseMatrixCSC{T}, i0::Integer, i1::Integer)
739+
if !(1 <= i0 <= A.m && 1 <= i1 <= A.n); throw(BoundsError()); end
740+
ind = binarysearch(A.rowval, i0, A.colptr[i1], A.colptr[i1+1]-1)
741+
ind > -1 ? A.nzval[ind] : zero(T)
733742
end
734743

735744
getindex{T<:Integer}(A::SparseMatrixCSC, I::AbstractVector{T}, j::Integer) = getindex(A,I,[j])
736745
getindex{T<:Integer}(A::SparseMatrixCSC, i::Integer, J::AbstractVector{T}) = getindex(A,[i],J)
737746

738747
function getindex_cols{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, J::AbstractVector)
739-
748+
# for indexing whole columns
740749
(m, n) = size(A)
741750
nJ = length(J)
742751

@@ -758,52 +767,35 @@ function getindex_cols{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, J::AbstractVector)
758767

759768
for j = 1:nJ
760769
col = J[j]
761-
762770
for k = colptrA[col]:colptrA[col+1]-1
763771
ptrS += 1
764772
rowvalS[ptrS] = rowvalA[k]
765773
nzvalS[ptrS] = nzvalA[k]
766774
end
767775
end
768-
769776
return SparseMatrixCSC(m, nJ, colptrS, rowvalS, nzvalS)
770-
771777
end
772778

773-
# TODO: See if growing arrays is faster than pre-computing structure
774-
# and then populating nonzeros
775-
# TODO: Use binary search in cases where nI >> nnz(A[:,j]) or nI << nnz(A[:,j])
776-
function getindex_I_sorted{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector, J::AbstractVector)
777-
779+
function getindex{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv,Ti}, I::Range, J::AbstractVector)
780+
# Ranges for indexing rows
778781
(m, n) = size(A)
782+
# whole columns:
783+
if I == 1:m
784+
return getindex_cols(A, J)
785+
end
786+
779787
nI = length(I)
780788
nJ = length(J)
781-
782789
colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval
783-
784-
I_ref = falses(m)
785-
I_ref[I] = true
786-
787-
I_repeat = zeros(Int, m)
788-
for i=1:nI; I_repeat[I[i]] += 1; end
789-
790790
colptrS = Array(Ti, nJ+1)
791791
colptrS[1] = 1
792792
nnzS = 0
793793

794794
# Form the structure of the result and compute space
795795
for j = 1:nJ
796796
col = J[j]
797-
798-
for k = colptrA[col]:colptrA[col+1]-1
799-
rowA = rowvalA[k]
800-
801-
if I_ref[rowA]
802-
for r = 1:I_repeat[rowA]
803-
nnzS += 1
804-
end
805-
end
806-
797+
for k in colptrA[col]:colptrA[col+1]-1
798+
if rowvalA[k] in I; nnzS += 1 end # `in` is fast for ranges
807799
end
808800
colptrS[j+1] = nnzS+1
809801
end
@@ -813,54 +805,40 @@ function getindex_I_sorted{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector,
813805
nzvalS = Array(Tv, nnzS)
814806
ptrS = 1
815807

816-
fI = zeros(Ti, m)
817-
for k=1:nI
818-
Ik = I[k]
819-
if fI[Ik] == 0; fI[Ik] = k; end
820-
end
821-
822808
for j = 1:nJ
823809
col = J[j]
824-
825810
for k = colptrA[col]:colptrA[col+1]-1
826811
rowA = rowvalA[k]
827-
828-
if I_ref[rowA]
829-
for r = 1:I_repeat[rowA]
830-
rowvalS[ptrS] = fI[rowA] + r - 1
831-
nzvalS[ptrS] = nzvalA[k]
832-
ptrS += 1
833-
end
812+
i = rangesearch(I, rowA)
813+
if i > -1
814+
rowvalS[ptrS] = i
815+
nzvalS[ptrS] = nzvalA[k]
816+
ptrS += 1
834817
end
835-
836818
end
837819
end
838820

839821
return SparseMatrixCSC(nI, nJ, colptrS, rowvalS, nzvalS)
840822
end
841823

842-
# getindex_I_sorted based on merging of sorted lists
843-
function getindex_I_sorted_old{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::Vector, J::AbstractVector)
844-
824+
function getindex_I_sorted{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector, J::AbstractVector)
825+
# Sorted vectors for indexing rows.
826+
# Similar to getindex_general but without the transpose trick.
845827
(m, n) = size(A)
846828
nI = length(I)
847829
nJ = length(J)
848830

849831
colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval
850-
851832
colptrS = Array(Ti, nJ+1)
852833
colptrS[1] = 1
853834
nnzS = 0
854835

855836
# Form the structure of the result and compute space
856837
for j = 1:nJ
857838
col = J[j]
858-
859-
ptrI::Int = 1
860-
839+
ptrI::Int = 1 # runs through I
861840
ptrA::Int = colptrA[col]
862841
stopA::Int = colptrA[col+1]
863-
864842
while ptrI <= nI && ptrA < stopA
865843
rowA = rowvalA[ptrA]
866844
rowI = I[ptrI]
@@ -874,22 +852,17 @@ function getindex_I_sorted_old{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::Vector, J::A
874852
ptrI += 1
875853
end
876854
end
877-
colptrS[j+1] = nnzS+1
878-
855+
colptrS[j+1] = nnzS+1
879856
end
880857

881-
fI = find(I)
882-
883858
# Populate the values in the result
884859
rowvalS = Array(Ti, nnzS)
885860
nzvalS = Array(Tv, nnzS)
886-
ptrS = 0
861+
ptrS = 1
887862

888863
for j = 1:nJ
889864
col = J[j]
890-
891-
ptrI::Int = 1
892-
865+
ptrI::Int = 1 # runs through I
893866
ptrA::Int = colptrA[col]
894867
stopA::Int = colptrA[col+1]
895868

@@ -902,30 +875,27 @@ function getindex_I_sorted_old{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::Vector, J::A
902875
elseif rowI < rowA
903876
ptrI += 1
904877
else
905-
ptrS += 1
906-
rowvalS[ptrS] = fI[ptrI]
878+
rowvalS[ptrS] = ptrI
907879
nzvalS[ptrS] = nzvalA[ptrA]
880+
ptrS += 1
908881
ptrI += 1
909882
end
910883
end
911-
912884
end
913-
914885
return SparseMatrixCSC(nI, nJ, colptrS, rowvalS, nzvalS)
915886
end
916887

917-
function getindex_general{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::Vector, J::AbstractVector)
888+
function getindex_general{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector, J::AbstractVector)
889+
# Anything for indexing rows.
890+
# This sorts I first then does some trick with constructing the transpose.
918891
(m, n) = size(A)
919892
nI = length(I)
920893
nJ = length(J)
921894

922895
colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval
923-
924896
nnzS = 0
925-
926897
pI = sortperm(I); I = I[pI]
927898
fI = find(I)
928-
929899
W = zeros(Int, nI + 1) # Keep row counts
930900
W[1] = 1 # For cumsum later
931901

@@ -993,36 +963,24 @@ function getindex_general{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::Vector, J::Abstra
993963
# Transpose so that rows are in sorted order and return
994964
S_T = SparseMatrixCSC(nJ, nI, colptrS_T, rowvalS_T, nzvalS_T)
995965
return S_T.'
996-
997966
end
998967

999-
# S = A[I, J]
968+
969+
# the general case:
1000970
function getindex{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector, J::AbstractVector)
1001-
m = size(A, 1)
1002-
if isa(I, Range)
1003-
if I == 1:m # whole columns
1004-
return getindex_cols(A, J)
1005-
else # ranges are always sorted, but maybe in reverse
1006-
if step(I)>0
1007-
return getindex_I_sorted(A, I, J)
1008-
else
1009-
I = [I]
1010-
return getindex_general(A, I, J)
1011-
# todo:
1012-
# return reverse(getindex_I_sorted(A, reverse(I), J))
1013-
end
1014-
end
1015-
else
1016-
if issorted(I)
1017-
return getindex_I_sorted(A, I, J)
1018-
else
1019-
return getindex_general(A, I, J)
1020-
end
971+
loop_overI_threshold = 0 # ~40 but doesn't matter much.
972+
if issorted(I)
973+
return getindex_I_sorted(A, I, J)
974+
else
975+
return getindex_general(A, I, J)
1021976
end
1022977
end
1023978

1024979
# logical getindex
980+
getindex{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv,Ti}, I::Range{Bool}, J::AbstractVector{Bool}) = error("Cannot index with Range{Bool}")
981+
getindex{Tv,Ti<:Integer,T<:Integer}(A::SparseMatrixCSC{Tv,Ti}, I::Range{Bool}, J::AbstractVector{T}) = error("Cannot index with Range{Bool}")
1025982

983+
getindex{T<:Integer}(A::SparseMatrixCSC, I::Range{T}, J::AbstractVector{Bool}) = A[I,find(J)]
1026984
getindex(A::SparseMatrixCSC, I::Integer, J::AbstractVector{Bool}) = A[I,find(J)]
1027985
getindex(A::SparseMatrixCSC, I::AbstractVector{Bool}, J::Integer) = A[find(I),J]
1028986
getindex(A::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{Bool}) = A[find(I),find(J)]

test/sparse.jl

+51
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,57 @@ for op in (:sin, :cos, :tan, :iceil, :ifloor, :ceil, :floor, :abs, :abs2)
239239
end
240240
end
241241

242+
# getindex tests
243+
ni = 23
244+
nj = 32
245+
a116 = reshape(1:(ni*nj), ni, nj)
246+
s116 = sparse(a116)
247+
248+
ad116 = diagm(diag(a116))
249+
sd116 = sparse(ad116)
250+
251+
for (aa116, ss116) in [(a116, s116), (ad116, sd116)]
252+
ij=11; i=3; j=2
253+
@test ss116[ij] == aa116[ij]
254+
@test ss116[(i,j)] == aa116[i,j]
255+
@test ss116[i,j] == aa116[i,j]
256+
@test ss116[i-1,j] == aa116[i-1,j]
257+
ss116[i,j] = 0
258+
@test ss116[i,j] == 0
259+
ss116 = sparse(aa116)
260+
261+
# range indexing
262+
@test full(ss116[i,:]) == aa116[i,:]
263+
@test full(ss116[:,j]) == aa116[:,j]'' # sparse matrices/vectors always have ndims==2:
264+
@test full(ss116[i,1:2:end]) == aa116[i,1:2:end]
265+
@test full(ss116[1:2:end,j]) == aa116[1:2:end,j]''
266+
@test full(ss116[i,end:-2:1]) == aa116[i,end:-2:1]
267+
@test full(ss116[end:-2:1,j]) == aa116[end:-2:1,j]''
268+
# float-range indexing is not supported
269+
270+
# sorted vector indexing
271+
@test full(ss116[i,[3:2:end-3]]) == aa116[i,[3:2:end-3]]
272+
@test full(ss116[[3:2:end-3],j]) == aa116[[3:2:end-3],j]''
273+
@test full(ss116[i,[end-3:-2:1]]) == aa116[i,[end-3:-2:1]]
274+
@test full(ss116[[end-3:-2:1],j]) == aa116[[end-3:-2:1],j]''
275+
276+
# unsorted vector indexing with repetition
277+
p = [4, 1, 2, 3, 2, 6]
278+
@test full(ss116[p,:]) == aa116[p,:]
279+
@test full(ss116[:,p]) == aa116[:,p]
280+
@test full(ss116[p,p]) == aa116[p,p]
281+
282+
# bool indexing
283+
li = randbool(size(aa116,1))
284+
lj = randbool(size(aa116,2))
285+
@test full(ss116[li,j]) == aa116[li,j]''
286+
@test full(ss116[li,:]) == aa116[li,:]
287+
@test full(ss116[i,lj]) == aa116[i,lj]
288+
@test full(ss116[:,lj]) == aa116[:,lj]
289+
@test full(ss116[li,lj]) == aa116[li,lj]
290+
end
291+
292+
242293
# setindex tests
243294
let a = spzeros(Int, 10, 10)
244295
@test countnz(a) == 0

0 commit comments

Comments
 (0)