Skip to content

Commit 31eeec3

Browse files
committed
remove threadedregion and move jl_threading_run to julia
1 parent cb8e1cf commit 31eeec3

File tree

5 files changed

+29
-110
lines changed

5 files changed

+29
-110
lines changed

base/threadingconstructs.jl

+23-21
Original file line numberDiff line numberDiff line change
@@ -22,34 +22,28 @@ function _threadsfor(iter,lbody)
2222
lidx = iter.args[1] # index
2323
range = iter.args[2]
2424
quote
25-
local threadsfor_fun
2625
let range = $(esc(range))
27-
function threadsfor_fun(onethread=false)
26+
function threadsfor_fun(grain)
2827
r = range # Load into local variable
2928
lenr = length(r)
30-
# divide loop iterations among threads
31-
if onethread
32-
tid = 1
33-
len, rem = lenr, 0
34-
else
35-
tid = threadid()
36-
len, rem = divrem(lenr, nthreads())
37-
end
29+
# divide loop iterations among tasks
30+
ngrains = min(nthreads(), lenr)
31+
len, rem = divrem(lenr, ngrains)
3832
# not enough iterations for all the threads?
3933
if len == 0
40-
if tid > rem
34+
if grain > rem
4135
return
4236
end
4337
len, rem = 1, 0
4438
end
4539
# compute this thread's iterations
46-
f = firstindex(r) + ((tid-1) * len)
40+
f = firstindex(r) + ((grain-1) * len)
4741
l = f + len - 1
4842
# distribute remaining iterations evenly
4943
if rem > 0
50-
if tid <= rem
51-
f = f + (tid-1)
52-
l = l + tid
44+
if grain <= rem
45+
f = f + (grain-1)
46+
l = l + grain
5347
else
5448
f = f + rem
5549
l = l + rem
@@ -61,17 +55,25 @@ function _threadsfor(iter,lbody)
6155
$(esc(lbody))
6256
end
6357
end
64-
end
65-
if threadid() != 1
66-
# only thread 1 can enter/exit _threadedregion
67-
Base.invokelatest(threadsfor_fun, true)
68-
else
69-
ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun)
58+
threading_run(threadsfor_fun, length(range))
7059
end
7160
nothing
7261
end
7362
end
7463

64+
function threading_run(func, len)
65+
ngrains = min(nthreads(), len)
66+
tasks = Vector{Task}(undef, ngrains)
67+
for grain in 1:ngrains
68+
t = Task(()->func(grain))
69+
t.sticky = false
70+
tasks[grain] = t
71+
schedule(t)
72+
end
73+
Base.sync_end(tasks)
74+
return nothing
75+
end
76+
7577
"""
7678
Threads.@threads
7779

src/jl_uv.c

+1-3
Original file line numberDiff line numberDiff line change
@@ -201,13 +201,11 @@ JL_DLLEXPORT void jl_uv_req_set_data(uv_req_t *req, void *data) { req->data = da
201201
JL_DLLEXPORT void *jl_uv_handle_data(uv_handle_t *handle) { return handle->data; }
202202
JL_DLLEXPORT void *jl_uv_write_handle(uv_write_t *req) { return req->handle; }
203203

204-
extern volatile unsigned _threadedregion;
205-
206204
JL_DLLEXPORT int jl_process_events(void)
207205
{
208206
jl_ptls_t ptls = jl_get_ptls_states();
209207
uv_loop_t *loop = jl_io_loop;
210-
if (loop && (_threadedregion || ptls->tid == 0)) {
208+
if (loop) {
211209
jl_gc_safepoint_(ptls);
212210
if (jl_mutex_trylock(&jl_uv_mutex)) {
213211
loop->stop_flag = 0;

src/partr.c

+3-14
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,6 @@ static int may_sleep(jl_ptls_t ptls)
392392
return jl_atomic_load(&sleep_check_state) == sleeping && jl_atomic_load(&ptls->sleep_check_state) == sleeping;
393393
}
394394

395-
extern volatile unsigned _threadedregion;
396-
397395
JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q)
398396
{
399397
jl_ptls_t ptls = jl_get_ptls_states();
@@ -413,7 +411,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q)
413411
}
414412

415413
jl_cpu_pause();
416-
if (sleep_check_after_threshold(&start_cycles) || (!_threadedregion && ptls->tid == 0)) {
414+
if (sleep_check_after_threshold(&start_cycles)) {
417415
if (!sleep_check_now(ptls->tid))
418416
continue;
419417
jl_atomic_store(&ptls->sleep_check_state, sleeping); // acquire sleep-check lock
@@ -425,14 +423,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q)
425423
// although none are allowed to create new ones
426424
// outside of threaded regions, all IO is permitted,
427425
// but only on thread 1
428-
int uvlock = 0;
429-
if (_threadedregion) {
430-
uvlock = jl_mutex_trylock(&jl_uv_mutex);
431-
}
432-
else if (ptls->tid == 0) {
433-
uvlock = 1;
434-
JL_UV_LOCK();
435-
}
426+
int uvlock = jl_mutex_trylock(&jl_uv_mutex);
436427
if (uvlock) {
437428
int active = 1;
438429
if (jl_atomic_load(&jl_uv_n_waiters) != 0) {
@@ -462,9 +453,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q)
462453
// to the last thread to do an explicit operation,
463454
// which may starve other threads of critical work
464455
}
465-
if (!_threadedregion && active && ptls->tid == 0) {
466-
// thread 0 is the only thread permitted to run the event loop
467-
// so it needs to stay alive
456+
if (active) {
468457
start_cycles = 0;
469458
continue;
470459
}

src/threading.c

-68
Original file line numberDiff line numberDiff line change
@@ -475,74 +475,6 @@ void jl_start_threads(void)
475475
uv_barrier_wait(&thread_init_done);
476476
}
477477

478-
unsigned volatile _threadedregion; // HACK: keep track of whether it is safe to do IO
479-
480-
// simple fork/join mode code
481-
JL_DLLEXPORT void jl_threading_run(jl_value_t *func)
482-
{
483-
jl_ptls_t ptls = jl_get_ptls_states();
484-
int8_t gc_state = jl_gc_unsafe_enter(ptls);
485-
size_t world = jl_world_counter;
486-
jl_method_instance_t *mfunc = jl_lookup_generic(&func, 1, jl_int32hash_fast(jl_return_address()), world);
487-
// Ignore constant return value for now.
488-
jl_code_instance_t *fptr = jl_compile_method_internal(mfunc, world);
489-
if (fptr->invoke == jl_fptr_const_return)
490-
return;
491-
492-
size_t nthreads = jl_n_threads;
493-
jl_svec_t *ts = jl_alloc_svec(nthreads);
494-
JL_GC_PUSH1(&ts);
495-
jl_value_t *wait_func = jl_get_global(jl_base_module, jl_symbol("wait"));
496-
jl_value_t *schd_func = jl_get_global(jl_base_module, jl_symbol("schedule"));
497-
// create and schedule all tasks
498-
_threadedregion += 1;
499-
for (int i = 0; i < nthreads; i++) {
500-
jl_value_t *args2[2];
501-
args2[0] = (jl_value_t*)jl_task_type;
502-
args2[1] = func;
503-
jl_task_t *t = (jl_task_t*)jl_apply(args2, 2);
504-
jl_svecset(ts, i, t);
505-
t->sticky = 1;
506-
t->tid = i;
507-
args2[0] = schd_func;
508-
args2[1] = (jl_value_t*)t;
509-
jl_apply(args2, 2);
510-
if (i == 1) {
511-
// let threads know work is coming (optimistic)
512-
jl_wakeup_thread(-1);
513-
}
514-
}
515-
if (nthreads > 2) {
516-
// let threads know work is ready (guaranteed)
517-
jl_wakeup_thread(-1);
518-
}
519-
// join with all tasks
520-
JL_TRY {
521-
for (int i = 0; i < nthreads; i++) {
522-
jl_value_t *t = jl_svecref(ts, i);
523-
jl_value_t *args[2] = { wait_func, t };
524-
jl_apply(args, 2);
525-
}
526-
}
527-
JL_CATCH {
528-
_threadedregion -= 1;
529-
jl_wake_libuv();
530-
JL_UV_LOCK();
531-
JL_UV_UNLOCK();
532-
jl_rethrow();
533-
}
534-
// make sure no threads are sitting in the event loop
535-
_threadedregion -= 1;
536-
jl_wake_libuv();
537-
// make sure no more callbacks will run while user code continues
538-
// outside thread region and might touch an I/O object.
539-
JL_UV_LOCK();
540-
JL_UV_UNLOCK();
541-
JL_GC_POP();
542-
jl_gc_unsafe_leave(ptls, gc_state);
543-
}
544-
545-
546478
// Make gc alignment available for threading
547479
// see threads.jl alignment
548480
JL_DLLEXPORT int jl_alignment(size_t sz)

test/threads_exec.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,9 @@ for period in (0.06, Dates.Millisecond(60))
411411
t = Timer(period)
412412
wait(t)
413413
ccall(:uv_async_send, Cvoid, (Ptr{Cvoid},), async)
414-
ccall(:uv_async_send, Cvoid, (Ptr{Cvoid},), async)
415414
wait(c)
416415
sleep(period)
417416
ccall(:uv_async_send, Cvoid, (Ptr{Cvoid},), async)
418-
ccall(:uv_async_send, Cvoid, (Ptr{Cvoid},), async)
419417
end))
420418
wait(c)
421419
notify(c)
@@ -700,8 +698,8 @@ function _atthreads_with_error(a, err)
700698
end
701699
a
702700
end
703-
@test_throws TaskFailedException _atthreads_with_error(zeros(nthreads()), true)
701+
@test_throws CompositeException _atthreads_with_error(zeros(nthreads()), true)
704702
let a = zeros(nthreads())
705703
_atthreads_with_error(a, false)
706-
@test a == [1:nthreads();]
704+
@test all(n->(1 <= n <= nthreads()), a)
707705
end

0 commit comments

Comments
 (0)