Skip to content

Commit d5eb967

Browse files
committed
add count(itr) (fixes JuliaLang#20403) and throw and error in count if non-boolean values are encountered
1 parent ad5cd7b commit d5eb967

File tree

8 files changed

+37
-3
lines changed

8 files changed

+37
-3
lines changed

NEWS.md

+4
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ This section lists changes that do not have deprecation warnings.
161161
that internally uses twice-precision arithmetic. These two
162162
outcomes exhibit differences in both precision and speed.
163163

164+
* The `count` function no longer sums non-boolean values ([#20404])
165+
164166
Library improvements
165167
--------------------
166168

@@ -231,6 +233,8 @@ Library improvements
231233

232234
* Additional methods for `ones` and `zeros` functions to support the same signature as the `similar` function ([#19635]).
233235

236+
* `count` now has a `count(itr)` method equivalent to `count(identity, itr)` ([#20403]).
237+
234238
* Methods for `map` and `filter` with `Nullable` arguments have been
235239
implemented; the semantics are as if the `Nullable` were a container with
236240
zero or one elements ([#16961]).

base/bitarray.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,7 @@ function countnz(B::BitArray)
15751575
end
15761576
return n
15771577
end
1578+
count(B::BitArray) = countnz(B)
15781579

15791580
# returns the index of the next non-zero element, or 0 if all zeros
15801581
function findnext(B::BitArray, start::Integer)

base/reduce.jl

+15-1
Original file line numberDiff line numberDiff line change
@@ -642,21 +642,35 @@ end
642642

643643
"""
644644
count(p, itr) -> Integer
645+
count(itr) -> Integer
645646
646647
Count the number of elements in `itr` for which predicate `p` returns `true`.
648+
If `p` is omitted, counts the number of `true` elements in `itr` (which
649+
should be a collection of boolean values).
647650
648651
```jldoctest
649652
julia> count(i->(4<=i<=6), [2,3,4,5,6])
650653
3
654+
655+
julia> count([true, false, true, true])
656+
3
651657
```
652658
"""
653659
function count(pred, itr)
654660
n = 0
655661
for x in itr
656-
n += pred(x)
662+
n += pred(x)::Bool
663+
end
664+
return n
665+
end
666+
function count(pred, a::AbstractArray)
667+
n = 0
668+
for i in eachindex(a)
669+
@inbounds n += pred(a[i])::Bool
657670
end
658671
return n
659672
end
673+
count(itr) = count(identity, itr)
660674

661675
"""
662676
countnz(A) -> Integer

base/sparse/sparsematrix.jl

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ julia> nnz(A)
4444
"""
4545
nnz(S::SparseMatrixCSC) = Int(S.colptr[end]-1)
4646
countnz(S::SparseMatrixCSC) = countnz(S.nzval)
47+
count(S::SparseMatrixCSC) = count(S.nzval)
4748

4849
"""
4950
nonzeros(A)

base/sparse/sparsevector.jl

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ length(x::SparseVector) = x.n
3131
size(x::SparseVector) = (x.n,)
3232
nnz(x::SparseVector) = length(x.nzval)
3333
countnz(x::SparseVector) = countnz(x.nzval)
34+
count(x::SparseVector) = count(x.nzval)
3435
nonzeros(x::SparseVector) = x.nzval
3536
nonzeroinds(x::SparseVector) = x.nzind
3637

test/reduce.jl

+12-2
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,18 @@ immutable SomeFunctor end
260260

261261
# count & countnz
262262

263-
@test count(x->x>0, Int[]) == 0
264-
@test count(x->x>0, -3:5) == 5
263+
@test count(x->x>0, Int[]) == count(Bool[]) == 0
264+
@test count(x->x>0, -3:5) == count((-3:5) .> 0) == 5
265+
@test count([true, true, false, true]) == count(BitVector([true, true, false, true])) == 3
266+
@test_throws TypeError count(sqrt, [1])
267+
@test_throws TypeError count([1])
268+
let itr = (x for x in 1:10 if x < 7)
269+
@test count(iseven, itr) == 3
270+
@test_throws TypeError count(itr)
271+
@test_throws TypeError count(sqrt, itr)
272+
end
273+
@test count(iseven(x) for x in 1:10 if x < 7) == 3
274+
@test count(iseven(x) for x in 1:10 if x < -7) == 0
265275

266276
@test countnz(Int[]) == 0
267277
@test countnz(Int[0]) == 0

test/sparse/sparse.jl

+1
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,7 @@ end
806806
FI = Array(I)
807807
@test sparse(FS[FI]) == S[I] == S[FI]
808808
@test sum(S[FI]) + sum(S[!FI]) == sum(S)
809+
@test countnz(I) == count(I)
809810

810811
sumS1 = sum(S)
811812
sumFI = sum(S[FI])

test/sparse/sparsevector.jl

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ let x = spv_x1
2626
@test nonzeros(x) == [1.25, -0.75, 3.5]
2727
end
2828

29+
@test count(SparseVector(8, [2, 5, 6], [true,false,true])) == 2
30+
2931
# full
3032

3133
for (x, xf) in [(spv_x1, x1_full)]

0 commit comments

Comments
 (0)