Skip to content

Commit dbb86b8

Browse files
committed
make flatten more compiler frendly
1. add inference test 2. add more test and remove `cat_nested`: `cat_nested` failed to infer in some cases. It has been inserted into `make_makeargs`, thus unneeded.
1 parent 32a026a commit dbb86b8

File tree

2 files changed

+48
-74
lines changed

2 files changed

+48
-74
lines changed

base/broadcast.jl

+39-71
Original file line numberDiff line numberDiff line change
@@ -314,21 +314,21 @@ some cases.
314314
"""
315315
function flatten(bc::Broadcasted{Style}) where {Style}
316316
isflat(bc) && return bc
317-
# concatenate the nested arguments into {a, b, c, d}
318-
args = cat_nested(bc)
319-
# build a function `makeargs` that takes a "flat" argument list and
317+
# 1. concatenate the nested arguments into {a, b, c, d}
318+
# 2. build a function `makeargs` that takes a "flat" argument list and
320319
# and creates the appropriate input arguments for `f`, e.g.,
321320
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
322321
#
323322
# `makeargs` is built recursively and looks a bit like this:
324323
# makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
325324
# = (w, g(x, y), makeargs2(z)...)
326325
# = (w, g(x, y), z)
327-
let makeargs = make_makeargs(()->(), bc.args), f = bc.f
328-
newf = @inline function(args::Vararg{Any,N}) where N
329-
f(makeargs(args...)...)
330-
end
331-
return Broadcasted{Style}(newf, args, bc.axes)
326+
let (makeargs, args) = make_makeargs((), bc.args), f = bc.f
327+
_make(::NTuple{N,Any}) where {N} =
328+
@inline function (args::Vararg{Any,N})
329+
f(makeargs(args...)...)
330+
end
331+
return Broadcasted{Style}(_make(args), args, bc.axes)
332332
end
333333
end
334334

@@ -338,79 +338,47 @@ _isflat(args::NestedTuple) = false
338338
_isflat(args::Tuple) = _isflat(tail(args))
339339
_isflat(args::Tuple{}) = true
340340

341-
cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
342-
cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
343-
cat_nested() = ()
344-
345341
"""
346-
make_makeargs(makeargs_tail::Function, t::Tuple) -> Function
342+
make_makeargs(args::Tuple, t::Tuple) -> Function, Tuple
347343
348344
Each element of `t` is one (consecutive) node in a broadcast tree.
349-
Ignoring `makeargs_tail` for the moment, the job of `make_makeargs` is
350-
to return a function that takes in flattened argument list and returns a
351-
tuple (each entry corresponding to an entry in `t`, having evaluated
352-
the corresponding element in the broadcast tree). As an additional
353-
complication, the passed in tuple may be longer than the number of leaves
354-
in the subtree described by `t`. The `makeargs_tail` function should
355-
be called on such additional arguments (but not the arguments consumed
356-
by `t`).
345+
`args` contains the rest arguments on the "right" side of `t`.
346+
The jobs of `make_makeargs` are:
347+
1. append the flattened arguments in `t` at the beginning of `args`.
348+
2. return a function that takes in flattened argument list and returns a
349+
tuple (each entry corresponding to an entry in `t`, having evaluated
350+
the corresponding element in the broadcast tree).
357351
"""
358-
@inline make_makeargs(makeargs_tail, t::Tuple{}) = makeargs_tail
359-
@inline function make_makeargs(makeargs_tail, t::Tuple)
360-
makeargs = make_makeargs(makeargs_tail, tail(t))
361-
(head, tail...)->(head, makeargs(tail...)...)
352+
@inline function make_makeargs(args, t::Tuple{})
353+
_make(::NTuple{N,Any}) where {N} = (args::Vararg{Any,N}) -> args
354+
_make(args), args
362355
end
363-
function make_makeargs(makeargs_tail, t::Tuple{<:Broadcasted, Vararg{Any}})
356+
@inline function make_makeargs(args, t::Tuple)
357+
makeargs, args′ = make_makeargs(args, tail(t))
358+
_make(::NTuple{N,Any}) where {N} =
359+
@inline function (head, tail::Vararg{Any,N})
360+
(head, makeargs(tail...)...)
361+
end
362+
_make(args′), (t[1], args′...)
363+
end
364+
function make_makeargs(args, t::Tuple{<:Broadcasted,Vararg{Any}})
364365
bc = t[1]
365366
# c.f. the same expression in the function on leaf nodes above. Here
366367
# we recurse into siblings in the broadcast tree.
367-
let makeargs_tail = make_makeargs(makeargs_tail, tail(t)),
368-
# Here we recurse into children. It would be valid to pass in makeargs_tail
369-
# here, and not use it below. However, in that case, our recursion is no
370-
# longer purely structural because we're building up one argument (the closure)
371-
# while destructuing another.
372-
makeargs_head = make_makeargs((args...)->args, bc.args),
373-
f = bc.f
374-
# Create two functions, one that splits of the first length(bc.args)
375-
# elements from the tuple and one that yields the remaining arguments.
376-
# N.B. We can't call headargs on `args...` directly because
377-
# args is flattened (i.e. our children have not been evaluated
378-
# yet).
379-
headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args)
380-
return @inline function(args::Vararg{Any,N}) where N
381-
args1 = makeargs_head(args...)
382-
a, b = headargs(args1...), makeargs_tail(tailargs(args1...)...)
383-
(f(a...), b...)
384-
end
385-
end
386-
end
387-
388-
@inline function make_headargs(t::Tuple)
389-
let headargs = make_headargs(tail(t))
390-
return @inline function(head, tail::Vararg{Any,N}) where N
391-
(head, headargs(tail...)...)
392-
end
368+
let (makeargs, args′) = make_makeargs(args, tail(t)), f = bc.f
369+
# Here we recurse into children. We can pass in `args′` here,
370+
# and get `args″` directly, but it is more compiler frendly to
371+
# treat `bc` as a new parent "node".
372+
makeargs_head, argsˢ = make_makeargs((), bc.args)
373+
args″ = (argsˢ..., args′...)
374+
_make(::NTuple{L,Any}, ::NTuple{N,Any}) where {L,N} =
375+
@inline function (args::Vararg{Any,N})
376+
a, b = Base.IteratorsMD.split(args, Val(L)) # split `args...` directly
377+
(f(makeargs_head(a...)...), makeargs(b...)...)
378+
end
379+
_make(argsˢ, args″), args″
393380
end
394381
end
395-
@inline function make_headargs(::Tuple{})
396-
return @inline function(tail::Vararg{Any,N}) where N
397-
()
398-
end
399-
end
400-
401-
@inline function make_tailargs(t::Tuple)
402-
let tailargs = make_tailargs(tail(t))
403-
return @inline function(head, tail::Vararg{Any,N}) where N
404-
tailargs(tail...)
405-
end
406-
end
407-
end
408-
@inline function make_tailargs(::Tuple{})
409-
return @inline function(tail::Vararg{Any,N}) where N
410-
tail
411-
end
412-
end
413-
414382
## Broadcasting utilities ##
415383

416384
## logic for deciding the BroadcastStyle

test/broadcast.jl

+9-3
Original file line numberDiff line numberDiff line change
@@ -772,13 +772,19 @@ end
772772

773773
# issue #27988: inference of Broadcast.flatten
774774
using .Broadcast: Broadcasted
775-
let
775+
let _cat_nested(bc) = Broadcast.flatten(bc).args
776776
bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
777-
@test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5)
777+
@test @inferred(_cat_nested(bc)) == (1,2,3,4,5)
778778
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62
779779
bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
780-
@test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5)
780+
@test @inferred(_cat_nested(bc)) == (1,2.0,2.5,3,4,5)
781781
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8
782+
# 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3
783+
bc = Broadcasted(+, (Broadcasted(+, (Broadcasted(-, (Broadcasted(*, (1, 1)), Broadcasted(*, (1, Broadcasted(Base.literal_pow, (Ref(^), 1, Ref(Val(2)))))))), Broadcasted(*, (1, 1)))), Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}())))))
784+
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 2
785+
# @. 1 + 1 * (1 + 1 + 1 + 1)
786+
bc = Broadcasted(+, (1, Broadcasted(*, (1, Broadcasted(+, (1, 1, 1, 1))))))
787+
@test @inferred(_cat_nested(bc)) == (1,1,1,1,1,1) # `cat_nested` failed to infer this
782788
end
783789

784790
let

0 commit comments

Comments
 (0)