Skip to content

Commit b813bf6

Browse files
rework range partitioning to be independent of threadid
1 parent a0e17e9 commit b813bf6

File tree

2 files changed

+59
-71
lines changed

2 files changed

+59
-71
lines changed

base/threadingconstructs.jl

+50-66
Original file line numberDiff line numberDiff line change
@@ -22,83 +22,67 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
2222
"""
2323
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))
2424

25-
function threading_run(func, static)
26-
static && ccall(:jl_enter_threaded_region, Cvoid, ())
27-
n = nthreads()
28-
tasks = Vector{Task}(undef, n)
29-
for i = 1:n
30-
t = Task(func)
31-
t.sticky = static
32-
static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i-1)
33-
tasks[i] = t
34-
schedule(t)
35-
end
36-
try
37-
for i = 1:n
38-
wait(tasks[i])
39-
end
40-
finally
41-
static && ccall(:jl_exit_threaded_region, Cvoid, ())
42-
end
43-
end
44-
4525
function _threadsfor(iter, lbody, schedule)
4626
lidx = iter.args[1] # index
4727
range = iter.args[2]
4828
quote
49-
local threadsfor_fun
5029
let range = $(esc(range))
51-
function threadsfor_fun(onethread=false)
52-
r = range # Load into local variable
53-
lenr = length(r)
54-
# divide loop iterations among threads
55-
if onethread
56-
tid = 1
57-
len, rem = lenr, 0
30+
local range_len = length(range)
31+
range_len == 0 && return
32+
local ntasks = if $(schedule === :dynamic)
33+
nthreads()
34+
elseif threadid() != 1 || ccall(:jl_in_threaded_region, Cint, ()) != 0
35+
$(if schedule === :static
36+
:(error("`@threads :static` can only be used from thread 1 and not nested"))
37+
else
38+
# only use threads when called from thread 1, outside @threads :static
39+
1
40+
end)
5841
else
59-
tid = threadid()
60-
len, rem = divrem(lenr, nthreads())
42+
nthreads()
6143
end
62-
# not enough iterations for all the threads?
63-
if len == 0
64-
if tid > rem
65-
return
44+
local n_per_part = div(range_len, ntasks, RoundUp)
45+
if ntasks == 1
46+
# run locally as a simple loop
47+
for $(esc(lidx)) in range
48+
$(esc(lbody))
6649
end
67-
len, rem = 1, 0
68-
end
69-
# compute this thread's iterations
70-
f = firstindex(r) + ((tid-1) * len)
71-
l = f + len - 1
72-
# distribute remaining iterations evenly
73-
if rem > 0
74-
if tid <= rem
75-
f = f + (tid-1)
76-
l = l + tid
77-
else
78-
f = f + rem
79-
l = l + rem
50+
else
51+
local parts = Iterators.partition(range, n_per_part)
52+
local nparts = length(parts)
53+
$(if schedule !== :dynamic
54+
:(ccall(:jl_enter_threaded_region, Cvoid, ()))
55+
end)
56+
try
57+
local tasks = Vector{Task}(undef, nparts)
58+
local tid = 0
59+
for part in parts
60+
local t = @task begin
61+
local $(esc(lidx))
62+
for $(esc(lidx)) in part
63+
$(esc(lbody))
64+
end
65+
end
66+
$(if schedule === :dynamic
67+
:(t.sticky = false)
68+
else
69+
:(t.sticky = true)
70+
:(ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid))
71+
end)
72+
tasks[tid + 1] = t
73+
schedule(t)
74+
tid += 1
75+
end
76+
for i = 1:nparts
77+
wait(tasks[i])
78+
end
79+
finally
80+
$(if schedule !== :dynamic
81+
:(ccall(:jl_exit_threaded_region, Cvoid, ()))
82+
end)
8083
end
8184
end
82-
# run this thread's iterations
83-
for i = f:l
84-
local $(esc(lidx)) = @inbounds r[i]
85-
$(esc(lbody))
86-
end
87-
end
88-
end
89-
if $(schedule === :dynamic)
90-
threading_run(threadsfor_fun, false)
91-
elseif threadid() != 1 || ccall(:jl_in_threaded_region, Cint, ()) != 0
92-
$(if schedule === :static
93-
:(error("`@threads :static` can only be used from thread 1 and not nested"))
94-
else
95-
# only use threads when called from thread 1, outside @threads :static
96-
:(Base.invokelatest(threadsfor_fun, true))
97-
end)
98-
else
99-
threading_run(threadsfor_fun, true)
10085
end
101-
nothing
10286
end
10387
end
10488

test/threads_exec.jl

+9-5
Original file line numberDiff line numberDiff line change
@@ -733,15 +733,19 @@ end
733733
@test_throws TaskFailedException @threads for i = 1:1; _atthreads_static_schedule(); end
734734

735735
# dynamic schedule
736-
function _atthreads_dynamic_schedule()
736+
function _atthreads_dynamic_schedule(n)
737737
inc = Threads.Atomic{Int}(0)
738-
Threads.@threads :dynamic for _ = 1:nthreads()
738+
flags = falses(n)
739+
Threads.@threads :dynamic for i = 1:n
739740
Threads.atomic_add!(inc, 1)
741+
flags[i] = true
740742
end
741-
return inc[]
743+
return inc[], flags
742744
end
743-
inc = _atthreads_dynamic_schedule()
744-
@test inc == nthreads()
745+
@test _atthreads_dynamic_schedule(nthreads()) == (nthreads(), trues(nthreads()))
746+
@test _atthreads_dynamic_schedule(1) == (1, trues(1))
747+
@test _atthreads_dynamic_schedule(10) == (10, trues(10))
748+
@test _atthreads_dynamic_schedule(nthreads() * 2) == (nthreads() * 2, trues(nthreads() * 2))
745749

746750
# nested dynamic schedule
747751
function _atthreads_dynamic_dynamic_schedule()

0 commit comments

Comments
 (0)