Skip to content

Commit 2193638

Browse files
pabloferzstevengj
authored andcommitted
Generalize broadcast to handle tuples and scalars (#16986)
* Generalized broadcast arguments * Naming fixes * Add some tests * News and documentation
1 parent a648f4a commit 2193638

12 files changed

+209
-91
lines changed

NEWS.md

+4
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ This section lists changes that do not have deprecation warnings.
3333
for `real(z) < 0`, which differs from `log(gamma(z))` by multiples of 2π
3434
in the imaginary part ([#18330]).
3535

36+
* `broadcast` now handles tuples, and treats any argument that is not a tuple
37+
or an array as a "scalar" ([#16986]).
38+
3639
Library improvements
3740
--------------------
3841

@@ -646,6 +649,7 @@ Language tooling improvements
646649
[#16854]: https://github.com/JuliaLang/julia/issues/16854
647650
[#16953]: https://github.com/JuliaLang/julia/issues/16953
648651
[#16972]: https://github.com/JuliaLang/julia/issues/16972
652+
[#16986]: https://github.com/JuliaLang/julia/issues/16986
649653
[#17033]: https://github.com/JuliaLang/julia/issues/17033
650654
[#17037]: https://github.com/JuliaLang/julia/issues/17037
651655
[#17075]: https://github.com/JuliaLang/julia/issues/17075

base/arraymath.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ promote_array_type{S<:Integer}(::typeof(.\), ::Type{S}, ::Type{Bool}, T::Type) =
5353
promote_array_type{S<:Integer}(F, ::Type{S}, ::Type{Bool}, T::Type) = T
5454

5555
for f in (:+, :-, :div, :mod, :&, :|, :$)
56-
@eval ($f){R,S}(A::AbstractArray{R}, B::AbstractArray{S}) =
57-
_elementwise($f, promote_op($f, R, S), A, B)
56+
@eval ($f)(A::AbstractArray, B::AbstractArray) =
57+
_elementwise($f, promote_eltype_op($f, A, B), A, B)
5858
end
5959
function _elementwise(op, ::Type{Any}, A::AbstractArray, B::AbstractArray)
6060
promote_shape(A, B) # check size compatibility

base/broadcast.jl

+105-54
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,44 @@ export broadcast_getindex, broadcast_setindex!
1010

1111
## Broadcasting utilities ##
1212

13-
# fallback routines for broadcasting with no arguments or with scalars
14-
# to just produce a scalar result:
13+
# fallback for broadcasting with zero arguments and some special cases
1514
broadcast(f) = f()
16-
broadcast(f, x::Number...) = f(x...)
15+
@inline broadcast(f, x::Number...) = f(x...)
16+
@inline broadcast{N}(f, t::NTuple{N}, ts::Vararg{NTuple{N}}) = map(f, t, ts...)
17+
@inline broadcast(f, As::AbstractArray...) = broadcast_t(f, promote_eltype_op(f, As...), As...)
1718

1819
# special cases for "X .= ..." (broadcast!) assignments
1920
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
2021
broadcast!(f, X::AbstractArray) = fill!(X, f())
2122
broadcast!(f, X::AbstractArray, x::Number...) = fill!(X, f(x...))
2223
function broadcast!{T,S,N}(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N})
23-
check_broadcast_shape(size(x), size(y))
24+
check_broadcast_shape(broadcast_indices(x), broadcast_indices(y))
2425
copy!(x, y)
2526
end
2627

27-
## Calculate the broadcast shape of the arguments, or error if incompatible
28+
# logic for deciding the resulting container type
29+
containertype(x) = containertype(typeof(x))
30+
containertype(::Type) = Any
31+
containertype{T<:Tuple}(::Type{T}) = Tuple
32+
containertype{T<:AbstractArray}(::Type{T}) = Array
33+
containertype(ct1, ct2) = promote_containertype(containertype(ct1), containertype(ct2))
34+
@inline containertype(ct1, ct2, cts...) = promote_containertype(containertype(ct1), containertype(ct2, cts...))
35+
36+
promote_containertype(::Type{Array}, ::Type{Array}) = Array
37+
promote_containertype(::Type{Array}, ct) = Array
38+
promote_containertype(ct, ::Type{Array}) = Array
39+
promote_containertype(::Type{Tuple}, ::Type{Any}) = Tuple
40+
promote_containertype(::Type{Any}, ::Type{Tuple}) = Tuple
41+
promote_containertype{T}(::Type{T}, ::Type{T}) = T
42+
43+
## Calculate the broadcast indices of the arguments, or error if incompatible
2844
# array inputs
29-
broadcast_shape() = ()
30-
broadcast_shape(A) = indices(A)
31-
@inline broadcast_shape(A, B...) = broadcast_shape((), indices(A), map(indices, B)...)
45+
broadcast_indices() = ()
46+
broadcast_indices(A) = broadcast_indices(containertype(A), A)
47+
broadcast_indices(::Type{Any}, A) = ()
48+
broadcast_indices(::Type{Tuple}, A) = (OneTo(length(A)),)
49+
broadcast_indices(::Type{Array}, A) = indices(A)
50+
@inline broadcast_indices(A, B...) = broadcast_shape((), broadcast_indices(A), map(broadcast_indices, B)...)
3251
# shape (i.e., tuple-of-indices) inputs
3352
broadcast_shape(shape::Tuple) = shape
3453
@inline broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs((), shape, shape1), shapes...)
@@ -50,24 +69,21 @@ _bcsm(a, b) = a == b || length(b) == 1
5069
_bcsm(a, b::Number) = b == 1
5170
_bcsm(a::Number, b::Number) = a == b || b == 1
5271

53-
## Check that all arguments are broadcast compatible with shape
5472
## Check that all arguments are broadcast compatible with shape
5573
# comparing one input against a shape
56-
check_broadcast_shape(::Tuple{}) = nothing
57-
check_broadcast_shape(::Tuple{}, A::Union{AbstractArray,Number}) = check_broadcast_shape((), indices(A))
5874
check_broadcast_shape(shp) = nothing
59-
check_broadcast_shape(shp, A) = check_broadcast_shape(shp, indices(A))
60-
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing
6175
check_broadcast_shape(shp, ::Tuple{}) = nothing
76+
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing
6277
check_broadcast_shape(::Tuple{}, Ashp::Tuple) = throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
6378
function check_broadcast_shape(shp, Ashp::Tuple)
6479
_bcsm(shp[1], Ashp[1]) || throw(DimensionMismatch("array could not be broadcast to match destination"))
6580
check_broadcast_shape(tail(shp), tail(Ashp))
6681
end
82+
check_broadcast_indices(shp, A) = check_broadcast_shape(shp, broadcast_indices(A))
6783
# comparing many inputs
68-
@inline function check_broadcast_shape(shp, A, As...)
69-
check_broadcast_shape(shp, A)
70-
check_broadcast_shape(shp, As...)
84+
@inline function check_broadcast_indices(shp, A, As...)
85+
check_broadcast_indices(shp, A)
86+
check_broadcast_indices(shp, As...)
7187
end
7288

7389
## Indexing manipulations
@@ -83,14 +99,13 @@ end
8399

84100
# newindexer(shape, A) generates `keep` and `Idefault` (for use by
85101
# `newindex` above) for a particular array `A`, given the
86-
# broadcast_shape `shape`
102+
# broadcast_indices `shape`
87103
# `keep` is equivalent to map(==, indices(A), shape) (but see #17126)
88-
newindexer(shape, x::Number) = (), ()
89-
@inline newindexer(shape, A) = newindexer(shape, indices(A))
90-
@inline newindexer(shape, indsA::Tuple{}) = (), ()
91-
@inline function newindexer(shape, indsA::Tuple)
104+
@inline newindexer(shape, A) = shapeindexer(shape, broadcast_indices(A))
105+
@inline shapeindexer(shape, indsA::Tuple{}) = (), ()
106+
@inline function shapeindexer(shape, indsA::Tuple)
92107
ind1 = indsA[1]
93-
keep, Idefault = newindexer(tail(shape), tail(indsA))
108+
keep, Idefault = shapeindexer(tail(shape), tail(indsA))
94109
(shape[1] == ind1, keep...), (first(ind1), Idefault...)
95110
end
96111

@@ -110,6 +125,10 @@ const bitcache_size = 64 * bitcache_chunks # do not change this
110125
dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) =
111126
Base.copy_to_bitarray_chunks!(Bc, ((bind - 1) << 6) + 1, C, 1, min(bitcache_size, (length(Bc)-bind+1) << 6))
112127

128+
@inline _broadcast_getindex(A, I) = _broadcast_getindex(containertype(A), A, I)
129+
@inline _broadcast_getindex(::Type{Any}, A, I) = A
130+
@inline _broadcast_getindex(::Any, A, I) = A[I]
131+
113132
## Broadcasting core
114133
# nargs encodes the number of As arguments (which matches the number
115134
# of keeps). The first two type parameters are to ensure specialization.
@@ -124,7 +143,7 @@ dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) =
124143
# reverse-broadcast the indices
125144
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
126145
# extract array values
127-
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
146+
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i))
128147
# call the function and store the result
129148
@inbounds B[I] = @ncall $nargs f val
130149
end
@@ -148,7 +167,7 @@ end
148167
# reverse-broadcast the indices
149168
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
150169
# extract array values
151-
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
170+
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i))
152171
# call the function and store the result
153172
@inbounds C[ind] = @ncall $nargs f val
154173
ind += 1
@@ -176,7 +195,7 @@ as in `broadcast!(f, A, A, B)` to perform `A[:] = broadcast(f, A, B)`.
176195
"""
177196
@inline function broadcast!{nargs}(f, B::AbstractArray, As::Vararg{Any,nargs})
178197
shape = indices(B)
179-
check_broadcast_shape(shape, As...)
198+
check_broadcast_indices(shape, As...)
180199
keeps, Idefaults = map_newindexer(shape, As)
181200
_broadcast!(f, B, keeps, Idefaults, As, Val{nargs})
182201
B
@@ -196,7 +215,7 @@ end
196215
# reverse-broadcast the indices
197216
@nexprs $nargs i->(I_i = newindex(I, keep_i, Idefault_i))
198217
# extract array values
199-
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
218+
@nexprs $nargs i->(@inbounds val_i = _broadcast_getindex(A_i, I_i))
200219
# call the function
201220
V = @ncall $nargs f val
202221
S = typeof(V)
@@ -219,7 +238,7 @@ end
219238
end
220239

221240
function broadcast_t(f, ::Type{Any}, As...)
222-
shape = broadcast_shape(As...)
241+
shape = broadcast_indices(As...)
223242
iter = CartesianRange(shape)
224243
if isempty(iter)
225244
return similar(Array{Any}, shape)
@@ -228,19 +247,46 @@ function broadcast_t(f, ::Type{Any}, As...)
228247
keeps, Idefaults = map_newindexer(shape, As)
229248
st = start(iter)
230249
I, st = next(iter, st)
231-
val = f([ As[i][newindex(I, keeps[i], Idefaults[i])] for i=1:nargs ]...)
250+
val = f([ _broadcast_getindex(As[i], newindex(I, keeps[i], Idefaults[i])) for i=1:nargs ]...)
232251
B = similar(Array{typeof(val)}, shape)
233252
B[I] = val
234253
return _broadcast!(f, B, keeps, Idefaults, As, Val{nargs}, iter, st, 1)
235254
end
236255

237-
@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_shape(As...)), As...)
256+
@inline broadcast_t(f, T, As...) = broadcast!(f, similar(Array{T}, broadcast_indices(As...)), As...)
257+
258+
@generated function broadcast_tup{AT,nargs}(f, As::AT, ::Type{Val{nargs}}, n)
259+
quote
260+
ntuple(n -> (@ncall $nargs f i->_broadcast_getindex(As[i], n)), Val{n})
261+
end
262+
end
263+
264+
function broadcast_c(f, ::Type{Tuple}, As...)
265+
shape = broadcast_indices(As...)
266+
check_broadcast_indices(shape, As...)
267+
n = length(shape[1])
268+
nargs = length(As)
269+
return broadcast_tup(f, As, Val{nargs}, n)
270+
end
271+
@inline broadcast_c(f, ::Type{Any}, a...) = f(a...)
272+
@inline broadcast_c(f, ::Type{Array}, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...)
238273

239274
"""
240275
broadcast(f, As...)
241276
242-
Broadcasts the arrays `As` to a common size by expanding singleton dimensions, and returns
243-
an array of the results `f(as...)` for each position.
277+
Broadcasts the arrays, tuples and/or scalars `As` to a container of the
278+
appropriate type and dimensions. In this context, anything that is not a
279+
subtype of `AbstractArray` or `Tuple` is considered a scalar. The resulting
280+
container is stablished by the following rules:
281+
282+
- If all the arguments are scalars, it returns a scalar.
283+
- If the arguments are tuples and zero or more scalars, it returns a tuple.
284+
- If there is at least an array in the arguments, it returns an array
285+
(and treats tuples as 1-dimensional arrays) expanding singleton dimensions.
286+
287+
A special syntax exists for broadcasting: `f.(args...)` is equivalent to
288+
`broadcast(f, args...)`, and nested `f.(g.(args...))` calls are fused into a
289+
single broadcast loop.
244290
245291
```jldoctest
246292
julia> A = [1, 2, 3, 4, 5]
@@ -266,27 +312,32 @@ julia> broadcast(+, A, B)
266312
8 9
267313
11 12
268314
14 15
269-
```
270-
"""
271-
@inline broadcast(f, As...) = broadcast_t(f, promote_eltype_op(f, As...), As...)
272315
273-
# alternate, more compact implementation; unfortunately slower.
274-
# also the `collect` machinery doesn't yet support arbitrary index bases.
275-
#=
276-
@generated function _broadcast{nargs}(f, keeps, As, ::Type{Val{nargs}}, iter)
277-
quote
278-
collect((@ncall $nargs f i->As[i][newindex(I, keeps[i])]) for I in iter)
279-
end
280-
end
316+
julia> parse.(Int, ["1", "2"])
317+
2-element Array{Int64,1}:
318+
1
319+
2
281320
282-
function broadcast(f, As...)
283-
shape = broadcast_shape(As...)
284-
iter = CartesianRange(shape)
285-
keeps, Idefaults = map_newindexer(shape, As)
286-
naT = Val{nfields(As)}
287-
_broadcast(f, keeps, Idefaults, As, naT, iter)
288-
end
289-
=#
321+
julia> abs.((1, -2))
322+
(1,2)
323+
324+
julia> broadcast(+, 1.0, (0, -2.0))
325+
(1.0,-1.0)
326+
327+
julia> broadcast(+, 1.0, (0, -2.0), [1])
328+
2-element Array{Float64,1}:
329+
2.0
330+
0.0
331+
332+
julia> string.(("one","two","three","four"), ": ", 1:4)
333+
4-element Array{String,1}:
334+
"one: 1"
335+
"two: 2"
336+
"three: 3"
337+
"four: 4"
338+
```
339+
"""
340+
@inline broadcast(f, As...) = broadcast_c(f, containertype(As...), As...)
290341

291342
"""
292343
bitbroadcast(f, As...)
@@ -304,7 +355,7 @@ julia> bitbroadcast(isodd,[1,2,3,4,5])
304355
true
305356
```
306357
"""
307-
@inline bitbroadcast(f, As...) = broadcast!(f, similar(BitArray, broadcast_shape(As...)), As...)
358+
@inline bitbroadcast(f, As...) = broadcast!(f, similar(BitArray, broadcast_indices(As...)), As...)
308359

309360
"""
310361
broadcast_getindex(A, inds...)
@@ -345,13 +396,13 @@ julia> broadcast_getindex(C,[1,2,10])
345396
15
346397
```
347398
"""
348-
broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(similar(Array{eltype(src)}, broadcast_shape(I...)), src, I...)
399+
broadcast_getindex(src::AbstractArray, I::AbstractArray...) = broadcast_getindex!(similar(Array{eltype(src)}, broadcast_indices(I...)), src, I...)
349400
@generated function broadcast_getindex!(dest::AbstractArray, src::AbstractArray, I::AbstractArray...)
350401
N = length(I)
351402
Isplat = Expr[:(I[$d]) for d = 1:N]
352403
quote
353404
@nexprs $N d->(I_d = I[d])
354-
check_broadcast_shape(indices(dest), $(Isplat...)) # unnecessary if this function is never called directly
405+
check_broadcast_indices(indices(dest), $(Isplat...)) # unnecessary if this function is never called directly
355406
checkbounds(src, $(Isplat...))
356407
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = indices(I_k, d) == OneTo(1)))
357408
@nloops $N i dest d->(@nexprs $N k->(j_d_k = Ibcast_d_k ? 1 : i_d)) begin
@@ -374,7 +425,7 @@ position in `X` at the indices in `A` given by the same positions in `inds`.
374425
quote
375426
@nexprs $N d->(I_d = I[d])
376427
checkbounds(A, $(Isplat...))
377-
shape = broadcast_shape($(Isplat...))
428+
shape = broadcast_indices($(Isplat...))
378429
@nextract $N shape d->(length(shape) < d ? OneTo(1) : shape[d])
379430
@nexprs $N d->(@nexprs $N k->(Ibcast_d_k = indices(I_k, d) == 1:1))
380431
if !isa(x, AbstractArray)

base/multidimensional.jl

+1
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ maybe_oneto() = OneTo(1)
282282

283283
### From abstractarray.jl: Internal multidimensional indexing definitions ###
284284
getindex(x::Number, i::CartesianIndex{0}) = x
285+
getindex(t::Tuple, I...) = getindex(t, IteratorsMD.flatten(I)...)
285286

286287
# These are not defined on directly on getindex to avoid
287288
# ambiguities for AbstractArray subtypes. See the note in abstractarray.jl

base/number.jl

+2
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ zero{T<:Number}(::Type{T}) = convert(T,0)
9999
one(x::Number) = oftype(x,1)
100100
one{T<:Number}(::Type{T}) = convert(T,1)
101101

102+
_default_type(::Type{Number}) = Int
103+
102104
"""
103105
factorial(n)
104106

base/reducedim.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ function _mapreducedim!{T,N}(f, op, R::AbstractArray, A::AbstractArray{T,N})
196196
return R
197197
end
198198
indsAt, indsRt = safe_tail(indices(A)), safe_tail(indices(R)) # handle d=1 manually
199-
keep, Idefault = Broadcast.newindexer(indsAt, indsRt)
199+
keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt)
200200
if reducedim1(R, A)
201201
# keep the accumulator as a local variable when reducing along the first dimension
202202
i1 = first(indices1(R))
@@ -331,7 +331,7 @@ function findminmax!{T,N}(f, Rval, Rind, A::AbstractArray{T,N})
331331
# If we're reducing along dimension 1, for efficiency we can make use of a temporary.
332332
# Otherwise, keep the result in Rval/Rind so that we traverse A in storage order.
333333
indsAt, indsRt = safe_tail(indices(A)), safe_tail(indices(Rval))
334-
keep, Idefault = Broadcast.newindexer(indsAt, indsRt)
334+
keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt)
335335
k = 0
336336
if reducedim1(Rval, A)
337337
i1 = first(indices1(Rval))

base/sparse/sparse.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
2626
rotl90, rotr90, round, scale!, setindex!, similar, size, transpose, tril,
2727
triu, vec, permute!
2828

29-
import Base.Broadcast: broadcast_shape
29+
import Base.Broadcast: broadcast_indices
3030

3131
export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
3232
SparseMatrixCSC, SparseVector, blkdiag, dense, droptol!, dropzeros!, dropzeros,

0 commit comments

Comments
 (0)