Skip to content

Commit ec797ef

Browse files
tkluckJeffBezanson
authored andcommitted
Sparse matrix: fix fast implementation of findnext and findprev for cartesian coordinates (#32007)
Revert "sparse findnext findprev hash performance improved (#31354)" This seems to duplicate work from #23317 and it causes performance degradation in the cases that one was designed for. See #31354 (comment) This reverts commit e0bef65. Thanks to @mbauman for spotting this issue in #32007 (comment).
1 parent 826bb8b commit ec797ef

File tree

2 files changed

+33
-96
lines changed

2 files changed

+33
-96
lines changed

stdlib/SparseArrays/src/abstractsparse.jl

+17-7
Original file line numberDiff line numberDiff line change
@@ -63,21 +63,31 @@ end
6363

6464
# The following two methods should be overloaded by concrete types to avoid
6565
# allocating the I = findall(...)
66-
_sparse_findnextnz(v::AbstractSparseArray, i::Integer) = (I = findall(!iszero, v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : nothing)
67-
_sparse_findprevnz(v::AbstractSparseArray, i::Integer) = (I = findall(!iszero, v); n = searchsortedlast(I, i); !iszero(n) ? I[n] : nothing)
68-
69-
function findnext(f::typeof(!iszero), v::AbstractSparseArray, i::Integer)
66+
_sparse_findnextnz(v::AbstractSparseArray, i) = (I = findall(!iszero, v); n = searchsortedfirst(I, i); n<=length(I) ? I[n] : nothing)
67+
_sparse_findprevnz(v::AbstractSparseArray, i) = (I = findall(!iszero, v); n = searchsortedlast(I, i); !iszero(n) ? I[n] : nothing)
68+
69+
function findnext(f::Function, v::AbstractSparseArray, i)
70+
# short-circuit the case f == !iszero because that avoids
71+
# allocating e.g. zero(BigInt) for the f(zero(...)) test.
72+
if nnz(v) == length(v) || (f != (!iszero) && f(zero(eltype(v))))
73+
return invoke(findnext, Tuple{Function,Any,Any}, f, v, i)
74+
end
7075
j = _sparse_findnextnz(v, i)
7176
while j !== nothing && !f(v[j])
72-
j = _sparse_findnextnz(v, j+1)
77+
j = _sparse_findnextnz(v, nextind(v, j))
7378
end
7479
return j
7580
end
7681

77-
function findprev(f::typeof(!iszero), v::AbstractSparseArray, i::Integer)
82+
function findprev(f::Function, v::AbstractSparseArray, i)
83+
# short-circuit the case f == !iszero because that avoids
84+
# allocating e.g. zero(BigInt) for the f(zero(...)) test.
85+
if nnz(v) == length(v) || (f != (!iszero) && f(zero(eltype(v))))
86+
return invoke(findprev, Tuple{Function,Any,Any}, f, v, i)
87+
end
7888
j = _sparse_findprevnz(v, i)
7989
while j !== nothing && !f(v[j])
80-
j = _sparse_findprevnz(v, j-1)
90+
j = _sparse_findprevnz(v, prevind(v, j))
8191
end
8292
return j
8393
end

stdlib/SparseArrays/src/sparsematrix.jl

+16-89
Original file line numberDiff line numberDiff line change
@@ -1312,36 +1312,34 @@ function findnz(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
13121312
return (I, J, V)
13131313
end
13141314

1315-
function _sparse_findnextnz(m::SparseMatrixCSC, i::Integer)
1316-
if i > length(m)
1317-
return nothing
1318-
end
1319-
row, col = Tuple(CartesianIndices(m)[i])
1315+
function _sparse_findnextnz(m::SparseMatrixCSC, ij::CartesianIndex{2})
1316+
row, col = Tuple(ij)
1317+
col > m.n && return nothing
1318+
13201319
lo, hi = m.colptr[col], m.colptr[col+1]
13211320
n = searchsortedfirst(m.rowval, row, lo, hi-1, Base.Order.Forward)
13221321
if lo <= n <= hi-1
1323-
return LinearIndices(m)[m.rowval[n], col]
1322+
return CartesianIndex(m.rowval[n], col)
13241323
end
1325-
nextcol = findnext(c->(c>hi), m.colptr, col+1)
1326-
nextcol === nothing && return nothing
1324+
nextcol = searchsortedfirst(m.colptr, hi + 1, col + 1, length(m.colptr), Base.Order.Forward)
1325+
nextcol > length(m.colptr) && return nothing
13271326
nextlo = m.colptr[nextcol-1]
1328-
return LinearIndices(m)[m.rowval[nextlo], nextcol-1]
1327+
return CartesianIndex(m.rowval[nextlo], nextcol - 1)
13291328
end
13301329

1331-
function _sparse_findprevnz(m::SparseMatrixCSC, i::Integer)
1332-
if iszero(i)
1333-
return nothing
1334-
end
1335-
row, col = Tuple(CartesianIndices(m)[i])
1330+
function _sparse_findprevnz(m::SparseMatrixCSC, ij::CartesianIndex{2})
1331+
row, col = Tuple(ij)
1332+
iszero(col) && return nothing
1333+
13361334
lo, hi = m.colptr[col], m.colptr[col+1]
13371335
n = searchsortedlast(m.rowval, row, lo, hi-1, Base.Order.Forward)
13381336
if lo <= n <= hi-1
1339-
return LinearIndices(m)[m.rowval[n], col]
1337+
return CartesianIndex(m.rowval[n], col)
13401338
end
1341-
prevcol = findprev(c->(c<lo), m.colptr, col-1)
1342-
prevcol === nothing && return nothing
1339+
prevcol = searchsortedlast(m.colptr, lo - 1, 1, col - 1, Base.Order.Forward)
1340+
prevcol < 1 && return nothing
13431341
prevhi = m.colptr[prevcol+1]
1344-
return LinearIndices(m)[m.rowval[prevhi-1], prevcol]
1342+
return CartesianIndex(m.rowval[prevhi-1], prevcol)
13451343
end
13461344

13471345

@@ -1361,77 +1359,6 @@ function sparse_sortedlinearindices!(I::Vector{Ti}, V::Vector, m::Int, n::Int) w
13611359
return SparseMatrixCSC(m, n, colptr, I, V)
13621360
end
13631361

1364-
# findfirst/next/prev/last
1365-
function _idxfirstnz(A::SparseMatrixCSC, ij::CartesianIndex{2})
1366-
nzr = nzrange(A, ij[2])
1367-
searchk = searchsortedfirst(A.rowval, ij[1], first(nzr), last(nzr), Forward)
1368-
return _idxnextnz(A, searchk)
1369-
end
1370-
1371-
function _idxlastnz(A::SparseMatrixCSC, ij::CartesianIndex{2})
1372-
nzr = nzrange(A, ij[2])
1373-
searchk = searchsortedlast(A.rowval, ij[1], first(nzr), last(nzr), Forward)
1374-
return _idxprevnz(A, searchk)
1375-
end
1376-
1377-
function _idxnextnz(A::SparseMatrixCSC, idx::Integer)
1378-
nnza = nnz(A)
1379-
nzval = nonzeros(A)
1380-
z = zero(eltype(A))
1381-
while idx <= nnza
1382-
nzv = nzval[idx]
1383-
!isequal(nzv, z) && return idx, nzv
1384-
idx += 1
1385-
end
1386-
return zero(idx), z
1387-
end
1388-
1389-
function _idxprevnz(A::SparseMatrixCSC, idx::Integer)
1390-
nzval = nonzeros(A)
1391-
z = zero(eltype(A))
1392-
while idx > 0
1393-
nzv = nzval[idx]
1394-
!isequal(nzv, z) && return idx, nzv
1395-
idx -= 1
1396-
end
1397-
return zero(idx), z
1398-
end
1399-
1400-
function _idx_to_cartesian(A::SparseMatrixCSC, idx::Integer)
1401-
rowval = rowvals(A)
1402-
i = rowval[idx]
1403-
j = searchsortedlast(A.colptr, idx, 1, size(A, 2), Base.Order.Forward)
1404-
return CartesianIndex(i, j)
1405-
end
1406-
1407-
function Base.findnext(pred::Function, A::SparseMatrixCSC, ij::CartesianIndex{2})
1408-
if nnz(A) == length(A) || pred(zero(eltype(A)))
1409-
return invoke(findnext, Tuple{Function,Any,Any}, pred, A, ij)
1410-
end
1411-
idx, nzv = _idxfirstnz(A, ij)
1412-
while idx > 0
1413-
if pred(nzv)
1414-
return _idx_to_cartesian(A, idx)
1415-
end
1416-
idx, nzv = _idxnextnz(A, idx + 1)
1417-
end
1418-
return nothing
1419-
end
1420-
1421-
function Base.findprev(pred::Function, A::SparseMatrixCSC, ij::CartesianIndex{2})
1422-
if nnz(A) == length(A) || pred(zero(eltype(A)))
1423-
return invoke(findprev, Tuple{Function,Any,Any}, pred, A, ij)
1424-
end
1425-
idx, nzv = _idxlastnz(A, ij)
1426-
while idx > 0
1427-
if pred(nzv)
1428-
return _idx_to_cartesian(A, idx)
1429-
end
1430-
idx, nzv = _idxprevnz(A, idx - 1)
1431-
end
1432-
return nothing
1433-
end
1434-
14351362
"""
14361363
sprand([rng],[type],m,[n],p::AbstractFloat,[rfn])
14371364

0 commit comments

Comments
 (0)