Skip to content

Commit d16d994

Browse files
pabloferzstevengj
authored andcommitted
Restructure of the promotion mechanism for broadcast (#18642)
* Restructure the promotion mechanism for broadcast * More broadcast tests * Use broadcast for element wise operators where appropriate
1 parent 410b39c commit d16d994

8 files changed

+93
-56
lines changed

base/abstractarray.jl

+1-5
Original file line numberDiff line numberDiff line change
@@ -1762,12 +1762,8 @@ end
17621762
# These are needed because map(eltype, As) is not inferrable
17631763
promote_eltype_op(::Any) = (@_pure_meta; Any)
17641764
promote_eltype_op(op, A) = (@_pure_meta; promote_op(op, eltype(A)))
1765-
promote_eltype_op{T}(op, ::AbstractArray{T}) = (@_pure_meta; promote_op(op, T))
1766-
promote_eltype_op{T}(op, ::AbstractArray{T}, A) = (@_pure_meta; promote_op(op, T, eltype(A)))
1767-
promote_eltype_op{T}(op, A, ::AbstractArray{T}) = (@_pure_meta; promote_op(op, eltype(A), T))
1768-
promote_eltype_op{R,S}(op, ::AbstractArray{R}, ::AbstractArray{S}) = (@_pure_meta; promote_op(op, R, S))
17691765
promote_eltype_op(op, A, B) = (@_pure_meta; promote_op(op, eltype(A), eltype(B)))
1770-
promote_eltype_op(op, A, B, C, D...) = (@_pure_meta; promote_eltype_op(op, promote_eltype_op(op, A, B), C, D...))
1766+
promote_eltype_op(op, A, B, C, D...) = (@_pure_meta; promote_eltype_op(op, eltype(A), promote_eltype_op(op, B, C, D...)))
17711767

17721768
## 1 argument
17731769

base/broadcast.jl

+47-22
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
module Broadcast
44

55
using Base.Cartesian
6-
using Base: promote_eltype_op, linearindices, tail, OneTo, to_shape,
6+
using Base: promote_eltype_op, _default_eltype, linearindices, tail, OneTo, to_shape,
77
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache
88
import Base: .+, .-, .*, ./, .\, .//, .==, .<, .!=, .<=, , .%, .<<, .>>, .^
99
import Base: broadcast
@@ -16,7 +16,7 @@ export broadcast_getindex, broadcast_setindex!
1616
broadcast(f) = f()
1717
@inline broadcast(f, x::Number...) = f(x...)
1818
@inline broadcast{N}(f, t::NTuple{N}, ts::Vararg{NTuple{N}}) = map(f, t, ts...)
19-
@inline broadcast(f, As::AbstractArray...) = broadcast_t(f, promote_eltype_op(f, As...), As...)
19+
@inline broadcast(f, As::AbstractArray...) = broadcast_c(f, Array, As...)
2020

2121
# special cases for "X .= ..." (broadcast!) assignments
2222
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
@@ -127,14 +127,14 @@ Base.@propagate_inbounds _broadcast_getindex(::Any, A, I) = A[I]
127127
## Broadcasting core
128128
# nargs encodes the number of As arguments (which matches the number
129129
# of keeps). The first two type parameters are to ensure specialization.
130-
@generated function _broadcast!{K,ID,AT,nargs}(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}})
130+
@generated function _broadcast!{K,ID,AT,nargs}(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}}, iter)
131131
quote
132132
$(Expr(:meta, :noinline))
133133
# destructure the keeps and As tuples
134134
@nexprs $nargs i->(A_i = As[i])
135135
@nexprs $nargs i->(keep_i = keeps[i])
136136
@nexprs $nargs i->(Idefault_i = Idefaults[i])
137-
@simd for I in CartesianRange(indices(B))
137+
@simd for I in iter
138138
# reverse-broadcast the indices
139139
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
140140
# extract array values
@@ -148,7 +148,7 @@ end
148148

149149
# For BitArray outputs, we cache the result in a "small" Vector{Bool},
150150
# and then copy in chunks into the output
151-
@generated function _broadcast!{K,ID,AT,nargs}(f, B::BitArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}})
151+
@generated function _broadcast!{K,ID,AT,nargs}(f, B::BitArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}}, iter)
152152
quote
153153
$(Expr(:meta, :noinline))
154154
# destructure the keeps and As tuples
@@ -159,7 +159,7 @@ end
159159
Bc = B.chunks
160160
ind = 1
161161
cind = 1
162-
@simd for I in CartesianRange(indices(B))
162+
@simd for I in iter
163163
# reverse-broadcast the indices
164164
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
165165
# extract array values
@@ -193,12 +193,12 @@ as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
193193
shape = indices(B)
194194
check_broadcast_indices(shape, As...)
195195
keeps, Idefaults = map_newindexer(shape, As)
196-
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs})
197-
B
196+
iter = CartesianRange(shape)
197+
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter)
198+
return B
198199
end
199200

200201
# broadcast with computed element type
201-
202202
@generated function _broadcast!{K,ID,AT,nargs}(f, B::AbstractArray, keeps::K, Idefaults::ID, As::AT, ::Type{Val{nargs}}, iter, st, count)
203203
quote
204204
$(Expr(:meta, :noinline))
@@ -233,12 +233,8 @@ end
233233
end
234234
end
235235

236-
function broadcast_t(f, ::Type{Any}, As...)
237-
shape = broadcast_indices(As...)
238-
iter = CartesianRange(shape)
239-
if isempty(iter)
240-
return similar(Array{Any}, shape)
241-
end
236+
# broadcast methods that dispatch on the type found by inference
237+
function broadcast_t(f, ::Type{Any}, shape, iter, As...)
242238
nargs = length(As)
243239
keeps, Idefaults = map_newindexer(shape, As)
244240
st = start(iter)
@@ -248,17 +244,46 @@ function broadcast_t(f, ::Type{Any}, As...)
248244
B[I] = val
249245
return _broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter, st, 1)
250246
end
247+
@inline function broadcast_t(f, T, shape, iter, As...)
248+
B = similar(Array{T}, shape)
249+
nargs = length(As)
250+
keeps, Idefaults = map_newindexer(shape, As)
251+
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter)
252+
return B
253+
end
251254

252-
@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_indices(As...)), As...)
253-
255+
# broadcast method that uses inference to find the type, but preserves abstract
256+
# container types when possible (used by binary elementwise operators)
257+
@inline broadcast_elwise_op(f, As...) =
258+
broadcast!(f, similar(Array{promote_eltype_op(f, As...)}, broadcast_indices(As...)), As...)
259+
260+
ftype(f, A) = typeof(a -> f(a))
261+
ftype(f, A...) = typeof(a -> f(a...))
262+
ftype(T::DataType, A) = Type{T}
263+
ftype(T::DataType, A...) = Type{T}
264+
ziptype(A) = Tuple{eltype(A)}
265+
ziptype(A, B) = Iterators.Zip2{Tuple{eltype(A)}, Tuple{eltype(B)}}
266+
@inline ziptype(A, B, C, D...) = Iterators.Zip{Tuple{eltype(A)}, ziptype(B, C, D...)}
267+
268+
# broadcast methods that dispatch on the type of the final container
269+
@inline function broadcast_c(f, ::Type{Array}, As...)
270+
T = _default_eltype(Base.Generator{ziptype(As...), ftype(f, As...)})
271+
shape = broadcast_indices(As...)
272+
iter = CartesianRange(shape)
273+
if isleaftype(T)
274+
return broadcast_t(f, T, shape, iter, As...)
275+
end
276+
if isempty(iter)
277+
return similar(Array{T}, shape)
278+
end
279+
return broadcast_t(f, Any, shape, iter, As...)
280+
end
254281
function broadcast_c(f, ::Type{Tuple}, As...)
255282
shape = broadcast_indices(As...)
256-
check_broadcast_indices(shape, As...)
257283
n = length(shape[1])
258284
return ntuple(k->f((_broadcast_getindex(A, k) for A in As)...), n)
259285
end
260286
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
261-
@inline broadcast_c(f, ::Type{Array}, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...)
262287

263288
"""
264289
broadcast(f, As...)
@@ -441,10 +466,10 @@ end
441466
## elementwise operators ##
442467

443468
for op in (:÷, :%, :<<, :>>, :-, :/, :\, ://, :^)
444-
@eval $(Symbol(:., op))(A::AbstractArray, B::AbstractArray) = broadcast($op, A, B)
469+
@eval $(Symbol(:., op))(A::AbstractArray, B::AbstractArray) = broadcast_elwise_op($op, A, B)
445470
end
446-
.+(As::AbstractArray...) = broadcast(+, As...)
447-
.*(As::AbstractArray...) = broadcast(*, As...)
471+
.+(As::AbstractArray...) = broadcast_elwise_op(+, As...)
472+
.*(As::AbstractArray...) = broadcast_elwise_op(*, As...)
448473

449474
# ## element-wise comparison operators returning BitArray ##
450475

base/deprecated.jl

+11
Original file line numberDiff line numberDiff line change
@@ -1035,6 +1035,17 @@ end))
10351035
@deprecate_binding cycle Iterators.cycle
10361036
@deprecate_binding repeated Iterators.repeated
10371037

1038+
# promote_op method where the operator is also a type
1039+
function promote_op(op::Type, Ts::Type...)
1040+
depwarn("promote_op(op::Type, ::Type...) is deprecated as it is no " *
1041+
"longer needed in Base. If you need its functionality, consider " *
1042+
"defining it locally.", :promote_op)
1043+
if isdefined(Core, :Inference)
1044+
return Core.Inference.return_type(op, Tuple{Ts...})
1045+
end
1046+
return op
1047+
end
1048+
10381049
# NOTE: Deprecation of Channel{T}() is implemented in channels.jl.
10391050
# To be removed from there when 0.6 deprecations are removed.
10401051

base/nullable.jl

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ convert( ::Type{Nullable }, ::Void) = Nullable{Union{}}()
5454
promote_rule{S,T}(::Type{Nullable{S}}, ::Type{T}) = Nullable{promote_type(S, T)}
5555
promote_rule{S,T}(::Type{Nullable{S}}, ::Type{Nullable{T}}) = Nullable{promote_type(S, T)}
5656
promote_op{S,T}(op::Any, ::Type{Nullable{S}}, ::Type{Nullable{T}}) = Nullable{promote_op(op, S, T)}
57+
promote_op{S,T}(op::Type, ::Type{Nullable{S}}, ::Type{Nullable{T}}) = Nullable{promote_op(op, S, T)}
5758

5859
function show{T}(io::IO, x::Nullable{T})
5960
if get(io, :compact, false)

base/promotion.jl

+9-20
Original file line numberDiff line numberDiff line change
@@ -217,34 +217,23 @@ max(x::Real, y::Real) = max(promote(x,y)...)
217217
min(x::Real, y::Real) = min(promote(x,y)...)
218218
minmax(x::Real, y::Real) = minmax(promote(x, y)...)
219219

220-
# "Promotion" that takes a function into account. These are meant to be
221-
# used mainly by broadcast methods, so it is advised against overriding them
222-
if isdefined(Core, :Inference)
223-
function _promote_op(op, T::ANY)
224-
G = Tuple{Generator{Tuple{T},typeof(op)}}
225-
return Core.Inference.return_type(first, G)
226-
end
227-
function _promote_op(op, R::ANY, S::ANY)
228-
F = typeof(a -> op(a...))
229-
G = Tuple{Generator{Iterators.Zip2{Tuple{R},Tuple{S}},F}}
230-
return Core.Inference.return_type(first, G)
231-
end
232-
else
233-
_promote_op(::ANY...) = (@_pure_meta; Any)
234-
end
220+
# "Promotion" that takes a function into account and tries to preserve
221+
# non-concrete types. These are meant to be used mainly by elementwise
222+
# operations, so it is advised against overriding them
235223
_default_type(T::Type) = (@_pure_meta; T)
236224

237225
promote_op(::Any...) = (@_pure_meta; Any)
238-
promote_op(T::Type, ::Any) = (@_pure_meta; T)
239-
promote_op(T::Type, ::Type) = (@_pure_meta; T) # To handle ambiguities
240-
# Promotion that tries to preserve non-concrete types
241226
function promote_op{S}(f, ::Type{S})
242-
T = _promote_op(f, _default_type(S))
227+
@_pure_meta
228+
Z = Tuple{_default_type(S)}
229+
T = _default_eltype(Generator{Z, typeof(a -> f(a))})
243230
isleaftype(S) && return isleaftype(T) ? T : Any
244231
return typejoin(S, T)
245232
end
246233
function promote_op{R,S}(f, ::Type{R}, ::Type{S})
247-
T = _promote_op(f, _default_type(R), _default_type(S))
234+
@_pure_meta
235+
Z = Iterators.Zip2{Tuple{_default_type(R)}, Tuple{_default_type(S)}}
236+
T = _default_eltype(Generator{Z, typeof(a -> f(a...))})
248237
isleaftype(R) && isleaftype(S) && return isleaftype(T) ? T : Any
249238
return typejoin(R, S, T)
250239
end

base/sparse/sparsematrix.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1740,7 +1740,7 @@ broadcast{Tv1,Ti1,Tv2,Ti2}(f::Function, A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::Spar
17401740
broadcast!(f, spzeros(promote_type(Tv1, Tv2), promote_type(Ti1, Ti2), to_shape(broadcast_indices(A_1, A_2))), A_1, A_2)
17411741

17421742
@inline broadcast_zpreserving!(args...) = broadcast!(args...)
1743-
@inline broadcast_zpreserving(args...) = broadcast(args...)
1743+
@inline broadcast_zpreserving(args...) = Base.Broadcast.broadcast_elwise_op(args...)
17441744
broadcast_zpreserving{Tv1,Ti1,Tv2,Ti2}(f::Function, A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) =
17451745
broadcast_zpreserving!(f, spzeros(promote_type(Tv1, Tv2), promote_type(Ti1, Ti2), to_shape(broadcast_indices(A_1, A_2))), A_1, A_2)
17461746
broadcast_zpreserving{Tv,Ti}(f::Function, A_1::SparseMatrixCSC{Tv,Ti}, A_2::Union{Array,BitArray,Number}) =

test/broadcast.jl

+17-2
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ import Base.Meta: isexpr
300300
# PR 16988
301301
@test Base.promote_op(+, Bool) === Int
302302
@test isa(broadcast(+, [true]), Array{Int,1})
303-
@test Base.promote_op(Float64, Bool) === Float64
304303

305304
# issue #17304
306305
let foo = [[1,2,3],[4,5,6],[7,8,9]]
@@ -312,7 +311,7 @@ end
312311
let f17314 = x -> x < 0 ? false : x
313312
@test eltype(broadcast(f17314, 1:3)) === Int
314313
@test eltype(broadcast(f17314, -1:1)) === Integer
315-
@test eltype(broadcast(f17314, Int[])) === Any
314+
@test eltype(broadcast(f17314, Int[])) === Union{Bool,Int}
316315
end
317316
let io = IOBuffer()
318317
broadcast(x->print(io,x), 1:5) # broadcast with side effects
@@ -337,3 +336,19 @@ end
337336
@test broadcast(+, 1.0, (0, -2.0)) == (1.0,-1.0)
338337
@test broadcast(+, 1.0, (0, -2.0), [1]) == [2.0, 0.0]
339338
@test broadcast(*, ["Hello"], ", ", ["World"], "!") == ["Hello, World!"]
339+
340+
# Ensure that even strange constructors that break `T(x)::T` work with broadcast
341+
immutable StrangeType18623 end
342+
StrangeType18623(x) = x
343+
StrangeType18623(x,y) = (x,y)
344+
@test @inferred broadcast(StrangeType18623, 1:3) == [1,2,3]
345+
@test @inferred broadcast(StrangeType18623, 1:3, 4:6) == [(1,4),(2,5),(3,6)]
346+
347+
@test typeof(Int.(Number[1, 2, 3])) === typeof((x->Int(x)).(Number[1, 2, 3]))
348+
349+
@test @inferred broadcast(CartesianIndex, 1:2) == [CartesianIndex(1), CartesianIndex(2)]
350+
@test @inferred broadcast(CartesianIndex, 1:2, 3:4) == [CartesianIndex(1,3), CartesianIndex(2,4)]
351+
352+
# Issue 18622
353+
@test @inferred muladd.([1.0], [2.0], [3.0])::Vector{Float64} == [5.0]
354+
@test @inferred tuple.(1:3, 4:6, 7:9)::Vector{Tuple{Int,Int,Int}} == [(1,4,7), (2,5,8), (3,6,9)]

test/numbers.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -2814,42 +2814,42 @@ let types = (Base.BitInteger_types..., BigInt, Bool,
28142814
Complex{Int}, Complex{UInt}, Complex32, Complex64, Complex128)
28152815
for S in types
28162816
for op in (+, -)
2817-
T = @inferred Base._promote_op(op, S)
2817+
T = @inferred Base.promote_op(op, S)
28182818
t = @inferred op(one(S))
28192819
@test T === typeof(t)
28202820
end
28212821

28222822
for R in types
28232823
for op in (+, -, *, /, ^)
2824-
T = @inferred Base._promote_op(op, S, R)
2824+
T = @inferred Base.promote_op(op, S, R)
28252825
t = @inferred op(one(S), one(R))
28262826
@test T === typeof(t)
28272827
end
28282828
end
28292829
end
28302830

2831-
@test @inferred(Base._promote_op(!, Bool)) === Bool
2831+
@test @inferred(Base.promote_op(!, Bool)) === Bool
28322832
end
28332833

28342834
let types = (Base.BitInteger_types..., BigInt, Bool,
28352835
Rational{Int}, Rational{BigInt},
28362836
Float16, Float32, Float64, BigFloat)
28372837
for S in types, T in types
28382838
for op in (<, >, <=, >=, (==))
2839-
@test @inferred(Base._promote_op(op, S, T)) === Bool
2839+
@test @inferred(Base.promote_op(op, S, T)) === Bool
28402840
end
28412841
end
28422842
end
28432843

28442844
let types = (Base.BitInteger_types..., BigInt, Bool)
28452845
for S in types
2846-
T = @inferred Base._promote_op(~, S)
2846+
T = @inferred Base.promote_op(~, S)
28472847
t = @inferred ~one(S)
28482848
@test T === typeof(t)
28492849

28502850
for R in types
28512851
for op in (&, |, <<, >>, (>>>), %, ÷)
2852-
T = @inferred Base._promote_op(op, S, R)
2852+
T = @inferred Base.promote_op(op, S, R)
28532853
t = @inferred op(one(S), one(R))
28542854
@test T === typeof(t)
28552855
end

0 commit comments

Comments
 (0)