Skip to content

Commit 0b02d79

Browse files
committed
make flattened Broadcasted more compiler friendly.
1. make `cat_nested` better inferred by switching to direct self-recursion. 2. `make_makeargs` now create a tuple of functions which take in the whole argument list and return the corresponding input for the broadcasted function.
1 parent ba146fb commit 0b02d79

File tree

2 files changed

+59
-81
lines changed

2 files changed

+59
-81
lines changed

base/broadcast.jl

+43-78
Original file line numberDiff line numberDiff line change
@@ -329,20 +329,16 @@ function flatten(bc::Broadcasted{Style}) where {Style}
329329
isflat(bc) && return bc
330330
# concatenate the nested arguments into {a, b, c, d}
331331
args = cat_nested(bc)
332-
# build a function `makeargs` that takes a "flat" argument list and
333-
# and creates the appropriate input arguments for `f`, e.g.,
334-
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
335-
#
336-
# `makeargs` is built recursively and looks a bit like this:
337-
# makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
338-
# = (w, g(x, y), makeargs2(z)...)
339-
# = (w, g(x, y), z)
340-
let makeargs = make_makeargs(()->(), bc.args), f = bc.f
341-
newf = @inline function(args::Vararg{Any,N}) where N
342-
f(makeargs(args...)...)
343-
end
344-
return Broadcasted{Style}(newf, args, bc.axes)
345-
end
332+
# build a tuple of functions `makeargs`. Its elements take
333+
# the whole "flat" argument list and and generate the appropriate
334+
# input arguments for the broadcasted function `f`, e.g.,
335+
# makeargs[1] = ((w, x, y, z)) -> w
336+
# makeargs[2] = ((w, x, y, z)) -> g(x, y)
337+
# makeargs[3] = ((w, x, y, z)) -> z
338+
makeargs = make_makeargs(bc.args)
339+
f = Base.maybeconstructor(bc.f)
340+
newf = (args...) -> (@inline; f(prepare_args(makeargs, args)...))
341+
return Broadcasted{Style}(newf, args, bc.axes)
346342
end
347343

348344
const NestedTuple = Tuple{<:Broadcasted,Vararg{Any}}
@@ -351,78 +347,47 @@ _isflat(args::NestedTuple) = false
351347
_isflat(args::Tuple) = _isflat(tail(args))
352348
_isflat(args::Tuple{}) = true
353349

354-
cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
355-
cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
356-
cat_nested() = ()
350+
cat_nested(bc::Broadcasted) = cat_nested_args(bc.args)
351+
cat_nested_args(::Tuple{}) = ()
352+
cat_nested_args(t::Tuple{Any}) = cat_nested(t[1])
353+
cat_nested_args(t::Tuple) = (cat_nested(t[1])..., cat_nested_args(tail(t))...)
354+
cat_nested(a) = (a,)
357355

358356
"""
359-
make_makeargs(makeargs_tail::Function, t::Tuple) -> Function
357+
make_makeargs(t::Tuple) -> Tuple{Vararg{Function}}
360358
361359
Each element of `t` is one (consecutive) node in a broadcast tree.
362-
Ignoring `makeargs_tail` for the moment, the job of `make_makeargs` is
363-
to return a function that takes in flattened argument list and returns a
364-
tuple (each entry corresponding to an entry in `t`, having evaluated
365-
the corresponding element in the broadcast tree). As an additional
366-
complication, the passed in tuple may be longer than the number of leaves
367-
in the subtree described by `t`. The `makeargs_tail` function should
368-
be called on such additional arguments (but not the arguments consumed
369-
by `t`).
360+
The returned `Tuple` are functions which take in the (whole) flattened
361+
list and generate the inputs for the corresponding broadcasted function.
370362
"""
371-
@inline make_makeargs(makeargs_tail, t::Tuple{}) = makeargs_tail
372-
@inline function make_makeargs(makeargs_tail, t::Tuple)
373-
makeargs = make_makeargs(makeargs_tail, tail(t))
374-
(head, tail...)->(head, makeargs(tail...)...)
363+
make_makeargs(args::Tuple) = _make_makeargs(args, 1)[1]
364+
365+
# We build `makeargs` by traversing the broadcast nodes recursively.
366+
# note: `n` indicates the flattened index of the next unused argument.
367+
@inline function _make_makeargs(args::Tuple, n::Int)
368+
head, n = _make_makeargs1(args[1], n)
369+
rest, n = _make_makeargs(tail(args), n)
370+
(head, rest...), n
375371
end
376-
function make_makeargs(makeargs_tail, t::Tuple{<:Broadcasted, Vararg{Any}})
377-
bc = t[1]
378-
# c.f. the same expression in the function on leaf nodes above. Here
379-
# we recurse into siblings in the broadcast tree.
380-
let makeargs_tail = make_makeargs(makeargs_tail, tail(t)),
381-
# Here we recurse into children. It would be valid to pass in makeargs_tail
382-
# here, and not use it below. However, in that case, our recursion is no
383-
# longer purely structural because we're building up one argument (the closure)
384-
# while destructuing another.
385-
makeargs_head = make_makeargs((args...)->args, bc.args),
386-
f = bc.f
387-
# Create two functions, one that splits of the first length(bc.args)
388-
# elements from the tuple and one that yields the remaining arguments.
389-
# N.B. We can't call headargs on `args...` directly because
390-
# args is flattened (i.e. our children have not been evaluated
391-
# yet).
392-
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
393-
return @inline function(args::Vararg{Any,N}) where N
394-
args1 = makeargs_head(args...)
395-
a, b = headargs(args1...), makeargs_tail(tailargs(args1...)...)
396-
(f(a...), b...)
397-
end
398-
end
372+
_make_makeargs(::Tuple{}, n::Int) = (), n
373+
374+
# A help struct to store the flattened index staticly
375+
struct Pick{N} <: Function end
376+
(::Pick{N})(@nospecialize(args::Tuple)) where {N} = args[N]
377+
378+
# For flat nodes, we just consume one argument (n += 1), and return the "Pick" function
379+
@inline _make_makeargs1(_, n::Int) = Pick{n}(), n + 1
380+
# For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
381+
@inline function _make_makeargs1(bc::Broadcasted, n::Int)
382+
makeargs, n = _make_makeargs(bc.args, n)
383+
f = Base.maybeconstructor(bc.f)
384+
makeargs1 = (args::Tuple) -> (@inline; f(prepare_args(makeargs, args)...))
385+
makeargs1, n
399386
end
400387

401-
@inline function make_headargs(t::Tuple)
402-
let headargs = make_headargs(tail(t))
403-
return @inline function(head, tail::Vararg{Any,N}) where N
404-
(head, headargs(tail...)...)
405-
end
406-
end
407-
end
408-
@inline function make_headargs(::Tuple{})
409-
return @inline function(tail::Vararg{Any,N}) where N
410-
()
411-
end
412-
end
413-
414-
@inline function make_tailargs(t::Tuple)
415-
let tailargs = make_tailargs(tail(t))
416-
return @inline function(head, tail::Vararg{Any,N}) where N
417-
tailargs(tail...)
418-
end
419-
end
420-
end
421-
@inline function make_tailargs(::Tuple{})
422-
return @inline function(tail::Vararg{Any,N}) where N
423-
tail
424-
end
425-
end
388+
@inline prepare_args(makeargs::Tuple, @nospecialize(x::Tuple)) = (makeargs[1](x), prepare_args(tail(makeargs), x)...)
389+
@inline prepare_args(makeargs::Tuple{Any}, @nospecialize(x::Tuple)) = (makeargs[1](x),)
390+
prepare_args(::Tuple{}, ::Tuple) = ()
426391

427392
## Broadcasting utilities ##
428393

test/broadcast.jl

+16-3
Original file line numberDiff line numberDiff line change
@@ -774,14 +774,27 @@ let X = zeros(2, 3)
774774
end
775775

776776
# issue #27988: inference of Broadcast.flatten
777-
using .Broadcast: Broadcasted
777+
using .Broadcast: Broadcasted, cat_nested
778778
let
779779
bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
780-
@test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5)
780+
@test @inferred(cat_nested(bc)) == (1,2,3,4,5)
781781
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62
782782
bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
783-
@test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5)
783+
@test @inferred(cat_nested(bc)) == (1,2.0,2.5,3,4,5)
784784
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8
785+
# 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3
786+
bc = Broadcasted(+, (Broadcasted(+, (Broadcasted(-, (Broadcasted(*, (1, 1)), Broadcasted(*, (1, Broadcasted(Base.literal_pow, (Ref(^), 1, Ref(Val(2)))))))), Broadcasted(*, (1, 1)))), Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}())))))
787+
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 2
788+
# @. 1 + 1 * (1 + 1 + 1 + 1)
789+
bc = Broadcasted(+, (1, Broadcasted(*, (1, Broadcasted(+, (1, 1, 1, 1))))))
790+
@test @inferred(cat_nested(bc)) == (1, 1, 1, 1, 1, 1) # `cat_nested` failed to infer this
791+
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc)
792+
# @. 1 + (1 + 1) + 1 + (1 + 1) + 1 + (1 + 1) + 1
793+
bc = Broadcasted(+, (1, Broadcasted(+, (1, 1)), 1, Broadcasted(+, (1, 1)), 1, Broadcasted(+, (1, 1)), 1))
794+
@test @inferred(cat_nested(bc)) == (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
795+
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc)
796+
bc = Broadcasted(Float32, (Broadcasted(+, (1, 1)),))
797+
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == Broadcast.materialize(bc)
785798
end
786799

787800
let

0 commit comments

Comments
 (0)