Skip to content

Commit 9fc9e6c

Browse files
committed
Support 2-arg argmin/argmax/findmin/findmax
xref JuliaLang/julia#35316
1 parent 37000b3 commit 9fc9e6c

File tree

3 files changed

+52
-1
lines changed

3 files changed

+52
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Compat"
22
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
3-
version = "3.25.0"
3+
version = "3.26.0"
44

55
[deps]
66
Base64 = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"

src/Compat.jl

+18
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,24 @@ if VERSION < v"1.2.0-DEV.246"
859859
end
860860
end
861861

862+
if VERSION < v"1.7.0-DEV.119"
863+
# Part of https://github.com/JuliaLang/julia/pull/35316
864+
isunordered(x) = false
865+
isunordered(x::AbstractFloat) = isnan(x)
866+
isunordered(x::Missing) = true
867+
868+
isgreater(x, y) = isunordered(x) || isunordered(y) ? isless(x, y) : isless(y, x)
869+
870+
Base.findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)
871+
_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m)
872+
873+
Base.findmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain)
874+
_rf_findmin((fm, m), (fx, x)) = isgreater(fm, fx) ? (fx, x) : (fm, m)
875+
876+
Base.argmax(f, domain) = findmax(f, domain)[2]
877+
Base.argmin(f, domain) = findmin(f, domain)[2]
878+
end
879+
862880
include("iterators.jl")
863881
include("deprecated.jl")
864882

test/runtests.jl

+33
Original file line numberDiff line numberDiff line change
@@ -833,3 +833,36 @@ end
833833
@test endswith("abc", r"C"i)
834834
@test endswith("abc", r"Bc"i)
835835
end
836+
837+
# https://github.com/JuliaLang/julia/pull/35316
838+
@testset "2arg" begin
839+
@testset "findmin(f, domain)" begin
840+
@test findmin(-, 1:10) == (-10, 10)
841+
@test findmin(identity, [1, 2, 3, missing]) === (missing, missing)
842+
@test findmin(identity, [1, NaN, 3, missing]) === (missing, missing)
843+
@test findmin(identity, [1, missing, NaN, 3]) === (missing, missing)
844+
@test findmin(identity, [1, NaN, 3]) === (NaN, NaN)
845+
@test findmin(identity, [1, 3, NaN]) === (NaN, NaN)
846+
@test all(findmin(cos, 0:π/2:2π) .≈ (-1.0, π))
847+
end
848+
849+
@testset "findmax(f, domain)" begin
850+
@test findmax(-, 1:10) == (-1, 1)
851+
@test findmax(identity, [1, 2, 3, missing]) === (missing, missing)
852+
@test findmax(identity, [1, NaN, 3, missing]) === (missing, missing)
853+
@test findmax(identity, [1, missing, NaN, 3]) === (missing, missing)
854+
@test findmax(identity, [1, NaN, 3]) === (NaN, NaN)
855+
@test findmax(identity, [1, 3, NaN]) === (NaN, NaN)
856+
@test findmax(cos, 0:π/2:2π) == (1.0, 0.0)
857+
end
858+
859+
@testset "argmin(f, domain)" begin
860+
@test argmin(-, 1:10) == 10
861+
@test argmin(sum, Iterators.product(1:5, 1:5)) == (1, 1)
862+
end
863+
864+
@testset "argmax(f, domain)" begin
865+
@test argmax(-, 1:10) == 1
866+
@test argmax(sum, Iterators.product(1:5, 1:5)) == (5, 5)
867+
end
868+
end

0 commit comments

Comments
 (0)