@@ -314,21 +314,21 @@ some cases.
314
314
"""
315
315
function flatten (bc:: Broadcasted{Style} ) where {Style}
316
316
isflat (bc) && return bc
317
- # concatenate the nested arguments into {a, b, c, d}
318
- args = cat_nested (bc)
319
- # build a function `makeargs` that takes a "flat" argument list and
317
+ # 1. concatenate the nested arguments into {a, b, c, d}
318
+ # 2. build a function `makeargs` that takes a "flat" argument list and
320
319
# and creates the appropriate input arguments for `f`, e.g.,
321
320
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
322
321
#
323
322
# `makeargs` is built recursively and looks a bit like this:
324
323
# makeargs(w, x, y, z) = (w, makeargs1(x, y, z)...)
325
324
# = (w, g(x, y), makeargs2(z)...)
326
325
# = (w, g(x, y), z)
327
- let makeargs = make_makeargs (()-> (), bc. args), f = bc. f
328
- newf = @inline function (args:: Vararg{Any,N} ) where N
329
- f (makeargs (args... )... )
330
- end
331
- return Broadcasted {Style} (newf, args, bc. axes)
326
+ let (makeargs, args) = make_makeargs ((), bc. args), f = bc. f
327
+ _make (:: NTuple{N,Any} ) where {N} =
328
+ @inline function (args:: Vararg{Any,N} )
329
+ f (makeargs (args... )... )
330
+ end
331
+ return Broadcasted {Style} (_make (args), args, bc. axes)
332
332
end
333
333
end
334
334
@@ -338,79 +338,47 @@ _isflat(args::NestedTuple) = false
338
338
_isflat (args:: Tuple ) = _isflat (tail (args))
339
339
_isflat (args:: Tuple{} ) = true
340
340
341
- cat_nested (t:: Broadcasted , rest... ) = (cat_nested (t. args... )... , cat_nested (rest... )... )
342
- cat_nested (t:: Any , rest... ) = (t, cat_nested (rest... )... )
343
- cat_nested () = ()
344
-
345
341
"""
346
- make_makeargs(makeargs_tail::Function , t::Tuple) -> Function
342
+ make_makeargs(args::Tuple , t::Tuple) -> Function, Tuple
347
343
348
344
Each element of `t` is one (consecutive) node in a broadcast tree.
349
- Ignoring `makeargs_tail` for the moment, the job of `make_makeargs` is
350
- to return a function that takes in flattened argument list and returns a
351
- tuple (each entry corresponding to an entry in `t`, having evaluated
352
- the corresponding element in the broadcast tree). As an additional
353
- complication, the passed in tuple may be longer than the number of leaves
354
- in the subtree described by `t`. The `makeargs_tail` function should
355
- be called on such additional arguments (but not the arguments consumed
356
- by `t`).
345
+ `args` contains the rest arguments on the "right" side of `t`.
346
+ The jobs of `make_makeargs` are:
347
+ 1. append the flattened arguments in `t` at the beginning of `args`.
348
+ 2. return a function that takes in flattened argument list and returns a
349
+ tuple (each entry corresponding to an entry in `t`, having evaluated
350
+ the corresponding element in the broadcast tree).
357
351
"""
358
- @inline make_makeargs (makeargs_tail, t:: Tuple{} ) = makeargs_tail
359
- @inline function make_makeargs (makeargs_tail, t:: Tuple )
360
- makeargs = make_makeargs (makeargs_tail, tail (t))
361
- (head, tail... )-> (head, makeargs (tail... )... )
352
+ @inline function make_makeargs (args, t:: Tuple{} )
353
+ _make (:: NTuple{N,Any} ) where {N} = (args:: Vararg{Any,N} ) -> args
354
+ _make (args), args
362
355
end
363
- function make_makeargs (makeargs_tail, t:: Tuple{<:Broadcasted, Vararg{Any}} )
356
+ @inline function make_makeargs (args, t:: Tuple )
357
+ makeargs, args′ = make_makeargs (args, tail (t))
358
+ _make (:: NTuple{N,Any} ) where {N} =
359
+ @inline function (head, tail:: Vararg{Any,N} )
360
+ (head, makeargs (tail... )... )
361
+ end
362
+ _make (args′), (t[1 ], args′... )
363
+ end
364
+ function make_makeargs (args, t:: Tuple{<:Broadcasted,Vararg{Any}} )
364
365
bc = t[1 ]
365
366
# c.f. the same expression in the function on leaf nodes above. Here
366
367
# we recurse into siblings in the broadcast tree.
367
- let makeargs_tail = make_makeargs (makeargs_tail, tail (t)),
368
- # Here we recurse into children. It would be valid to pass in makeargs_tail
369
- # here, and not use it below. However, in that case, our recursion is no
370
- # longer purely structural because we're building up one argument (the closure)
371
- # while destructuing another.
372
- makeargs_head = make_makeargs ((args... )-> args, bc. args),
373
- f = bc. f
374
- # Create two functions, one that splits of the first length(bc.args)
375
- # elements from the tuple and one that yields the remaining arguments.
376
- # N.B. We can't call headargs on `args...` directly because
377
- # args is flattened (i.e. our children have not been evaluated
378
- # yet).
379
- headargs, tailargs = make_headargs (bc. args), make_tailargs (bc. args)
380
- return @inline function (args:: Vararg{Any,N} ) where N
381
- args1 = makeargs_head (args... )
382
- a, b = headargs (args1... ), makeargs_tail (tailargs (args1... )... )
383
- (f (a... ), b... )
384
- end
385
- end
386
- end
387
-
388
- @inline function make_headargs (t:: Tuple )
389
- let headargs = make_headargs (tail (t))
390
- return @inline function (head, tail:: Vararg{Any,N} ) where N
391
- (head, headargs (tail... )... )
392
- end
368
+ let (makeargs, args′) = make_makeargs (args, tail (t)), f = bc. f
369
+ # Here we recurse into children. We can pass in `args′` here,
370
+ # and get `args″` directly, but it is more compiler frendly to
371
+ # treat `bc` as a new parent "node".
372
+ makeargs_head, argsˢ = make_makeargs ((), bc. args)
373
+ args″ = (argsˢ... , args′... )
374
+ _make (:: NTuple{L,Any} , :: NTuple{N,Any} ) where {L,N} =
375
+ @inline function (args:: Vararg{Any,N} )
376
+ a, b = Base. IteratorsMD. split (args, Val (L)) # split `args...` directly
377
+ (f (makeargs_head (a... )... ), makeargs (b... )... )
378
+ end
379
+ _make (argsˢ, args″), args″
393
380
end
394
381
end
395
- @inline function make_headargs (:: Tuple{} )
396
- return @inline function (tail:: Vararg{Any,N} ) where N
397
- ()
398
- end
399
- end
400
-
401
- @inline function make_tailargs (t:: Tuple )
402
- let tailargs = make_tailargs (tail (t))
403
- return @inline function (head, tail:: Vararg{Any,N} ) where N
404
- tailargs (tail... )
405
- end
406
- end
407
- end
408
- @inline function make_tailargs (:: Tuple{} )
409
- return @inline function (tail:: Vararg{Any,N} ) where N
410
- tail
411
- end
412
- end
413
-
414
382
# # Broadcasting utilities ##
415
383
416
384
# # logic for deciding the BroadcastStyle
0 commit comments