Skip to content

Commit 7429b48

Browse files
committed
add more test and remove cat_nested
`cat_nested` failed to infer in some cases. It has been inserted into `make_makeargs`, So I remove it.
1 parent fd6721b commit 7429b48

File tree

2 files changed

+11
-15
lines changed

2 files changed

+11
-15
lines changed

base/broadcast.jl

+4-11
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,8 @@ some cases.
314314
"""
315315
function flatten(bc::Broadcasted{Style}) where {Style}
316316
isflat(bc) && return bc
317-
# concatenate the nested arguments into {a, b, c, d}
318-
# args = cat_nested(bc)
319-
# build a function `makeargs` that takes a "flat" argument list and
317+
# 1. concatenate the nested arguments into {a, b, c, d}
318+
# 2. build a function `makeargs` that takes a "flat" argument list and
320319
# and creates the appropriate input arguments for `f`, e.g.,
321320
# makeargs = (w, x, y, z) -> (w, g(x, y), z)
322321
#
@@ -329,8 +328,7 @@ function flatten(bc::Broadcasted{Style}) where {Style}
329328
@inline function (args::Vararg{Any,N})
330329
f(makeargs(args...)...)
331330
end
332-
newf = _make(args)
333-
return Broadcasted{Style}(newf, args, bc.axes)
331+
return Broadcasted{Style}(_make(args), args, bc.axes)
334332
end
335333
end
336334

@@ -340,18 +338,13 @@ _isflat(args::NestedTuple) = false
340338
_isflat(args::Tuple) = _isflat(tail(args))
341339
_isflat(args::Tuple{}) = true
342340

343-
cat_nested(t::Broadcasted, rest...) = (cat_nested(t.args...)..., cat_nested(rest...)...)
344-
cat_nested(t::Any, rest...) = (t, cat_nested(rest...)...)
345-
cat_nested() = ()
346-
347341
"""
348342
make_makeargs(args::Tuple, t::Tuple) -> Function, Tuple
349343
350344
Each element of `t` is one (consecutive) node in a broadcast tree.
351345
`args` contains the rest arguments on the "right" side of `t`.
352346
The jobs of `make_makeargs` are:
353-
1. append the flattened arguments in `t` at the beginning of `args`, i.e.
354-
`(cat_nested(t)..., args...)`
347+
1. append the flattened arguments in `t` at the beginning of `args`.
355348
2. return a function that takes in flattened argument list and returns a
356349
tuple (each entry corresponding to an entry in `t`, having evaluated
357350
the corresponding element in the broadcast tree).

test/broadcast.jl

+7-4
Original file line numberDiff line numberDiff line change
@@ -772,16 +772,19 @@ end
772772

773773
# issue #27988: inference of Broadcast.flatten
774774
using .Broadcast: Broadcasted
775-
let
775+
let _cat_nested(bc) = Broadcast.flatten(bc).args
776776
bc = Broadcasted(+, (Broadcasted(*, (1, 2)), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
777-
@test @inferred(Broadcast.cat_nested(bc)) == (1,2,3,4,5)
777+
@test @inferred(_cat_nested(bc)) == (1,2,3,4,5)
778778
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 62
779779
bc = Broadcasted(+, (Broadcasted(*, (1, Broadcasted(/, (2.0, 2.5)))), Broadcasted(*, (Broadcasted(*, (3, 4)), 5))))
780-
@test @inferred(Broadcast.cat_nested(bc)) == (1,2.0,2.5,3,4,5)
780+
@test @inferred(_cat_nested(bc)) == (1,2.0,2.5,3,4,5)
781781
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 60.8
782782
# 1 .* 1 .- 1 .* 1 .^2 .+ 1 .* 1 .+ 1 .^ 3
783-
bc = Base.Broadcast.Broadcasted(+, (Base.Broadcast.Broadcasted(+, (Base.Broadcast.Broadcasted(-, (Base.Broadcast.Broadcasted(*, (1, 1)), Base.Broadcast.Broadcasted(*, (1, Base.Broadcast.Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{2}}(Val{2}()))))))), Base.Broadcast.Broadcasted(*, (1, 1)))), Base.Broadcast.Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}())))))
783+
bc = Broadcasted(+, (Broadcasted(+, (Broadcasted(-, (Broadcasted(*, (1, 1)), Broadcasted(*, (1, Broadcasted(Base.literal_pow, (Ref(^), 1, Ref(Val(2)))))))), Broadcasted(*, (1, 1)))), Broadcasted(Base.literal_pow, (Base.RefValue{typeof(^)}(^), 1, Base.RefValue{Val{3}}(Val{3}())))))
784784
@test @inferred(Broadcast.materialize(Broadcast.flatten(bc))) == @inferred(Broadcast.materialize(bc)) == 2
785+
# @. 1 + 1 * (1 + 1 + 1 + 1)
786+
bc = Broadcasted(+, (1, Broadcasted(*, (1, Broadcasted(+, (1, 1, 1, 1))))))
787+
@test @inferred(_cat_nested(bc)) == (1,1,1,1,1,1) # `cat_nested` failed to infer this
785788
end
786789

787790
let

0 commit comments

Comments
 (0)