Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify computation of return type in broadcast #39295

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -682,15 +682,6 @@ broadcastable(x::Union{AbstractArray,Number,Ref,Tuple,Broadcasted}) = x
broadcastable(x) = collect(x)
broadcastable(::Union{AbstractDict, NamedTuple}) = throw(ArgumentError("broadcasting over dictionaries and `NamedTuple`s is reserved"))

## Computation of inferred result type, for empty and concretely inferred cases only
_broadcast_getindex_eltype(bc::Broadcasted) = Base._return_type(bc.f, eltypes(bc.args))
_broadcast_getindex_eltype(A) = eltype(A) # Tuple, Array, etc.

eltypes(::Tuple{}) = Tuple{}
eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])}
eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])}
eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...}

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
Expand Down Expand Up @@ -735,10 +726,6 @@ end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Inferred eltype of result of broadcast(f, args...)
combine_eltypes(f, args::Tuple) =
promote_typejoin_union(Base._return_type(f, eltypes(args)))

## Broadcasting core

"""
Expand Down Expand Up @@ -901,7 +888,8 @@ copy(bc::Broadcasted{<:Union{Nothing,Unknown}}) =
const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}

@inline function copy(bc::Broadcasted{Style}) where {Style}
ElType = combine_eltypes(bc.f, bc.args)
ElType = promote_typejoin_union(Base._return_type(_broadcast_getindex,
Tuple{typeof(bc), Int}))
Copy link
Member

@mbauman mbauman Jan 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this needs to be:

Suggested change
Tuple{typeof(bc), Int}))
Tuple{typeof(bc), ndims(bc) == 1 ? eltype(axes(bc)[1]) : CartesianIndex{ndims(bc)}})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or is it Base._return_type(iterate, Base._return_type(eachindex, Tuple{typeof(bc)})) ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, dropped a function. I meant:

index_type(bc) = iterate(eachindex(bc))[1]
Base._return_type(index_type, Tuple{typeof(bc)})

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting that all together:

Suggested change
Tuple{typeof(bc), Int}))
_broadcast_getindex_eltype(bc) = _broadcast_getindex(bc, iterate(eachindex(bc))[1])
ElType = promote_typejoin_union(Base._return_type(_broadcast_getindex_eltype, Tuple{typeof(bc)}))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure this is better than the existing code, which does pretty much the same calls, but bases it on calling eltype, instead of inference, which has at least different tradeoffs for better or worse 🤔

Copy link
Member

@vtjnash vtjnash Apr 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to try to proceed with this PR / design (inferring iterate), or keep the current one (call eltype)?

if Base.isconcretetype(ElType)
# We can trust it and defer to the simpler `copyto!`
return copyto!(similar(bc, ElType), bc)
Expand Down
9 changes: 5 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,8 @@ StrangeType18623(x,y) = (x,y)
let
f(A, n) = broadcast(x -> +(x, n), A)
@test @inferred(f([1.0], 1)) == [2.0]
g() = (a = 1; Broadcast.combine_eltypes(x -> x + a, (1.0,)))
@test @inferred(g()) === Float64
g() = (a = 1; x -> x + a)
@test @inferred(broadcast(g(), 1.0)) === 2.0
Comment on lines -420 to +421
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pabloferz @Sacha0 Since you worked on these tests (this one and the one below), could you confirm that the new ones covers the same use case as the old ones? That wasn't completely clear to me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regrettably sufficient time has elapsed since I looked at these tests and such that I no longer have much memory of them. Sorry Milan! :)

end

# Ref as 0-dimensional array for broadcast
Expand Down Expand Up @@ -576,8 +576,9 @@ end

# Test that broadcast's promotion mechanism handles closures accepting more than one argument.
# (See issue #19641 and referenced issues and pull requests.)
let f() = (a = 1; Broadcast.combine_eltypes((x, y) -> x + y + a, (1.0, 1.0)))
@test @inferred(f()) == Float64
let
f() = (a = 1; (x, y) -> x + y + a)
@test @inferred(broadcast(f(), 1.0, 1.0)) === 3.0
end

@testset "broadcast resulting in BitArray" begin
Expand Down