Skip to content

Commit a5640ff

Browse files
Rabab53jpsamaroo
andcommitted
DArray: Implement in-place matrix-matrix multiply
Co-authored-by: Julian P Samaroo <[email protected]>
1 parent 26fc023 commit a5640ff

File tree

4 files changed

+436
-2
lines changed

4 files changed

+436
-2
lines changed

Diff for: src/Dagger.jl

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ include("array/matrix.jl")
7171
include("array/sparse_partition.jl")
7272
include("array/sort.jl")
7373
include("array/linalg.jl")
74+
include("array/mul.jl")
7475
include("array/cholesky.jl")
7576

7677
# Visualization

Diff for: src/array/matrix.jl

+1-2
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ end
3333
size(x::MatMul) = mul_size(x.a, x.b)
3434
MatMul(a,b) =
3535
MatMul{promote_type(eltype(a), eltype(b)), length(mul_size(a,b))}(a,b)
36-
(*)(a::ArrayOp, b::ArrayOp) = _to_darray(MatMul(a,b))
36+
3737
# Bonus method for matrix-vector multiplication
3838
(*)(a::ArrayOp, b::Vector) = _to_darray(MatMul(a,PromotePartition(b)))
39-
(*)(a::AbstractArray, b::ArrayOp) = _to_darray(MatMul(PromotePartition(a), b))
4039

4140
function (*)(a::ArrayDomain{2}, b::ArrayDomain{2})
4241
if size(a, 2) != size(b, 1)

Diff for: src/array/mul.jl

+350
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
function LinearAlgebra.generic_matmatmul!(
2+
C::DMatrix{T},
3+
transA::Char,
4+
transB::Char,
5+
A::DMatrix{T},
6+
B::DMatrix{T},
7+
_add::LinearAlgebra.MulAddMul,
8+
) where {T}
9+
if all(in(('N', 'T', 'C')), (transA, transB))
10+
if (transA == 'T' || transA == 'C') && transB == 'N' && A === B
11+
return syrk_dagger!(C, transA, A, _add)
12+
elseif transA == 'N' && (transB == 'T' || transB == 'C') && A === B
13+
return syrk_dagger!(C, transA, A, _add)
14+
else
15+
return gemm_dagger!(C, transA, transB, A, B, _add)
16+
end
17+
end
18+
19+
# FIXME Add symm and hemm implementation (please note hemm will be inside symm as the case for syrk)
20+
21+
return gemm_dagger!(C, transA, transB, A, B, _add)
22+
end
23+
24+
"""
25+
Performs one of the matrix-matrix operations
26+
27+
C = alpha [op( A ) * op( B )] + beta C,
28+
29+
where op( X ) is one of
30+
31+
op( X ) = X or op( X ) = X' or op( X ) = g( X' )
32+
33+
alpha and beta are scalars, and A, B and C are matrices, with op( A )
34+
an m by k matrix, op( B ) a k by n matrix and C an m by n matrix.
35+
"""
36+
function gemm_dagger!(
37+
C::DMatrix{T},
38+
transA::Char,
39+
transB::Char,
40+
A::DMatrix{T},
41+
B::DMatrix{T},
42+
_add::LinearAlgebra.MulAddMul,
43+
) where {T}
44+
Ac = A.chunks
45+
Bc = B.chunks
46+
Cc = C.chunks
47+
Amt, Ant = size(Ac)
48+
Bmt, Bnt = size(Bc)
49+
Cmt, Cnt = size(Cc)
50+
51+
alpha = _add.alpha
52+
beta = _add.beta
53+
#=
54+
if Ant != Bmt
55+
throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but B has number of blocks ($Bmt,$Bnt)"))
56+
end
57+
=#
58+
59+
Dagger.spawn_datadeps() do
60+
for m in range(1, Cmt)
61+
for n in range(1, Cnt)
62+
if transA == 'N'
63+
if transB == 'N'
64+
# A: NoTrans / B: NoTrans
65+
for k in range(1, Amt)
66+
mzone = k == 1 ? beta : T(1.0)
67+
Dagger.@spawn BLAS.gemm!(
68+
transA,
69+
transB,
70+
alpha,
71+
In(Ac[m, k]),
72+
In(Bc[k, n]),
73+
mzone,
74+
InOut(Cc[m, n]),
75+
)
76+
end
77+
else
78+
# A: NoTrans / B: [Conj]Trans
79+
for k in range(1, Amt)
80+
mzone = k == 1 ? beta : T(1.0)
81+
Dagger.@spawn BLAS.gemm!(
82+
transA,
83+
transB,
84+
alpha,
85+
In(Ac[m, k]),
86+
In(Bc[n, k]),
87+
mzone,
88+
InOut(Cc[m, n]),
89+
)
90+
end
91+
end
92+
else
93+
if transB == 'N'
94+
# A: [Conj]Trans / B: NoTrans
95+
for k in range(1, Amt)
96+
mzone = k == 1 ? beta : T(1.0)
97+
Dagger.@spawn BLAS.gemm!(
98+
transA,
99+
transB,
100+
alpha,
101+
In(Ac[k, m]),
102+
In(Bc[k, n]),
103+
mzone,
104+
InOut(Cc[m, n]),
105+
)
106+
end
107+
else
108+
# A: [Conj]Trans / B: [Conj]Trans
109+
for k in range(1, Amt)
110+
mzone = k == 1 ? beta : T(1.0)
111+
Dagger.@spawn BLAS.gemm!(
112+
transA,
113+
transB,
114+
alpha,
115+
In(Ac[k, m]),
116+
In(Bc[n, k]),
117+
mzone,
118+
InOut(Cc[m, n]),
119+
)
120+
end
121+
end
122+
end
123+
end
124+
end
125+
end
126+
127+
return C
128+
end
129+
130+
"""
131+
Performs one of the symmetric/hermitian rank k operations
132+
133+
C = alpha [ op( A ) * g( op( A )' )] + beta C,
134+
135+
where op( X ) is one of
136+
137+
op( X ) = X or op( X ) = g( X' )
138+
139+
where alpha and beta are real scalars, C is an n-by-n symmetric/hermitian
140+
matrix and A is an n-by-k matrix in the first case and a k-by-n
141+
matrix in the second case.
142+
"""
143+
function syrk_dagger!(
144+
C::DMatrix{T},
145+
trans::Char,
146+
A::DMatrix{T},
147+
_add::LinearAlgebra.MulAddMul,
148+
) where {T}
149+
150+
Ac = A.chunks
151+
Cc = C.chunks
152+
Amt, Ant = size(Ac)
153+
Cmt, Cnt = size(Cc)
154+
155+
alpha = _add.alpha
156+
beta = _add.beta
157+
158+
uplo = 'L'
159+
#=
160+
if Ant != Bmt
161+
throw(DimensionMismatch(lazy"A has number of blocks ($Amt,$Ant) but B has number of blocks ($Bmt,$Bnt)"))
162+
end
163+
=#
164+
165+
iscomplex = T <: Complex
166+
transs = iscomplex ? 'C' : 'T'
167+
168+
Dagger.spawn_datadeps() do
169+
for n in range(1, Cnt)
170+
if trans == 'N'
171+
# NoTrans
172+
for k in range(1, Ant)
173+
mzone = k == 1 ? real(beta) : one(real(T))
174+
if iscomplex
175+
Dagger.@spawn BLAS.herk!(
176+
uplo,
177+
trans,
178+
real(alpha),
179+
In(Ac[n, k]),
180+
mzone,
181+
InOut(Cc[n, n]),
182+
)
183+
else
184+
Dagger.@spawn BLAS.syrk!(
185+
uplo,
186+
trans,
187+
alpha,
188+
In(Ac[n, k]),
189+
mzone,
190+
InOut(Cc[n, n]),
191+
)
192+
end
193+
end
194+
if uplo == 'L'
195+
# NoTrans / Lower
196+
for m in range(n + 1, Cmt)
197+
for k in range(1, Ant)
198+
mzone = k == 1 ? beta : one(T)
199+
Dagger.@spawn BLAS.gemm!(
200+
trans,
201+
transs,
202+
alpha,
203+
In(Ac[m, k]),
204+
In(Ac[n, k]),
205+
mzone,
206+
InOut(Cc[m, n]),
207+
)
208+
end
209+
end
210+
else
211+
# NoTrans / Upper
212+
for m in range(n + 1, Cmt)
213+
for k in range(1, Ant)
214+
mzone = k == 1 ? beta : one(T)
215+
Dagger.@spawn BLAS.gemm!(
216+
trans,
217+
transs,
218+
alpha,
219+
In(Ac[n, k]),
220+
In(Ac[m, k]),
221+
mzone,
222+
InOut(Cc[n, m]),
223+
)
224+
end
225+
end
226+
end
227+
else
228+
# [Conj]Trans
229+
for k in range(1, Amt)
230+
mzone = k == 1 ? real(beta) : one(real(T))
231+
if iscomplex
232+
Dagger.@spawn BLAS.herk!(
233+
uplo,
234+
transs,
235+
real(alpha),
236+
In(Ac[k, n]),
237+
mzone,
238+
InOut(Cc[n, n]),
239+
)
240+
else
241+
Dagger.@spawn BLAS.syrk!(
242+
uplo,
243+
trans,
244+
alpha,
245+
In(Ac[k, n]),
246+
mzone,
247+
InOut(Cc[n, n]),
248+
)
249+
end
250+
end
251+
if uplo == 'L'
252+
# [Conj]Trans / Lower
253+
for m in range(n + 1, Cmt)
254+
for k in range(1, Amt)
255+
mzone = k == 1 ? beta : one(T)
256+
Dagger.@spawn BLAS.gemm!(
257+
transs,
258+
'N',
259+
alpha,
260+
In(Ac[k, m]),
261+
In(Ac[k, n]),
262+
mzone,
263+
InOut(Cc[m, n]),
264+
)
265+
end
266+
end
267+
else
268+
# [Conj]Trans / Upper
269+
for m in range(n + 1, Cmt)
270+
for k in range(1, Amt)
271+
mzone = k == 1 ? beta : one(T)
272+
Dagger.@spawn BLAS.gemm!(
273+
transs,
274+
'N',
275+
alpha,
276+
In(Ac[k, n]),
277+
In(Ac[k, m]),
278+
mzone,
279+
InOut(Cc[n, m]),
280+
)
281+
end
282+
end
283+
end
284+
end
285+
end
286+
end
287+
288+
C = copytri!(C, 'L')
289+
return C
290+
end
291+
292+
293+
# copy transposed(adjoint) of upper(lower) side-diagonals.
294+
@inline function copytri!(A::DArray{T,2}, uplo::AbstractChar) where {T}
295+
#n = checksquare(A) FIXME find replacement in DArray
296+
297+
Ac = A.chunks
298+
Amt, Ant = size(Ac)
299+
300+
Dagger.spawn_datadeps() do
301+
if uplo == 'U'
302+
for i = 1:Amt, j = (i):Amt
303+
if (i == j)
304+
Dagger.@spawn copydiagtile!(Out(Ac[j, i]), In(Ac[i, j]), uplo)
305+
else
306+
Dagger.@spawn copytile!(Out(Ac[j, i]), In(Ac[i, j]))
307+
end
308+
end
309+
elseif uplo == 'L'
310+
for i = 1:Amt, j = (i):Amt
311+
if (i == j)
312+
Dagger.@spawn copydiagtile!(Out(Ac[i, j]), In(Ac[j, i]), uplo)
313+
else
314+
Dagger.@spawn copytile!(Out(Ac[i, j]), In(Ac[j, i]))
315+
end
316+
317+
end
318+
else
319+
throw(ArgumentError(lazy"uplo argument must be 'U' (upper) or 'L' (lower), got $uplo"))
320+
end
321+
end
322+
323+
return A
324+
end
325+
326+
@inline function copytile!(A, B)
327+
m, n = size(A)
328+
329+
for i = 1:m, j = 1:n
330+
A[j, i] = B[i, j]
331+
end
332+
end
333+
334+
@inline function copydiagtile!(A, B, uplo)
335+
m, n = size(A)
336+
337+
if uplo == 'U'
338+
for i = 1:m, j = 1:n
339+
if j >= i
340+
A[j, i] = B[i, j]
341+
end
342+
end
343+
elseif uplo == 'L'
344+
for i = 1:m, j = 1:n
345+
if j <= i
346+
A[j, i] = B[i, j]
347+
end
348+
end
349+
end
350+
end

0 commit comments

Comments
 (0)