Skip to content

Commit f4dbc15

Browse files
stev47mateuszbaran
andauthored
implement Statistics.median (#973)
* implement Statistics.median This implements `Statistics.median` based on the existing bitonic sorting, avoiding unnecessary allocation. While it is generally suboptimal to sort the whole array, the compiler manages to skip some branches since only the middle element(s) are used. Thus `median` is generally faster than `sort`. Using a dedicated median selection network could yield better performance and might be considered for future improvement. * add median tests from Statistics.jl * bump version --------- Co-authored-by: Mateusz Baran <[email protected]>
1 parent b23d668 commit f4dbc15

File tree

4 files changed

+118
-4
lines changed

4 files changed

+118
-4
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StaticArrays"
22
uuid = "90137ffa-7385-5640-81b9-e52037218182"
3-
version = "1.9.9"
3+
version = "1.9.10"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

ext/StaticArraysStatisticsExt.jl

+38-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
module StaticArraysStatisticsExt
22

3-
import Statistics: mean
3+
import Statistics: mean, median
4+
5+
using Base.Order: Forward, ord
6+
using Statistics: median!, middle
47

58
using StaticArrays
6-
using StaticArrays: _InitialValue, _reduce, _mapreduce
9+
using StaticArrays: BitonicSort, _InitialValue, _reduce, _mapreduce, _bitonic_sort_limit, _sort
710

811
_mean_denom(a, ::Colon) = length(a)
912
_mean_denom(a, dims::Int) = size(a, dims)
@@ -12,4 +15,37 @@ _mean_denom(a, ::Val{D}) where {D} = size(a, D)
1215
@inline mean(a::StaticArray; dims=:) = _reduce(+, a, dims) / _mean_denom(a, dims)
1316
@inline mean(f::Function, a::StaticArray; dims=:) = _mapreduce(f, +, dims, _InitialValue(), Size(a), a) / _mean_denom(a, dims)
1417

18+
@inline function median(a::StaticArray; dims = :)
19+
if dims == Colon()
20+
median(vec(a))
21+
else
22+
# FIXME: Implement `mapslices` correctly on `StaticArray` to remove
23+
# this fallback.
24+
median(Array(a); dims)
25+
end
26+
end
27+
28+
@inline function median(a::StaticVector)
29+
(isimmutable(a) && length(a) <= _bitonic_sort_limit) ||
30+
return median!(Base.copymutable(a))
31+
32+
# following Statistics.median
33+
isempty(a) &&
34+
throw(ArgumentError("median of empty vector is undefined, $(repr(a))"))
35+
eltype(a) >: Missing && any(ismissing, a) &&
36+
return missing
37+
nanix = findfirst(x -> x isa Number && isnan(x), a)
38+
isnothing(nanix) ||
39+
return a[nanix]
40+
41+
order = ord(isless, identity, nothing, Forward)
42+
sa = _sort(Tuple(a), BitonicSort, order)
43+
44+
n = length(a)
45+
# sa is 1-indexed
46+
return isodd(n) ?
47+
middle(sa[n ÷ 2 + 1]) :
48+
middle(sa[n ÷ 2], sa[n ÷ 2 + 1])
49+
end
50+
1551
end # module

src/sort.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ const BitonicSort = BitonicSortAlg()
99

1010
# BitonicSort has non-optimal asymptotic behaviour, so we define a cutoff
1111
# length. This also prevents compilation time to skyrocket for larger vectors.
12+
const _bitonic_sort_limit = 20
1213
defalg(a::StaticVector) =
13-
isimmutable(a) && length(a) <= 20 ? BitonicSort : QuickSort
14+
isimmutable(a) && length(a) <= _bitonic_sort_limit ? BitonicSort : QuickSort
1415

1516
@inline function sort(a::StaticVector;
1617
alg::Algorithm = defalg(a),

test/sort.jl

+77
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using StaticArrays, Test
2+
using Statistics: Statistics, median, median!, middle
23

34
@testset "sort" begin
45

@@ -30,4 +31,80 @@ using StaticArrays, Test
3031
@test sortperm(SA[1, 1, 1, 0]) == SA[4, 1, 2, 3]
3132
end
3233

34+
@testset "median" begin
35+
@test_throws ArgumentError median(SA[])
36+
@test ismissing(median(SA[1, missing]))
37+
@test isnan(median(SA[1., NaN]))
38+
39+
@testset for T in (Int, Float64)
40+
for N in (1, 2, 3, 10, 20, 30)
41+
v = rand(SVector{N,T})
42+
mref = median(Vector(v))
43+
44+
@test @inferred(median(v) == mref)
45+
end
46+
end
47+
48+
# Tests based on upstream `Statistics.jl`.
49+
# https://github.com/JuliaStats/Statistics.jl/blob/d49c2bf4f81e1efb4980a35fe39c815ef8396297/test/runtests.jl#L31-L92
50+
@test median(SA[1.]) === 1.
51+
@test median(SA[1.,3]) === 2.
52+
@test median(SA[1.,3,2]) === 2.
53+
54+
@test median(SA[1,3,2]) === 2.0
55+
@test median(SA[1,3,2,4]) === 2.5
56+
57+
@test median(SA[0.0,Inf]) == Inf
58+
@test median(SA[0.0,-Inf]) == -Inf
59+
@test median(SA[0.,Inf,-Inf]) == 0.0
60+
@test median(SA[1.,-1.,Inf,-Inf]) == 0.0
61+
@test isnan(median(SA[-Inf,Inf]))
62+
63+
X = SA[2 3 1 -1; 7 4 5 -4]
64+
@test all(median(X, dims=2) .== SA[1.5, 4.5])
65+
@test all(median(X, dims=1) .== SA[4.5 3.5 3.0 -2.5])
66+
@test X == SA[2 3 1 -1; 7 4 5 -4] # issue #17153
67+
68+
@test_throws ArgumentError median(SA[])
69+
@test isnan(median(SA[NaN]))
70+
@test isnan(median(SA[0.0,NaN]))
71+
@test isnan(median(SA[NaN,0.0]))
72+
@test isnan(median(SA[NaN,0.0,1.0]))
73+
@test isnan(median(SA{Any}[NaN,0.0,1.0]))
74+
@test isequal(median(SA[NaN 0.0; 1.2 4.5], dims=2), reshape(SA[NaN; 2.85], 2, 1))
75+
76+
# the specific NaN value is propagated from the input
77+
@test median(SA[NaN]) === NaN
78+
@test median(SA[0.0,NaN]) === NaN
79+
@test median(SA[0.0,NaN,NaN]) === NaN
80+
@test median(SA[-NaN]) === -NaN
81+
@test median(SA[0.0,-NaN]) === -NaN
82+
@test median(SA[0.0,-NaN,-NaN]) === -NaN
83+
84+
@test ismissing(median(SA[1, missing]))
85+
@test ismissing(median(SA[1, 2, missing]))
86+
@test ismissing(median(SA[NaN, 2.0, missing]))
87+
@test ismissing(median(SA[NaN, missing]))
88+
@test ismissing(median(SA[missing, NaN]))
89+
@test ismissing(median(SA{Any}[missing, 2.0, 3.0, 4.0, NaN]))
90+
@test median(skipmissing(SA[1, missing, 2])) === 1.5
91+
92+
@test median!(Base.copymutable(SA[1 2 3 4])) == 2.5
93+
@test median!(Base.copymutable(SA[1 2; 3 4])) == 2.5
94+
95+
@test @inferred(median(SA{Float16}[1, 2, NaN])) === Float16(NaN)
96+
@test @inferred(median(SA{Float16}[1, 2, 3])) === Float16(2)
97+
@test @inferred(median(SA{Float32}[1, 2, NaN])) === NaN32
98+
@test @inferred(median(SA{Float32}[1, 2, 3])) === 2.0f0
99+
100+
# custom type implementing minimal interface
101+
struct A
102+
x
103+
end
104+
Statistics.middle(x::A, y::A) = A(middle(x.x, y.x))
105+
Base.isless(x::A, y::A) = isless(x.x, y.x)
106+
@test median(SA[A(1), A(2)]) === A(1.5)
107+
@test median(SA{Any}[A(1), A(2)]) === A(1.5)
108+
end
109+
33110
end

0 commit comments

Comments
 (0)