-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathlazybroadcasting.jl
96 lines (73 loc) · 4.1 KB
/
lazybroadcasting.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
struct LazyArrayStyle{N} <: AbstractArrayStyle{N} end
LazyArrayStyle(::Val{N}) where N = LazyArrayStyle{N}()
LazyArrayStyle{M}(::Val{N}) where {N,M} = LazyArrayStyle{N}()
struct BroadcastArray{T, N, BRD<:Broadcasted} <: AbstractArray{T, N}
broadcasted::BRD
end
BroadcastArray{T,N}(bc::BRD) where {T,N,BRD<:Broadcasted} = BroadcastArray{T,N,BRD}(bc)
BroadcastArray{T}(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Tuple{Vararg{Any,N}},<:Any,<:Tuple}) where {T,N} =
BroadcastArray{T,N}(bc)
_broadcast2broadcastarray(a, b...) = tuple(a, b...)
_broadcast2broadcastarray(a::Broadcasted, b...) = tuple(BroadcastArray(a), b...)
_BroadcastArray(bc::Broadcasted) = BroadcastArray{combine_eltypes(bc.f, bc.args)}(bc)
BroadcastArray(bc::Broadcasted{S}) where S =
_BroadcastArray(instantiate(Broadcasted{S}(bc.f, _broadcast2broadcastarray(bc.args...))))
BroadcastArray(b::BroadcastArray) = b
BroadcastArray(f, A, As...) = BroadcastArray(broadcasted(f, A, As...))
axes(A::BroadcastArray) = axes(A.broadcasted)
size(A::BroadcastArray) = map(length, axes(A))
IndexStyle(::BroadcastArray{<:Any,1}) = IndexLinear()
@propagate_inbounds getindex(A::BroadcastArray, kj::Int...) = A.broadcasted[kj...]
@propagate_inbounds _broadcast_getindex_range(A::Union{Ref,AbstractArray{<:Any,0},Number}, I) = A[] # Scalar-likes can just ignore all indices
# Everything else falls back to dynamically dropping broadcasted indices based upon its axes
@propagate_inbounds _broadcast_getindex_range(A, I) = A[I]
getindex(B::BroadcastArray{<:Any,1}, kr::AbstractVector{<:Integer}) =
BroadcastArray(B.broadcasted.f, map(a -> _broadcast_getindex_range(a,kr), B.broadcasted.args)...)
copy(bc::Broadcasted{<:LazyArrayStyle}) = BroadcastArray(bc)
# issue 16: sum(b, dims=(1,2,3)) faster than sum(b)
Base._sum(b::BroadcastArray{T,N}, ::Colon) where {T,N} = first(Base._sum(b, ntuple(identity, N)))
Base._prod(b::BroadcastArray{T,N}, ::Colon) where {T,N} = first(Base._prod(b, ntuple(identity, N)))
BroadcastStyle(::Type{<:BroadcastArray{<:Any,N}}) where N = LazyArrayStyle{N}()
BroadcastStyle(L::LazyArrayStyle{N}, ::StaticArrayStyle{N}) where N = L
BroadcastStyle(::StaticArrayStyle{N}, L::LazyArrayStyle{N}) where N = L
"""
BroadcastLayout(f, layouts)
is returned by `MemoryLayout(A)` if a matrix `A` is a `BroadcastArray`.
`f` is a function that broadcast operation is applied and `layouts` is
a tuple of `MemoryLayout` of the broadcasted arguments.
"""
struct BroadcastLayout{F, LAY} <: MemoryLayout
f::F
layouts::LAY
end
MemoryLayout(A::BroadcastArray) = BroadcastLayout(A.broadcasted.f, MemoryLayout.(A.broadcasted.args))
## scalar-range broadcast operations ##
# Ranges already support smart broadcasting
for op in (+, -, big)
@eval begin
broadcasted(::LazyArrayStyle{1}, ::typeof($op), r::AbstractRange) =
broadcast(DefaultArrayStyle{1}(), $op, r)
end
end
for op in (-, +, *, /)
@eval broadcasted(::LazyArrayStyle{1}, ::typeof($op), r::AbstractRange, x::Real) =
broadcast(DefaultArrayStyle{1}(), $op, r, x)
end
for op in (-, +, *, \)
@eval broadcasted(::LazyArrayStyle{1}, ::typeof($op), x::Real, r::AbstractRange) =
broadcast(DefaultArrayStyle{1}(), $op, x, r)
end
broadcasted(::LazyArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N} =
broadcast(DefaultArrayStyle{N}(), op, r)
broadcasted(::LazyArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} =
broadcast(DefaultArrayStyle{N}(), op, r, x)
broadcasted(::LazyArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} =
broadcast(DefaultArrayStyle{N}(), op, x, r)
broadcasted(::LazyArrayStyle{N}, op, r1::AbstractFill{T,N}, r2::AbstractFill{V,N}) where {T,V,N} =
broadcast(DefaultArrayStyle{N}(), op, r1, r2)
broadcasted(::LazyArrayStyle{N}, ::typeof(*), a::Zeros{T,N}, b::Zeros{V,N}) where {T,V,N} =
broadcast(DefaultArrayStyle{N}(), *, a, b)
broadcasted(::LazyArrayStyle{N}, ::typeof(*), a::AbstractArray{T,N}, b::Zeros{V,N}) where {T,V,N} =
broadcast(DefaultArrayStyle{N}(), *, a, b)
broadcasted(::LazyArrayStyle{N}, ::typeof(*), a::Zeros{T,N}, b::AbstractArray{V,N}) where {T,V,N} =
broadcast(DefaultArrayStyle{N}(), *, a, b)