Skip to content

Commit 6176dc7

Browse files
pszufejpsamaroo
authored andcommitted
Add rand! implementation
1 parent 95febd2 commit 6176dc7

File tree

10 files changed

+119
-1
lines changed

10 files changed

+119
-1
lines changed

Project.toml

+4
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2828
[weakdeps]
2929
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
3030
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
31+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
3132
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
3233
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
3334
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
3435

3536
[extensions]
37+
DistributionsExt = "Distributions"
3638
GraphVizExt = "GraphViz"
3739
GraphVizSimpleExt = "Colors"
3840
JSON3Ext = "JSON3"
@@ -43,6 +45,7 @@ Adapt = "4.0.4"
4345
Colors = "0.12"
4446
DataFrames = "1"
4547
DataStructures = "0.18"
48+
Distributions = "0.25"
4649
GraphViz = "0.2"
4750
Graphs = "1"
4851
MacroTools = "0.5"
@@ -61,6 +64,7 @@ julia = "1.8"
6164
[extras]
6265
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
6366
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
67+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
6468
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
6569
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
6670
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

docs/src/darray.md

+4
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,10 @@ From `Base`:
431431
- `map`/`reduce`/`mapreduce`
432432
- `sum`/`prod`
433433
- `minimum`/`maximum`/`extrema`
434+
- `map!`
435+
436+
From `Random`:
437+
- `rand!`/`randn!`
434438

435439
From `Statistics`:
436440
- `mean`

ext/DistributionsExt.jl

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
module DistributionsExt
2+
3+
if isdefined(Base, :get_extension)
4+
using Distributions
5+
else
6+
using ..Distributions
7+
end
8+
9+
using Dagger, Random
10+
11+
Random.rand!(s::Sampleable, A::DArray{T}) where T = map!(_ -> rand(s), A)
12+
Random.rand!(rng::AbstractRNG, s::Sampleable{Univariate}, A::DArray{T}) where T = map!(_ -> rand(rng, s), A)
13+
Random.rand!(rng::AbstractRNG, s::Sampleable{ArrayLikeVariate{M}}, A::DArray{T}) where {M, T} = map!(_ -> rand(rng, s), A)
14+
15+
end # module DistributionsExt

src/Dagger.jl

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import Distributed: Future, RemoteChannel, myid, workers, nworkers, procs, remot
1313

1414
import LinearAlgebra
1515
import LinearAlgebra: Adjoint, BLAS, Diagonal, Bidiagonal, Tridiagonal, LAPACK, LowerTriangular, PosDefException, Transpose, UpperTriangular, UnitLowerTriangular, UnitUpperTriangular, diagind, ishermitian, issymmetric
16+
import Random
17+
import Random: AbstractRNG
1618

1719
import UUIDs: UUID, uuid4
1820

@@ -82,6 +84,7 @@ include("array/sort.jl")
8284
include("array/linalg.jl")
8385
include("array/mul.jl")
8486
include("array/cholesky.jl")
87+
include("array/random.jl")
8588

8689
# Visualization
8790
include("visualization.jl")
@@ -101,6 +104,9 @@ function __init__()
101104
system_uuid()
102105

103106
@static if !isdefined(Base, :get_extension)
107+
@require Distributions="31c24e10-a181-5473-b8eb-7969acd0382f" begin
108+
include(joinpath(dirname(@__DIR__), "ext", "DistributionsExt.jl"))
109+
end
104110
@require Graphs="86223c79-3864-5bf0-83f7-82e725a168b6" begin
105111
@require GraphViz="f526b714-d49f-11e8-06ff-31ed36ee7ee0" begin
106112
include(joinpath(dirname(@__DIR__), "ext", "GraphVizExt.jl"))

src/array/operators.jl

+21
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,24 @@ end
111111

112112
Base.first(A::DArray) = A[begin]
113113
Base.last(A::DArray) = A[end]
114+
115+
# In-place operations
116+
117+
function Base.map!(f, a::DArray{T}) where T
118+
Dagger.spawn_datadeps() do
119+
for ca in chunks(a)
120+
Dagger.@spawn map!(f, InOut(ca), ca)
121+
end
122+
end
123+
return a
124+
end
125+
126+
function Base.map!(f, a::DArray{T}, b::AbstractArray{U}) where {T, U}
127+
b2 = view(b, a.partitioning)
128+
Dagger.spawn_datadeps() do
129+
for (c_a, c_b2) in zip(chunks(a), chunks(b2))
130+
Dagger.@spawn map!(f, InOut(c_a), c_b2)
131+
end
132+
end
133+
return a
134+
end

src/array/random.jl

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
Random.rand!(dm::DArray{T}) where T = map!(_ -> rand(T), dm)
2+
Random.randn!(dm::DArray{T}) where T = map!(_ -> randn(T), dm)
3+
4+
randfork(rng::R, n::Integer) where {R<:AbstractRNG} =
5+
R(abs(rand(copy(rng), Int) + n))
6+
7+
function Random.rand!(rng::AbstractRNG, A::DArray{T}) where T
8+
part_sz = prod(map(length, first(A.subdomains).indexes))
9+
Dagger.spawn_datadeps() do
10+
for Ac in chunks(A)
11+
rng = randfork(rng, part_sz)
12+
Dagger.@spawn map!(_->rand(rng, T), InOut(Ac), Ac)
13+
end
14+
end
15+
return A
16+
end
17+
function Random.randn!(rng::AbstractRNG, A::DArray{T}) where T
18+
part_sz = prod(map(length, first(A.subdomains).indexes))
19+
Dagger.spawn_datadeps() do
20+
for Ac in chunks(A)
21+
rng = randfork(rng, part_sz)
22+
Dagger.@spawn map!(_->randn(rng, T), InOut(Ac), Ac)
23+
end
24+
end
25+
return A
26+
end

test/Project.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@ ArgParse = "c7e460c6-2fb9-53a9-8c5b-16f535851c63"
33
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
44
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
55
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
6-
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
6+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
77
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
8+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
89
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
910
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011
MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94"

test/array/core.jl

+10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@ end
1212
@test collect(X1) .+ 1 == collect(X2)
1313
end
1414

15+
@testset "map!" begin
16+
X = zeros(Blocks(10,10), Int, 40, 40)
17+
18+
map!(_ -> 1, X)
19+
@test sum(X) == length(X)
20+
21+
map!(x -> 2x+1, X, Matrix(X))
22+
@test sum(X) == 3*length(X)
23+
end
24+
1525
@testset "DiffEq support" begin
1626
X = ones(Blocks(10), 100)
1727
X0 = zero(X)

test/array/random.jl

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
using Distributions, Statistics
2+
3+
@testset "rand!" begin
4+
X = zeros(Blocks(10,10), Float64, 40, 40)
5+
6+
Random.rand!(X)
7+
@test all(0.0 .<= X .<= 1.0)
8+
@test sum(X) > 0.01
9+
10+
Random.rand!(Uniform(1, 10), X)
11+
@test all(1.0 .<= X .<= 10.0)
12+
@test sum(X) > length(X) + 0.1
13+
14+
for RNGT in (MersenneTwister, Xoshiro)
15+
Random.rand!(RNGT(1234), X)
16+
@test all(0.0 .<= X .<= 1.0)
17+
@test sum(X) > 0.01
18+
19+
Random.rand!(RNGT(1234), Uniform(1,10), X)
20+
@test all(1.0 .<= X .<= 10.0)
21+
@test sum(X) > length(X) + 0.1
22+
23+
Random.randn!(X)
24+
@test sum(X) <= 2*length(X)
25+
26+
Random.randn!(RNGT(1234), X)
27+
@test sum(X) <= 2*length(X)
28+
@test all(X .!= 0.0)
29+
end
30+
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ tests = [
1818
("Array - MapReduce", "array/mapreduce.jl"),
1919
("Array - LinearAlgebra - Matmul", "array/linalg/matmul.jl"),
2020
("Array - LinearAlgebra - Cholesky", "array/linalg/cholesky.jl"),
21+
("Array - Random", "array/random.jl"),
2122
("Caching", "cache.jl"),
2223
("Disk Caching", "diskcaching.jl"),
2324
("File IO", "file-io.jl"),

0 commit comments

Comments
 (0)