Skip to content

Commit e0bef65

Browse files
KlausCViralBShah
authored andcommittedApr 3, 2019
sparse findnext findprev hash performance improved (#31354)
* sparse findnext findprev hash performance improved * added tests and minor changes
1 parent aee211b commit e0bef65

File tree

2 files changed

+90
-4
lines changed

2 files changed

+90
-4
lines changed
 

‎stdlib/SparseArrays/src/sparsematrix.jl

+71
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,77 @@ function sparse_sortedlinearindices!(I::Vector{Ti}, V::Vector, m::Int, n::Int) w
13641364
return SparseMatrixCSC(m, n, colptr, I, V)
13651365
end
13661366

1367+
# findfirst/next/prev/last
1368+
function _idxfirstnz(A::SparseMatrixCSC, ij::CartesianIndex{2})
1369+
nzr = nzrange(A, ij[2])
1370+
searchk = searchsortedfirst(A.rowval, ij[1], first(nzr), last(nzr), Forward)
1371+
return _idxnextnz(A, searchk)
1372+
end
1373+
1374+
function _idxlastnz(A::SparseMatrixCSC, ij::CartesianIndex{2})
1375+
nzr = nzrange(A, ij[2])
1376+
searchk = searchsortedlast(A.rowval, ij[1], first(nzr), last(nzr), Forward)
1377+
return _idxprevnz(A, searchk)
1378+
end
1379+
1380+
function _idxnextnz(A::SparseMatrixCSC, idx::Integer)
1381+
nnza = nnz(A)
1382+
nzval = nonzeros(A)
1383+
z = zero(eltype(A))
1384+
while idx <= nnza
1385+
nzv = nzval[idx]
1386+
!isequal(nzv, z) && return idx, nzv
1387+
idx += 1
1388+
end
1389+
return zero(idx), z
1390+
end
1391+
1392+
function _idxprevnz(A::SparseMatrixCSC, idx::Integer)
1393+
nzval = nonzeros(A)
1394+
z = zero(eltype(A))
1395+
while idx > 0
1396+
nzv = nzval[idx]
1397+
!isequal(nzv, z) && return idx, nzv
1398+
idx -= 1
1399+
end
1400+
return zero(idx), z
1401+
end
1402+
1403+
function _idx_to_cartesian(A::SparseMatrixCSC, idx::Integer)
1404+
rowval = rowvals(A)
1405+
i = rowval[idx]
1406+
j = searchsortedlast(A.colptr, idx, 1, size(A, 2), Base.Order.Forward)
1407+
return CartesianIndex(i, j)
1408+
end
1409+
1410+
function Base.findnext(pred::Function, A::SparseMatrixCSC, ij::CartesianIndex{2})
1411+
if nnz(A) == length(A) || pred(zero(eltype(A)))
1412+
return invoke(findnext, Tuple{Function,Any,Any}, pred, A, ij)
1413+
end
1414+
idx, nzv = _idxfirstnz(A, ij)
1415+
while idx > 0
1416+
if pred(nzv)
1417+
return _idx_to_cartesian(A, idx)
1418+
end
1419+
idx, nzv = _idxnextnz(A, idx + 1)
1420+
end
1421+
return nothing
1422+
end
1423+
1424+
function Base.findprev(pred::Function, A::SparseMatrixCSC, ij::CartesianIndex{2})
1425+
if nnz(A) == length(A) || pred(zero(eltype(A)))
1426+
return invoke(findprev, Tuple{Function,Any,Any}, pred, A, ij)
1427+
end
1428+
idx, nzv = _idxlastnz(A, ij)
1429+
while idx > 0
1430+
if pred(nzv)
1431+
return _idx_to_cartesian(A, idx)
1432+
end
1433+
idx, nzv = _idxprevnz(A, idx - 1)
1434+
end
1435+
return nothing
1436+
end
1437+
13671438
"""
13681439
sprand([rng],[type],m,[n],p::AbstractFloat,[rfn])
13691440

‎stdlib/SparseArrays/test/sparse.jl

+19-4
Original file line numberDiff line numberDiff line change
@@ -2245,16 +2245,20 @@ end
22452245
@test findprev(!iszero, x,i) == findprev(!iszero, x_sp,i)
22462246
end
22472247

2248-
y = [0 0 0 0 0;
2248+
y = [7 0 0 0 0;
22492249
1 0 1 0 0;
2250-
1 0 0 0 1;
2250+
1 7 0 7 1;
22512251
0 0 1 0 0;
2252-
1 0 1 1 0]
2253-
y_sp = sparse(y)
2252+
1 0 1 1 0.0]
2253+
y_sp = [x == 7 ? -0.0 : x for x in sparse(y)]
2254+
y = Array(y_sp)
2255+
@test isequal(y_sp[1,1], -0.0)
22542256

22552257
for i in keys(y)
22562258
@test findnext(!iszero, y,i) == findnext(!iszero, y_sp,i)
22572259
@test findprev(!iszero, y,i) == findprev(!iszero, y_sp,i)
2260+
@test findnext(iszero, y,i) == findnext(iszero, y_sp,i)
2261+
@test findprev(iszero, y,i) == findprev(iszero, y_sp,i)
22582262
end
22592263

22602264
z_sp = sparsevec(Dict(1=>1, 5=>1, 8=>0, 10=>1))
@@ -2264,6 +2268,17 @@ end
22642268
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
22652269
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
22662270
end
2271+
2272+
w = [ "a" ""; "" "b"]
2273+
w_sp = sparse(w)
2274+
2275+
for i in keys(w)
2276+
@test findnext(!isequal(""), w,i) == findnext(!isequal(""), w_sp,i)
2277+
@test findprev(!isequal(""), w,i) == findprev(!isequal(""), w_sp,i)
2278+
@test findnext(isequal(""), w,i) == findnext(isequal(""), w_sp,i)
2279+
@test findprev(isequal(""), w,i) == findprev(isequal(""), w_sp,i)
2280+
end
2281+
22672282
end
22682283

22692284
# #20711

0 commit comments

Comments
 (0)