@@ -712,31 +712,40 @@ prod{T}(A::SparseMatrixCSC{T}, region) = reducedim(*,A,region,one(T))
712
712
# sum(A::SparseMatrixCSC{Bool}) = countnz(A)
713
713
714
714
# # getindex
715
- getindex (A:: SparseMatrixCSC , i:: Integer ) = getindex (A, ind2sub (size (A),i)... )
716
-
717
- function getindex {T} (A:: SparseMatrixCSC{T} , i0:: Integer , i1:: Integer )
718
- if ! (1 <= i0 <= A. m && 1 <= i1 <= A. n); throw (BoundsError ()); end
719
- first = A. colptr[i1]
720
- last = A. colptr[i1+ 1 ]- 1
721
- while first <= last
722
- mid = (first + last) >> 1
723
- t = A. rowval[mid]
724
- if t == i0
725
- return A. nzval[mid]
726
- elseif t > i0
727
- last = mid - 1
715
+ function binarysearch (haystack:: AbstractVector , needle, lo:: Int , hi:: Int )
716
+ # Finds the first occurrence of needle in haystack[lo:hi]
717
+ lo = lo- 1
718
+ hi2 = hi
719
+ hi = hi+ 1
720
+ @inbounds while lo < hi- 1
721
+ m = (lo+ hi)>>> 1
722
+ if haystack[m] < needle
723
+ lo = m
728
724
else
729
- first = mid + 1
725
+ hi = m
730
726
end
731
727
end
732
- return zero (T)
728
+ (hi== hi2+ 1 || haystack[hi]!= needle) ? - 1 : hi
729
+ end
730
+ function rangesearch (haystack:: Range , needle)
731
+ (i,rem) = divrem (needle - first (haystack), step (haystack))
732
+ (rem== 0 && 1 <= i+ 1 <= length (haystack)) ? i+ 1 : - 1
733
+ end
734
+
735
+ getindex (A:: SparseMatrixCSC , i:: Integer ) = getindex (A, ind2sub (size (A),i))
736
+ getindex (A:: SparseMatrixCSC , I:: (Integer,Integer) ) = getindex (A, I[1 ], I[2 ])
737
+
738
+ function getindex {T} (A:: SparseMatrixCSC{T} , i0:: Integer , i1:: Integer )
739
+ if ! (1 <= i0 <= A. m && 1 <= i1 <= A. n); throw (BoundsError ()); end
740
+ ind = binarysearch (A. rowval, i0, A. colptr[i1], A. colptr[i1+ 1 ]- 1 )
741
+ ind > - 1 ? A. nzval[ind] : zero (T)
733
742
end
734
743
735
744
getindex {T<:Integer} (A:: SparseMatrixCSC , I:: AbstractVector{T} , j:: Integer ) = getindex (A,I,[j])
736
745
getindex {T<:Integer} (A:: SparseMatrixCSC , i:: Integer , J:: AbstractVector{T} ) = getindex (A,[i],J)
737
746
738
747
function getindex_cols {Tv,Ti} (A:: SparseMatrixCSC{Tv,Ti} , J:: AbstractVector )
739
-
748
+ # for indexing whole columns
740
749
(m, n) = size (A)
741
750
nJ = length (J)
742
751
@@ -758,52 +767,35 @@ function getindex_cols{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, J::AbstractVector)
758
767
759
768
for j = 1 : nJ
760
769
col = J[j]
761
-
762
770
for k = colptrA[col]: colptrA[col+ 1 ]- 1
763
771
ptrS += 1
764
772
rowvalS[ptrS] = rowvalA[k]
765
773
nzvalS[ptrS] = nzvalA[k]
766
774
end
767
775
end
768
-
769
776
return SparseMatrixCSC (m, nJ, colptrS, rowvalS, nzvalS)
770
-
771
777
end
772
778
773
- # TODO : See if growing arrays is faster than pre-computing structure
774
- # and then populating nonzeros
775
- # TODO : Use binary search in cases where nI >> nnz(A[:,j]) or nI << nnz(A[:,j])
776
- function getindex_I_sorted {Tv,Ti} (A:: SparseMatrixCSC{Tv,Ti} , I:: AbstractVector , J:: AbstractVector )
777
-
779
+ function getindex {Tv,Ti<:Integer} (A:: SparseMatrixCSC{Tv,Ti} , I:: Range , J:: AbstractVector )
780
+ # Ranges for indexing rows
778
781
(m, n) = size (A)
782
+ # whole columns:
783
+ if I == 1 : m
784
+ return getindex_cols (A, J)
785
+ end
786
+
779
787
nI = length (I)
780
788
nJ = length (J)
781
-
782
789
colptrA = A. colptr; rowvalA = A. rowval; nzvalA = A. nzval
783
-
784
- I_ref = falses (m)
785
- I_ref[I] = true
786
-
787
- I_repeat = zeros (Int, m)
788
- for i= 1 : nI; I_repeat[I[i]] += 1 ; end
789
-
790
790
colptrS = Array (Ti, nJ+ 1 )
791
791
colptrS[1 ] = 1
792
792
nnzS = 0
793
793
794
794
# Form the structure of the result and compute space
795
795
for j = 1 : nJ
796
796
col = J[j]
797
-
798
- for k = colptrA[col]: colptrA[col+ 1 ]- 1
799
- rowA = rowvalA[k]
800
-
801
- if I_ref[rowA]
802
- for r = 1 : I_repeat[rowA]
803
- nnzS += 1
804
- end
805
- end
806
-
797
+ for k in colptrA[col]: colptrA[col+ 1 ]- 1
798
+ if rowvalA[k] in I; nnzS += 1 end # `in` is fast for ranges
807
799
end
808
800
colptrS[j+ 1 ] = nnzS+ 1
809
801
end
@@ -813,54 +805,40 @@ function getindex_I_sorted{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::AbstractVector,
813
805
nzvalS = Array (Tv, nnzS)
814
806
ptrS = 1
815
807
816
- fI = zeros (Ti, m)
817
- for k= 1 : nI
818
- Ik = I[k]
819
- if fI[Ik] == 0 ; fI[Ik] = k; end
820
- end
821
-
822
808
for j = 1 : nJ
823
809
col = J[j]
824
-
825
810
for k = colptrA[col]: colptrA[col+ 1 ]- 1
826
811
rowA = rowvalA[k]
827
-
828
- if I_ref[rowA]
829
- for r = 1 : I_repeat[rowA]
830
- rowvalS[ptrS] = fI[rowA] + r - 1
831
- nzvalS[ptrS] = nzvalA[k]
832
- ptrS += 1
833
- end
812
+ i = rangesearch (I, rowA)
813
+ if i > - 1
814
+ rowvalS[ptrS] = i
815
+ nzvalS[ptrS] = nzvalA[k]
816
+ ptrS += 1
834
817
end
835
-
836
818
end
837
819
end
838
820
839
821
return SparseMatrixCSC (nI, nJ, colptrS, rowvalS, nzvalS)
840
822
end
841
823
842
- # getindex_I_sorted based on merging of sorted lists
843
- function getindex_I_sorted_old {Tv,Ti} (A :: SparseMatrixCSC{Tv,Ti} , I :: Vector , J :: AbstractVector )
844
-
824
+ function getindex_I_sorted {Tv,Ti} (A :: SparseMatrixCSC{Tv,Ti} , I :: AbstractVector , J :: AbstractVector )
825
+ # Sorted vectors for indexing rows.
826
+ # Similar to getindex_general but without the transpose trick.
845
827
(m, n) = size (A)
846
828
nI = length (I)
847
829
nJ = length (J)
848
830
849
831
colptrA = A. colptr; rowvalA = A. rowval; nzvalA = A. nzval
850
-
851
832
colptrS = Array (Ti, nJ+ 1 )
852
833
colptrS[1 ] = 1
853
834
nnzS = 0
854
835
855
836
# Form the structure of the result and compute space
856
837
for j = 1 : nJ
857
838
col = J[j]
858
-
859
- ptrI:: Int = 1
860
-
839
+ ptrI:: Int = 1 # runs through I
861
840
ptrA:: Int = colptrA[col]
862
841
stopA:: Int = colptrA[col+ 1 ]
863
-
864
842
while ptrI <= nI && ptrA < stopA
865
843
rowA = rowvalA[ptrA]
866
844
rowI = I[ptrI]
@@ -874,22 +852,17 @@ function getindex_I_sorted_old{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::Vector, J::A
874
852
ptrI += 1
875
853
end
876
854
end
877
- colptrS[j+ 1 ] = nnzS+ 1
878
-
855
+ colptrS[j+ 1 ] = nnzS+ 1
879
856
end
880
857
881
- fI = find (I)
882
-
883
858
# Populate the values in the result
884
859
rowvalS = Array (Ti, nnzS)
885
860
nzvalS = Array (Tv, nnzS)
886
- ptrS = 0
861
+ ptrS = 1
887
862
888
863
for j = 1 : nJ
889
864
col = J[j]
890
-
891
- ptrI:: Int = 1
892
-
865
+ ptrI:: Int = 1 # runs through I
893
866
ptrA:: Int = colptrA[col]
894
867
stopA:: Int = colptrA[col+ 1 ]
895
868
@@ -902,30 +875,27 @@ function getindex_I_sorted_old{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::Vector, J::A
902
875
elseif rowI < rowA
903
876
ptrI += 1
904
877
else
905
- ptrS += 1
906
- rowvalS[ptrS] = fI[ptrI]
878
+ rowvalS[ptrS] = ptrI
907
879
nzvalS[ptrS] = nzvalA[ptrA]
880
+ ptrS += 1
908
881
ptrI += 1
909
882
end
910
883
end
911
-
912
884
end
913
-
914
885
return SparseMatrixCSC (nI, nJ, colptrS, rowvalS, nzvalS)
915
886
end
916
887
917
- function getindex_general {Tv,Ti} (A:: SparseMatrixCSC{Tv,Ti} , I:: Vector , J:: AbstractVector )
888
+ function getindex_general {Tv,Ti} (A:: SparseMatrixCSC{Tv,Ti} , I:: AbstractVector , J:: AbstractVector )
889
+ # Anything for indexing rows.
890
+ # This sorts I first then does some trick with constructing the transpose.
918
891
(m, n) = size (A)
919
892
nI = length (I)
920
893
nJ = length (J)
921
894
922
895
colptrA = A. colptr; rowvalA = A. rowval; nzvalA = A. nzval
923
-
924
896
nnzS = 0
925
-
926
897
pI = sortperm (I); I = I[pI]
927
898
fI = find (I)
928
-
929
899
W = zeros (Int, nI + 1 ) # Keep row counts
930
900
W[1 ] = 1 # For cumsum later
931
901
@@ -993,36 +963,24 @@ function getindex_general{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, I::Vector, J::Abstra
993
963
# Transpose so that rows are in sorted order and return
994
964
S_T = SparseMatrixCSC (nJ, nI, colptrS_T, rowvalS_T, nzvalS_T)
995
965
return S_T.'
996
-
997
966
end
998
967
999
- # S = A[I, J]
968
+
969
+ # the general case:
1000
970
function getindex {Tv,Ti} (A:: SparseMatrixCSC{Tv,Ti} , I:: AbstractVector , J:: AbstractVector )
1001
- m = size (A, 1 )
1002
- if isa (I, Range)
1003
- if I == 1 : m # whole columns
1004
- return getindex_cols (A, J)
1005
- else # ranges are always sorted, but maybe in reverse
1006
- if step (I)> 0
1007
- return getindex_I_sorted (A, I, J)
1008
- else
1009
- I = [I]
1010
- return getindex_general (A, I, J)
1011
- # todo:
1012
- # return reverse(getindex_I_sorted(A, reverse(I), J))
1013
- end
1014
- end
1015
- else
1016
- if issorted (I)
1017
- return getindex_I_sorted (A, I, J)
1018
- else
1019
- return getindex_general (A, I, J)
1020
- end
971
+ loop_overI_threshold = 0 # ~40 but doesn't matter much.
972
+ if issorted (I)
973
+ return getindex_I_sorted (A, I, J)
974
+ else
975
+ return getindex_general (A, I, J)
1021
976
end
1022
977
end
1023
978
1024
979
# logical getindex
980
+ getindex {Tv,Ti<:Integer} (A:: SparseMatrixCSC{Tv,Ti} , I:: Range{Bool} , J:: AbstractVector{Bool} ) = error (" Cannot index with Range{Bool}" )
981
+ getindex {Tv,Ti<:Integer,T<:Integer} (A:: SparseMatrixCSC{Tv,Ti} , I:: Range{Bool} , J:: AbstractVector{T} ) = error (" Cannot index with Range{Bool}" )
1025
982
983
+ getindex {T<:Integer} (A:: SparseMatrixCSC , I:: Range{T} , J:: AbstractVector{Bool} ) = A[I,find (J)]
1026
984
getindex (A:: SparseMatrixCSC , I:: Integer , J:: AbstractVector{Bool} ) = A[I,find (J)]
1027
985
getindex (A:: SparseMatrixCSC , I:: AbstractVector{Bool} , J:: Integer ) = A[find (I),J]
1028
986
getindex (A:: SparseMatrixCSC , I:: AbstractVector{Bool} , J:: AbstractVector{Bool} ) = A[find (I),find (J)]
0 commit comments