-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathadd.jl
53 lines (40 loc) · 1.49 KB
/
add.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
####
# These are special routines to make operations involving +
# more efficient
####
const Add{Factors<:Tuple} = Applied{<:Any, typeof(+), Factors}
size(M::Add, p::Int) = size(M)[p]
axes(M::Add, p::Int) = axes(M)[p]
ndims(M::Add) = ndims(first(M.args))
length(M::Add) = prod(size(M))
size(M::Add) = length.(axes(M))
axes(M::Add) = axes(first(M.args))
eltype(M::Add) = Base._return_type(+, eltype.(M.args))
const AddArray{T,N,Factors<:Tuple} = ApplyArray{T,N,<:Add{Factors}}
const AddVector{T,Factors<:Tuple} = AddArray{T,1,Factors}
const AddMatrix{T,Factors<:Tuple} = AddArray{T,2,Factors}
AddArray(factors...) = ApplyArray(+, factors...)
"""
Add(A1, A2, …, AN)
A lazy representation of `A1 + A2 + … + AN`; i.e., a shorthand for `applied(+, A1, A2, …, AN)`.
"""
Add(As...) = applied(+, As...)
getindex(M::Add, k::Integer) = sum(getindex.(M.args, k))
getindex(M::Add, k::Integer, j::Integer) = sum(getindex.(M.args, k, j))
getindex(M::Add, k::CartesianIndex{1}) = M[convert(Int, k)]
getindex(M::Add, kj::CartesianIndex{2}) = M[kj[1], kj[2]]
for MulAdd_ in [MatMulMatAdd, MatMulVecAdd]
# `MulAdd{<:ApplyLayout{typeof(+)}}` cannot "win" against
# `MatMulMatAdd` and `MatMulVecAdd` hence `@eval`:
@eval function materialize!(M::$MulAdd_{<:ApplyLayout{typeof(+)}})
α, A, B, β, C = M.α, M.A, M.B, M.β, M.C
if C ≡ B
B = copy(B)
end
lmul!(β, C)
for A in A.applied.args
C .= α .* Mul(A, B) .+ C
end
C
end
end