Skip to content

Commit 419e660

Browse files
committed
Implement map/map! for sparse matrices and reimplement broadcast/broadcast! over a single sparse matrix in terms of map/map!.
1 parent c80d523 commit 419e660

File tree

3 files changed

+348
-42
lines changed

3 files changed

+348
-42
lines changed

base/sparse/sparse.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
module SparseArrays
44

5-
using Base: ReshapedArray, promote_op, setindex_shape_check, to_shape
5+
using Base: ReshapedArray, promote_op, setindex_shape_check, to_shape, tail
66
using Base.Sort: Forward
77
using Base.LinAlg: AbstractTriangular, PosDefException
88

@@ -24,7 +24,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
2424
vcat, hcat, hvcat, cat, imag, indmax, ishermitian, kron, length, log, log1p, max, min,
2525
maximum, minimum, norm, one, promote_eltype, real, reinterpret, reshape, rot180,
2626
rotl90, rotr90, round, scale!, setindex!, similar, size, transpose, tril,
27-
triu, vec, permute!
27+
triu, vec, permute!, map, map!
2828

2929
import Base.Broadcast: broadcast_indices
3030

base/sparse/sparsematrix.jl

+282-40
Original file line numberDiff line numberDiff line change
@@ -1398,59 +1398,301 @@ end
13981398

13991399
sparse(S::UniformScaling, m::Integer, n::Integer=m) = speye_scaled(S.λ, m, n)
14001400

1401+
## map/map! over sparse matrices
14011402

1402-
## Broadcast operations involving a single sparse matrix and possibly broadcast scalars
1403-
1404-
function broadcast{Tf}(f::Tf, A::SparseMatrixCSC)
1405-
fofzero = f(zero(eltype(A)))
1406-
fpreszero = fofzero == zero(fofzero)
1407-
return fpreszero ? _broadcast_zeropres(f, A) : _broadcast_notzeropres(f, fofzero, A)
1408-
end
1409-
"Returns a `SparseMatrixCSC` storing only the nonzero entries of `broadcast(f, Matrix(A))`."
1410-
function _broadcast_zeropres{Tf}(f::Tf, A::SparseMatrixCSC)
1411-
Bcolptr = similar(A.colptr, A.n + 1)
1412-
Browval = similar(A.rowval, nnz(A))
1413-
Bnzval = similar(A.nzval, Base.Broadcast.promote_eltype_op(f, A), nnz(A))
1414-
Bk = 1
1415-
@inbounds for j in 1:A.n
1416-
Bcolptr[j] = Bk
1403+
# map/map! entry points
1404+
function map!{Tf,N}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
1405+
_checksameshape(C, A, Bs...)
1406+
fofzeros = f(_zeros_eltypes(A, Bs...)...)
1407+
fpreszeros = fofzeros == zero(fofzeros)
1408+
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
1409+
_map_notzeropres!(f, fofzeros, C, A, Bs...)
1410+
end
1411+
function map{Tf,N}(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N})
1412+
_checksameshape(A, Bs...)
1413+
fofzeros = f(_zeros_eltypes(A, Bs...)...)
1414+
fpreszeros = fofzeros == zero(fofzeros)
1415+
maxnnzC = fpreszeros ? _sumnnzs(A, Bs...) : length(A)
1416+
entrytypeC = Base.Broadcast.promote_eltype_op(f, A, Bs...)
1417+
indextypeC = _promote_indtype(A, Bs...)
1418+
Ccolptr = Vector{indextypeC}(A.n + 1)
1419+
Crowval = Vector{indextypeC}(maxnnzC)
1420+
Cnzval = Vector{entrytypeC}(maxnnzC)
1421+
C = SparseMatrixCSC(A.m, A.n, Ccolptr, Crowval, Cnzval)
1422+
return fpreszeros ? _map_zeropres!(f, C, A, Bs...) :
1423+
_map_notzeropres!(f, fofzeros, C, A, Bs...)
1424+
end
1425+
# map/map! entry point helper functions
1426+
@inline _sumnnzs(A) = nnz(A)
1427+
@inline _sumnnzs(A, Bs...) = nnz(A) + _sumnnzs(Bs...)
1428+
@inline _zeros_eltypes(A) = (zero(eltype(A)),)
1429+
@inline _zeros_eltypes(A, Bs...) = (zero(eltype(A)), _zeros_eltypes(Bs...)...)
1430+
@inline _promote_indtype(A) = eltype(A.rowval)
1431+
@inline _promote_indtype(A, Bs...) = promote_type(eltype(A.rowval), _promote_indtype(Bs...))
1432+
@inline _aresameshape(A) = true
1433+
@inline _aresameshape(A, B) = size(A) == size(B)
1434+
@inline _aresameshape(A, B, Cs...) = _aresameshape(A, B) ? _aresameshape(B, Cs...) : false
1435+
@inline _checksameshape(As...) = _aresameshape(As...) || throw(DimensionMismatch("argument shapes must match"))
1436+
1437+
# _map_zeropres!/_map_notzeropres! specialized for a single sparse matrix
1438+
"Stores only the nonzero entries of `map(f, Matrix(A))` in `C`."
1439+
function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC)
1440+
spaceC::Int = min(length(C.rowval), length(C.nzval))
1441+
Ck = 1
1442+
@inbounds for j in 1:C.n
1443+
C.colptr[j] = Ck
14171444
for Ak in nzrange(A, j)
1418-
x = f(A.nzval[Ak])
1419-
if x != 0
1420-
Browval[Bk] = A.rowval[Ak]
1421-
Bnzval[Bk] = x
1422-
Bk += 1
1445+
Cx = f(A.nzval[Ak])
1446+
if Cx != zero(eltype(C))
1447+
if Ck > spaceC
1448+
spaceC = maxnnzC = Ck + nnz(A) - (Ak - 1)
1449+
length(C.rowval) < maxnnzC && resize!(C.rowval, maxnnzC)
1450+
length(C.nzval) < maxnnzC && resize!(C.nzval, maxnnzC)
1451+
end
1452+
C.rowval[Ck] = A.rowval[Ak]
1453+
C.nzval[Ck] = Cx
1454+
Ck += 1
14231455
end
14241456
end
14251457
end
1426-
Bcolptr[A.n + 1] = Bk
1427-
resize!(Browval, Bk - 1)
1428-
resize!(Bnzval, Bk - 1)
1429-
return SparseMatrixCSC(A.m, A.n, Bcolptr, Browval, Bnzval)
1458+
C.colptr[C.n + 1] = Ck
1459+
return C
14301460
end
14311461
"""
1432-
Returns a (dense) `SparseMatrixCSC` with `fillvalue` stored in place of each unstored
1433-
entry in `A` and `f(A[i,j])` stored in place of each stored entry `A[i,j]` in `A`.
1462+
Densifies `C`, storing `fillvalue` in place of each unstored entry in `A` and
1463+
`f(A[i,j])` in place of each stored entry `A[i,j]` in `A`.
14341464
"""
1435-
function _broadcast_notzeropres{Tf}(f::Tf, fillvalue, A::SparseMatrixCSC)
1436-
nnzB = A.m * A.n
1437-
# Build structure
1438-
Bcolptr = similar(A.colptr, A.n + 1)
1439-
copy!(Bcolptr, 1:A.m:(nnzB + 1))
1440-
Browval = similar(A.rowval, nnzB)
1441-
for k in 1:A.m:(nnzB - A.m + 1)
1442-
copy!(Browval, k, 1:A.m)
1465+
function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMatrixCSC)
1466+
nnzC = C.m * C.n
1467+
# Expand C's storage if necessary
1468+
length(C.rowval) < nnzC && resize!(C.rowval, nnzC)
1469+
length(C.nzval) < nnzC && resize!(C.nzval, nnzC)
1470+
# Build C's structure
1471+
copy!(C.colptr, 1:C.m:(nnzC + 1))
1472+
for k in 1:C.m:(nnzC - C.m + 1)
1473+
copy!(C.rowval, k, 1:C.m)
14431474
end
14441475
# Populate values
1445-
Bnzval = fill(fillvalue, nnzB)
1446-
@inbounds for (j, jo) in zip(1:A.n, 0:A.m:(nnzB - 1)), k in nzrange(A, j)
1447-
Bnzval[jo + A.rowval[k]] = f(A.nzval[k])
1476+
fill!(C.nzval, fillvalue)
1477+
@inbounds for (j, jo) in zip(1:C.n, 0:C.m:(nnzC - 1)), Ak in nzrange(A, j)
1478+
Cx = f(A.nzval[Ak])
1479+
Cx != fillvalue && (C.nzval[jo + A.rowval[Ak]] = Cx)
14481480
end
1449-
# NOTE: Combining the fill call into the loop above to avoid multiple sweeps over /
1450-
# nonsequential access of Bnzval does not appear to improve performance
1451-
return SparseMatrixCSC(A.m, A.n, Bcolptr, Browval, Bnzval)
1481+
# NOTE: Combining the fill! above into the loop above to avoid multiple sweeps over /
1482+
# nonsequential access of C.nzval does not appear to improve performance.
1483+
return C
14521484
end
14531485

1486+
# _map_zeropres!/_map_notzeropres! specialized for a pair of sparse matrices
1487+
function _map_zeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::SparseMatrixCSC)
1488+
spaceC::Int = min(length(C.rowval), length(C.nzval))
1489+
rowsentinelA = convert(eltype(A.rowval), C.m + 1)
1490+
rowsentinelB = convert(eltype(B.rowval), C.m + 1)
1491+
Ck = 1
1492+
@inbounds for j in 1:C.n
1493+
C.colptr[j] = Ck
1494+
Ak, stopAk = A.colptr[j], A.colptr[j + 1]
1495+
Bk, stopBk = B.colptr[j], B.colptr[j + 1]
1496+
Ai = Ak < stopAk ? A.rowval[Ak] : rowsentinelA
1497+
Bi = Bk < stopBk ? B.rowval[Bk] : rowsentinelB
1498+
while true
1499+
if Ai == Bi
1500+
Ai == rowsentinelA && break # column complete
1501+
Cx, Ci::eltype(C.rowval) = f(A.nzval[Ak], B.nzval[Bk]), Ai
1502+
Ak += one(Ak); Ai = Ak < stopAk ? A.rowval[Ak] : rowsentinelA
1503+
Bk += one(Bk); Bi = Bk < stopBk ? B.rowval[Bk] : rowsentinelB
1504+
elseif Ai < Bi
1505+
Cx, Ci = f(A.nzval[Ak], zero(eltype(B))), Ai
1506+
Ak += one(Ak); Ai = Ak < stopAk ? A.rowval[Ak] : rowsentinelA
1507+
else # Bi < Ai
1508+
Cx, Ci = f(zero(eltype(A)), B.nzval[Bk]), Bi
1509+
Bk += one(Bk); Bi = Bk < stopBk ? B.rowval[Bk] : rowsentinelB
1510+
end
1511+
# NOTE: The ordering of the conditional chain above impacts which matrices this
1512+
# method performs best for. The above provides good performance all around.
1513+
if Cx != zero(eltype(C))
1514+
if Ck > spaceC
1515+
spaceC = maxnnzC = Ck + (nnz(A) - (Ak - 1)) + (nnz(B) - (Bk - 1))
1516+
length(C.rowval) < maxnnzC && resize!(C.rowval, maxnnzC)
1517+
length(C.nzval) < maxnnzC && resize!(C.nzval, maxnnzC)
1518+
end
1519+
C.rowval[Ck] = Ci
1520+
C.nzval[Ck] = Cx
1521+
Ck += 1
1522+
end
1523+
end
1524+
end
1525+
C.colptr[C.n + 1] = Ck
1526+
return C
1527+
end
1528+
function _map_notzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMatrixCSC, B::SparseMatrixCSC)
1529+
nnzC = C.m * C.n
1530+
# Expand C's storage if necessary
1531+
length(C.rowval) < nnzC && resize!(C.rowval, nnzC)
1532+
length(C.nzval) < nnzC && resize!(C.nzval, nnzC)
1533+
# Build C's structure
1534+
copy!(C.colptr, 1:C.m:(nnzC + 1))
1535+
for k in 1:C.m:(nnzC - C.m + 1)
1536+
copy!(C.rowval, k, 1:C.m)
1537+
end
1538+
# Populate values
1539+
fill!(C.nzval, fillvalue)
1540+
# NOTE: Combining this fill! into the loop below to avoid multiple sweeps over /
1541+
# nonsequential access of C.nzval does not appear to improve performance.
1542+
rowsentinelA = convert(eltype(A.rowval), C.m + 1)
1543+
rowsentinelB = convert(eltype(B.rowval), C.m + 1)
1544+
@inbounds for (j, jo) in zip(1:C.n, 0:C.m:(nnzC - 1))
1545+
Ak, stopAk = A.colptr[j], A.colptr[j + 1]
1546+
Bk, stopBk = B.colptr[j], B.colptr[j + 1]
1547+
Ai = Ak < stopAk ? A.rowval[Ak] : rowsentinelA
1548+
Bi = Bk < stopBk ? B.rowval[Bk] : rowsentinelB
1549+
while true
1550+
if Ai == Bi
1551+
Ai == rowsentinelA && break # column complete
1552+
Cx, Ci::eltype(C.rowval) = f(A.nzval[Ak], B.nzval[Bk]), Ai
1553+
Ak += one(Ak); Ai = Ak < stopAk ? A.rowval[Ak] : rowsentinelA
1554+
Bk += one(Bk); Bi = Bk < stopBk ? B.rowval[Bk] : rowsentinelB
1555+
elseif Ai < Bi
1556+
Cx, Ci = f(A.nzval[Ak], zero(eltype(B))), Ai
1557+
Ak += one(Ak); Ai = Ak < stopAk ? A.rowval[Ak] : rowsentinelA
1558+
else # Bi < Ai
1559+
Cx, Ci = f(zero(eltype(A)), B.nzval[Bk]), Bi
1560+
Bk += one(Bk); Bi = Bk < stopBk ? B.rowval[Bk] : rowsentinelB
1561+
end
1562+
Cx != fillvalue && (C.nzval[jo + Ci] = Cx)
1563+
end
1564+
end
1565+
return C
1566+
end
1567+
1568+
# _map_zeropres!/_map_notzeropres! for more than two sparse matrices
1569+
function _map_zeropres!{Tf,N}(f::Tf, C::SparseMatrixCSC, As::Vararg{SparseMatrixCSC,N})
1570+
spaceC::Int = min(length(C.rowval), length(C.nzval))
1571+
rowsentinel = C.m + 1
1572+
Ck = 1
1573+
stopks = _indforcol_all(1, As)
1574+
@inbounds for j in 1:C.n
1575+
C.colptr[j] = Ck
1576+
ks = stopks
1577+
stopks = _indforcol_all(j + 1, As)
1578+
rows = _rowforind_all(rowsentinel, ks, stopks, As)
1579+
activerow = min(rows...)
1580+
while activerow < rowsentinel
1581+
# activerows = _isactiverow_all(activerow, rows)
1582+
# Cx = f(_gatherargs(activerows, ks, As)...)
1583+
# ks = _updateind_all(activerows, ks)
1584+
# rows = _updaterow_all(rowsentinel, activerows, rows, ks, stopks, As)
1585+
vals, ks, rows = _fusedupdate_all(rowsentinel, activerow, rows, ks, stopks, As)
1586+
Cx = f(vals...)
1587+
if Cx != zero(eltype(C))
1588+
if Ck > spaceC
1589+
spaceC = maxnnzC = Ck + _sumnnzs(As...) - (sum(ks) - N)
1590+
length(C.rowval) < maxnnzC && resize!(C.rowval, maxnnzC)
1591+
length(C.nzval) < maxnnzC && resize!(C.nzval, maxnnzC)
1592+
end
1593+
C.rowval[Ck] = activerow
1594+
C.nzval[Ck] = Cx
1595+
Ck += 1
1596+
end
1597+
activerow = min(rows...)
1598+
end
1599+
end
1600+
C.colptr[C.n + 1] = Ck
1601+
return C
1602+
end
1603+
function _map_notzeropres!{Tf,N}(f::Tf, fillvalue, C::SparseMatrixCSC, As::Vararg{SparseMatrixCSC,N})
1604+
nnzC = C.m * C.n
1605+
# Expand C's storage if necessary
1606+
length(C.rowval) < nnzC && resize!(C.rowval, nnzC)
1607+
length(C.nzval) < nnzC && resize!(C.nzval, nnzC)
1608+
# Build C's structure
1609+
copy!(C.colptr, 1:C.m:(nnzC + 1))
1610+
for k in 1:C.m:(nnzC - C.m + 1)
1611+
copy!(C.rowval, k, 1:C.m)
1612+
end
1613+
# Populate values
1614+
fill!(C.nzval, fillvalue)
1615+
# NOTE: Combining this fill! into the loop below to avoid multiple sweeps over /
1616+
# nonsequential access of C.nzval does not appear to improve performance.
1617+
rowsentinel = C.m + 1
1618+
stopks = _indforcol_all(1, As)
1619+
@inbounds for (j, jo) in zip(1:C.n, 0:C.m:(nnzC - 1))
1620+
ks = stopks
1621+
stopks = _indforcol_all(j + 1, As)
1622+
rows = _rowforind_all(rowsentinel, ks, stopks, As)
1623+
activerow = min(rows...)
1624+
while activerow < rowsentinel
1625+
# activerows = _isactiverow_all(activerow, rows)
1626+
# Cx = f(_gatherargs(activerows, ks, As)...)
1627+
# ks = _updateind_all(activerows, ks)
1628+
# rows = _updaterow_all(rowsentinel, activerows, rows, ks, stopks, As)
1629+
vals, ks, rows = _fusedupdate_all(rowsentinel, activerow, rows, ks, stopks, As)
1630+
Cx = f(vals...)
1631+
Cx != fillvalue && (C.nzval[jo + activerow] = Cx)
1632+
activerow = min(rows...)
1633+
end
1634+
end
1635+
return C
1636+
end
1637+
# helper methods
1638+
@inline _indforcol(j, A) = A.colptr[j]
1639+
@inline _indforcol_all(j, ::Tuple{}) = ()
1640+
@inline _indforcol_all(j, As) = (
1641+
_indforcol(j, first(As)),
1642+
_indforcol_all(j, tail(As))...)
1643+
@inline _rowforind(rowsentinel, k, stopk, A) =
1644+
k < stopk ? A.rowval[k] : convert(eltype(A.rowval), rowsentinel)
1645+
@inline _rowforind_all(rowsentinel, ::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
1646+
@inline _rowforind_all(rowsentinel, ks, stopks, As) = (
1647+
_rowforind(rowsentinel, first(ks), first(stopks), first(As)),
1648+
_rowforind_all(rowsentinel, tail(ks), tail(stopks), tail(As))...)
1649+
# fusing the following defs. avoids a few branches, yielding 5-30% runtime reduction
1650+
# @inline _isactiverow(activerow, row) = row == activerow
1651+
# @inline _isactiverow_all(activerow, ::Tuple{}) = ()
1652+
# @inline _isactiverow_all(activerow, rows) = (
1653+
# _isactiverow(activerow, first(rows)),
1654+
# _isactiverow_all(activerow, tail(rows))...)
1655+
# @inline _gatherarg(isactiverow, k, A) = isactiverow ? A.nzval[k] : zero(eltype(A))
1656+
# @inline _gatherargs(::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
1657+
# @inline _gatherargs(activerows, ks, As) = (
1658+
# _gatherarg(first(activerows), first(ks), first(As)),
1659+
# _gatherargs(tail(activerows), tail(ks), tail(As))...)
1660+
# @inline _updateind(isactiverow, k) = isactiverow ? (k + one(k)) : k
1661+
# @inline _updateind_all(::Tuple{}, ::Tuple{}) = ()
1662+
# @inline _updateind_all(activerows, ks) = (
1663+
# _updateind(first(activerows), first(ks)),
1664+
# _updateind_all(tail(activerows), tail(ks))...)
1665+
# @inline _updaterow(rowsentinel, isrowactive, presrow, k, stopk, A) =
1666+
# isrowactive ? (k < stopk ? A.rowval[k] : oftype(presrow, rowsentinel)) : presrow
1667+
# @inline _updaterow_all(rowsentinel, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
1668+
# @inline _updaterow_all(rowsentinel, activerows, rows, ks, stopks, As) = (
1669+
# _updaterow(rowsentinel, first(activerows), first(rows), first(ks), first(stopks), first(As)),
1670+
# _updaterow_all(rowsentinel, tail(activerows), tail(rows), tail(ks), tail(stopks), tail(As))...)
1671+
@inline function _fusedupdate(rowsentinel, activerow, row, k, stopk, A)
1672+
# returns (val, nextk, nextrow)
1673+
if row == activerow
1674+
nextk = k + one(k)
1675+
(A.nzval[k], nextk, (nextk < stopk ? A.rowval[nextk] : oftype(row, rowsentinel)))
1676+
else
1677+
(zero(eltype(A)), k, row)
1678+
end
1679+
end
1680+
@inline _fusedupdate_all(rowsentinel, activerow, rows, ks, stopks, As) =
1681+
_fusedupdate_all((#=vals=#), (#=nextks=#), (#=nextrows=#), rowsentinel, activerow, rows, ks, stopks, As)
1682+
@inline _fusedupdate_all(vals, nextks, nextrows, rowsent, activerow, ::Tuple{}, ::Tuple{}, ::Tuple{}, ::Tuple{}) =
1683+
(vals, nextks, nextrows)
1684+
@inline function _fusedupdate_all(vals, nextks, nextrows, rowsentinel, activerow, rows, ks, stopks, As)
1685+
val, nextk, nextrow = _fusedupdate(rowsentinel, activerow, first(rows), first(ks), first(stopks), first(As))
1686+
return _fusedupdate_all((vals..., val), (nextks..., nextk), (nextrows..., nextrow),
1687+
rowsentinel, activerow, tail(rows), tail(ks), tail(stopks), tail(As))
1688+
end
1689+
1690+
1691+
## Broadcast operations involving a single sparse matrix and possibly broadcast scalars
1692+
1693+
broadcast{Tf}(f::Tf, A::SparseMatrixCSC) = map(f, A)
1694+
broadcast!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC) = map!(f, C, A)
1695+
14541696
# Cover common broadcast operations involving a single sparse matrix and one or more
14551697
# broadcast scalars.
14561698
#

0 commit comments

Comments
 (0)