@@ -341,20 +341,16 @@ function flatten(bc::Broadcasted)
341
341
isflat (bc) && return bc
342
342
# concatenate the nested arguments into {a, b, c, d}
343
343
args = cat_nested (bc)
344
- # build a function `makeargs` that takes a "flat" argument list and
345
- # and creates the appropriate input arguments for `f`, e.g.,
346
- # makeargs = (w, x, y, z) -> (w, g(x, y), z)
347
- #
348
- # `makeargs` is built recursively and looks a bit like this:
349
- # makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
350
- # = (w, g(x, y), makeargs2(z)...)
351
- # = (w, g(x, y), z)
352
- let makeargs = make_makeargs (()-> (), bc. args), f = bc. f
353
- newf = @inline function (args:: Vararg{Any,N} ) where N
354
- f (makeargs (args... )... )
355
- end
356
- return Broadcasted (bc. style, newf, args, bc. axes)
357
- end
344
+ # build a tuple of functions `makeargs`. Its elements take
345
+ # the whole "flat" argument list and and generate the appropriate
346
+ # input arguments for the broadcasted function `f`, e.g.,
347
+ # makeargs[1] = ((w, x, y, z)) -> w
348
+ # makeargs[2] = ((w, x, y, z)) -> g(x, y)
349
+ # makeargs[3] = ((w, x, y, z)) -> z
350
+ makeargs = make_makeargs (bc. args)
351
+ f = Base. maybeconstructor (bc. f)
352
+ newf = (args... ) -> (@inline ; f (prepare_args (makeargs, args)... ))
353
+ return Broadcasted (bc. style, newf, args, bc. axes)
358
354
end
359
355
360
356
const NestedTuple = Tuple{<: Broadcasted ,Vararg{Any}}
@@ -363,78 +359,47 @@ _isflat(args::NestedTuple) = false
363
359
_isflat (args:: Tuple ) = _isflat (tail (args))
364
360
_isflat (args:: Tuple{} ) = true
365
361
366
- cat_nested (t:: Broadcasted , rest... ) = (cat_nested (t. args... )... , cat_nested (rest... )... )
367
- cat_nested (t:: Any , rest... ) = (t, cat_nested (rest... )... )
368
- cat_nested () = ()
362
+ cat_nested (bc:: Broadcasted ) = cat_nested_args (bc. args)
363
+ cat_nested_args (:: Tuple{} ) = ()
364
+ cat_nested_args (t:: Tuple{Any} ) = cat_nested (t[1 ])
365
+ cat_nested_args (t:: Tuple ) = (cat_nested (t[1 ])... , cat_nested_args (tail (t))... )
366
+ cat_nested (a) = (a,)
369
367
370
368
"""
371
- make_makeargs(makeargs_tail::Function, t::Tuple) -> Function
369
+ make_makeargs(t::Tuple) -> Tuple{Vararg{ Function}}
372
370
373
371
Each element of `t` is one (consecutive) node in a broadcast tree.
374
- Ignoring `makeargs_tail` for the moment, the job of `make_makeargs` is
375
- to return a function that takes in flattened argument list and returns a
376
- tuple (each entry corresponding to an entry in `t`, having evaluated
377
- the corresponding element in the broadcast tree). As an additional
378
- complication, the passed in tuple may be longer than the number of leaves
379
- in the subtree described by `t`. The `makeargs_tail` function should
380
- be called on such additional arguments (but not the arguments consumed
381
- by `t`).
372
+ The returned `Tuple` are functions which take in the (whole) flattened
373
+ list and generate the inputs for the corresponding broadcasted function.
382
374
"""
383
- @inline make_makeargs (makeargs_tail, t:: Tuple{} ) = makeargs_tail
384
- @inline function make_makeargs (makeargs_tail, t:: Tuple )
385
- makeargs = make_makeargs (makeargs_tail, tail (t))
386
- (head, tail... )-> (head, makeargs (tail... )... )
375
+ make_makeargs (args:: Tuple ) = _make_makeargs (args, 1 )[1 ]
376
+
377
+ # We build `makeargs` by traversing the broadcast nodes recursively.
378
+ # note: `n` indicates the flattened index of the next unused argument.
379
+ @inline function _make_makeargs (args:: Tuple , n:: Int )
380
+ head, n = _make_makeargs1 (args[1 ], n)
381
+ rest, n = _make_makeargs (tail (args), n)
382
+ (head, rest... ), n
387
383
end
388
- function make_makeargs (makeargs_tail, t:: Tuple{<:Broadcasted, Vararg{Any}} )
389
- bc = t[1 ]
390
- # c.f. the same expression in the function on leaf nodes above. Here
391
- # we recurse into siblings in the broadcast tree.
392
- let makeargs_tail = make_makeargs (makeargs_tail, tail (t)),
393
- # Here we recurse into children. It would be valid to pass in makeargs_tail
394
- # here, and not use it below. However, in that case, our recursion is no
395
- # longer purely structural because we're building up one argument (the closure)
396
- # while destructuing another.
397
- makeargs_head = make_makeargs ((args... )-> args, bc. args),
398
- f = bc. f
399
- # Create two functions, one that splits of the first length(bc.args)
400
- # elements from the tuple and one that yields the remaining arguments.
401
- # N.B. We can't call headargs on `args...` directly because
402
- # args is flattened (i.e. our children have not been evaluated
403
- # yet).
404
- headargs, tailargs = make_headargs (bc. args), make_tailargs (bc. args)
405
- return @inline function (args:: Vararg{Any,N} ) where N
406
- args1 = makeargs_head (args... )
407
- a, b = headargs (args1... ), makeargs_tail (tailargs (args1... )... )
408
- (f (a... ), b... )
409
- end
410
- end
384
+ _make_makeargs (:: Tuple{} , n:: Int ) = (), n
385
+
386
+ # A help struct to store the flattened index staticly
387
+ struct Pick{N} <: Function end
388
+ (:: Pick{N} )(@nospecialize (args:: Tuple )) where {N} = args[N]
389
+
390
+ # For flat nodes, we just consume one argument (n += 1), and return the "Pick" function
391
+ @inline _make_makeargs1 (_, n:: Int ) = Pick {n} (), n + 1
392
+ # For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
393
+ @inline function _make_makeargs1 (bc:: Broadcasted , n:: Int )
394
+ makeargs, n = _make_makeargs (bc. args, n)
395
+ f = Base. maybeconstructor (bc. f)
396
+ makeargs1 = (args:: Tuple ) -> (@inline ; f (prepare_args (makeargs, args)... ))
397
+ makeargs1, n
411
398
end
412
399
413
- @inline function make_headargs (t:: Tuple )
414
- let headargs = make_headargs (tail (t))
415
- return @inline function (head, tail:: Vararg{Any,N} ) where N
416
- (head, headargs (tail... )... )
417
- end
418
- end
419
- end
420
- @inline function make_headargs (:: Tuple{} )
421
- return @inline function (tail:: Vararg{Any,N} ) where N
422
- ()
423
- end
424
- end
425
-
426
- @inline function make_tailargs (t:: Tuple )
427
- let tailargs = make_tailargs (tail (t))
428
- return @inline function (head, tail:: Vararg{Any,N} ) where N
429
- tailargs (tail... )
430
- end
431
- end
432
- end
433
- @inline function make_tailargs (:: Tuple{} )
434
- return @inline function (tail:: Vararg{Any,N} ) where N
435
- tail
436
- end
437
- end
400
+ @inline prepare_args (makeargs:: Tuple , @nospecialize (x:: Tuple )) = (makeargs[1 ](x), prepare_args (tail (makeargs), x)... )
401
+ @inline prepare_args (makeargs:: Tuple{Any} , @nospecialize (x:: Tuple )) = (makeargs[1 ](x),)
402
+ prepare_args (:: Tuple{} , :: Tuple ) = ()
438
403
439
404
# # Broadcasting utilities ##
440
405
0 commit comments