@@ -323,13 +323,9 @@ function flatten(bc::Broadcasted{Style}) where {Style}
323
323
# makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
324
324
# = (w, g(x, y), makeargs2(z)...)
325
325
# = (w, g(x, y), z)
326
- let (makeargs, args) = make_makeargs ((), bc. args), f = bc. f
327
- _make (:: NTuple{N,Any} ) where {N} =
328
- @inline function (args:: Vararg{Any,N} )
329
- f (makeargs (args... )... )
330
- end
331
- return Broadcasted {Style} (_make (args), args, bc. axes)
332
- end
326
+ headf, args = make_makeargs (bc. args, ())
327
+ newf = RootNode (bc. f, headf)
328
+ Broadcasted {Style} (newf, args, bc. axes)
333
329
end
334
330
335
331
const NestedTuple = Tuple{<: Broadcasted ,Vararg{Any}}
@@ -339,7 +335,7 @@ _isflat(args::Tuple) = _isflat(tail(args))
339
335
_isflat (args:: Tuple{} ) = true
340
336
341
337
"""
342
- make_makeargs(args ::Tuple, t ::Tuple) -> Function, Tuple
338
+ make_makeargs(t ::Tuple, args ::Tuple) -> Function, Tuple
343
339
344
340
Each element of `t` is one (consecutive) node in a broadcast tree.
345
341
`args` contains the rest arguments on the "right" side of `t`.
@@ -349,36 +345,57 @@ The jobs of `make_makeargs` are:
349
345
tuple (each entry corresponding to an entry in `t`, having evaluated
350
346
the corresponding element in the broadcast tree).
351
347
"""
352
- @inline function make_makeargs (args, t:: Tuple{} )
353
- _make (:: NTuple{N,Any} ) where {N} = (args:: Vararg{Any,N} ) -> args
354
- _make (args), args
355
- end
356
- @inline function make_makeargs (args, t:: Tuple )
357
- makeargs, args′ = make_makeargs (args, tail (t))
358
- _make (:: NTuple{N,Any} ) where {N} =
359
- @inline function (head, tail:: Vararg{Any,N} )
360
- (head, makeargs (tail... )... )
361
- end
362
- _make (args′), (t[1 ], args′... )
348
+ make_makeargs (:: Tuple{} , args) = tuple, args
349
+
350
+ function make_makeargs (t:: Tuple , args)
351
+ tailf, args′ = make_makeargs (tail (t), args)
352
+ newf = tailf === tuple ? tuple : FlatNode (tailf) # avoid unneeded recursion
353
+ newf, (t[1 ], args′... )
363
354
end
364
- function make_makeargs (args, t:: Tuple{<:Broadcasted,Vararg{Any}} )
355
+
356
+ function make_makeargs (t:: NestedTuple , args)
365
357
bc = t[1 ]
366
- # c.f. the same expression in the function on leaf nodes above. Here
367
- # we recurse into siblings in the broadcast tree.
368
- let (makeargs, args′) = make_makeargs (args, tail (t)), f = bc. f
369
- # Here we recurse into children. We can pass in `args′` here,
370
- # and get `args″` directly, but it is more compiler frendly to
371
- # treat `bc` as a new parent "node".
372
- makeargs_head, argsˢ = make_makeargs ((), bc. args)
373
- args″ = (argsˢ... , args′... )
374
- _make (:: NTuple{L,Any} , :: NTuple{N,Any} ) where {L,N} =
375
- @inline function (args:: Vararg{Any,N} )
376
- a, b = Base. IteratorsMD. split (args, Val (L)) # split `args...` directly
377
- (f (makeargs_head (a... )... ), makeargs (b... )... )
378
- end
379
- _make (argsˢ, args″), args″
380
- end
358
+ # Here we recurse into siblings in the broadcast tree.
359
+ tailf, args′ = make_makeargs (tail (t), args)
360
+ # Here we recurse into children.
361
+ # It is more compiler frendly to treat `bc` as a new parent "node".
362
+ headf, argsˢ = make_makeargs (bc. args, ())
363
+ NestedNode {length(argsˢ)} (bc. f, headf, tailf), (argsˢ... , args′... )
364
+ end
365
+
366
+ # Some help structs to flatten `Broadcasted`.
367
+ # TODO : make them better printed in REPL.
368
+ struct RootNode{F,H} <: Function
369
+ f:: F
370
+ prepare:: H
371
+ end
372
+ RootNode (:: Type{F} , prepare:: H ) where {F,H} = RootNode {Type{F},H} (F, prepare)
373
+ @inline (f:: RootNode )(args:: Vararg{Any} ) = f. f (f. prepare (args... )... )
374
+
375
+ struct FlatNode{T} <: Function
376
+ rest:: T
377
+ end
378
+ @inline (f:: FlatNode )(x, args:: Vararg{Any} ) = (x, f. rest (args... )... )
379
+
380
+ struct NestedNode{L,F,H,T} <: Function
381
+ f:: F
382
+ prepare:: H
383
+ rest:: T
381
384
end
385
+ NestedNode {L} (f:: F , prepare:: H , rest:: T ) where {L,F,T,H} = NestedNode {L,F,H,T} (f, prepare, rest)
386
+ NestedNode {L} (:: Type{F} , prepare:: H , rest:: T ) where {L,F,T,H} = NestedNode {L,Type{F},H,T} (F, prepare, rest)
387
+
388
+ # Specialize small `L` manually.
389
+ @inline (f:: NestedNode{1} )(x, args:: Vararg{Any} ) = (f. f (f. prepare (x)... ), f. rest (args... )... )
390
+ @inline (f:: NestedNode{2} )(x1, x2, args:: Vararg{Any} ) = (f. f (f. prepare (x1, x2)... ), f. rest (args... )... )
391
+ @inline (f:: NestedNode{3} )(x1, x2, x3, args:: Vararg{Any} ) = (f. f (f. prepare (x1, x2, x3)... ), f. rest (args... )... )
392
+ @inline (f:: NestedNode{4} )(x1, x2, x3, x4, args:: Vararg{Any} ) = (f. f (f. prepare (x1, x2, x3, x4)... ), f. rest (args... )... )
393
+ # Split based fallback.
394
+ @inline function (f:: NestedNode{L} )(args:: Vararg{Any} ) where {L}
395
+ head, tail = Base. IteratorsMD. split (args, Val (L))
396
+ (f. f (f. prepare (head... )... ), f. rest (tail... )... )
397
+ end
398
+
382
399
# # Broadcasting utilities ##
383
400
384
401
# # logic for deciding the BroadcastStyle
0 commit comments