Skip to content

Commit 93cdc0e

Browse files
committed
Make use of fast sparse outer products from Julia (closes #4)
The feature was added to Julia in time for v1.2 in JuliaLang/julia#24980, so get rid of the custom `outer()` method here and rewrite `quadprod()` in terms of just standard matrix methods. Julia v1.2 is the minimum-supported version at this point, so no need to worry about backporting the functionality. In the future, this function may yet still go away since the implementation is nearly trivial at this point, but that can be a follow-up PR.
1 parent c229f2d commit 93cdc0e

File tree

2 files changed

+7
-88
lines changed

2 files changed

+7
-88
lines changed

src/numerics.jl

+3-86
Original file line numberDiff line numberDiff line change
@@ -1,100 +1,17 @@
11
using SparseArrays
22

3-
"""
4-
Computes the outer product between a given column of a sparse matrix and a vector.
5-
"""
6-
function outer end
7-
8-
"""
9-
outer(A::SparseMatrixCSC, n::Integer, w::AbstractVector)
10-
11-
Performs the equivalent of ``\\vec a_n \\vec w^\\dagger`` where ``\\vec a_n`` is the
12-
column `A[:,n]`.
13-
"""
14-
function outer(A::SparseMatrixCSC{Tv,Ti}, n::Integer, w::AbstractVector{Tv}) where {Tv,Ti}
15-
colptrn = nzrange(A, n)
16-
rowvalA = rowvals(A)
17-
nzvalsA = nonzeros(A)
18-
19-
nnza = length(colptrn)
20-
nnzw = length(w)
21-
numnz = nnza * nnzw
22-
23-
colptr = Vector{Ti}(undef, nnzw+1)
24-
rowval = Vector{Ti}(undef, numnz)
25-
nzvals = Vector{Tv}(undef, numnz)
26-
27-
idx = 0
28-
@inbounds for jj = 1:nnzw
29-
colptr[jj] = idx + 1
30-
31-
wv = conj(w[jj])
32-
iszero(wv) && continue
33-
34-
for ii = colptrn
35-
idx += 1
36-
rowval[idx] = rowvalA[ii] # copy row index from A
37-
nzvals[idx] = wv * nzvalsA[ii] # outer product values
38-
end
39-
end
40-
@inbounds colptr[nnzw+1] = idx + 1
41-
return SparseMatrixCSC(size(A,1), nnzw, colptr, rowval, nzvals)
42-
end
43-
44-
"""
45-
outer(w::AbstractVector, A::SparseMatrixCSC, n::Integer)
46-
47-
Performs the equivalent of ``\\vec w \\vec{a}_n^\\dagger`` where ``\\vec a_n`` is the
48-
column `A[:,n]`.
49-
"""
50-
function outer(w::AbstractVector{Tv}, A::SparseMatrixCSC{Tv,Ti}, n::Integer) where {Tv,Ti}
51-
colptrn = nzrange(A, n)
52-
rowvalA = rowvals(A)
53-
nzvalsA = nonzeros(A)
54-
55-
nnza = length(colptrn)
56-
nnzw = length(w)
57-
numnz = nnza * nnzw
58-
59-
colptr = zeros(Ti, size(A,1)+1)
60-
rowval = Vector{Ti}(undef, numnz)
61-
nzvals = Vector{Tv}(undef, numnz)
62-
63-
idx = 0
64-
@inbounds colptr[1] = 1 # col 1 always at index 1
65-
@inbounds for jj = colptrn
66-
av = conj(nzvalsA[jj])
67-
rv = rowvalA[jj]
68-
69-
for ii = 1:nnzw
70-
wv = w[ii]
71-
iszero(wv) && continue
72-
73-
idx += 1
74-
colptr[rv+1] += 1 # count num of entries in column
75-
rowval[idx] = ii
76-
nzvals[idx] = w[ii] * av # outer product values
77-
end
78-
end
79-
cumsum!(colptr, colptr) # offsets are sum of all previous
80-
81-
return SparseMatrixCSC(nnzw, size(A,1), colptr, rowval, nzvals)
82-
end
83-
843
"""
854
quadprod(A, b, n, dir=:col)
865
876
Computes the quadratic product ``ABA^T`` efficiently for the case where ``B`` is all zero
887
except for the `n`th column or row vector `b`, for `dir = :col` or `dir = :row`,
898
respectively.
909
"""
91-
function quadprod(A, b, n, dir::Symbol=:col)
10+
@inline function quadprod(A, b, n, dir::Symbol=:col)
9211
if dir == :col
93-
w = A * b
94-
return outer(w, A, n)
12+
return (A * sparse(b)) * view(A, :, n)'
9513
elseif dir == :row
96-
w = A * b
97-
return outer(A, n, w)
14+
return view(A, :, n) * (A * sparse(b))'
9815
else
9916
error("Unrecognized direction `dir = $(repr(dir))`.")
10017
end

test/numerics.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Br = sparse(fill(i,n), collect(1:n), b, n, n)
1616
bt = convert(Vector{T}, b)
1717
Bct = convert(SparseMatrixCSC{T}, Bc)
1818
Brt = convert(SparseMatrixCSC{T}, Br)
19-
@test At * Bct * At' == @inferred quadprod(At, bt, i, :col)
20-
@test At * Brt * At' @inferred quadprod(At, bt, i, :row)
19+
@test At * Bct * At' == quadprod(At, bt, i, :col)
20+
@test @inferred(quadprod(At, bt, i, :col)) isa SparseMatrixCSC
21+
@test At * Brt * At' quadprod(At, bt, i, :row)
22+
@test @inferred(quadprod(At, bt, i, :row)) isa SparseMatrixCSC
2123
end

0 commit comments

Comments
 (0)