Skip to content

Commit 5100430

Browse files
committed
remove threadedregion and move jl_threading_run to julia
1 parent 29c08a7 commit 5100430

File tree

5 files changed

+26
-111
lines changed

5 files changed

+26
-111
lines changed

base/task.jl

+8
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,14 @@ function __preinit_threads__()
431431
nothing
432432
end
433433

434+
function _run_on(t::Task, tid)
435+
@assert !istaskstarted(t)
436+
t.sticky = true
437+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
438+
schedule(t)
439+
return t
440+
end
441+
434442
function enq_work(t::Task)
435443
(t.state == :runnable && t.queue === nothing) || error("schedule: Task not runnable")
436444
tid = Threads.threadid(t)

base/threadingconstructs.jl

+13-21
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,18 @@ on `threadid()`.
1818
"""
1919
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))
2020

21-
# Only read/written by the main thread
22-
const in_threaded_loop = Ref(false)
23-
2421
function _threadsfor(iter,lbody)
2522
lidx = iter.args[1] # index
2623
range = iter.args[2]
2724
quote
2825
local threadsfor_fun
2926
let range = $(esc(range))
30-
function threadsfor_fun(onethread=false)
27+
function threadsfor_fun()
3128
r = range # Load into local variable
3229
lenr = length(r)
3330
# divide loop iterations among threads
34-
if onethread
35-
tid = 1
36-
len, rem = lenr, 0
37-
else
38-
tid = threadid()
39-
len, rem = divrem(lenr, nthreads())
40-
end
31+
tid = threadid()
32+
len, rem = divrem(lenr, nthreads())
4133
# not enough iterations for all the threads?
4234
if len == 0
4335
if tid > rem
@@ -65,20 +57,20 @@ function _threadsfor(iter,lbody)
6557
end
6658
end
6759
end
68-
# Hack to make nested threaded loops kinda work
69-
if threadid() != 1 || in_threaded_loop[]
70-
# We are in a nested threaded loop
71-
Base.invokelatest(threadsfor_fun, true)
72-
else
73-
in_threaded_loop[] = true
74-
# the ccall is not expected to throw
75-
ccall(:jl_threading_run, Cvoid, (Any,), threadsfor_fun)
76-
in_threaded_loop[] = false
77-
end
60+
threading_run(threadsfor_fun)
7861
nothing
7962
end
8063
end
8164

65+
function threading_run(func)
66+
tasks = Vector{Task}(undef, nthreads())
67+
for tid = 1:nthreads()
68+
tasks[tid] = Base._run_on(Task(func), tid)
69+
end
70+
foreach(wait, tasks)
71+
return nothing
72+
end
73+
8274
"""
8375
Threads.@threads
8476

src/jl_uv.c

+2-4
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,10 @@ JL_DLLEXPORT void jl_uv_req_set_data(uv_req_t *req, void *data) { req->data = da
201201
JL_DLLEXPORT void *jl_uv_handle_data(uv_handle_t *handle) { return handle->data; }
202202
JL_DLLEXPORT void *jl_uv_write_handle(uv_write_t *req) { return req->handle; }
203203

204-
extern volatile unsigned _threadedregion;
205-
206204
JL_DLLEXPORT int jl_run_once(uv_loop_t *loop)
207205
{
208206
jl_ptls_t ptls = jl_get_ptls_states();
209-
if (loop && (_threadedregion || ptls->tid == 0)) {
207+
if (loop) {
210208
jl_gc_safepoint_(ptls);
211209
JL_UV_LOCK();
212210
loop->stop_flag = 0;
@@ -220,7 +218,7 @@ JL_DLLEXPORT int jl_run_once(uv_loop_t *loop)
220218
JL_DLLEXPORT int jl_process_events(uv_loop_t *loop)
221219
{
222220
jl_ptls_t ptls = jl_get_ptls_states();
223-
if (loop && (_threadedregion || ptls->tid == 0)) {
221+
if (loop) {
224222
jl_gc_safepoint_(ptls);
225223
if (jl_mutex_trylock(&jl_uv_mutex)) {
226224
loop->stop_flag = 0;

src/partr.c

+3-14
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,6 @@ static int may_sleep(jl_ptls_t ptls)
412412
return jl_atomic_load(&sleep_check_state) == sleeping && jl_atomic_load(&ptls->sleep_check_state) == sleeping;
413413
}
414414

415-
extern volatile unsigned _threadedregion;
416-
417415
JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *getsticky)
418416
{
419417
jl_ptls_t ptls = jl_get_ptls_states();
@@ -435,7 +433,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *getsticky)
435433
#endif
436434

437435
jl_cpu_pause();
438-
if (sleep_check_after_threshold(&start_cycles) || (!_threadedregion && ptls->tid == 0)) {
436+
if (sleep_check_after_threshold(&start_cycles)) {
439437
if (!sleep_check_now(ptls->tid))
440438
continue;
441439
jl_atomic_store(&ptls->sleep_check_state, sleeping); // acquire sleep-check lock
@@ -447,14 +445,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *getsticky)
447445
// although none are allowed to create new ones
448446
// outside of threaded regions, all IO is permitted,
449447
// but only on thread 1
450-
int uvlock = 0;
451-
if (_threadedregion) {
452-
uvlock = jl_mutex_trylock(&jl_uv_mutex);
453-
}
454-
else if (ptls->tid == 0) {
455-
uvlock = 1;
456-
JL_UV_LOCK();
457-
}
448+
int uvlock = jl_mutex_trylock(&jl_uv_mutex);
458449
if (uvlock) {
459450
int active = 1;
460451
if (jl_atomic_load(&jl_uv_n_waiters) != 0) {
@@ -484,9 +475,7 @@ JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *getsticky)
484475
// to the last thread to do an explicit operation,
485476
// which may starve other threads of critical work
486477
}
487-
if (!_threadedregion && active && ptls->tid == 0) {
488-
// thread 0 is the only thread permitted to run the event loop
489-
// so it needs to stay alive
478+
if (active) {
490479
start_cycles = 0;
491480
continue;
492481
}

src/threading.c

-72
Original file line numberDiff line numberDiff line change
@@ -472,78 +472,6 @@ void jl_start_threads(void)
472472

473473
#endif
474474

475-
unsigned volatile _threadedregion; // HACK: keep track of whether it is safe to do IO
476-
477-
// simple fork/join mode code
478-
JL_DLLEXPORT void jl_threading_run(jl_value_t *func)
479-
{
480-
jl_ptls_t ptls = jl_get_ptls_states();
481-
int8_t gc_state = jl_gc_unsafe_enter(ptls);
482-
size_t world = jl_world_counter;
483-
jl_method_instance_t *mfunc = jl_lookup_generic(&func, 1, jl_int32hash_fast(jl_return_address()), world);
484-
// Ignore constant return value for now.
485-
jl_code_instance_t *fptr = jl_compile_method_internal(mfunc, world);
486-
if (fptr->invoke == jl_fptr_const_return)
487-
return;
488-
489-
size_t nthreads = jl_n_threads;
490-
jl_svec_t *ts = jl_alloc_svec(nthreads);
491-
JL_GC_PUSH1(&ts);
492-
jl_value_t *wait_func = jl_get_global(jl_base_module, jl_symbol("wait"));
493-
jl_value_t *schd_func = jl_get_global(jl_base_module, jl_symbol("schedule"));
494-
// create and schedule all tasks
495-
_threadedregion += 1;
496-
for (int i = 0; i < nthreads; i++) {
497-
jl_value_t *args2[2];
498-
args2[0] = (jl_value_t*)jl_task_type;
499-
args2[1] = func;
500-
jl_task_t *t = (jl_task_t*)jl_apply(args2, 2);
501-
jl_svecset(ts, i, t);
502-
t->sticky = 1;
503-
t->tid = i;
504-
args2[0] = schd_func;
505-
args2[1] = (jl_value_t*)t;
506-
jl_apply(args2, 2);
507-
#ifdef JULIA_ENABLE_THREADING
508-
if (i == 1) {
509-
// let threads know work is coming (optimistic)
510-
jl_wakeup_thread(-1);
511-
}
512-
#endif
513-
}
514-
#ifdef JULIA_ENABLE_THREADING
515-
if (nthreads > 2) {
516-
// let threads know work is ready (guaranteed)
517-
jl_wakeup_thread(-1);
518-
}
519-
#endif
520-
// join with all tasks
521-
JL_TRY {
522-
for (int i = 0; i < nthreads; i++) {
523-
jl_value_t *t = jl_svecref(ts, i);
524-
jl_value_t *args[2] = { wait_func, t };
525-
jl_apply(args, 2);
526-
}
527-
}
528-
JL_CATCH {
529-
_threadedregion -= 1;
530-
jl_wake_libuv();
531-
JL_UV_LOCK();
532-
JL_UV_UNLOCK();
533-
jl_rethrow();
534-
}
535-
// make sure no threads are sitting in the event loop
536-
_threadedregion -= 1;
537-
jl_wake_libuv();
538-
// make sure no more callbacks will run while user code continues
539-
// outside thread region and might touch an I/O object.
540-
JL_UV_LOCK();
541-
JL_UV_UNLOCK();
542-
JL_GC_POP();
543-
jl_gc_unsafe_leave(ptls, gc_state);
544-
}
545-
546-
547475
#ifndef JULIA_ENABLE_THREADING
548476

549477
void jl_init_threading(void)

0 commit comments

Comments
 (0)