@@ -22,83 +22,67 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
22
22
"""
23
23
nthreads () = Int (unsafe_load (cglobal (:jl_n_threads , Cint)))
24
24
25
- function threading_run (func, static)
26
- static && ccall (:jl_enter_threaded_region , Cvoid, ())
27
- n = nthreads ()
28
- tasks = Vector {Task} (undef, n)
29
- for i = 1 : n
30
- t = Task (func)
31
- t. sticky = static
32
- static && ccall (:jl_set_task_tid , Cint, (Any, Cint), t, i- 1 )
33
- tasks[i] = t
34
- schedule (t)
35
- end
36
- try
37
- for i = 1 : n
38
- wait (tasks[i])
39
- end
40
- finally
41
- static && ccall (:jl_exit_threaded_region , Cvoid, ())
42
- end
43
- end
44
-
45
25
function _threadsfor (iter, lbody, schedule)
46
26
lidx = iter. args[1 ] # index
47
27
range = iter. args[2 ]
48
28
quote
49
- local threadsfor_fun
50
29
let range = $ (esc (range))
51
- function threadsfor_fun (onethread= false )
52
- r = range # Load into local variable
53
- lenr = length (r)
54
- # divide loop iterations among threads
55
- if onethread
56
- tid = 1
57
- len, rem = lenr, 0
30
+ local range_len = length (range)
31
+ range_len == 0 && return
32
+ local ntasks = if $ (schedule === :dynamic )
33
+ nthreads ()
34
+ elseif threadid () != 1 || ccall (:jl_in_threaded_region , Cint, ()) != 0
35
+ $ (if schedule === :static
36
+ :(error (" `@threads :static` can only be used from thread 1 and not nested" ))
37
+ else
38
+ # only use threads when called from thread 1, outside @threads :static
39
+ 1
40
+ end )
58
41
else
59
- tid = threadid ()
60
- len, rem = divrem (lenr, nthreads ())
42
+ nthreads ()
61
43
end
62
- # not enough iterations for all the threads?
63
- if len == 0
64
- if tid > rem
65
- return
44
+ local n_per_part = div (range_len, ntasks, RoundUp)
45
+ if ntasks == 1
46
+ # run locally as a simple loop
47
+ for $ (esc (lidx)) in range
48
+ $ (esc (lbody))
66
49
end
67
- len, rem = 1 , 0
68
- end
69
- # compute this thread's iterations
70
- f = firstindex (r) + ((tid- 1 ) * len)
71
- l = f + len - 1
72
- # distribute remaining iterations evenly
73
- if rem > 0
74
- if tid <= rem
75
- f = f + (tid- 1 )
76
- l = l + tid
77
- else
78
- f = f + rem
79
- l = l + rem
50
+ else
51
+ local parts = Iterators. partition (range, n_per_part)
52
+ local nparts = length (parts)
53
+ $ (if schedule != = :dynamic
54
+ :(ccall (:jl_enter_threaded_region , Cvoid, ()))
55
+ end )
56
+ try
57
+ local tasks = Vector {Task} (undef, nparts)
58
+ local tid = 0
59
+ for part in parts
60
+ local t = @task begin
61
+ local $ (esc (lidx))
62
+ for $ (esc (lidx)) in part
63
+ $ (esc (lbody))
64
+ end
65
+ end
66
+ $ (if schedule === :dynamic
67
+ :(t. sticky = false )
68
+ else
69
+ :(t. sticky = true )
70
+ :(ccall (:jl_set_task_tid , Cint, (Any, Cint), t, tid))
71
+ end )
72
+ tasks[tid + 1 ] = t
73
+ schedule (t)
74
+ tid += 1
75
+ end
76
+ for i = 1 : nparts
77
+ wait (tasks[i])
78
+ end
79
+ finally
80
+ $ (if schedule != = :dynamic
81
+ :(ccall (:jl_exit_threaded_region , Cvoid, ()))
82
+ end )
80
83
end
81
84
end
82
- # run this thread's iterations
83
- for i = f: l
84
- local $ (esc (lidx)) = @inbounds r[i]
85
- $ (esc (lbody))
86
- end
87
- end
88
- end
89
- if $ (schedule === :dynamic )
90
- threading_run (threadsfor_fun, false )
91
- elseif threadid () != 1 || ccall (:jl_in_threaded_region , Cint, ()) != 0
92
- $ (if schedule === :static
93
- :(error (" `@threads :static` can only be used from thread 1 and not nested" ))
94
- else
95
- # only use threads when called from thread 1, outside @threads :static
96
- :(Base. invokelatest (threadsfor_fun, true ))
97
- end )
98
- else
99
- threading_run (threadsfor_fun, true )
100
85
end
101
- nothing
102
86
end
103
87
end
104
88
0 commit comments