Skip to content

Commit 03255f8

Browse files
add :dynamic scheduling option for Threads.@threads
1 parent 2682819 commit 03255f8

File tree

2 files changed

+88
-13
lines changed

2 files changed

+88
-13
lines changed

base/threadingconstructs.jl

+48-13
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
2222
"""
2323
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))
2424

25-
function threading_run(func)
26-
ccall(:jl_enter_threaded_region, Cvoid, ())
25+
function threading_run(func, static)
26+
static && ccall(:jl_enter_threaded_region, Cvoid, ())
2727
n = nthreads()
2828
tasks = Vector{Task}(undef, n)
2929
for i = 1:n
3030
t = Task(func)
31-
t.sticky = true
32-
ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i-1)
31+
t.sticky = static
32+
static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i-1)
3333
tasks[i] = t
3434
schedule(t)
3535
end
@@ -38,7 +38,7 @@ function threading_run(func)
3838
wait(tasks[i])
3939
end
4040
finally
41-
ccall(:jl_exit_threaded_region, Cvoid, ())
41+
static && ccall(:jl_exit_threaded_region, Cvoid, ())
4242
end
4343
end
4444

@@ -86,15 +86,17 @@ function _threadsfor(iter, lbody, schedule)
8686
end
8787
end
8888
end
89-
if threadid() != 1 || ccall(:jl_in_threaded_region, Cint, ()) != 0
89+
if $(schedule === :dynamic)
90+
threading_run(threadsfor_fun, false)
91+
elseif threadid() != 1 || ccall(:jl_in_threaded_region, Cint, ()) != 0
9092
$(if schedule === :static
9193
:(error("`@threads :static` can only be used from thread 1 and not nested"))
9294
else
93-
# only use threads when called from thread 1, outside @threads
95+
# only use threads when called from thread 1, outside @threads :static
9496
:(Base.invokelatest(threadsfor_fun, true))
95-
end)
97+
end)
9698
else
97-
threading_run(threadsfor_fun)
99+
threading_run(threadsfor_fun, true)
98100
end
99101
nothing
100102
end
@@ -110,15 +112,48 @@ A barrier is placed at the end of the loop which waits for all tasks to finish
110112
execution.
111113
112114
The `schedule` argument can be used to request a particular scheduling policy.
113-
The only currently supported value is `:static`, which creates one task per thread
114-
and divides the iterations equally among them. Specifying `:static` is an error
115-
if used from inside another `@threads` loop or from a thread other than 1.
115+
Options are:
116+
- `:static` is the default schedule which creates one task per thread and divides the
117+
iterations equally among them, assigning each task specifically to each thread.
118+
Specifying `:static` is an error if used from inside another `@threads` loop
119+
or from a thread other than 1.
120+
- `:dynamic` is like `:static` except the tasks are assigned to threads dynamically,
121+
allowing more flexible scheduling if other tasks are active on other threads.
122+
Specifying `:dynamic` is allowed from inside another `@threads` loop and from
123+
threads other than 1.
116124
117125
The default schedule (used when no `schedule` argument is present) is subject to change.
118126
127+
For example, here an illustration of the different scheduling strategies, where `busywait`
128+
is a non-yielding timed loop that runs for a number of seconds.
129+
130+
```julia-repl
131+
julia> @time begin
132+
Threads.@spawn busywait(5)
133+
Threads.@threads :static for i in 1:Threads.nthreads()
134+
busywait(1)
135+
end
136+
end
137+
6.003001 seconds (16.33 k allocations: 899.255 KiB, 0.25% compilation time)
138+
139+
julia> @time begin
140+
Threads.@spawn busywait(5)
141+
Threads.@threads :dynamic for i in 1:Threads.nthreads()
142+
busywait(1)
143+
end
144+
end
145+
2.012056 seconds (16.05 k allocations: 883.919 KiB, 0.66% compilation time)
146+
```
147+
148+
The `:dynamic` example takes 2 seconds because one of the non-occupied threads is able
149+
to run two of the 1-second iterations to complete the for loop.
150+
119151
!!! compat "Julia 1.5"
120152
The `schedule` argument is available as of Julia 1.5.
121153
154+
!!! compat "Julia 1.8"
155+
The `:dynamic` option for the `schedule` argument is available as of Julia 1.8.
156+
122157
See also: [`@spawn`](@ref Threads.@spawn), [`nthreads()`](@ref Threads.nthreads),
123158
[`threadid()`](@ref Threads.threadid), `pmap` in [`Distributed`](@ref man-distributed),
124159
`BLAS.set_num_threads` in [`LinearAlgebra`](@ref man-linalg).
@@ -133,7 +168,7 @@ macro threads(args...)
133168
# for now only allow quoted symbols
134169
sched = nothing
135170
end
136-
if sched !== :static
171+
if sched !== :static && sched !== :dynamic
137172
throw(ArgumentError("unsupported schedule argument in @threads"))
138173
end
139174
elseif na == 1

test/threads_exec.jl

+40
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,46 @@ end
732732
@test _atthreads_static_schedule() == [1:nthreads();]
733733
@test_throws TaskFailedException @threads for i = 1:1; _atthreads_static_schedule(); end
734734

735+
# dynamic schedule
736+
function _atthreads_dynamic_schedule()
737+
inc = Threads.Atomic{Int}(0)
738+
Threads.@threads :dynamic for _ = 1:nthreads()
739+
Threads.atomic_add!(inc, 1)
740+
end
741+
return inc
742+
end
743+
inc = _atthreads_dynamic_schedule()
744+
@test inc[] == nthreads()
745+
746+
# nested dynamic schedule
747+
function _atthreads_dynamic_dynamic_schedule()
748+
inc = Threads.Atomic{Int}(0)
749+
Threads.@threads :dynamic for _ = 1:nthreads()
750+
Threads.@threads :dynamic for _ = 1:nthreads()
751+
Threads.atomic_add!(inc, 1)
752+
end
753+
end
754+
755+
return inc
756+
end
757+
inc = _atthreads_dynamic_dynamic_schedule()
758+
@test inc[] == nthreads() * nthreads()
759+
760+
function _atthreads_static_dynamic_schedule()
761+
ids = zeros(Int, nthreads())
762+
inc = Threads.Atomic{Int}(0)
763+
Threads.@threads :static for i = 1:nthreads()
764+
ids[i] = Threads.threadid()
765+
Threads.@threads :dynamic for _ = 1:nthreads()
766+
Threads.atomic_add!(inc, 1)
767+
end
768+
end
769+
return ids, inc
770+
end
771+
ids, inc = _atthreads_static_dynamic_schedule()
772+
@test ids == [1:nthreads();]
773+
@test inc[] == nthreads() * nthreads()
774+
735775
try
736776
@macroexpand @threads(for i = 1:10, j = 1:10; end)
737777
catch ex

0 commit comments

Comments
 (0)