Skip to content

Commit 726ccff

Browse files
committed
Make flattened Broadcasted better inlined
Similar to #41139, but avoid unnecessary extra methods.
1 parent 955f05d commit 726ccff

File tree

1 file changed

+52
-35
lines changed

1 file changed

+52
-35
lines changed

base/broadcast.jl

+52-35
Original file line numberDiff line numberDiff line change
@@ -323,13 +323,9 @@ function flatten(bc::Broadcasted{Style}) where {Style}
323323
# makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
324324
# = (w, g(x, y), makeargs2(z)...)
325325
# = (w, g(x, y), z)
326-
let (makeargs, args) = make_makeargs((), bc.args), f = bc.f
327-
_make(::NTuple{N,Any}) where {N} =
328-
@inline function (args::Vararg{Any,N})
329-
f(makeargs(args...)...)
330-
end
331-
return Broadcasted{Style}(_make(args), args, bc.axes)
332-
end
326+
headf, args = make_makeargs(bc.args, ())
327+
newf = RootNode(bc.f, headf)
328+
Broadcasted{Style}(newf, args, bc.axes)
333329
end
334330

335331
const NestedTuple = Tuple{<:Broadcasted,Vararg{Any}}
@@ -339,7 +335,7 @@ _isflat(args::Tuple) = _isflat(tail(args))
339335
_isflat(args::Tuple{}) = true
340336

341337
"""
342-
make_makeargs(args::Tuple, t::Tuple) -> Function, Tuple
338+
make_makeargs(t::Tuple, args::Tuple) -> Function, Tuple
343339
344340
Each element of `t` is one (consecutive) node in a broadcast tree.
345341
`args` contains the rest arguments on the "right" side of `t`.
@@ -349,36 +345,57 @@ The jobs of `make_makeargs` are:
349345
tuple (each entry corresponding to an entry in `t`, having evaluated
350346
the corresponding element in the broadcast tree).
351347
"""
352-
@inline function make_makeargs(args, t::Tuple{})
353-
_make(::NTuple{N,Any}) where {N} = (args::Vararg{Any,N}) -> args
354-
_make(args), args
355-
end
356-
@inline function make_makeargs(args, t::Tuple)
357-
makeargs, args′ = make_makeargs(args, tail(t))
358-
_make(::NTuple{N,Any}) where {N} =
359-
@inline function (head, tail::Vararg{Any,N})
360-
(head, makeargs(tail...)...)
361-
end
362-
_make(args′), (t[1], args′...)
348+
make_makeargs(::Tuple{}, args) = tuple, args
349+
350+
function make_makeargs(t::Tuple, args)
351+
tailf, args′ = make_makeargs(tail(t), args)
352+
newf = tailf === tuple ? tuple : FlatNode(tailf) # avoid unneeded recursion
353+
newf, (t[1], args′...)
363354
end
364-
function make_makeargs(args, t::Tuple{<:Broadcasted,Vararg{Any}})
355+
356+
function make_makeargs(t::NestedTuple, args)
365357
bc = t[1]
366-
# c.f. the same expression in the function on leaf nodes above. Here
367-
# we recurse into siblings in the broadcast tree.
368-
let (makeargs, args′) = make_makeargs(args, tail(t)), f = bc.f
369-
# Here we recurse into children. We can pass in `args′` here,
370-
# and get `args″` directly, but it is more compiler frendly to
371-
# treat `bc` as a new parent "node".
372-
makeargs_head, argsˢ = make_makeargs((), bc.args)
373-
args″ = (argsˢ..., args′...)
374-
_make(::NTuple{L,Any}, ::NTuple{N,Any}) where {L,N} =
375-
@inline function (args::Vararg{Any,N})
376-
a, b = Base.IteratorsMD.split(args, Val(L)) # split `args...` directly
377-
(f(makeargs_head(a...)...), makeargs(b...)...)
378-
end
379-
_make(argsˢ, args″), args″
380-
end
358+
# Here we recurse into siblings in the broadcast tree.
359+
tailf, args′ = make_makeargs(tail(t), args)
360+
# Here we recurse into children.
361+
# It is more compiler frendly to treat `bc` as a new parent "node".
362+
headf, argsˢ = make_makeargs(bc.args, ())
363+
NestedNode{length(argsˢ)}(bc.f, headf, tailf), (argsˢ..., args′...)
364+
end
365+
366+
# Some help structs to flatten `Broadcasted`.
367+
# TODO: make them better printed in REPL.
368+
struct RootNode{F,H} <: Function
369+
f::F
370+
prepare::H
371+
end
372+
RootNode(::Type{F}, prepare::H) where {F,H} = RootNode{Type{F},H}(F, prepare)
373+
@inline (f::RootNode)(args::Vararg{Any}) = f.f(f.prepare(args...)...)
374+
375+
struct FlatNode{T} <: Function
376+
rest::T
377+
end
378+
@inline (f::FlatNode)(x, args::Vararg{Any}) = (x, f.rest(args...)...)
379+
380+
struct NestedNode{L,F,H,T} <: Function
381+
f::F
382+
prepare::H
383+
rest::T
381384
end
385+
NestedNode{L}(f::F, prepare::H, rest::T) where {L,F,T,H} = NestedNode{L,F,H,T}(f, prepare, rest)
386+
NestedNode{L}(::Type{F}, prepare::H, rest::T) where {L,F,T,H} = NestedNode{L,Type{F},H,T}(F, prepare, rest)
387+
388+
# Specialize small `L` manually.
389+
@inline (f::NestedNode{1})(x, args::Vararg{Any}) = (f.f(f.prepare(x)...), f.rest(args...)...)
390+
@inline (f::NestedNode{2})(x1, x2, args::Vararg{Any}) = (f.f(f.prepare(x1, x2)...), f.rest(args...)...)
391+
@inline (f::NestedNode{3})(x1, x2, x3, args::Vararg{Any}) = (f.f(f.prepare(x1, x2, x3)...), f.rest(args...)...)
392+
@inline (f::NestedNode{4})(x1, x2, x3, x4, args::Vararg{Any}) = (f.f(f.prepare(x1, x2, x3, x4)...), f.rest(args...)...)
393+
# Split based fallback.
394+
@inline function (f::NestedNode{L})(args::Vararg{Any}) where {L}
395+
head, tail = Base.IteratorsMD.split(args, Val(L))
396+
(f.f(f.prepare(head...)...), f.rest(tail...)...)
397+
end
398+
382399
## Broadcasting utilities ##
383400

384401
## logic for deciding the BroadcastStyle

0 commit comments

Comments
 (0)