Skip to content

Commit 3c80e3a

Browse files
committed
tests: Split array testsets
1 parent 912aab8 commit 3c80e3a

File tree

5 files changed

+65
-67
lines changed

5 files changed

+65
-67
lines changed

Diff for: test/array.jl renamed to test/array/core.jl

+11-63
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
using LinearAlgebra, SparseArrays, Random, SharedArrays
2-
import Dagger: DArray, chunks, domainchunks, treereduce_nd
3-
import Distributed: myid, procs
4-
import Statistics: mean, var, std
5-
import OnlineStats
6-
71
@testset "treereduce_nd" begin
82
xs = rand(1:10, 8,8,8)
93
concats = [(x...)->cat(x..., dims=n) for n in 1:3]
@@ -80,52 +74,6 @@ end
8074
end
8175
end
8276

83-
function test_mapreduce(f, init_func; no_init=true, zero_init=zero,
84-
types=(Int32, Int64, Float32, Float64),
85-
cmp=isapprox)
86-
@testset "$T" for T in types
87-
X = init_func(Blocks(10, 10), T, 100, 100)
88-
inits = ()
89-
if no_init
90-
inits = (inits..., nothing)
91-
end
92-
if zero_init !== nothing
93-
inits = (inits..., zero_init(T))
94-
end
95-
@testset "dims=$dims" for dims in (Colon(), 1, 2, (1,), (2,))
96-
@testset "init=$init" for init in inits
97-
if init === nothing
98-
if dims == Colon()
99-
@test cmp(f(X; dims), f(collect(X); dims))
100-
else
101-
@test cmp(collect(f(X; dims)), f(collect(X); dims))
102-
end
103-
else
104-
if dims == Colon()
105-
@test cmp(f(X; dims, init), f(collect(X); dims, init))
106-
else
107-
@test cmp(collect(f(X; dims, init)), f(collect(X); dims, init))
108-
end
109-
end
110-
end
111-
end
112-
end
113-
end
114-
115-
# Base
116-
@testset "reduce" test_mapreduce((X; dims, init=Base._InitialValue())->reduce(+, X; dims, init), ones)
117-
@testset "mapreduce" test_mapreduce((X; dims, init=Base._InitialValue())->mapreduce(x->x+1, +, X; dims, init), ones)
118-
@testset "sum" test_mapreduce(sum, ones)
119-
@testset "prod" test_mapreduce(prod, rand)
120-
@testset "minimum" test_mapreduce(minimum, rand)
121-
@testset "maximum" test_mapreduce(maximum, rand)
122-
@testset "extrema" test_mapreduce(extrema, rand; cmp=Base.:(==), zero_init=T->(zero(T), zero(T)))
123-
124-
# Statistics
125-
@testset "mean" test_mapreduce(mean, rand; zero_init=nothing, types=(Float32, Float64))
126-
@testset "var" test_mapreduce(var, rand; zero_init=nothing, types=(Float32, Float64))
127-
@testset "std" test_mapreduce(std, rand; zero_init=nothing, types=(Float32, Float64))
128-
12977
@testset "broadcast" begin
13078
X1 = rand(Blocks(10), 100)
13179
X2 = X1 .* 3.4
@@ -138,7 +86,7 @@ end
13886

13987
@testset "distributing an array" begin
14088
function test_dist(X)
141-
X1 = distribute(X, Blocks(10, 20))
89+
X1 = Distribute(Blocks(10, 20), X)
14290
Xc = fetch(X1)
14391
@test Xc isa DArray{eltype(X),ndims(X)}
14492
@test Xc == X
@@ -147,7 +95,7 @@ end
14795
@test map(x->size(x) == (10, 20), domainchunks(Xc)) |> all
14896
end
14997
x = [1 2; 3 4]
150-
@test distribute(x, Blocks(1,1)) == x
98+
@test Distribute(Blocks(1,1), x) == x
15199
test_dist(rand(100, 100))
152100
test_dist(sprand(100, 100, 0.1))
153101

@@ -174,7 +122,7 @@ end
174122
@testset "matrix-matrix multiply" begin
175123
function test_mul(X)
176124
tol = 1e-12
177-
X1 = distribute(X, Blocks(10, 20))
125+
X1 = Distribute(Blocks(10, 20), X)
178126
@test_throws DimensionMismatch X1*X1
179127
X2 = X1'*X1
180128
X3 = X1*X1'
@@ -188,7 +136,7 @@ end
188136
test_mul(rand(40, 40))
189137

190138
x = rand(10,10)
191-
X = distribute(x, Blocks(3,3))
139+
X = Distribute(Blocks(3,3), x)
192140
y = rand(10)
193141
@test norm(collect(X*y) - x*y) < 1e-13
194142
end
@@ -202,24 +150,24 @@ end
202150

203151
@testset "concat" begin
204152
m = rand(75,75)
205-
x = distribute(m, Blocks(10,20))
206-
y = distribute(m, Blocks(10,10))
153+
x = Distribute(Blocks(10,20), m)
154+
y = Distribute(Blocks(10,10), m)
207155
@test hcat(m,m) == collect(hcat(x,x)) == collect(hcat(x,y))
208156
@test vcat(m,m) == collect(vcat(x,x))
209157
@test_throws DimensionMismatch vcat(x,y)
210158
end
211159

212160
@testset "scale" begin
213161
x = rand(10,10)
214-
X = distribute(x, Blocks(3,3))
162+
X = Distribute(Blocks(3,3), x)
215163
y = rand(10)
216164

217165
@test Diagonal(y)*x == collect(Diagonal(y)*X)
218166
end
219167

220168
@testset "Getindex" begin
221169
function test_getindex(x)
222-
X = distribute(x, Blocks(3,3))
170+
X = Distribute(Blocks(3,3), x)
223171
@test collect(X[3:8, 2:7]) == x[3:8, 2:7]
224172
ragged_idx = [1,2,9,7,6,2,4,5]
225173
@test collect(X[ragged_idx, 2:7]) == x[ragged_idx, 2:7]
@@ -248,7 +196,7 @@ end
248196

249197

250198
@testset "cleanup" begin
251-
X = distribute(rand(10,10), Blocks(10,10))
199+
X = Distribute(Blocks(10,10), rand(10,10))
252200
@test collect(sin.(X)) == collect(sin.(X))
253201
end
254202

@@ -269,7 +217,7 @@ end
269217
x=rand(10,10)
270218
y=copy(x)
271219
y[3:8, 2:7] .= 1.0
272-
X = distribute(x, Blocks(3,3))
220+
X = Distribute(Blocks(3,3), x)
273221
@test collect(setindex(X,1.0, 3:8, 2:7)) == y
274222
@test collect(X) == x
275223
end
@@ -292,7 +240,7 @@ end
292240
@test collect(sort(y)) == x
293241

294242
x = ones(10)
295-
y = distribute(x, Blocks(3))
243+
y = Distribute(Blocks(3), x)
296244
@test_broken map(x->length(collect(x)), sort(y).chunks) == [3,3,3,1]
297245
end
298246

Diff for: test/linalg.jl renamed to test/array/linalg.jl

-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using LinearAlgebra
2-
31
@testset "Linear Algebra" begin
42
@testset "GEMM: $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
53
A = rand(T, 128, 128)

Diff for: test/array/mapreduce.jl

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
function test_mapreduce(f, init_func; no_init=true, zero_init=zero,
2+
types=(Int32, Int64, Float32, Float64),
3+
cmp=isapprox)
4+
@testset "$T" for T in types
5+
X = init_func(Blocks(10, 10), T, 100, 100)
6+
inits = ()
7+
if no_init
8+
inits = (inits..., nothing)
9+
end
10+
if zero_init !== nothing
11+
inits = (inits..., zero_init(T))
12+
end
13+
@testset "dims=$dims" for dims in (Colon(), 1, 2, (1,), (2,))
14+
@testset "init=$init" for init in inits
15+
if init === nothing
16+
if dims == Colon()
17+
@test cmp(f(X; dims), f(collect(X); dims))
18+
else
19+
@test cmp(collect(f(X; dims)), f(collect(X); dims))
20+
end
21+
else
22+
if dims == Colon()
23+
@test cmp(f(X; dims, init), f(collect(X); dims, init))
24+
else
25+
@test cmp(collect(f(X; dims, init)), f(collect(X); dims, init))
26+
end
27+
end
28+
end
29+
end
30+
end
31+
end
32+
33+
# Base
34+
@testset "reduce" test_mapreduce((X; dims, init=Base._InitialValue())->reduce(+, X; dims, init), ones)
35+
@testset "mapreduce" test_mapreduce((X; dims, init=Base._InitialValue())->mapreduce(x->x+1, +, X; dims, init), ones)
36+
@testset "sum" test_mapreduce(sum, ones)
37+
@testset "prod" test_mapreduce(prod, rand)
38+
@testset "minimum" test_mapreduce(minimum, rand)
39+
@testset "maximum" test_mapreduce(maximum, rand)
40+
@testset "extrema" test_mapreduce(extrema, rand; cmp=Base.:(==), zero_init=T->(zero(T), zero(T)))
41+
42+
# Statistics
43+
@testset "mean" test_mapreduce(mean, rand; zero_init=nothing, types=(Float32, Float64))
44+
@testset "var" test_mapreduce(var, rand; zero_init=nothing, types=(Float32, Float64))
45+
@testset "std" test_mapreduce(std, rand; zero_init=nothing, types=(Float32, Float64))

Diff for: test/imports.jl

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
using LinearAlgebra, SparseArrays, Random, SharedArrays
2+
import Dagger: DArray, chunks, domainchunks, treereduce_nd
3+
import Distributed: myid, procs
4+
import Statistics: mean, var, std
5+
import OnlineStats

Diff for: test/runtests.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ tests = [
1111
("Task Queues", "task-queues.jl"),
1212
("Datadeps", "datadeps.jl"),
1313
("Domain Utilities", "domain.jl"),
14-
("Array", "array.jl"),
15-
("Linear Algebra", "linalg.jl"),
14+
("Array - Core", "array/core.jl"),
15+
("Array - MapReduce", "array/mapreduce.jl"),
16+
("Array - LinearAlgebra", "array/linalg.jl"),
1617
("Caching", "cache.jl"),
1718
("Disk Caching", "diskcaching.jl"),
1819
("File IO", "file-io.jl"),
@@ -70,6 +71,7 @@ end
7071
using Distributed
7172
addprocs(3)
7273

74+
include("imports.jl")
7375
include("util.jl")
7476
include("fakeproc.jl")
7577

0 commit comments

Comments
 (0)