Skip to content

Commit 3aa9a11

Browse files
committed
Add threading support
Adds `MPI.Init_thread` and the `ThreadLevel` enum, along with a threaded test. Additionally, set the UCX_ERROR_SIGNALS environment variable if not already set to fix #337.
1 parent ac4ed7a commit 3aa9a11

9 files changed

+136
-6
lines changed

deps/gen_consts.jl

+4
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ MPI_handle = [
8686
]
8787

8888
MPI_Cints = [
89+
:MPI_THREAD_SINGLE,
90+
:MPI_THREAD_FUNNELED,
91+
:MPI_THREAD_SERIALIZED,
92+
:MPI_THREAD_MULTIPLE,
8993
:MPI_PROC_NULL,
9094
:MPI_ANY_SOURCE,
9195
:MPI_ANY_TAG,

docs/src/environment.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ mpiexec
1111
```@docs
1212
MPI.Abort
1313
MPI.Init
14+
MPI.Init_thread
1415
MPI.Initialized
1516
MPI.Finalize
1617
MPI.Finalized

src/MPI.jl

+8
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ function __init__()
8686
ENV["UCX_MEM_MALLOC_RELOC"] = "no"
8787
ENV["UCX_MEM_EVENTS"] = "no"
8888

89+
# Julia multithreading uses SIGSEGV to sync thread
90+
# https://docs.julialang.org/en/v1/devdocs/debuggingtips/#Dealing-with-signals-1
91+
# By default, UCX will error if this occurs (issue #337)
92+
if !haskey(ENV, "UCX_ERROR_SIGNALS")
93+
# default is "SIGILL,SIGSEGV,SIGBUS,SIGFPE"
94+
ENV["UCX_ERROR_SIGNALS"] = "SIGILL,SIGBUS,SIGFPE"
95+
end
96+
8997
@require CuArrays="3a865a2d-5b23-5a0f-bc46-62713ec82fae" include("cuda.jl")
9098
end
9199

src/consts/microsoftmpi.jl

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# https://github.com/microsoft/Microsoft-MPI/blob/master/src/include/mpi.h
2+
13
const MPI_Aint = Int
24
const MPI_Offset = Int64
35
const MPI_Count = Int64
@@ -64,6 +66,10 @@ const MPI_C_DOUBLE_COMPLEX = reinterpret(Cint, 0x4c001014)
6466
const MPI_File = Cint
6567
const MPI_FILE_NULL = Cint(0)
6668

69+
const MPI_THREAD_SINGLE = Cint(0)
70+
const MPI_THREAD_FUNNELED = Cint(1)
71+
const MPI_THREAD_SERIALIZED = Cint(2)
72+
const MPI_THREAD_MULTIPLE = Cint(3)
6773
const MPI_PROC_NULL = Cint(-1)
6874
const MPI_ANY_SOURCE = Cint(-2)
6975
const MPI_ANY_TAG = Cint(-1)

src/consts/mpich.jl

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# https://github.com/pmodels/mpich/blob/master/src/include/mpi.h.in
2+
13
const MPI_Aint = Int
24
const MPI_Count = Int64
35
const MPI_Offset = Int64
@@ -69,6 +71,10 @@ const MPI_UINT64_T = Cint(1275070526)
6971
const MPI_C_FLOAT_COMPLEX = Cint(1275070528)
7072
const MPI_C_DOUBLE_COMPLEX = Cint(1275072577)
7173

74+
const MPI_THREAD_SINGLE = Cint(0)
75+
const MPI_THREAD_FUNNELED = Cint(1)
76+
const MPI_THREAD_SERIALIZED = Cint(2)
77+
const MPI_THREAD_MULTIPLE = Cint(3)
7278
const MPI_PROC_NULL = Cint(-1)
7379
const MPI_ANY_SOURCE = Cint(-2)
7480
const MPI_ANY_TAG = Cint(-1)

src/consts/openmpi.jl

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# https://github.com/open-mpi/ompi/blob/master/ompi/include/mpi.h.in
2+
13
const MPI_Aint = Int
24
const MPI_Count = Int64
35
const MPI_Offset = Int64
@@ -75,6 +77,10 @@ const MPI_UINT64_T = Cint(65)
7577
const MPI_C_FLOAT_COMPLEX = Cint(69)
7678
const MPI_C_DOUBLE_COMPLEX = Cint(70)
7779

80+
const MPI_THREAD_SINGLE = Cint(0)
81+
const MPI_THREAD_FUNNELED = Cint(1)
82+
const MPI_THREAD_SERIALIZED = Cint(2)
83+
const MPI_THREAD_MULTIPLE = Cint(3)
7884
const MPI_PROC_NULL = Cint(-2)
7985
const MPI_ANY_SOURCE = Cint(-1)
8086
const MPI_ANY_TAG = Cint(-1)

src/environment.jl

+62-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ end
3434
3535
Initialize MPI in the current process.
3636
37-
All MPI programs must contain exactly one call to `MPI.Init()`. In particular, note that it is not valid to call `MPI.Init` again after calling [`MPI.Finalize`](@ref).
37+
All MPI programs must contain exactly one call to `MPI.Init` or
38+
[`MPI.Init_thread`](@ref). In particular, note that it is not valid to call `MPI.Init` or
39+
`MPI.Init_thread` again after calling [`MPI.Finalize`](@ref).
3840
39-
The only MPI functions that may be called before `MPI.Init()` are
41+
The only MPI functions that may be called before `MPI.Init`/`MPI.Init_thread` are
4042
[`MPI.Initialized`](@ref) and [`MPI.Finalized`](@ref).
4143
4244
# External links
@@ -53,6 +55,64 @@ function Init()
5355
end
5456
end
5557

58+
@enum ThreadLevel begin
59+
THREAD_SINGLE = MPI_THREAD_SINGLE
60+
THREAD_FUNNELED = MPI_THREAD_FUNNELED
61+
THREAD_SERIALIZED = MPI_THREAD_SERIALIZED
62+
THREAD_MULTIPLE = MPI_THREAD_MULTIPLE
63+
end
64+
65+
66+
"""
67+
Init_thread(required::ThreadLevel)
68+
69+
Initialize MPI and the MPI thread environment in the current process. The argument
70+
specifies the required thread level, which is one of the following:
71+
72+
- `MPI.THREAD_SINGLE`: Only one thread will execute.
73+
- `MPI.THREAD_FUNNELED`: The process may be multi-threaded, but the application must ensure that only the main thread makes MPI calls.
74+
- `MPI.THREAD_SERIALIZED`: The process may be multi-threaded, and multiple threads may make MPI calls, but only one at a time (i.e. all MPI calls are serialized).
75+
- `MPI.THREAD_MULTIPLE`: Multiple threads may call MPI, with no restrictions.
76+
77+
Tne function will return the provided `ThreadLevel`, and values may be compared via inequalities, i.e.
78+
```julia
79+
if Init_thread(required) < required
80+
error("Insufficient threading")
81+
end
82+
```
83+
84+
All MPI programs must contain exactly one call to [`MPI.Init`](@ref) or
85+
`MPI.Init_thread`. In particular, note that it is not valid to call `MPI.Init` or
86+
`MPI.Init_thread` again after calling [`MPI.Finalize`](@ref).
87+
88+
The only MPI functions that may be called before `MPI.Init`/`MPI.Init_thread` are
89+
[`MPI.Initialized`](@ref) and [`MPI.Finalized`](@ref).
90+
91+
# External links
92+
$(_doc_external("MPI_Init_thread"))
93+
"""
94+
function Init_thread(required::ThreadLevel)
95+
REFCOUNT[] == -1 || error("MPI.REFCOUNT in incorrect state: MPI may only be initialized once per session.")
96+
r_provided = Ref{ThreadLevel}()
97+
# int MPI_Init_thread(int *argc, char ***argv, int required, int *provided)
98+
@mpichk ccall((:MPI_Init_thread, libmpi), Cint,
99+
(Ptr{Cint},Ptr{Cvoid}, ThreadLevel, Ref{ThreadLevel}),
100+
C_NULL, C_NULL, required, r_provided)
101+
provided = r_provided[]
102+
if provided < required
103+
@warn "Thread level requested = $required, provided = $provided"
104+
end
105+
106+
REFCOUNT[] = 1
107+
atexit(refcount_dec)
108+
109+
for f in mpi_init_hooks
110+
f()
111+
end
112+
return provided
113+
end
114+
115+
56116
"""
57117
Finalize()
58118

test/runtests.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ const coverage_opts =
1717
JL_LOG_USER => "user",
1818
JL_LOG_ALL => "all")
1919

20-
# Files to run with mpiexec -n 1
21-
singlefiles = ["test_spawn.jl"]
22-
2320
function runtests()
2421
nprocs = clamp(Sys.CPU_THREADS, 2, 4)
2522
exename = joinpath(Sys.BINDIR, Base.julia_exename())
@@ -33,8 +30,12 @@ function runtests()
3330
for f in testfiles
3431
coverage_opt = coverage_opts[Base.JLOptions().code_coverage]
3532
mpiexec() do cmd
36-
if f singlefiles
33+
if f == "test_spawn.jl"
3734
run(`$cmd -n 1 $exename --code-coverage=$coverage_opt $(joinpath(testdir, f))`)
35+
elseif f == "test_threads.jl"
36+
withenv("JULIA_NUM_THREAD" => "4") do
37+
run(`$cmd -n $nprocs $exename --code-coverage=$coverage_opt $(joinpath(testdir, f))`)
38+
end
3839
else
3940
run(`$cmd -n $nprocs $exename --code-coverage=$coverage_opt $(joinpath(testdir, f))`)
4041
end

test/test_threads.jl

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using Test, Pkg
2+
using MPI
3+
4+
if get(ENV,"JULIA_MPI_TEST_ARRAYTYPE","") == "CuArray"
5+
using CuArrays
6+
ArrayType = CuArray
7+
else
8+
ArrayType = Array
9+
end
10+
11+
provided = MPI.Init_thread(MPI.THREAD_MULTIPLE)
12+
13+
comm = MPI.COMM_WORLD
14+
size = MPI.Comm_size(comm)
15+
rank = MPI.Comm_rank(comm)
16+
17+
const N = 10
18+
19+
dst = mod(rank+1, size)
20+
src = mod(rank-1, size)
21+
22+
if provided == MPI.THREAD_MULTIPLE
23+
send_arr = collect(1.0:N)
24+
recv_arr = zeros(N)
25+
26+
reqs = Array{MPI.Request}(undef, 2N)
27+
28+
Threads.@threads for i = 1:N
29+
reqs[N+i] = MPI.Irecv!(@view(recv_arr[i:i]), src, i, comm)
30+
reqs[i] = MPI.Isend(@view(send_arr[i:i]), dst, i, comm)
31+
end
32+
33+
MPI.Waitall!(reqs)
34+
35+
@test recv_arr == send_arr
36+
end
37+
38+
MPI.Finalize()

0 commit comments

Comments
 (0)