Skip to content

Commit 3cbd0c7

Browse files
committed
Fast-track @threads when nthreads() == 1
This avoids overhead when threading is disabled, Example benchmark: ``` using BenchmarkTools, Base.Threads function func(val) local sum = 0*(1 .^ val) for idx in 1:100 sum += idx.^val end return sum end function func_threaded(val) local sum = 0*(1 .^ val) @threads for idx in 1:100 sum += idx.^val end return sum end @show @benchmark func(2.0) @show @benchmark func_threaded(2.0) ``` Before change: ``` @benchmark(func(2.0)) = Trial(81.890 ns) @benchmark(func_threaded(2.0)) = Trial(7.999 μs) ``` After change: ``` @benchmark(func(2.0)) = Trial(81.234 ns) @benchmark(func_threaded(2.0)) = Trial(3.818 μs) ``` The rest of the overhead is attributed to the `Box` introduced by #15276
1 parent da5f986 commit 3cbd0c7

File tree

1 file changed

+51
-44
lines changed

1 file changed

+51
-44
lines changed

base/threadingconstructs.jl

+51-44
Original file line numberDiff line numberDiff line change
@@ -25,55 +25,62 @@ function _threadsfor(iter,lbody)
2525
lidx = iter.args[1] # index
2626
range = iter.args[2]
2727
quote
28-
local threadsfor_fun
29-
let range = $(esc(range))
30-
function threadsfor_fun(onethread=false)
31-
r = range # Load into local variable
32-
lenr = length(r)
33-
# divide loop iterations among threads
34-
if onethread
35-
tid = 1
36-
len, rem = lenr, 0
37-
else
38-
tid = threadid()
39-
len, rem = divrem(lenr, nthreads())
40-
end
41-
# not enough iterations for all the threads?
42-
if len == 0
43-
if tid > rem
44-
return
45-
end
46-
len, rem = 1, 0
28+
# Fast-track serial execution for case of nthreads == 1
29+
if nthreads() == 1
30+
for $(esc(lidx)) in $(esc(range))
31+
$(esc(lbody))
4732
end
48-
# compute this thread's iterations
49-
f = 1 + ((tid-1) * len)
50-
l = f + len - 1
51-
# distribute remaining iterations evenly
52-
if rem > 0
53-
if tid <= rem
54-
f = f + (tid-1)
55-
l = l + tid
33+
else
34+
local threadsfor_fun
35+
let range = $(esc(range))
36+
function threadsfor_fun(onethread=false)
37+
r = range # Load into local variable
38+
lenr = length(r)
39+
# divide loop iterations among threads
40+
if onethread
41+
tid = 1
42+
len, rem = lenr, 0
5643
else
57-
f = f + rem
58-
l = l + rem
44+
tid = threadid()
45+
len, rem = divrem(lenr, nthreads())
46+
end
47+
# not enough iterations for all the threads?
48+
if len == 0
49+
if tid > rem
50+
return
51+
end
52+
len, rem = 1, 0
53+
end
54+
# compute this thread's iterations
55+
f = 1 + ((tid-1) * len)
56+
l = f + len - 1
57+
# distribute remaining iterations evenly
58+
if rem > 0
59+
if tid <= rem
60+
f = f + (tid-1)
61+
l = l + tid
62+
else
63+
f = f + rem
64+
l = l + rem
65+
end
66+
end
67+
# run this thread's iterations
68+
for i = f:l
69+
local $(esc(lidx)) = Base.unsafe_getindex(r,i)
70+
$(esc(lbody))
5971
end
6072
end
61-
# run this thread's iterations
62-
for i = f:l
63-
local $(esc(lidx)) = Base.unsafe_getindex(r,i)
64-
$(esc(lbody))
6573
end
66-
end
67-
end
68-
# Hack to make nested threaded loops kinda work
69-
if threadid() != 1 || in_threaded_loop[]
70-
# We are in a nested threaded loop
71-
Base.invokelatest(threadsfor_fun, true)
72-
else
73-
in_threaded_loop[] = true
74-
# the ccall is not expected to throw
75-
ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun)
76-
in_threaded_loop[] = false
74+
# Hack to make nested threaded loops kinda work
75+
if threadid() != 1 || in_threaded_loop[]
76+
# We are in a nested threaded loop
77+
Base.invokelatest(threadsfor_fun, true)
78+
else
79+
in_threaded_loop[] = true
80+
# the ccall is not expected to throw
81+
ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun)
82+
in_threaded_loop[] = false
83+
end
7784
end
7885
nothing
7986
end

0 commit comments

Comments
 (0)