Skip to content

Commit 67a7435

Browse files
authored
Merge pull request #546 from pszufe/psz/daggerrand
Add rand! implementation
2 parents 95febd2 + c65328a commit 67a7435

14 files changed

+255
-22
lines changed

Diff for: Project.toml

+4
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2828
[weakdeps]
2929
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
3030
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
31+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3132
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
3233
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
3334
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
3435

3536
[extensions]
37+
DistributionsExt = "Distributions"
3638
GraphVizExt = "GraphViz"
3739
GraphVizSimpleExt = "Colors"
3840
JSON3Ext = "JSON3"
@@ -43,6 +45,7 @@ Adapt = "4.0.4"
4345
Colors = "0.12"
4446
DataFrames = "1"
4547
DataStructures = "0.18"
48+
Distributions = "0.25"
4649
GraphViz = "0.2"
4750
Graphs = "1"
4851
MacroTools = "0.5"
@@ -61,6 +64,7 @@ julia = "1.8"
6164
[extras]
6265
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
6366
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
67+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
6468
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
6569
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
6670
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

Diff for: docs/src/darray.md

+4
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,10 @@ From `Base`:
431431
- `map`/`reduce`/`mapreduce`
432432
- `sum`/`prod`
433433
- `minimum`/`maximum`/`extrema`
434+
- `map!`
435+
436+
From `Random`:
437+
- `rand!`/`randn!`
434438

435439
From `Statistics`:
436440
- `mean`

Diff for: ext/DistributionsExt.jl

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module DistributionsExt
2+
3+
if isdefined(Base, :get_extension)
4+
using Distributions
5+
else
6+
using ..Distributions
7+
end
8+
9+
using Dagger, Random
10+
11+
Random.rand!(s::Sampleable, A::DArray{T}) where T = map!(_ -> rand(s), A)
12+
Random.rand!(rng::AbstractRNG, s::Sampleable{Univariate}, A::DArray{T}) where T = map!(_ -> rand(rng, s), A)
13+
Random.rand!(rng::AbstractRNG, s::Sampleable{ArrayLikeVariate{M}}, A::DArray{T}) where {M, T} = map!(_ -> rand(rng, s), A)
14+
15+
end # module DistributionsExt

Diff for: src/Dagger.jl

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remot
1313

1414
import LinearAlgebra
1515
import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric
16+
import Random
17+
import Random: AbstractRNG
1618

1719
import UUIDs: UUID, uuid4
1820

@@ -82,6 +84,7 @@ include("array/sort.jl")
8284
include("array/linalg.jl")
8385
include("array/mul.jl")
8486
include("array/cholesky.jl")
87+
include("array/random.jl")
8588

8689
# Visualization
8790
include("visualization.jl")
@@ -101,6 +104,9 @@ function __init__()
101104
system_uuid()
102105

103106
@static if !isdefined(Base, :get_extension)
107+
@require Distributions="31c24e10-a181-5473-b8eb-7969acd0382f" begin
108+
include(joinpath(dirname(@__DIR__), "ext", "DistributionsExt.jl"))
109+
end
104110
@require Graphs="86223c79-3864-5bf0-83f7-82e725a168b6" begin
105111
@require GraphViz="f526b714-d49f-11e8-06ff-31ed36ee7ee0" begin
106112
include(joinpath(dirname(@__DIR__), "ext", "GraphVizExt.jl"))

Diff for: src/array/cholesky.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{UpperTriangular}) where T
2222

2323
info = [convert(LinearAlgebra.BlasInt, 0)]
2424
try
25-
Dagger.spawn_datadeps(;aliasing=true) do
25+
Dagger.spawn_datadeps() do
2626
for k in range(1, mt)
2727
Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info))
2828
for n in range(k+1, nt)

Diff for: src/array/operators.jl

+21
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,24 @@ end
111111

112112
Base.first(A::DArray) = A[begin]
113113
Base.last(A::DArray) = A[end]
114+
115+
# In-place operations
116+
117+
function Base.map!(f, a::DArray{T}) where T
118+
Dagger.spawn_datadeps() do
119+
for ca in chunks(a)
120+
Dagger.@spawn map!(f, InOut(ca), ca)
121+
end
122+
end
123+
return a
124+
end
125+
126+
function Base.map!(f, a::DArray{T}, b::AbstractArray{U}) where {T, U}
127+
b2 = view(b, a.partitioning)
128+
Dagger.spawn_datadeps() do
129+
for (c_a, c_b2) in zip(chunks(a), chunks(b2))
130+
Dagger.@spawn map!(f, InOut(c_a), c_b2)
131+
end
132+
end
133+
return a
134+
end

Diff for: src/array/random.jl

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
Random.rand!(dm::DArray{T}) where T = map!(_ -> rand(T), dm)
2+
Random.randn!(dm::DArray{T}) where T = map!(_ -> randn(T), dm)
3+
4+
randfork(rng::R, n::Integer) where {R<:AbstractRNG} =
5+
R(abs(rand(copy(rng), Int) + n))
6+
7+
function Random.rand!(rng::AbstractRNG, A::DArray{T}) where T
8+
part_sz = prod(map(length, first(A.subdomains).indexes))
9+
Dagger.spawn_datadeps() do
10+
for Ac in chunks(A)
11+
rng = randfork(rng, part_sz)
12+
Dagger.@spawn map!(_->rand(rng, T), InOut(Ac), Ac)
13+
end
14+
end
15+
return A
16+
end
17+
function Random.randn!(rng::AbstractRNG, A::DArray{T}) where T
18+
part_sz = prod(map(length, first(A.subdomains).indexes))
19+
Dagger.spawn_datadeps() do
20+
for Ac in chunks(A)
21+
rng = randfork(rng, part_sz)
22+
Dagger.@spawn map!(_->randn(rng, T), InOut(Ac), Ac)
23+
end
24+
end
25+
return A
26+
end

Diff for: src/datadeps.jl

+30-17
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ struct DataDepsAliasingState
9898
ainfos_readers::Dict{AbstractAliasing,Vector{Pair{DTask,Int}}}
9999
ainfos_overlaps::Dict{AbstractAliasing,Set{AbstractAliasing}}
100100

101+
# Cache ainfo lookups
102+
ainfo_cache::Dict{Tuple{Any,Any},AbstractAliasing}
103+
101104
function DataDepsAliasingState()
102105
data_origin = Dict{AbstractAliasing,MemorySpace}()
103106
data_locality = Dict{AbstractAliasing,MemorySpace}()
@@ -106,8 +109,11 @@ struct DataDepsAliasingState
106109
ainfos_readers = Dict{AbstractAliasing,Vector{Pair{DTask,Int}}}()
107110
ainfos_overlaps = Dict{AbstractAliasing,Set{AbstractAliasing}}()
108111

112+
ainfo_cache = Dict{Tuple{Any,Any},AbstractAliasing}()
113+
109114
return new(data_origin, data_locality,
110-
ainfos_owner, ainfos_readers, ainfos_overlaps)
115+
ainfos_owner, ainfos_readers, ainfos_overlaps,
116+
ainfo_cache)
111117
end
112118
end
113119
struct DataDepsNonAliasingState
@@ -156,6 +162,12 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
156162
end
157163
end
158164

165+
function aliasing(astate::DataDepsAliasingState, arg, dep_mod)
166+
return get!(astate.ainfo_cache, (arg, dep_mod)) do
167+
return aliasing(arg, dep_mod)
168+
end
169+
end
170+
159171
# Determine which arguments could be written to, and thus need tracking
160172

161173
"Whether `arg` has any writedep in this datadeps region."
@@ -190,7 +202,7 @@ function has_writedep(state::DataDepsState, arg, deps, task::DTask)
190202
for (readdep, writedep, other_ainfo, _, _) in other_taskdeps
191203
writedep || continue
192204
for (dep_mod, _, _) in deps
193-
ainfo = aliasing(arg, dep_mod)
205+
ainfo = aliasing(state.alias_state, arg, dep_mod)
194206
if will_alias(ainfo, other_ainfo)
195207
return true
196208
end
@@ -221,7 +233,7 @@ function is_writedep(arg, deps, task::DTask)
221233
end
222234

223235
# Aliasing state setup
224-
function populate_task_info!(state::DataDepsState, spec, task)
236+
function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
225237
# Populate task dependencies
226238
dependencies_to_add = Vector{Tuple{Bool,Bool,AbstractAliasing,<:Any,<:Any}}()
227239

@@ -233,13 +245,13 @@ function populate_task_info!(state::DataDepsState, spec, task)
233245
# Unwrap the Chunk underlying any DTask arguments
234246
arg = arg isa DTask ? fetch(arg; raw=true) : arg
235247

236-
# Skip non-mutable arguments
237-
Base.datatype_pointerfree(typeof(arg)) && continue
248+
# Skip non-aliasing arguments
249+
type_may_alias(typeof(arg)) || continue
238250

239251
# Add all aliasing dependencies
240252
for (dep_mod, readdep, writedep) in deps
241253
if state.aliasing
242-
ainfo = aliasing(arg, dep_mod)
254+
ainfo = aliasing(state.alias_state, arg, dep_mod)
243255
else
244256
ainfo = UnknownAliasing()
245257
end
@@ -251,15 +263,16 @@ function populate_task_info!(state::DataDepsState, spec, task)
251263
end
252264

253265
# Track the task result too
254-
push!(dependencies_to_add, (true, true, UnknownAliasing(), identity, task))
266+
# N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this
267+
push!(dependencies_to_add, (false, false, UnknownAliasing(), identity, task))
255268

256269
# Record argument/result dependencies
257270
push!(state.dependencies, task => dependencies_to_add)
258271
end
259272
function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, arg, deps)
260273
astate = state.alias_state
261274
for (dep_mod, readdep, writedep) in deps
262-
ainfo = aliasing(arg, dep_mod)
275+
ainfo = aliasing(astate, arg, dep_mod)
263276

264277
# Initialize owner and readers
265278
if !haskey(astate.ainfos_owner, ainfo)
@@ -405,7 +418,7 @@ function generate_slot!(state::DataDepsState, dest_space, data)
405418
data_converted = move(from_proc, to_proc, data)
406419
data_chunk = tochunk(data_converted, to_proc)
407420
@assert processor(data_chunk) in processors(dest_space)
408-
@assert memory_space(data_chunk) == memory_space(data_converted)
421+
@assert memory_space(data_converted) == memory_space(data_chunk) "space mismatch! $(memory_space(data_converted)) != $(memory_space(data_chunk)) ($(typeof(data_converted)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)"
409422
@assert orig_space != memory_space(data_chunk) "space preserved! $orig_space != $(memory_space(data_chunk)) ($(typeof(data)) vs. $(typeof(data_chunk))), spaces ($orig_space -> $dest_space)"
410423
return data_chunk
411424
end
@@ -664,7 +677,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
664677
# Is the data written previously or now?
665678
arg, deps = unwrap_inout(arg)
666679
arg = arg isa DTask ? fetch(arg; raw=true) : arg
667-
if Base.datatype_pointerfree(typeof(arg)) || !has_writedep(state, arg, deps, task)
680+
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task)
668681
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)"
669682
spec.args[idx] = pos => arg
670683
continue
@@ -676,7 +689,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
676689
end
677690
if queue.aliasing
678691
for (dep_mod, _, _) in deps
679-
ainfo = aliasing(arg, dep_mod)
692+
ainfo = aliasing(astate, arg, dep_mod)
680693
data_space = astate.data_locality[ainfo]
681694
nonlocal = our_space != data_space
682695
if nonlocal
@@ -736,10 +749,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
736749
for (idx, (_, arg)) in enumerate(task_args)
737750
arg, deps = unwrap_inout(arg)
738751
arg = arg isa DTask ? fetch(arg; raw=true) : arg
739-
Base.datatype_pointerfree(typeof(arg)) && continue
752+
type_may_alias(typeof(arg)) || continue
740753
if queue.aliasing
741754
for (dep_mod, _, writedep) in deps
742-
ainfo = aliasing(arg, dep_mod)
755+
ainfo = aliasing(astate, arg, dep_mod)
743756
if writedep
744757
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Syncing as writer"
745758
get_write_deps!(state, ainfo, task, write_num, syncdeps)
@@ -769,10 +782,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
769782
for (idx, (_, arg)) in enumerate(task_args)
770783
arg, deps = unwrap_inout(arg)
771784
arg = arg isa DTask ? fetch(arg; raw=true) : arg
772-
Base.datatype_pointerfree(typeof(arg)) && continue
785+
type_may_alias(typeof(arg)) || continue
773786
if queue.aliasing
774787
for (dep_mod, _, writedep) in deps
775-
ainfo = aliasing(arg, dep_mod)
788+
ainfo = aliasing(astate, arg, dep_mod)
776789
if writedep
777790
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx][$dep_mod] Set as owner"
778791
add_writer!(state, ainfo, task, write_num)
@@ -864,7 +877,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
864877
for arg in keys(astate.data_origin)
865878
# Is the data previously written?
866879
arg, deps = unwrap_inout(arg)
867-
if Base.datatype_pointerfree(typeof(arg)) || !has_writedep(state, arg, deps)
880+
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps)
868881
@dagdebug nothing :spawn_datadeps "Skipped copy-from (unwritten)"
869882
end
870883

@@ -935,7 +948,7 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
935948
throw(ArgumentError("Dynamic scheduling is no longer available"))
936949
end
937950
wait_all(; check_errors=true) do
938-
scheduler = something(scheduler, DATADEPS_SCHEDULER[], :naive)::Symbol
951+
scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol
939952
launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool
940953
if launch_wait
941954
result = spawn_bulk() do

0 commit comments

Comments
 (0)