Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC/WIP: make setindex! not remove zeros from sparsity pattern #15568

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 35 additions & 107 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2240,19 +2240,7 @@ function setindex!{T,Ti}(A::SparseMatrixCSC{T,Ti}, v, i0::Integer, i1::Integer)
v = convert(T, v)
r1 = Int(A.colptr[i1])
r2 = Int(A.colptr[i1+1]-1)
if v == 0 #either do nothing or delete entry if it exists
if r1 <= r2
r1 = searchsortedfirst(A.rowval, i0, r1, r2, Forward)
if (r1 <= r2) && (A.rowval[r1] == i0)
deleteat!(A.rowval, r1)
deleteat!(A.nzval, r1)
@simd for j = (i1+1):(A.n+1)
@inbounds A.colptr[j] -= 1
end
end
end
return A
end

i = (r1 > r2) ? r1 : searchsortedfirst(A.rowval, i0, r1, r2, Forward)

if (i <= r2) && (A.rowval[i] == i0)
Expand All @@ -2279,8 +2267,7 @@ setindex!(A::SparseMatrixCSC, x, ::Colon, ::Colon) = setindex!(A, x, 1:size(A, 1
setindex!(A::SparseMatrixCSC, x, ::Colon, j::Union{Integer, AbstractVector}) = setindex!(A, x, 1:size(A, 1), j)
setindex!(A::SparseMatrixCSC, x, i::Union{Integer, AbstractVector}, ::Colon) = setindex!(A, x, i, 1:size(A, 2))

setindex!{Tv,T<:Integer}(A::SparseMatrixCSC{Tv}, x::Number, I::AbstractVector{T}, J::AbstractVector{T}) =
(0 == x) ? spdelete!(A, I, J) : spset!(A, convert(Tv,x), I, J)
setindex!{Tv,T<:Integer}(A::SparseMatrixCSC{Tv}, x::Number, I::AbstractVector{T}, J::AbstractVector{T}) = spset!(A, convert(Tv,x), I, J)

function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector{Ti}, J::AbstractVector{Ti})
!issorted(I) && (@inbounds I = I[sortperm(I)])
Expand Down Expand Up @@ -2390,63 +2377,6 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector
return A
end

function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}, J::AbstractVector{Ti})
m, n = size(A)
nnzA = nnz(A)
(nnzA == 0) && (return A)

!issorted(I) && (@inbounds I = I[sortperm(I)])
!issorted(J) && (@inbounds J = J[sortperm(J)])

((I[end] > m) || (J[end] > n)) && throw(DimensionMismatch(""))

colptr = colptrA = A.colptr
rowval = rowvalA = A.rowval
nzval = nzvalA = A.nzval
rowidx = 1
ndel = 0
@inbounds for col in 1:n
rrange = colptr[col]:(colptr[col+1]-1)
(ndel > 0) && (colptrA[col] = colptr[col] - ndel)
if isempty(rrange) || !(col in J)
nincl = length(rrange)
if(ndel > 0) && !isempty(rrange)
copy!(rowvalA, rowidx, rowval, rrange[1], nincl)
copy!(nzvalA, rowidx, nzval, rrange[1], nincl)
end
rowidx += nincl
else
for ridx in rrange
if rowval[ridx] in I
if ndel == 0
colptrA = copy(colptr)
rowvalA = copy(rowval)
nzvalA = copy(nzval)
end
ndel += 1
else
if ndel > 0
rowvalA[rowidx] = rowval[ridx]
nzvalA[rowidx] = nzval[ridx]
end
rowidx += 1
end
end
end
end

if ndel > 0
colptrA[n+1] = rowidx
deleteat!(rowvalA, rowidx:nnzA)
deleteat!(nzvalA, rowidx:nnzA)

A.colptr = colptrA
A.rowval = rowvalA
A.nzval = nzvalA
end
return A
end

setindex!{Tv,Ti,T<:Integer}(A::SparseMatrixCSC{Tv,Ti}, S::Matrix, I::AbstractVector{T}, J::AbstractVector{T}) =
setindex!(A, convert(SparseMatrixCSC{Tv,Ti}, S), I, J)

Expand Down Expand Up @@ -2596,7 +2526,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})

colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval
colptrB = colptrA; rowvalB = rowvalA; nzvalB = nzvalA
nadd = ndel = 0
nadd = 0
bidx = xidx = 1
r1 = r2 = 0

Expand All @@ -2612,7 +2542,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
if r1 <= r2
copylen = searchsortedfirst(rowvalA, row, r1, r2, Forward) - r1
if (copylen > 0)
if (nadd > 0) || (ndel > 0)
if (nadd > 0)
copy!(rowvalB, bidx, rowvalA, r1, copylen)
copy!(nzvalB, bidx, nzvalA, r1, copylen)
end
Expand All @@ -2621,25 +2551,25 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
end
end

# 0: no change, 1: update, 2: delete, 3: add new
mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? ((v == 0) ? 2 : 1) : ((v == 0) ? 0 : 3)
# 0: update, 1: add new
if r1 <= r2 && rowvalA[r1] == row
mode = 0
else
mode = 1
end

if (mode > 1) && (nadd == 0) && (ndel == 0)
if (mode == 1) && (nadd == 0)
# copy storage to take changes
colptrB = copy(colptrA)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
rowvalB = Array(Ti, length(rowvalA)+n); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+n); copy!(nzvalB, 1, nzvalA, 1, r1-1)
end
if mode == 1
if mode == 0
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
r1 += 1
elseif mode == 2
r1 += 1
ndel += 1
elseif mode == 3
elseif mode == 1
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
Expand All @@ -2649,7 +2579,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
end # if I[row, col]
end # for row in 1:A.m

if ((nadd != 0) || (ndel != 0))
if (nadd != 0)
l = r2-r1+1
if l > 0
copy!(rowvalB, bidx, rowvalA, r1, l)
Expand All @@ -2659,8 +2589,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
colptrB[col+1] = bidx

if (xidx > n) && (length(colptrB) > (col+1))
diff = nadd - ndel
colptrB[(col+2):end] = colptrA[(col+2):end] .+ diff
colptrB[(col+2):end] = colptrA[(col+2):end] .+ nadd
r1 = colptrA[col+1]
r2 = colptrA[end]-1
l = r2-r1+1
Expand All @@ -2676,7 +2605,7 @@ function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractMatrix{Bool})
(xidx > n) && break
end # for col in 1:A.n

if (nadd != 0) || (ndel != 0)
if (nadd != 0)
n = length(nzvalB)
if n > (bidx-1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Can this happen anymore?

deleteat!(nzvalB, bidx:n)
Expand All @@ -2694,7 +2623,7 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto

colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval; szA = size(A)
colptrB = colptrA; rowvalB = rowvalA; nzvalB = nzvalA
nadd = ndel = 0
nadd = 0
bidx = aidx = 1

S = issorted(I) ? (1:n) : sortperm(I)
Expand All @@ -2715,8 +2644,8 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
r2 = Int(colptrA[col+1] - 1)

# copy from last position till current column
if (nadd > 0) || (ndel > 0)
colptrB[(lastcol+1):col] = colptrA[(lastcol+1):col] .+ (nadd - ndel)
if (nadd > 0)
colptrB[(lastcol+1):col] = colptrA[(lastcol+1):col] .+ nadd
copylen = r1 - aidx
if copylen > 0
copy!(rowvalB, bidx, rowvalA, aidx, copylen)
Expand All @@ -2733,7 +2662,7 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
if r1 <= r2
copylen = searchsortedfirst(rowvalA, row, r1, r2, Forward) - r1
if (copylen > 0)
if (nadd > 0) || (ndel > 0)
if (nadd > 0)
copy!(rowvalB, bidx, rowvalA, r1, copylen)
copy!(nzvalB, bidx, nzvalA, r1, copylen)
end
Expand All @@ -2743,27 +2672,26 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
end
end

# 0: no change, 1: update, 2: delete, 3: add new
mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? ((v == 0) ? 2 : 1) : ((v == 0) ? 0 : 3)
# 0: update, 1: add new
if r1 <= r2 && rowvalA[r1] == row
mode = 0
else
mode = 1
end

if (mode > 1) && (nadd == 0) && (ndel == 0)
if (mode == 1) && (nadd == 0)
# copy storage to take changes
colptrB = copy(colptrA)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
rowvalB = Array(Ti, length(rowvalA)+n); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+n); copy!(nzvalB, 1, nzvalA, 1, r1-1)
end
if mode == 1
if mode == 0
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
aidx += 1
r1 += 1
elseif mode == 2
r1 += 1
aidx += 1
ndel += 1
elseif mode == 3
elseif mode == 1
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
Expand All @@ -2772,8 +2700,8 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
end

# copy the rest
@inbounds if (nadd > 0) || (ndel > 0)
colptrB[(lastcol+1):end] = colptrA[(lastcol+1):end] .+ (nadd - ndel)
@inbounds if (nadd > 0)
colptrB[(lastcol+1):end] = colptrA[(lastcol+1):end] .+ (nadd)
r1 = colptrA[end]-1
copylen = r1 - aidx + 1
if copylen > 0
Expand Down
28 changes: 22 additions & 6 deletions test/sparsedir/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -521,19 +521,24 @@ for (aa116, ss116) in [(a116, s116), (ad116, sd116)]
end
end

# workaround issue #7197: comment out let-block
#let S = SparseMatrixCSC(3, 3, UInt8[1,1,1,1], UInt8[], Int64[])
S1290 = SparseMatrixCSC(3, 3, UInt8[1,1,1,1], UInt8[], Int64[])

let S1290 = SparseMatrixCSC(3, 3, UInt8[1,1,1,1], UInt8[], Int64[])
S1290[1,1] = 1
S1290[5] = 2
S1290[end] = 3
@test S1290[end] == (S1290[1] + S1290[2,2])
@test 6 == sum(diag(S1290))
@test full(S1290)[[3,1],1] == full(S1290[[3,1],1])
# end
end


# setindex tests
let a = spzeros(5, 5)
a[3,2] = 0.0
@test countnz(a) == 0
@test nnz(a) == 1
end

let a = spzeros(Int, 10, 10)
@test countnz(a) == 0
a[1,:] = 1
Expand All @@ -547,6 +552,9 @@ let a = spzeros(Int, 10, 10)
@test a[1,:] == sparse([1:10;])
a[:,2] = 1:10
@test a[:,2] == sparse([1:10;])
a[:,2] = 0
@test countnz(a) == 9
@test nnz(a) == 19
end

let A = spzeros(Int, 10, 20)
Expand All @@ -559,8 +567,11 @@ let A = spzeros(Int, 10, 20)
A[6:10,11:20] = 20
@test countnz(A) == 100
@test A[6:10,11:20] == 20 * ones(Int, 5, 10)
# Storing zeros in structural nonzeros doesn't modify sparsity pattern
A[6:10,11:20] = 0
@test nnz(A) == 100
A[4:8,8:16] = 15
@test countnz(A) == 121
@test nnz(A) == 121
@test A[4:8,8:16] == 15 * ones(Int, 5, 9)
end

Expand All @@ -587,6 +598,8 @@ let A = speye(Int, 5), I=1:10, X=reshape([trues(10); falses(15)],5,5)
@test A[I] == A[X] == [1,0,0,0,0,0,1,0,0,0]
A[I] = [1:10;]
@test A[I] == A[X] == collect(1:10)
A[I] = zeros(Int, 10)
@test A[I] == A[X] == zeros(Int, 10)
end

let S = sprand(50, 30, 0.5, x->round(Int,rand(x)*100)), I = sprandbool(50, 30, 0.2)
Expand All @@ -603,10 +616,12 @@ let S = sprand(50, 30, 0.5, x->round(Int,rand(x)*100)), I = sprandbool(50, 30, 0
@test (sum(S) + sumFI) == sumS1

S[FI] = 10
nnz_S1 = nnz(S)
@test sum(S) == sumS2 + 10*sum(FI)
S[FI] = 0
nnz_S2 = nnz(S)
@test sum(S) == sumS2

@test nnz_S1 == nnz_S2
S[FI] = [1:sum(FI);]
@test sum(S) == sumS2 + sum(1:sum(FI))
end
Expand Down Expand Up @@ -1291,6 +1306,7 @@ let
x = UpperTriangular(A)*ones(n)
@test UpperTriangular(A)\x ≈ ones(n)
A[2,2] = 0
Base.SparseArrays.dropzeros!(A)
@test_throws LinAlg.SingularException LowerTriangular(A)\ones(n)
@test_throws LinAlg.SingularException UpperTriangular(A)\ones(n)
end
Expand Down