Skip to content

Commit ae4b229

Browse files
committed
Rework replace and replace!
Introduce a new replace!(new::Callable, res::T, A::T, count::Union{Nothing,Int}) method which custom types can implement to support all replace and replace! methods automatically, instead of the current replace!(new::Callable, A::T, count::Int). This offers several advantages: - For arrays, instead of copying the input and then replace elements, we can do the copy and replace operations at the same time, which is quite faster for arrays when count=nothing. - For dicts and sets, copying up-front is still faster as long as most original elements are preserved, but for replace(), we can apply replacements directly instead of storing a them in a temporary vector. - When the LHS of a pair contains a singleton type, we can subtract it from the element type of the result, e.g. Union{T,Missing} becomes T. Also simplify the dispatch logic by removing the internal _replace! method in favor of replace!.
1 parent 5f427c7 commit ae4b229

File tree

2 files changed

+94
-36
lines changed

2 files changed

+94
-36
lines changed

base/set.jl

+85-36
Original file line numberDiff line numberDiff line change
@@ -574,8 +574,14 @@ _copy_oftype(x::AbstractArray{T}, ::Type{T}) where {T} = copy(x)
574574
_copy_oftype(x::AbstractDict{K,V}, ::Type{Pair{K,V}}) where {K,V} = copy(x)
575575
_copy_oftype(x::AbstractSet{T}, ::Type{T}) where {T} = copy(x)
576576

577+
_similar_or_copy(x::Any) = similar(x)
578+
_similar_or_copy(x::Any, ::Type{T}) where {T} = similar(x, T)
579+
# Make a copy on construction since it is faster than inserting elements separately
580+
_similar_or_copy(x::Union{AbstractDict,AbstractSet}) where {T} = copy(x)
581+
_similar_or_copy(x::Union{AbstractDict,AbstractSet}, ::Type{T}) where {T} = _copy_oftype(x, T)
582+
577583
# to make replace/replace! work for a new container type Cont, only
578-
# replace!(new::Callable, A::Cont; count::Integer=typemax(Int))
584+
# replace!(new::Callable, res::Cont, A::Cont; count::Integer=typemax(Int))
579585
# has to be implemented
580586

581587
"""
@@ -600,16 +606,17 @@ julia> replace!(Set([1, 2, 3]), 1=>0)
600606
Set([0, 2, 3])
601607
```
602608
"""
603-
replace!(A, old_new::Pair...; count::Integer=typemax(Int)) = _replace!(A, count, old_new)
609+
replace!(A, old_new::Pair...; count::Union{Integer,Nothing}=nothing) =
610+
replace!(A, A, count, old_new)
604611

605-
function _replace!(A, count::Integer, old_new::Tuple{Vararg{Pair}})
612+
function replace!(res, A, count::Union{Integer,Nothing}, old_new::Tuple{Vararg{Pair}})
606613
@inline function new(x)
607614
for o_n in old_new
608615
isequal(first(o_n), x) && return last(o_n)
609616
end
610617
return x # no replace
611618
end
612-
replace!(new, A, count=count)
619+
replace!(new, res, A, count)
613620
end
614621

615622
"""
@@ -630,7 +637,7 @@ julia> replace!(isodd, A, 0, count=2)
630637
1
631638
```
632639
"""
633-
replace!(pred::Callable, A, new; count::Integer=typemax(Int)) =
640+
replace!(pred::Callable, A, new; count::Union{Integer,Nothing}=nothing) =
634641
replace!(x -> ifelse(pred(x), new, x), A, count=count)
635642

636643
"""
@@ -661,9 +668,14 @@ Set([6, 12])
661668
```
662669
"""
663670
function replace!(new::Callable, A::Union{AbstractArray,AbstractDict,AbstractSet};
664-
count::Integer=typemax(Int))
665-
count < 0 && throw(DomainError(count, "`count` must not be negative"))
666-
count != 0 && _replace!(new, A, min(count, typemax(Int)) % Int)
671+
count::Union{Integer,Nothing}=nothing)
672+
if count === nothing
673+
replace!(new, A, A, nothing)
674+
elseif count < 0
675+
throw(DomainError(count, "`count` must not be negative"))
676+
elseif count != 0
677+
replace!(new, A, A, min(count, typemax(Int)) % Int)
678+
end
667679
A
668680
end
669681

@@ -686,16 +698,33 @@ julia> replace([1, 2, 1, 3], 1=>0, 2=>4, count=2)
686698
3
687699
```
688700
"""
689-
function replace(A, old_new::Pair...; count::Integer=typemax(Int))
701+
function replace(A, old_new::Pair...; count::Union{Integer,Nothing}=nothing)
690702
V = promote_valuetype(old_new...)
691-
T = promote_type(eltype(A), V)
692-
_replace!(_copy_oftype(A, T), count, old_new)
703+
if count isa Nothing
704+
T = promote_type(subtract_singletontype(eltype(A), old_new...), V)
705+
replace!(_similar_or_copy(A, T), A, nothing, old_new)
706+
else
707+
U = promote_type(eltype(A), V)
708+
replace!(_similar_or_copy(A, U), A, min(count, typemax(Int)) % Int, old_new)
709+
end
693710
end
694711

695712
promote_valuetype(x::Pair{K, V}) where {K, V} = V
696713
promote_valuetype(x::Pair{K, V}, y::Pair...) where {K, V} =
697714
promote_type(V, promote_valuetype(y...))
698715

716+
# Subtract singleton types which are going to be replaced
717+
@pure issingletontype(::Type{T}) where {T} = isdefined(T, :instance)
718+
function subtract_singletontype(::Type{T}, x::Pair{K}) where {T, K}
719+
if issingletontype(K) # singleton type
720+
Core.Compiler.typesubtract(T, K)
721+
else
722+
T
723+
end
724+
end
725+
subtract_singletontype(::Type{T}, x::Pair{K}, y::Pair...) where {T, K} =
726+
subtract_singletontype(subtract_singletontype(T, y...), x)
727+
699728
"""
700729
replace(pred::Function, A, new; [count::Integer])
701730
@@ -713,9 +742,10 @@ julia> replace(isodd, [1, 2, 3, 1], 0, count=2)
713742
1
714743
```
715744
"""
716-
function replace(pred::Callable, A, new; count::Integer=typemax(Int))
745+
function replace(pred::Callable, A, new; count::Union{Integer,Nothing}=nothing)
717746
T = promote_type(eltype(A), typeof(new))
718-
replace!(pred, _copy_oftype(A, T), new, count=count)
747+
replace!(x -> ifelse(pred(x), new, x), _similar_or_copy(A, T), A,
748+
count === nothing ? nothing : min(count, typemax(Int)) % Int)
719749
end
720750

721751
"""
@@ -742,7 +772,9 @@ Dict{Int64,Int64} with 2 entries:
742772
1 => 3
743773
```
744774
"""
745-
replace(new::Callable, A; count::Integer=typemax(Int)) = replace!(new, copy(A), count=count)
775+
replace(new::Callable, A; count::Union{Integer,Nothing}=nothing) =
776+
replace!(new, _similar_or_copy(A), A,
777+
count === nothing ? nothing : min(count, typemax(Int)) % Int)
746778

747779
# Handle ambiguities
748780
replace!(a::Callable, b::Pair; count::Integer=-1) = throw(MethodError(replace!, (a, b)))
@@ -757,36 +789,53 @@ replace(a::AbstractString, b::Pair, c::Pair) = throw(MethodError(replace, (a, b,
757789
askey(k, ::AbstractDict) = k.first
758790
askey(k, ::AbstractSet) = k
759791

760-
function _replace!(new::Callable, A::Union{AbstractDict,AbstractSet}, count::Int)
761-
repl = Pair{eltype(A),eltype(A)}[]
792+
function replace!(new::Callable, res::T, A::T,
793+
count::Union{Int,Nothing}) where T<:Union{AbstractDict,AbstractSet}
762794
c = 0
763-
for x in A
764-
y = new(x)
765-
if x !== y
766-
push!(repl, x => y)
767-
c += 1
795+
if res === A
796+
repl = Pair{eltype(A),eltype(A)}[]
797+
for x in A
798+
y = new(x)
799+
if x !== y
800+
push!(repl, x => y)
801+
c += 1
802+
end
803+
c == count && break
804+
end
805+
for oldnew in repl
806+
pop!(res, askey(first(oldnew), res))
807+
end
808+
for oldnew in repl
809+
push!(res, last(oldnew))
810+
end
811+
else
812+
for x in A
813+
y = new(x)
814+
if x !== y
815+
pop!(res, askey(x, res))
816+
push!(res, y)
817+
c += 1
818+
end
819+
c == count && break
768820
end
769-
c == count && break
770-
end
771-
for oldnew in repl
772-
pop!(A, askey(first(oldnew), A))
773-
end
774-
for oldnew in repl
775-
push!(A, last(oldnew))
776821
end
822+
res
777823
end
778824

779-
### AbstractArray
825+
### replace! for AbstractArray
780826

781-
function _replace!(new::Callable, A::AbstractArray, count::Int)
827+
function replace!(new::Callable, res::AbstractArray, A::AbstractArray,
828+
count::Union{Int,Nothing})
782829
c = 0
783830
for i in eachindex(A)
784831
@inbounds Ai = A[i]
785-
y = new(Ai)
786-
if Ai !== y
787-
@inbounds A[i] = y
788-
c += 1
832+
if count === nothing || c < count
833+
y = new(Ai)
834+
@inbounds res[i] = y
835+
c += (Ai !== y)
836+
else
837+
@inbounds res[i] = Ai
789838
end
790-
c == count && break
791839
end
792-
end
840+
res
841+
end

test/sets.jl

+9
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,15 @@ end
548548
x = @inferred replace(x -> x > 1, [1, 2], missing)
549549
@test isequal(x, [1, missing]) && x isa Vector{Union{Int, Missing}}
550550

551+
x = @inferred replace([1, missing], missing=>2)
552+
@test x == [1, 2] && x isa Vector{Int}
553+
x = @inferred replace([1, missing], missing=>2, count=1)
554+
@test x == [1, 2] && x isa Vector{Union{Int, Missing}}
555+
x = @inferred replace([1, missing], missing=>missing)
556+
@test isequal(x, [1, missing]) && x isa Vector{Union{Int, Missing}}
557+
x = @inferred replace([1, missing], missing=>2, 1=>missing)
558+
@test isequal(x, [missing, 2]) && x isa Vector{Union{Int, Missing}}
559+
551560
# test that isequal is used
552561
@test replace([NaN, 1.0], NaN=>0.0) == [0.0, 1.0]
553562
@test replace([1, missing], missing=>0) == [1, 0]

0 commit comments

Comments
 (0)