Skip to content

Commit 4923992

Browse files
alystElOceanografo
authored andcommitted
Allow any f for sum/minimum/maximum(f, v::AbstractSparseVector) (JuliaLang#29884)
* sum/minimum/maximum(f, v::AbstractSparseVector) generalize sum/minimum/maximum(abs/abs2, v::AbstractSparseVector) to arbitrary f * add broken sum(f, [])==0 tests for reference
1 parent d3f4e93 commit 4923992

File tree

2 files changed

+62
-18
lines changed

2 files changed

+62
-18
lines changed

stdlib/SparseArrays/src/sparsevector.jl

+34-18
Original file line numberDiff line numberDiff line change
@@ -1354,34 +1354,50 @@ end
13541354

13551355
### Reduction
13561356

1357+
function _sum(f, x::AbstractSparseVector)
1358+
n = length(x)
1359+
n > 0 || return sum(f, nonzeros(x)) # return zero() of proper type
1360+
m = nnz(x)
1361+
(m == 0 ? n * f(zero(eltype(x))) :
1362+
m == n ? sum(f, nonzeros(x)) :
1363+
Base.add_sum((n - m) * f(zero(eltype(x))), sum(f, nonzeros(x))))
1364+
end
1365+
1366+
sum(f::Union{Function, Type}, x::AbstractSparseVector) = _sum(f, x) # resolve ambiguity
1367+
sum(f, x::AbstractSparseVector) = _sum(f, x)
13571368
sum(x::AbstractSparseVector) = sum(nonzeros(x))
13581369

1359-
function maximum(x::AbstractSparseVector{T}) where T<:Real
1370+
function _maximum(f, x::AbstractSparseVector)
13601371
n = length(x)
1361-
n > 0 || throw(ArgumentError("maximum over empty array is not allowed."))
1372+
if n == 0
1373+
if f === abs || f === abs2
1374+
return zero(eltype(x)) # preserving maximum(abs/abs2, x) behaviour in 1.0.x
1375+
else
1376+
throw(ArgumentError("maximum over an empty array is not allowed."))
1377+
end
1378+
end
13621379
m = nnz(x)
1363-
(m == 0 ? zero(T) :
1364-
m == n ? maximum(nonzeros(x)) :
1365-
max(zero(T), maximum(nonzeros(x))))::T
1380+
(m == 0 ? f(zero(eltype(x))) :
1381+
m == n ? maximum(f, nonzeros(x)) :
1382+
max(f(zero(eltype(x))), maximum(f, nonzeros(x))))
13661383
end
13671384

1368-
function minimum(x::AbstractSparseVector{T}) where T<:Real
1385+
maximum(f::Union{Function, Type}, x::AbstractSparseVector) = _maximum(f, x) # resolve ambiguity
1386+
maximum(f, x::AbstractSparseVector) = _maximum(f, x)
1387+
maximum(x::AbstractSparseVector) = maximum(identity, x)
1388+
1389+
function _minimum(f, x::AbstractSparseVector)
13691390
n = length(x)
1370-
n > 0 || throw(ArgumentError("minimum over empty array is not allowed."))
1391+
n > 0 || throw(ArgumentError("minimum over an empty array is not allowed."))
13711392
m = nnz(x)
1372-
(m == 0 ? zero(T) :
1373-
m == n ? minimum(nonzeros(x)) :
1374-
min(zero(T), minimum(nonzeros(x))))::T
1393+
(m == 0 ? f(zero(eltype(x))) :
1394+
m == n ? minimum(f, nonzeros(x)) :
1395+
min(f(zero(eltype(x))), minimum(f, nonzeros(x))))
13751396
end
13761397

1377-
for f in [:sum, :maximum, :minimum], op in [:abs, :abs2]
1378-
SV = :AbstractSparseVector
1379-
if f === :minimum
1380-
@eval ($f)(::typeof($op), x::$SV{T}) where {T<:Number} = nnz(x) < length(x) ? ($op)(zero(T)) : ($f)($op, nonzeros(x))
1381-
else
1382-
@eval ($f)(::typeof($op), x::$SV) = ($f)($op, nonzeros(x))
1383-
end
1384-
end
1398+
minimum(f::Union{Function, Type}, x::AbstractSparseVector) = _minimum(f, x) # resolve ambiguity
1399+
minimum(f, x::AbstractSparseVector) = _minimum(f, x)
1400+
minimum(x::AbstractSparseVector) = minimum(identity, x)
13851401

13861402
norm(x::SparseVectorUnion, p::Real=2) = norm(nonzeros(x), p)
13871403

stdlib/SparseArrays/test/sparsevector.jl

+28
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,19 @@ end
789789
@test sum(x) == 4.0
790790
@test sum(abs, x) == 5.5
791791
@test sum(abs2, x) == 14.375
792+
@test @inferred(sum(t -> true, x)) === 8
793+
@test @inferred(sum(t -> abs(t) + one(t), x)) == 13.5
794+
795+
@test @inferred(sum(t -> true, spzeros(Float64, 8))) === 8
796+
@test @inferred(sum(t -> abs(t) + one(t), spzeros(Float64, 8))) === 8.0
797+
798+
# reducing over an empty collection
799+
# FIXME sum(f, []) throws, should be fixed both for generic and sparse vectors
800+
@test_broken sum(t -> true, zeros(Float64, 0)) === 0
801+
@test_broken sum(t -> true, spzeros(Float64, 0)) === 0
802+
@test @inferred(sum(abs2, spzeros(Float64, 0))) === 0.0
803+
@test_broken sum(t -> abs(t) + one(t), zeros(Float64, 0)) === 0.0
804+
@test_broken sum(t -> abs(t) + one(t), spzeros(Float64, 0)) === 0.0
792805

793806
@test norm(x) == sqrt(14.375)
794807
@test norm(x, 1) == 5.5
@@ -802,6 +815,12 @@ end
802815
@test minimum(x) == -0.75
803816
@test maximum(abs, x) == 3.5
804817
@test minimum(abs, x) == 0.0
818+
@test @inferred(minimum(t -> true, x)) === true
819+
@test @inferred(maximum(t -> true, x)) === true
820+
@test @inferred(minimum(t -> abs(t) + one(t), x)) == 1.0
821+
@test @inferred(maximum(t -> abs(t) + one(t), x)) == 4.5
822+
@test @inferred(minimum(t -> t + one(t), x)) == 0.25
823+
@test @inferred(maximum(t -> -abs(t) + one(t), x)) == 1.0
805824
end
806825

807826
let x = abs.(spv_x1)
@@ -826,6 +845,15 @@ end
826845
@test minimum(x) == 0.0
827846
@test maximum(abs, x) == 0.0
828847
@test minimum(abs, x) == 0.0
848+
@test @inferred(minimum(t -> true, x)) === true
849+
@test @inferred(maximum(t -> true, x)) === true
850+
@test @inferred(minimum(t -> abs(t) + one(t), x)) === 1.0
851+
@test @inferred(maximum(t -> abs(t) + one(t), x)) === 1.0
852+
end
853+
854+
let x = spzeros(Float64, 0)
855+
@test_throws ArgumentError minimum(t -> true, x)
856+
@test_throws ArgumentError maximum(t -> true, x)
829857
end
830858
end
831859

0 commit comments

Comments
 (0)