Skip to content

Commit 541c3bd

Browse files
authored
Merge pull request #492 from Rabab53/matmul
DArray: Implement efficient in-place matrix-matrix multiply
2 parents 10e1307 + 3c168b2 commit 541c3bd

26 files changed

+1182
-337
lines changed

Diff for: .buildkite/pipeline.yml

+8-7
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
queue: "juliaecosystem"
55
sandbox_capable: "true"
66
os: linux
7+
arch: x86_64
78
command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\")'"
89
.bench: &bench
910
if: build.message =~ /\[run benchmarks\]/
@@ -15,7 +16,7 @@
1516
num_cpus: 16
1617
steps:
1718
- label: Julia 1.8
18-
timeout_in_minutes: 60
19+
timeout_in_minutes: 90
1920
<<: *test
2021
plugins:
2122
- JuliaCI/julia#v1:
@@ -25,7 +26,7 @@ steps:
2526
- JuliaCI/julia-coverage#v1:
2627
codecov: true
2728
- label: Julia 1.9
28-
timeout_in_minutes: 60
29+
timeout_in_minutes: 90
2930
<<: *test
3031
plugins:
3132
- JuliaCI/julia#v1:
@@ -35,7 +36,7 @@ steps:
3536
- JuliaCI/julia-coverage#v1:
3637
codecov: true
3738
- label: Julia 1.10
38-
timeout_in_minutes: 60
39+
timeout_in_minutes: 90
3940
<<: *test
4041
plugins:
4142
- JuliaCI/julia#v1:
@@ -45,7 +46,7 @@ steps:
4546
- JuliaCI/julia-coverage#v1:
4647
codecov: true
4748
- label: Julia nightly
48-
timeout_in_minutes: 60
49+
timeout_in_minutes: 90
4950
<<: *test
5051
plugins:
5152
- JuliaCI/julia#v1:
@@ -55,7 +56,7 @@ steps:
5556
- JuliaCI/julia-coverage#v1:
5657
codecov: true
5758
- label: Julia 1.8 (macOS)
58-
timeout_in_minutes: 60
59+
timeout_in_minutes: 90
5960
<<: *test
6061
agents:
6162
queue: "juliaecosystem"
@@ -69,7 +70,7 @@ steps:
6970
- JuliaCI/julia-coverage#v1:
7071
codecov: true
7172
- label: Julia 1.8 - TimespanLogging
72-
timeout_in_minutes: 60
73+
timeout_in_minutes: 20
7374
<<: *test
7475
plugins:
7576
- JuliaCI/julia#v1:
@@ -78,7 +79,7 @@ steps:
7879
codecov: true
7980
command: "julia --project -e 'using Pkg; Pkg.instantiate(); Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.test(\"TimespanLogging\")'"
8081
- label: Julia 1.8 - DaggerWebDash
81-
timeout_in_minutes: 60
82+
timeout_in_minutes: 20
8283
<<: *test
8384
plugins:
8485
- JuliaCI/julia#v1:

Diff for: Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ SharedArrays = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
2020
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2121
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2222
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
23+
TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
2324
TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63"
2425
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2526

@@ -34,6 +35,7 @@ Requires = "1"
3435
ScopedValues = "1.1"
3536
Statistics = "1"
3637
StatsBase = "0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
38+
TaskLocalValues = "0.1"
3739
TimespanLogging = "0.1"
3840
julia = "1.8"
3941

Diff for: docs/src/darray.md

+10-7
Original file line numberDiff line numberDiff line change
@@ -134,12 +134,7 @@ julia> DZ = DY .* 3
134134
Dagger.DArray{Float64, 2, Blocks{2}, typeof(cat)}(100, 100)
135135
```
136136

137-
Now, `DZ` will contain the result of computing `(DX .+ DX) .* 3`. Note that
138-
`DArray` objects are immutable, and operations on them are thus functional
139-
transformations of their input `DArray`.
140-
141-
!!! note
142-
Support for mutation of `DArray`s is planned for a future release
137+
Now, `DZ` will contain the result of computing `(DX .+ DX) .* 3`.
143138

144139
```
145140
julia> Dagger.chunks(DZ)
@@ -208,15 +203,17 @@ julia> collect(DZ)
208203
```
209204

210205
A variety of other operations exist on the `DArray`, and it should generally
211-
behavior otherwise similar to any other `AbstractArray` type. If you find that
206+
behave otherwise similar to any other `AbstractArray` type. If you find that
212207
it's missing an operation that you need, please file an issue!
213208

214209
### Known Supported Operations
215210

216211
This list is not exhaustive, but documents operations which are known to work well with the `DArray`:
217212

218213
From `Base`:
214+
- `getindex`/`setindex!`
219215
- Broadcasting
216+
- `similar`/`copy`/`copyto!`
220217
- `map`/`reduce`/`mapreduce`
221218
- `sum`/`prod`
222219
- `minimum`/`maximum`/`extrema`
@@ -225,3 +222,9 @@ From `Statistics`:
225222
- `mean`
226223
- `var`
227224
- `std`
225+
226+
From `LinearAlgebra`:
227+
- `transpose`/`adjoint` (Out-of-place transpose)
228+
- `*` (Out-of-place Matrix-(Matrix/Vector) multiply)
229+
- `mul!` (In-place Matrix-Matrix multiply)
230+
- `cholesky`/`cholesky!` (In-place/Out-of-place Cholesky factorization)

Diff for: src/Dagger.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,19 @@ include("datadeps.jl")
6060
include("array/darray.jl")
6161
include("array/alloc.jl")
6262
include("array/map-reduce.jl")
63+
include("array/copy.jl")
6364

6465
# File IO
6566
include("file-io.jl")
6667

6768
include("array/operators.jl")
68-
include("array/getindex.jl")
69+
include("array/indexing.jl")
6970
include("array/setindex.jl")
7071
include("array/matrix.jl")
7172
include("array/sparse_partition.jl")
7273
include("array/sort.jl")
7374
include("array/linalg.jl")
75+
include("array/mul.jl")
7476
include("array/cholesky.jl")
7577

7678
# Visualization

Diff for: src/array/copy.jl

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copy Buffering
2+
3+
function maybe_copy_buffered(f, args...)
4+
@assert all(arg->arg isa Pair{<:DArray,<:Blocks}, args) "maybe_copy_buffered only supports `DArray`=>`Blocks`"
5+
if any(arg_part->arg_part[1].partitioning != arg_part[2], args)
6+
return copy_buffered(f, args...)
7+
else
8+
return f(map(first, args)...)
9+
end
10+
end
11+
function copy_buffered(f, args...)
12+
real_args = map(arg_part->arg_part[1], args)
13+
buffered_args = map(arg_part->allocate_copy_buffer(arg_part[2], arg_part[1]), args)
14+
for (buf_arg, arg) in zip(buffered_args, real_args)
15+
copyto!(buf_arg, arg)
16+
end
17+
result = f(buffered_args...)
18+
for (buf_arg, arg) in zip(buffered_args, real_args)
19+
copyto!(arg, buf_arg)
20+
end
21+
return result
22+
end
23+
function allocate_copy_buffer(part::Blocks{N}, A::DArray{T,N}) where {T,N}
24+
# FIXME: undef initializer
25+
return zeros(part, T, size(A))
26+
end
27+
function Base.copyto!(B::DArray{T,N}, A::DArray{T,N}) where {T,N}
28+
if size(B) != size(A)
29+
throw(DimensionMismatch("Cannot copy from array of size $(size(A)) to array of size $(size(B))"))
30+
end
31+
32+
Bc = B.chunks
33+
Ac = A.chunks
34+
Asd_all = A.subdomains::DomainBlocks{N}
35+
36+
Dagger.spawn_datadeps() do
37+
for Bidx in CartesianIndices(Bc)
38+
Bpart = Bc[Bidx]
39+
Bsd = B.subdomains[Bidx]
40+
41+
# Find the first overlapping subdomain of A
42+
if A.partitioning isa Blocks
43+
Aidx = CartesianIndex(ntuple(i->fld1(Bsd.indexes[i].start, A.partitioning.blocksize[i]), N))
44+
else
45+
# Fallback just in case of non-dense partitioning
46+
Aidx = first(CartesianIndices(Ac))
47+
Asd = first(Asd_all)
48+
for dim in 1:N
49+
while Asd.indexes[dim].stop < Bsd.indexes[dim].start
50+
Aidx += CartesianIndex(ntuple(i->i==dim, N))
51+
Asd = Asd_all[Aidx]
52+
end
53+
end
54+
end
55+
Aidx_start = Aidx
56+
57+
# Find the last overlapping subdomain of A
58+
for dim in 1:N
59+
while true
60+
Aidx_next = Aidx + CartesianIndex(ntuple(i->i==dim, N))
61+
if !(Aidx_next in CartesianIndices(Ac))
62+
break
63+
end
64+
Asd_next = Asd_all[Aidx_next]
65+
if Asd_next.indexes[dim].start <= Bsd.indexes[dim].stop
66+
Aidx = Aidx_next
67+
else
68+
break
69+
end
70+
end
71+
end
72+
Aidx_end = Aidx
73+
74+
# Find the span and set of subdomains of A overlapping Bpart
75+
Aidx_span = Aidx_start:Aidx_end
76+
Asd_view = view(A.subdomains, Aidx_span)
77+
78+
# Copy all overlapping subdomains of A
79+
for Aidx in Aidx_span
80+
Asd = Asd_all[Aidx]
81+
Apart = Ac[Aidx]
82+
83+
# Compute the true range
84+
range_start = CartesianIndex(ntuple(i->max(Bsd.indexes[i].start, Asd.indexes[i].start), N))
85+
range_end = CartesianIndex(ntuple(i->min(Bsd.indexes[i].stop, Asd.indexes[i].stop), N))
86+
range_diff = range_end - range_start
87+
88+
# Compute the offset range into Apart
89+
Asd_start = ntuple(i->Asd.indexes[i].start, N)
90+
Asd_end = ntuple(i->Asd.indexes[i].stop, N)
91+
Arange = range(range_start - CartesianIndex(Asd_start) + CartesianIndex{N}(1),
92+
range_start - CartesianIndex(Asd_start) + CartesianIndex{N}(1) + range_diff)
93+
94+
# Compute the offset range into Bpart
95+
Bsd_start = ntuple(i->Bsd.indexes[i].start, N)
96+
Bsd_end = ntuple(i->Bsd.indexes[i].stop, N)
97+
Brange = range(range_start - CartesianIndex(Bsd_start) + CartesianIndex{N}(1),
98+
range_start - CartesianIndex(Bsd_start) + CartesianIndex{N}(1) + range_diff)
99+
100+
# Perform view copy
101+
Dagger.@spawn copyto_view!(Out(Bpart), Brange, In(Apart), Arange)
102+
end
103+
end
104+
end
105+
106+
return B
107+
end
108+
function copyto_view!(Bpart, Brange, Apart, Arange)
109+
copyto!(view(Bpart, Brange), view(Apart, Arange))
110+
return
111+
end

Diff for: src/array/darray.jl

+44-16
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,20 @@ import Serialization: serialize, deserialize
99
1010
An `N`-dimensional domain over an array.
1111
"""
12-
struct ArrayDomain{N}
13-
indexes::NTuple{N, Any}
12+
struct ArrayDomain{N,T<:Tuple}
13+
indexes::T
1414
end
15-
include("../lib/domain-blocks.jl")
16-
1715

18-
ArrayDomain(xs...) = ArrayDomain(xs)
16+
ArrayDomain(xs::T) where T<:Tuple = ArrayDomain{length(xs),T}(xs)
17+
ArrayDomain(xs::NTuple{N,Base.OneTo}) where N =
18+
ArrayDomain{N,NTuple{N,UnitRange{Int}}}(ntuple(i->UnitRange(xs[i]), N))
19+
ArrayDomain(xs::NTuple{N,Int}) where N =
20+
ArrayDomain{N,NTuple{N,UnitRange{Int}}}(ntuple(i->xs[i]:xs[i], N))
21+
ArrayDomain(xs...) = ArrayDomain((xs...,))
1922
ArrayDomain(xs::Array) = ArrayDomain((xs...,))
2023

24+
include("../lib/domain-blocks.jl")
25+
2126
indexes(a::ArrayDomain) = a.indexes
2227
chunks(a::ArrayDomain{N}) where {N} = DomainBlocks(
2328
ntuple(i->first(indexes(a)[i]), Val(N)), map(x->[length(x)], indexes(a)))
@@ -117,6 +122,7 @@ indicates the number of dimensions in the resulting array.
117122
"""
118123
Blocks(xs::Int...) = Blocks(xs)
119124

125+
const DArrayDomain{N} = ArrayDomain{N, NTuple{N, UnitRange{Int}}}
120126

121127
"""
122128
DArray{T,N,F}(domain, subdomains, chunks, concat)
@@ -133,8 +139,8 @@ An N-dimensional distributed array of element type T, with a concatenation funct
133139
and concatenates them along dimension `d`. `cat` is used by default.
134140
"""
135141
mutable struct DArray{T,N,B<:AbstractBlocks{N},F} <: ArrayOp{T, N}
136-
domain::ArrayDomain{N}
137-
subdomains::AbstractArray{ArrayDomain{N}, N}
142+
domain::DArrayDomain{N}
143+
subdomains::AbstractArray{DArrayDomain{N}, N}
138144
chunks::AbstractArray{Any, N}
139145
partitioning::B
140146
concat::F
@@ -143,20 +149,27 @@ mutable struct DArray{T,N,B<:AbstractBlocks{N},F} <: ArrayOp{T, N}
143149
end
144150
end
145151

152+
WrappedDArray{T,N} = Union{<:DArray{T,N}, Transpose{<:DArray{T,N}}, Adjoint{<:DArray{T,N}}}
153+
WrappedDMatrix{T} = WrappedDArray{T,2}
154+
WrappedDVector{T} = WrappedDArray{T,1}
155+
DMatrix{T} = DArray{T,2}
156+
DVector{T} = DArray{T,1}
157+
158+
146159
# mainly for backwards-compatibility
147160
DArray{T, N}(domain, subdomains, chunks, partitioning, concat=cat) where {T,N} =
148161
DArray(T, domain, subdomains, chunks, partitioning, concat)
149162

150-
function DArray(T, domain::ArrayDomain{N},
151-
subdomains::AbstractArray{ArrayDomain{N}, N},
152-
chunks::AbstractArray{<:Any, N}, partitioning::B, concat=cat) where {N,B<:AbstractMultiBlocks{N}}
163+
function DArray(T, domain::DArrayDomain{N},
164+
subdomains::AbstractArray{DArrayDomain{N}, N},
165+
chunks::AbstractArray{<:Any, N}, partitioning::B, concat=cat) where {N,B<:AbstractBlocks{N}}
153166
DArray{T,N,B,typeof(concat)}(domain, subdomains, chunks, partitioning, concat)
154167
end
155168

156-
function DArray(T, domain::ArrayDomain{N},
157-
subdomains::ArrayDomain{N},
169+
function DArray(T, domain::DArrayDomain{N},
170+
subdomains::DArrayDomain{N},
158171
chunks::Any, partitioning::B, concat=cat) where {N,B<:AbstractSingleBlocks{N}}
159-
_subdomains = Array{ArrayDomain{N}, N}(undef, ntuple(i->1, N)...)
172+
_subdomains = Array{DArrayDomain{N}, N}(undef, ntuple(i->1, N)...)
160173
_subdomains[1] = subdomains
161174
_chunks = Array{Any, N}(undef, ntuple(i->1, N)...)
162175
_chunks[1] = chunks
@@ -201,6 +214,13 @@ function Base.similar(x::DArray{T,N}) where {T,N}
201214
return DArray(T, x.domain, x.subdomains, thunks, x.partitioning, x.concat)
202215
end
203216

217+
function Base.similar(A::DArray{T,N} where T, ::Type{S}, dims::Dims{N}) where {S,N}
218+
d = ArrayDomain(map(x->1:x, dims))
219+
p = A.partitioning
220+
a = AllocateArray(S, (_, _, x...) -> Array{S,N}(undef, x...), d, partition(p, d), p)
221+
return _to_darray(a)
222+
end
223+
204224
Base.copy(x::DArray{T,N,B,F}) where {T,N,B,F} =
205225
map(identity, x)::DArray{T,N,B,F}
206226

@@ -214,7 +234,7 @@ Base.:(/)(x::DArray{T,N,B,F}, y::U) where {T<:Real,U<:Real,N,B,F} =
214234
A `view` of a `DArray` chunk returns a `DArray` of `Thunk`s.
215235
"""
216236
function Base.view(c::DArray, d)
217-
subchunks, subdomains = lookup_parts(chunks(c), domainchunks(c), d)
237+
subchunks, subdomains = lookup_parts(c, chunks(c), domainchunks(c), d)
218238
d1 = alignfirst(d)
219239
DArray(eltype(c), d1, subdomains, subchunks, c.partitioning, c.concat)
220240
end
@@ -252,7 +272,7 @@ function group_indices(cumlength, idxs::AbstractRange)
252272
end
253273

254274
_cumsum(x::AbstractArray) = length(x) == 0 ? Int[] : cumsum(x)
255-
function lookup_parts(ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}) where N
275+
function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{N}) where N
256276
groups = map(group_indices, subdmns.cumlength, indexes(d))
257277
sz = map(length, groups)
258278
pieces = Array{Any}(undef, sz)
@@ -264,7 +284,15 @@ function lookup_parts(ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomai
264284
end
265285
out_cumlength = map(g->_cumsum(map(x->length(x[2]), g)), groups)
266286
out_dmn = DomainBlocks(ntuple(x->1,Val(N)), out_cumlength)
267-
pieces, out_dmn
287+
return pieces, out_dmn
288+
end
289+
function lookup_parts(A::DArray, ps::AbstractArray, subdmns::DomainBlocks{N}, d::ArrayDomain{S}) where {N,S}
290+
if S != 1
291+
throw(BoundsError(A, d.indexes))
292+
end
293+
inds = CartesianIndices(A)[d.indexes...]
294+
new_d = ntuple(i->first(inds).I[i]:last(inds).I[i], N)
295+
return lookup_parts(A, ps, subdmns, ArrayDomain(new_d))
268296
end
269297

270298
"""

0 commit comments

Comments
 (0)